Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable opt-in autodetection of distributed configuration for mpi4py, attempt 2 #22235

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
1992969
Enable automatic detection of distrbuted variables with any configura…
coreyjadams Mar 7, 2024
900a037
Auto-detect of mpi4py-based configuration is now strictly opt-in.
coreyjadams Mar 11, 2024
4f22d86
Merge branch 'google:main' into main
coreyjadams Mar 11, 2024
01f4709
Address feedback and comments on PR 20174; fix typo in documentation.
coreyjadams Apr 16, 2024
af1a4f0
update documentation and elaborate on spec_detect_method variable
coreyjadams Apr 16, 2024
85bcf42
Merge branch 'main' of https://github.com/google/jax
coreyjadams Apr 16, 2024
a697299
Merge branch 'google:main' into main
coreyjadams May 13, 2024
e4fd97e
Merge branch 'google:main' into main
coreyjadams May 13, 2024
3a96e73
Merge branch 'google:main' into main
coreyjadams May 28, 2024
72fe093
Remove unmerged code
coreyjadams May 28, 2024
19e6694
Unify variable naming and fix function argument ordering
coreyjadams May 28, 2024
301bbc6
Add test to verify mpi4py based distributed initialization
coreyjadams May 28, 2024
5a91ac3
Merge branch 'google:main' into main
coreyjadams May 29, 2024
6701bd1
Merge branch 'google:main' into main
coreyjadams Jun 27, 2024
6cc07a9
Update mpi4py_cluster.py
coreyjadams Jun 27, 2024
ef3a2e2
Update mpi4py_cluster.py
coreyjadams Jun 27, 2024
6235eb3
Update distributed.py
coreyjadams Jun 27, 2024
f7086cb
Update build file to include mpi4py cluster.
coreyjadams Jun 27, 2024
10edc86
Change copyright year to the year this was authored
coreyjadams Jun 27, 2024
79b8cbf
Fix mypy issues; change variable name to more universally known name
coreyjadams Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,7 @@ pytype_strict_library(
"_src/clusters/cluster.py",
"_src/clusters/ompi_cluster.py",
"_src/clusters/slurm_cluster.py",
"_src/clusters/mpi4py_cluster.py",
"_src/distributed.py",
"_src/xla_bridge.py",
],
Expand Down
1 change: 1 addition & 0 deletions jax/_src/clusters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
# available one from the list will be picked.
from .ompi_cluster import OmpiCluster
from .slurm_cluster import SlurmCluster
from .mpi4py_cluster import Mpi4pyCluster
from .cloud_tpu_cluster import GkeTpuCluster
from .cloud_tpu_cluster import GceTpuCluster
9 changes: 9 additions & 0 deletions jax/_src/clusters/cloud_tpu_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def has_megascale_address():
return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None

class BaseTpuCluster(clusters.ClusterEnv):

name: str = "tpu"

"""Abstract cluster supports both single and multislice TPU environments.

If MEGASCALE_COORDINATOR_ADDRESS is not set, we assume single slice topology.
Expand Down Expand Up @@ -169,6 +172,9 @@ def _get_worker_list_in_slice() -> list[str]:
raise NotImplementedError()

class GceTpuCluster(BaseTpuCluster):

name: str = "gcetpu"

@classmethod
def is_env_present(cls) -> bool:
if not running_in_cloud_tpu_vm:
Expand All @@ -194,6 +200,9 @@ def _get_worker_list_in_slice() -> list[str]:
return [worker.split(':')[2] for worker in workers]

class GkeTpuCluster(BaseTpuCluster):

name: str = "gketpu"

@classmethod
def is_env_present(cls) -> bool:
if running_in_cloud_tpu_vm and os.environ.get("TPU_WORKER_HOSTNAMES") is not None:
Expand Down
23 changes: 22 additions & 1 deletion jax/_src/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,47 @@ class ClusterEnv:
"""

_cluster_types: list[type[ClusterEnv]] = []
opt_in_only_method: bool = False # Override this in derived classes if necessary

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._cluster_types.append(cls)


@classmethod
# pytype: disable=bad-return-type
def auto_detect_unset_distributed_params(cls,
coordinator_address: str | None,
num_processes: int | None,
process_id: int | None,
local_device_ids: Sequence[int] | None,
cluster_detection_method: str | None,
initialization_timeout: int | None,
) -> tuple[str | None, int | None, int | None,
Sequence[int] | None]:

if all(p is not None for p in (coordinator_address, num_processes,
process_id, local_device_ids)):
return (coordinator_address, num_processes, process_id,
local_device_ids)
env = next((env for env in cls._cluster_types if env.is_env_present()), None)

# First, we check the spec detection method because it will ignore submitted values
# If if succeeds.
if cluster_detection_method is not None:
env = next( (env for env in cls._cluster_types if env.name == cluster_detection_method), None )
if env is None:
logger.error(f"Automatic Distributed initialization can not proceed:"
f" {cluster_detection_method} is not supported.")
elif not env.is_env_present():
logger.error(f"Automatic Distributed initialization can not proceed:"
f" {cluster_detection_method} is supported but not functional in this environment.")
else:
env = next((env for env in cls._cluster_types if env.opt_in_only_method == False and env.is_env_present()), None)

# Above: I have wrapped the env selection in a conditional to go through
# opt-in methods first (currently only mpi4py) but to check all possible options
# otherwise. Passing no cluster_detection_method results in the default, original behavior.

if env:
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
if coordinator_address is None:
Expand Down
93 changes: 93 additions & 0 deletions jax/_src/clusters/mpi4py_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from jax._src import clusters
import socket

from importlib.util import find_spec


class Mpi4pyCluster(clusters.ClusterEnv):


name: str = "mpi4py"
opt_in_only_method: bool = True

@classmethod
def is_env_present(cls) -> bool:

# Relies on mpi4py:
return find_spec("mpi4py") is not None

@classmethod
def get_coordinator_address(cls, timeout_secs: int | None) -> str:

# Using mpi4py, figure out rank 0 and it's hostname.
# Then broadcast the hostname and port.


from mpi4py import MPI #type: ignore
# Get the global communicator:
COMM_WORLD = MPI.COMM_WORLD

# On rank 0, get the hostname:

if COMM_WORLD.Get_rank() == 0:
# Order all the hostnames, and find unique ones
hostname = socket.gethostname()

# Apparently, we want to pick a port in an ephemeral range...
port_id = hash(hostname) % 2**12 + (65535 - 2**12 + 1)

hostname = f'{hostname}:{port_id}'

else:
hostname = "None"



# Broadcast the host_ip to all ranks:
hostname = COMM_WORLD.bcast(hostname, root=0)


return hostname


@classmethod
def get_process_count(cls) -> int:
from mpi4py import MPI
return int(MPI.COMM_WORLD.Get_size())

@classmethod
def get_process_id(cls) -> int:
from mpi4py import MPI
return int(MPI.COMM_WORLD.Get_rank())

@classmethod
def get_local_process_id(cls) -> int | None:

# Using mpi4py, split the global communicator into sub communicators
# based on hostname. mpi will assign them ranks and that will allow
# a selection of the local process ID.
from mpi4py import MPI
COMM_WORLD = MPI.COMM_WORLD

# This is the alternative method that is simpler:
new_comm = COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)


# The rank in the new communicator - which is host-local only - IS the local rank:
return int(new_comm.Get_rank())
3 changes: 3 additions & 0 deletions jax/_src/clusters/ompi_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK'

class OmpiCluster(clusters.ClusterEnv):

name: str = "ompi"

@classmethod
def is_env_present(cls) -> bool:
return _ORTE_URI in os.environ
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/clusters/slurm_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
_NUM_NODES = 'SLURM_STEP_NUM_NODES'

class SlurmCluster(clusters.ClusterEnv):

name: str = "slurm"

@classmethod
def is_env_present(cls) -> bool:
return _JOBID_PARAM in os.environ
Expand Down
34 changes: 33 additions & 1 deletion jax/_src/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,22 @@ def initialize(self,
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
cluster_detection_method: str | None = None,
initialization_timeout: int = 300,
coordinator_bind_address: str | None = None):
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
if isinstance(local_device_ids, int):
local_device_ids = [local_device_ids]


(coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address,
num_processes,
process_id,
local_device_ids,
cluster_detection_method,
initialization_timeout,
)
)
Expand Down Expand Up @@ -84,6 +87,18 @@ def initialize(self,

self.process_id = process_id

# Emit a warning about PROXY variables if they are in the user's env:
proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()]

if len(proxy_vars) > 0:
vars = " ".join(proxy_vars) + ". "
warning = (
f'JAX detected proxy variable(s) in the environment as distributed setup: {vars}'
'On some systems, this may cause a hang of distributed.initialize and '
'you may need to unset these ENV variable(s)'
)
logger.warning(warning)

if process_id == 0:
if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
Expand Down Expand Up @@ -130,6 +145,7 @@ def initialize(coordinator_address: str | None = None,
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
cluster_detection_method: str | None = None,
initialization_timeout: int = 300,
coordinator_bind_address: str | None = None):
"""Initializes the JAX distributed system.
Expand All @@ -147,9 +163,20 @@ def initialize(coordinator_address: str | None = None,
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they
will be chosen automatically.

The ``cluster_detection_method`` may be used to choose a specific method for detecting those
distributed arguments. You may pass any of the automatic ``spec_detect_methods`` to this
argument though it is not necessary in the TPU, Slurm, or Open MPI cases. For other MPI
installations, if you have a functional ``mpi4py`` installed, you may pass
``cluster_detection_method="mpi4py"`` to bootstrap the required arguments.

Otherwise, you must provide the ``coordinator_address``,
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.

Please note: on some systems, particularly HPC clusters that only access external networks
through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to
:func:`~jax.distributed.initialize` may timeout. You may need to unset these variables
prior to application launch.

Args:
coordinator_address: the IP address of process `0` and a port on which that
process should launch a coordinator service. The choice of
Expand All @@ -166,6 +193,10 @@ def initialize(coordinator_address: str | None = None,
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
If ``None``, defaults to all local devices being visible to the process except when processes
are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed
run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment,
and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``.
Legacy auto-detect options (OMPI, Slurm) remain enabled.
initialization_timeout: Time period (in seconds) for which connection will
be retried. If the initialization takes more than the timeout specified,
the initialization will error. Defaults to 300 secs i.e. 5 mins.
Expand Down Expand Up @@ -197,7 +228,8 @@ def initialize(coordinator_address: str | None = None,
raise RuntimeError("jax.distributed.initialize() must be called before "
"any JAX computations are executed.")
global_state.initialize(coordinator_address, num_processes, process_id,
local_device_ids, initialization_timeout, coordinator_bind_address)
local_device_ids, cluster_detection_method,
initialization_timeout, coordinator_bind_address)
atexit.register(shutdown)


Expand Down
43 changes: 43 additions & 0 deletions tests/multiprocess_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
from jax.experimental import pjit
import jax.numpy as jnp

# Used to test for mpi4py installation and skip tests if not installed
import importlib.util

try:
import portpicker
except ImportError:
Expand Down Expand Up @@ -218,6 +221,46 @@ def test_gpu_ompi_distributed_initialize(self):
finally:
proc.kill()

def test_gpu_mpi4py_distributed_initialize(self):
if not jtu.test_device_matches(['gpu']):
raise unittest.SkipTest('Tests only for GPU.')
if shutil.which('mpirun') is None:
raise unittest.SkipTest('Tests only for MPI (mpirun not found).')
if importlib.util.find_spec("mpi4py") is None:
raise unittest.SkipTest('Test of mpi4py initialize only possible with mpi4py installed.')

num_gpus = 4
num_gpus_per_task = 1

with contextlib.ExitStack() as exit_stack:
args = [
'mpirun',
'--oversubscribe',
'--allow-run-as-root',
'-n',
str(num_gpus),
sys.executable,
'-c',
('import jax, os; '
'jax.distributed.initialize(spec_detection_method="mpi4py"); '
'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")'
)
]
env = os.environ.copy()
# In case the job was launched via Slurm,
# prevent OpenMPI from detecting Slurm environment
env.pop('SLURM_JOBID', None)
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=True)
proc = exit_stack.enter_context(proc)

try:
out, _ = proc.communicate()
self.assertEqual(proc.returncode, 0)
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
finally:
proc.kill()


@unittest.skipIf(
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
Expand Down