Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 79b8cbf
Author: Corey Adams <[email protected]>
Date:   Mon Jul 1 14:14:15 2024 -0500

    Fix mypy issues; change variable name to more universally known name

commit 10edc86
Author: Corey Adams <[email protected]>
Date:   Thu Jun 27 13:25:32 2024 -0500

    Change copyright year to the year this was authored

commit f7086cb
Author: Corey Adams <[email protected]>
Date:   Thu Jun 27 13:15:32 2024 -0500

    Update build file to include mpi4py cluster.

commit 6235eb3
Author: Corey adams <[email protected]>
Date:   Thu Jun 27 12:11:48 2024 -0500

    Update distributed.py

    Clean up documentation slightly.

commit ef3a2e2
Author: Corey adams <[email protected]>
Date:   Thu Jun 27 12:09:37 2024 -0500

    Update mpi4py_cluster.py

    Further clean up unneeded comments.

commit 6cc07a9
Author: Corey adams <[email protected]>
Date:   Thu Jun 27 12:08:38 2024 -0500

    Update mpi4py_cluster.py

    Remove unneeded commented code.

commit 6701bd1
Merge: 5a91ac3 98b8754
Author: Corey adams <[email protected]>
Date:   Thu Jun 27 12:07:25 2024 -0500

    Merge branch 'google:main' into main

commit 5a91ac3
Merge: 301bbc6 6c51234
Author: Corey adams <[email protected]>
Date:   Tue May 28 22:14:08 2024 -0500

    Merge branch 'google:main' into main

commit 301bbc6
Author: Corey Adams <[email protected]>
Date:   Tue May 28 11:34:51 2024 -0500

    Add test to verify mpi4py based distributed initialization

commit 19e6694
Author: Corey Adams <[email protected]>
Date:   Tue May 28 11:14:40 2024 -0500

    Unify variable naming and fix function argument ordering

commit 72fe093
Author: Corey Adams <[email protected]>
Date:   Tue May 28 10:56:25 2024 -0500

    Remove unmerged code

commit 3a96e73
Merge: e4fd97e ff3db9b
Author: Corey adams <[email protected]>
Date:   Tue May 28 10:51:41 2024 -0500

    Merge branch 'google:main' into main

commit e4fd97e
Merge: a697299 72a81e5
Author: Corey adams <[email protected]>
Date:   Mon May 13 16:01:35 2024 -0500

    Merge branch 'google:main' into main

commit a697299
Merge: 85bcf42 1e48adc
Author: Corey adams <[email protected]>
Date:   Mon May 13 14:21:32 2024 -0500

    Merge branch 'google:main' into main

commit 85bcf42
Merge: af1a4f0 06cd05d
Author: Corey Adams <[email protected]>
Date:   Tue Apr 16 09:09:31 2024 -0500

    Merge branch 'main' of https://github.com/google/jax

commit af1a4f0
Author: Corey Adams <[email protected]>
Date:   Tue Apr 16 08:58:33 2024 -0500

    update documentation and elaborate on spec_detect_method variable

commit 01f4709
Author: Corey Adams <[email protected]>
Date:   Tue Apr 16 08:45:38 2024 -0500

    Address feedback and comments on PR 20174; fix typo in documentation.

commit 4f22d86
Merge: 900a037 71ec6e3
Author: Corey adams <[email protected]>
Date:   Mon Mar 11 11:51:30 2024 -0500

    Merge branch 'google:main' into main

commit 900a037
Author: Corey Adams <[email protected]>
Date:   Mon Mar 11 11:50:48 2024 -0500

    Auto-detect of mpi4py-based configuration is now strictly opt-in.

commit 1992969
Author: Corey Adams <[email protected]>
Date:   Thu Mar 7 12:27:43 2024 -0600

    Enable automatic detection of distrbuted variables with any configuration of MPI, as long as mpi4py is available
  • Loading branch information
felker committed Jul 2, 2024
1 parent 484d09f commit ffc9292
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 2 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,7 @@ pytype_strict_library(
"_src/clusters/cluster.py",
"_src/clusters/ompi_cluster.py",
"_src/clusters/slurm_cluster.py",
"_src/clusters/mpi4py_cluster.py",
"_src/distributed.py",
"_src/xla_bridge.py",
],
Expand Down
1 change: 1 addition & 0 deletions jax/_src/clusters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
# available one from the list will be picked.
from .ompi_cluster import OmpiCluster
from .slurm_cluster import SlurmCluster
from .mpi4py_cluster import Mpi4pyCluster
from .cloud_tpu_cluster import GkeTpuCluster
from .cloud_tpu_cluster import GceTpuCluster
9 changes: 9 additions & 0 deletions jax/_src/clusters/cloud_tpu_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def has_megascale_address():
return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None

class BaseTpuCluster(clusters.ClusterEnv):

name: str = "tpu"

"""Abstract cluster supports both single and multislice TPU environments.
If MEGASCALE_COORDINATOR_ADDRESS is not set, we assume single slice topology.
Expand Down Expand Up @@ -169,6 +172,9 @@ def _get_worker_list_in_slice() -> list[str]:
raise NotImplementedError()

class GceTpuCluster(BaseTpuCluster):

name: str = "gcetpu"

@classmethod
def is_env_present(cls) -> bool:
if not running_in_cloud_tpu_vm:
Expand All @@ -194,6 +200,9 @@ def _get_worker_list_in_slice() -> list[str]:
return [worker.split(':')[2] for worker in workers]

class GkeTpuCluster(BaseTpuCluster):

name: str = "gketpu"

@classmethod
def is_env_present(cls) -> bool:
if running_in_cloud_tpu_vm and os.environ.get("TPU_WORKER_HOSTNAMES") is not None:
Expand Down
23 changes: 22 additions & 1 deletion jax/_src/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,47 @@ class ClusterEnv:
"""

_cluster_types: list[type[ClusterEnv]] = []
opt_in_only_method: bool = False # Override this in derived classes if necessary

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._cluster_types.append(cls)


@classmethod
# pytype: disable=bad-return-type
def auto_detect_unset_distributed_params(cls,
coordinator_address: str | None,
num_processes: int | None,
process_id: int | None,
local_device_ids: Sequence[int] | None,
cluster_detection_method: str | None,
initialization_timeout: int | None,
) -> tuple[str | None, int | None, int | None,
Sequence[int] | None]:

if all(p is not None for p in (coordinator_address, num_processes,
process_id, local_device_ids)):
return (coordinator_address, num_processes, process_id,
local_device_ids)
env = next((env for env in cls._cluster_types if env.is_env_present()), None)

# First, we check the spec detection method because it will ignore submitted values
# If if succeeds.
if cluster_detection_method is not None:
env = next( (env for env in cls._cluster_types if env.name == cluster_detection_method), None )
if env is None:
logger.error(f"Automatic Distributed initialization can not proceed:"
f" {cluster_detection_method} is not supported.")
elif not env.is_env_present():
logger.error(f"Automatic Distributed initialization can not proceed:"
f" {cluster_detection_method} is supported but not functional in this environment.")
else:
env = next((env for env in cls._cluster_types if env.opt_in_only_method == False and env.is_env_present()), None)

# Above: I have wrapped the env selection in a conditional to go through
# opt-in methods first (currently only mpi4py) but to check all possible options
# otherwise. Passing no cluster_detection_method results in the default, original behavior.

if env:
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
if coordinator_address is None:
Expand Down
93 changes: 93 additions & 0 deletions jax/_src/clusters/mpi4py_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from jax._src import clusters
import socket

from importlib.util import find_spec


class Mpi4pyCluster(clusters.ClusterEnv):


name: str = "mpi4py"
opt_in_only_method: bool = True

@classmethod
def is_env_present(cls) -> bool:

# Relies on mpi4py:
return find_spec("mpi4py") is not None

@classmethod
def get_coordinator_address(cls, timeout_secs: int | None) -> str:

# Using mpi4py, figure out rank 0 and it's hostname.
# Then broadcast the hostname and port.


from mpi4py import MPI #type: ignore
# Get the global communicator:
COMM_WORLD = MPI.COMM_WORLD

# On rank 0, get the hostname:

if COMM_WORLD.Get_rank() == 0:
# Order all the hostnames, and find unique ones
hostname = socket.gethostname()

# Apparently, we want to pick a port in an ephemeral range...
port_id = hash(hostname) % 2**12 + (65535 - 2**12 + 1)

hostname = f'{hostname}:{port_id}'

else:
hostname = "None"



# Broadcast the host_ip to all ranks:
hostname = COMM_WORLD.bcast(hostname, root=0)


return hostname


@classmethod
def get_process_count(cls) -> int:
from mpi4py import MPI
return int(MPI.COMM_WORLD.Get_size())

@classmethod
def get_process_id(cls) -> int:
from mpi4py import MPI
return int(MPI.COMM_WORLD.Get_rank())

@classmethod
def get_local_process_id(cls) -> int | None:

# Using mpi4py, split the global communicator into sub communicators
# based on hostname. mpi will assign them ranks and that will allow
# a selection of the local process ID.
from mpi4py import MPI
COMM_WORLD = MPI.COMM_WORLD

# This is the alternative method that is simpler:
new_comm = COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)


# The rank in the new communicator - which is host-local only - IS the local rank:
return int(new_comm.Get_rank())
3 changes: 3 additions & 0 deletions jax/_src/clusters/ompi_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK'

class OmpiCluster(clusters.ClusterEnv):

name: str = "ompi"

@classmethod
def is_env_present(cls) -> bool:
return _ORTE_URI in os.environ
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/clusters/slurm_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
_NUM_NODES = 'SLURM_STEP_NUM_NODES'

class SlurmCluster(clusters.ClusterEnv):

name: str = "slurm"

@classmethod
def is_env_present(cls) -> bool:
return _JOBID_PARAM in os.environ
Expand Down
34 changes: 33 additions & 1 deletion jax/_src/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,22 @@ def initialize(self,
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
cluster_detection_method: str | None = None,
initialization_timeout: int = 300,
coordinator_bind_address: str | None = None):
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
if isinstance(local_device_ids, int):
local_device_ids = [local_device_ids]


(coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address,
num_processes,
process_id,
local_device_ids,
cluster_detection_method,
initialization_timeout,
)
)
Expand Down Expand Up @@ -84,6 +87,18 @@ def initialize(self,

self.process_id = process_id

# Emit a warning about PROXY variables if they are in the user's env:
proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()]

if len(proxy_vars) > 0:
vars = " ".join(proxy_vars) + ". "
warning = (
f'JAX detected proxy variable(s) in the environment as distributed setup: {vars}'
'On some systems, this may cause a hang of distributed.initialize and '
'you may need to unset these ENV variable(s)'
)
logger.warning(warning)

if process_id == 0:
if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
Expand Down Expand Up @@ -130,6 +145,7 @@ def initialize(coordinator_address: str | None = None,
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
cluster_detection_method: str | None = None,
initialization_timeout: int = 300,
coordinator_bind_address: str | None = None):
"""Initializes the JAX distributed system.
Expand All @@ -147,9 +163,20 @@ def initialize(coordinator_address: str | None = None,
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they
will be chosen automatically.
The ``cluster_detection_method`` may be used to choose a specific method for detecting those
distributed arguments. You may pass any of the automatic ``spec_detect_methods`` to this
argument though it is not necessary in the TPU, Slurm, or Open MPI cases. For other MPI
installations, if you have a functional ``mpi4py`` installed, you may pass
``cluster_detection_method="mpi4py"`` to bootstrap the required arguments.
Otherwise, you must provide the ``coordinator_address``,
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
Please note: on some systems, particularly HPC clusters that only access external networks
through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to
:func:`~jax.distributed.initialize` may timeout. You may need to unset these variables
prior to application launch.
Args:
coordinator_address: the IP address of process `0` and a port on which that
process should launch a coordinator service. The choice of
Expand All @@ -166,6 +193,10 @@ def initialize(coordinator_address: str | None = None,
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
If ``None``, defaults to all local devices being visible to the process except when processes
are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed
run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment,
and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``.
Legacy auto-detect options (OMPI, Slurm) remain enabled.
initialization_timeout: Time period (in seconds) for which connection will
be retried. If the initialization takes more than the timeout specified,
the initialization will error. Defaults to 300 secs i.e. 5 mins.
Expand Down Expand Up @@ -197,7 +228,8 @@ def initialize(coordinator_address: str | None = None,
raise RuntimeError("jax.distributed.initialize() must be called before "
"any JAX computations are executed.")
global_state.initialize(coordinator_address, num_processes, process_id,
local_device_ids, initialization_timeout, coordinator_bind_address)
local_device_ids, cluster_detection_method,
initialization_timeout, coordinator_bind_address)
atexit.register(shutdown)


Expand Down
43 changes: 43 additions & 0 deletions tests/multiprocess_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
from jax.experimental import pjit
import jax.numpy as jnp

# Used to test for mpi4py installation and skip tests if not installed
import importlib.util

try:
import portpicker
except ImportError:
Expand Down Expand Up @@ -218,6 +221,46 @@ def test_gpu_ompi_distributed_initialize(self):
finally:
proc.kill()

def test_gpu_mpi4py_distributed_initialize(self):
if not jtu.test_device_matches(['gpu']):
raise unittest.SkipTest('Tests only for GPU.')
if shutil.which('mpirun') is None:
raise unittest.SkipTest('Tests only for MPI (mpirun not found).')
if importlib.util.find_spec("mpi4py") is None:
raise unittest.SkipTest('Test of mpi4py initialize only possible with mpi4py installed.')

num_gpus = 4
num_gpus_per_task = 1

with contextlib.ExitStack() as exit_stack:
args = [
'mpirun',
'--oversubscribe',
'--allow-run-as-root',
'-n',
str(num_gpus),
sys.executable,
'-c',
('import jax, os; '
'jax.distributed.initialize(spec_detection_method="mpi4py"); '
'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")'
)
]
env = os.environ.copy()
# In case the job was launched via Slurm,
# prevent OpenMPI from detecting Slurm environment
env.pop('SLURM_JOBID', None)
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=True)
proc = exit_stack.enter_context(proc)

try:
out, _ = proc.communicate()
self.assertEqual(proc.returncode, 0)
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
finally:
proc.kill()


@unittest.skipIf(
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
Expand Down

0 comments on commit ffc9292

Please sign in to comment.