From ffc9292365b600ad6e7dbd31f4ef365d12421ce2 Mon Sep 17 00:00:00 2001 From: Kyle Gerard Felker Date: Tue, 2 Jul 2024 13:18:05 -0500 Subject: [PATCH] Squashed commit of the following: commit 79b8cbf0cb47e32743e0970bc1abeb6a673866a8 Author: Corey Adams Date: Mon Jul 1 14:14:15 2024 -0500 Fix mypy issues; change variable name to more universally known name commit 10edc866f568908e536e5c7bd6b59b4e5351781e Author: Corey Adams Date: Thu Jun 27 13:25:32 2024 -0500 Change copyright year to the year this was authored commit f7086cb44cc98d58a96ae804dcd1787bc31470f7 Author: Corey Adams Date: Thu Jun 27 13:15:32 2024 -0500 Update build file to include mpi4py cluster. commit 6235eb311b9fca2bd81fe1c49456d164b7332753 Author: Corey adams Date: Thu Jun 27 12:11:48 2024 -0500 Update distributed.py Clean up documentation slightly. commit ef3a2e220945b2158cf20edeb1e04bbbf8f290ff Author: Corey adams Date: Thu Jun 27 12:09:37 2024 -0500 Update mpi4py_cluster.py Further clean up unneeded comments. commit 6cc07a9a52fc202ecc65c04c513096391c27d02d Author: Corey adams Date: Thu Jun 27 12:08:38 2024 -0500 Update mpi4py_cluster.py Remove unneeded commented code. commit 6701bd1a9d645a0e08d95df1692f43946f0a5eb8 Merge: 5a91ac342 98b87540a Author: Corey adams Date: Thu Jun 27 12:07:25 2024 -0500 Merge branch 'google:main' into main commit 5a91ac34248afa6f65af3cae66df7d0d122c1d26 Merge: 301bbc67f 6c51234f9 Author: Corey adams Date: Tue May 28 22:14:08 2024 -0500 Merge branch 'google:main' into main commit 301bbc67f938bc30c543cf300cec8a9c75f3eef8 Author: Corey Adams Date: Tue May 28 11:34:51 2024 -0500 Add test to verify mpi4py based distributed initialization commit 19e66949a36bb0edb4cd66b0f170f42b326928ec Author: Corey Adams Date: Tue May 28 11:14:40 2024 -0500 Unify variable naming and fix function argument ordering commit 72fe093042519e48d9c26b7ede3b266c7a850be6 Author: Corey Adams Date: Tue May 28 10:56:25 2024 -0500 Remove unmerged code commit 3a96e738a3cdf9b6ed194cb764fa5640a37f6b95 Merge: e4fd97e19 ff3db9b3a Author: Corey adams Date: Tue May 28 10:51:41 2024 -0500 Merge branch 'google:main' into main commit e4fd97e197211921fb6911054592041015af94ef Merge: a69729900 72a81e58e Author: Corey adams Date: Mon May 13 16:01:35 2024 -0500 Merge branch 'google:main' into main commit a6972990070d5d2f405d5ede9f82d35c7e6d157a Merge: 85bcf42bd 1e48adc69 Author: Corey adams Date: Mon May 13 14:21:32 2024 -0500 Merge branch 'google:main' into main commit 85bcf42bdd36ad88a3d287c357cd12fde74c7fc0 Merge: af1a4f0a1 06cd05d1d Author: Corey Adams Date: Tue Apr 16 09:09:31 2024 -0500 Merge branch 'main' of https://github.com/google/jax commit af1a4f0a12008780e9507d1bdd91e9d11ec35916 Author: Corey Adams Date: Tue Apr 16 08:58:33 2024 -0500 update documentation and elaborate on spec_detect_method variable commit 01f4709d5ecd4af675f4fb23d02d6a69b927adac Author: Corey Adams Date: Tue Apr 16 08:45:38 2024 -0500 Address feedback and comments on PR 20174; fix typo in documentation. commit 4f22d86e7358c29ed588267a7d91fe55fb94f143 Merge: 900a0372f 71ec6e33c Author: Corey adams Date: Mon Mar 11 11:51:30 2024 -0500 Merge branch 'google:main' into main commit 900a0372f6147d3c9ab53c95b6a4262e5cfe4457 Author: Corey Adams Date: Mon Mar 11 11:50:48 2024 -0500 Auto-detect of mpi4py-based configuration is now strictly opt-in. commit 1992969da6164e456492fe0f9cd4287f6d8f03cf Author: Corey Adams Date: Thu Mar 7 12:27:43 2024 -0600 Enable automatic detection of distrbuted variables with any configuration of MPI, as long as mpi4py is available --- jax/BUILD | 1 + jax/_src/clusters/__init__.py | 1 + jax/_src/clusters/cloud_tpu_cluster.py | 9 +++ jax/_src/clusters/cluster.py | 23 ++++++- jax/_src/clusters/mpi4py_cluster.py | 93 ++++++++++++++++++++++++++ jax/_src/clusters/ompi_cluster.py | 3 + jax/_src/clusters/slurm_cluster.py | 3 + jax/_src/distributed.py | 34 +++++++++- tests/multiprocess_gpu_test.py | 43 ++++++++++++ 9 files changed, 208 insertions(+), 2 deletions(-) create mode 100644 jax/_src/clusters/mpi4py_cluster.py diff --git a/jax/BUILD b/jax/BUILD index a0de4f19ce5c..a4de4550c092 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -952,6 +952,7 @@ pytype_strict_library( "_src/clusters/cluster.py", "_src/clusters/ompi_cluster.py", "_src/clusters/slurm_cluster.py", + "_src/clusters/mpi4py_cluster.py", "_src/distributed.py", "_src/xla_bridge.py", ], diff --git a/jax/_src/clusters/__init__.py b/jax/_src/clusters/__init__.py index 3cc76a2b9d1b..73e4ac9412f7 100644 --- a/jax/_src/clusters/__init__.py +++ b/jax/_src/clusters/__init__.py @@ -22,5 +22,6 @@ # available one from the list will be picked. from .ompi_cluster import OmpiCluster from .slurm_cluster import SlurmCluster +from .mpi4py_cluster import Mpi4pyCluster from .cloud_tpu_cluster import GkeTpuCluster from .cloud_tpu_cluster import GceTpuCluster diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index 4a12445856f6..c85abb2f83de 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -74,6 +74,9 @@ def has_megascale_address(): return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None class BaseTpuCluster(clusters.ClusterEnv): + + name: str = "tpu" + """Abstract cluster supports both single and multislice TPU environments. If MEGASCALE_COORDINATOR_ADDRESS is not set, we assume single slice topology. @@ -169,6 +172,9 @@ def _get_worker_list_in_slice() -> list[str]: raise NotImplementedError() class GceTpuCluster(BaseTpuCluster): + + name: str = "gcetpu" + @classmethod def is_env_present(cls) -> bool: if not running_in_cloud_tpu_vm: @@ -194,6 +200,9 @@ def _get_worker_list_in_slice() -> list[str]: return [worker.split(':')[2] for worker in workers] class GkeTpuCluster(BaseTpuCluster): + + name: str = "gketpu" + @classmethod def is_env_present(cls) -> bool: if running_in_cloud_tpu_vm and os.environ.get("TPU_WORKER_HOSTNAMES") is not None: diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index 7d4fcfd43a3a..4c7df0617403 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -31,11 +31,13 @@ class ClusterEnv: """ _cluster_types: list[type[ClusterEnv]] = [] + opt_in_only_method: bool = False # Override this in derived classes if necessary def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._cluster_types.append(cls) + @classmethod # pytype: disable=bad-return-type def auto_detect_unset_distributed_params(cls, @@ -43,14 +45,33 @@ def auto_detect_unset_distributed_params(cls, num_processes: int | None, process_id: int | None, local_device_ids: Sequence[int] | None, + cluster_detection_method: str | None, initialization_timeout: int | None, ) -> tuple[str | None, int | None, int | None, Sequence[int] | None]: + if all(p is not None for p in (coordinator_address, num_processes, process_id, local_device_ids)): return (coordinator_address, num_processes, process_id, local_device_ids) - env = next((env for env in cls._cluster_types if env.is_env_present()), None) + + # First, we check the spec detection method because it will ignore submitted values + # If if succeeds. + if cluster_detection_method is not None: + env = next( (env for env in cls._cluster_types if env.name == cluster_detection_method), None ) + if env is None: + logger.error(f"Automatic Distributed initialization can not proceed:" + f" {cluster_detection_method} is not supported.") + elif not env.is_env_present(): + logger.error(f"Automatic Distributed initialization can not proceed:" + f" {cluster_detection_method} is supported but not functional in this environment.") + else: + env = next((env for env in cls._cluster_types if env.opt_in_only_method == False and env.is_env_present()), None) + + # Above: I have wrapped the env selection in a conditional to go through + # opt-in methods first (currently only mpi4py) but to check all possible options + # otherwise. Passing no cluster_detection_method results in the default, original behavior. + if env: logger.debug('Initializing distributed JAX environment via %s', env.__name__) if coordinator_address is None: diff --git a/jax/_src/clusters/mpi4py_cluster.py b/jax/_src/clusters/mpi4py_cluster.py new file mode 100644 index 000000000000..10793778f745 --- /dev/null +++ b/jax/_src/clusters/mpi4py_cluster.py @@ -0,0 +1,93 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from jax._src import clusters +import socket + +from importlib.util import find_spec + + +class Mpi4pyCluster(clusters.ClusterEnv): + + + name: str = "mpi4py" + opt_in_only_method: bool = True + + @classmethod + def is_env_present(cls) -> bool: + + # Relies on mpi4py: + return find_spec("mpi4py") is not None + + @classmethod + def get_coordinator_address(cls, timeout_secs: int | None) -> str: + + # Using mpi4py, figure out rank 0 and it's hostname. + # Then broadcast the hostname and port. + + + from mpi4py import MPI #type: ignore + # Get the global communicator: + COMM_WORLD = MPI.COMM_WORLD + + # On rank 0, get the hostname: + + if COMM_WORLD.Get_rank() == 0: + # Order all the hostnames, and find unique ones + hostname = socket.gethostname() + + # Apparently, we want to pick a port in an ephemeral range... + port_id = hash(hostname) % 2**12 + (65535 - 2**12 + 1) + + hostname = f'{hostname}:{port_id}' + + else: + hostname = "None" + + + + # Broadcast the host_ip to all ranks: + hostname = COMM_WORLD.bcast(hostname, root=0) + + + return hostname + + + @classmethod + def get_process_count(cls) -> int: + from mpi4py import MPI + return int(MPI.COMM_WORLD.Get_size()) + + @classmethod + def get_process_id(cls) -> int: + from mpi4py import MPI + return int(MPI.COMM_WORLD.Get_rank()) + + @classmethod + def get_local_process_id(cls) -> int | None: + + # Using mpi4py, split the global communicator into sub communicators + # based on hostname. mpi will assign them ranks and that will allow + # a selection of the local process ID. + from mpi4py import MPI + COMM_WORLD = MPI.COMM_WORLD + + # This is the alternative method that is simpler: + new_comm = COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED) + + + # The rank in the new communicator - which is host-local only - IS the local rank: + return int(new_comm.Get_rank()) diff --git a/jax/_src/clusters/ompi_cluster.py b/jax/_src/clusters/ompi_cluster.py index 9249958e2885..151968c1c2bc 100644 --- a/jax/_src/clusters/ompi_cluster.py +++ b/jax/_src/clusters/ompi_cluster.py @@ -25,6 +25,9 @@ _LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK' class OmpiCluster(clusters.ClusterEnv): + + name: str = "ompi" + @classmethod def is_env_present(cls) -> bool: return _ORTE_URI in os.environ diff --git a/jax/_src/clusters/slurm_cluster.py b/jax/_src/clusters/slurm_cluster.py index 933036a64eab..8cec07601094 100644 --- a/jax/_src/clusters/slurm_cluster.py +++ b/jax/_src/clusters/slurm_cluster.py @@ -25,6 +25,9 @@ _NUM_NODES = 'SLURM_STEP_NUM_NODES' class SlurmCluster(clusters.ClusterEnv): + + name: str = "slurm" + @classmethod def is_env_present(cls) -> bool: return _JOBID_PARAM in os.environ diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index c1a8ec7fe948..5e8e956cf98b 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -41,6 +41,7 @@ def initialize(self, num_processes: int | None = None, process_id: int | None = None, local_device_ids: int | Sequence[int] | None = None, + cluster_detection_method: str | None = None, initialization_timeout: int = 300, coordinator_bind_address: str | None = None): coordinator_address = (coordinator_address or @@ -48,12 +49,14 @@ def initialize(self, if isinstance(local_device_ids, int): local_device_ids = [local_device_ids] + (coordinator_address, num_processes, process_id, local_device_ids) = ( clusters.ClusterEnv.auto_detect_unset_distributed_params( coordinator_address, num_processes, process_id, local_device_ids, + cluster_detection_method, initialization_timeout, ) ) @@ -84,6 +87,18 @@ def initialize(self, self.process_id = process_id + # Emit a warning about PROXY variables if they are in the user's env: + proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()] + + if len(proxy_vars) > 0: + vars = " ".join(proxy_vars) + ". " + warning = ( + f'JAX detected proxy variable(s) in the environment as distributed setup: {vars}' + 'On some systems, this may cause a hang of distributed.initialize and ' + 'you may need to unset these ENV variable(s)' + ) + logger.warning(warning) + if process_id == 0: if self.service is not None: raise RuntimeError('distributed.initialize should only be called once.') @@ -130,6 +145,7 @@ def initialize(coordinator_address: str | None = None, num_processes: int | None = None, process_id: int | None = None, local_device_ids: int | Sequence[int] | None = None, + cluster_detection_method: str | None = None, initialization_timeout: int = 300, coordinator_bind_address: str | None = None): """Initializes the JAX distributed system. @@ -147,9 +163,20 @@ def initialize(coordinator_address: str | None = None, If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they will be chosen automatically. + The ``cluster_detection_method`` may be used to choose a specific method for detecting those + distributed arguments. You may pass any of the automatic ``spec_detect_methods`` to this + argument though it is not necessary in the TPU, Slurm, or Open MPI cases. For other MPI + installations, if you have a functional ``mpi4py`` installed, you may pass + ``cluster_detection_method="mpi4py"`` to bootstrap the required arguments. + Otherwise, you must provide the ``coordinator_address``, ``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`. + Please note: on some systems, particularly HPC clusters that only access external networks + through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to + :func:`~jax.distributed.initialize` may timeout. You may need to unset these variables + prior to application launch. + Args: coordinator_address: the IP address of process `0` and a port on which that process should launch a coordinator service. The choice of @@ -166,6 +193,10 @@ def initialize(coordinator_address: str | None = None, local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``. If ``None``, defaults to all local devices being visible to the process except when processes are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process. + cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed + run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment, + and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``. + Legacy auto-detect options (OMPI, Slurm) remain enabled. initialization_timeout: Time period (in seconds) for which connection will be retried. If the initialization takes more than the timeout specified, the initialization will error. Defaults to 300 secs i.e. 5 mins. @@ -197,7 +228,8 @@ def initialize(coordinator_address: str | None = None, raise RuntimeError("jax.distributed.initialize() must be called before " "any JAX computations are executed.") global_state.initialize(coordinator_address, num_processes, process_id, - local_device_ids, initialization_timeout, coordinator_bind_address) + local_device_ids, cluster_detection_method, + initialization_timeout, coordinator_bind_address) atexit.register(shutdown) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index bdb26f83e0dd..760e340815af 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -33,6 +33,9 @@ from jax.experimental import pjit import jax.numpy as jnp +# Used to test for mpi4py installation and skip tests if not installed +import importlib.util + try: import portpicker except ImportError: @@ -218,6 +221,46 @@ def test_gpu_ompi_distributed_initialize(self): finally: proc.kill() + def test_gpu_mpi4py_distributed_initialize(self): + if not jtu.test_device_matches(['gpu']): + raise unittest.SkipTest('Tests only for GPU.') + if shutil.which('mpirun') is None: + raise unittest.SkipTest('Tests only for MPI (mpirun not found).') + if importlib.util.find_spec("mpi4py") is None: + raise unittest.SkipTest('Test of mpi4py initialize only possible with mpi4py installed.') + + num_gpus = 4 + num_gpus_per_task = 1 + + with contextlib.ExitStack() as exit_stack: + args = [ + 'mpirun', + '--oversubscribe', + '--allow-run-as-root', + '-n', + str(num_gpus), + sys.executable, + '-c', + ('import jax, os; ' + 'jax.distributed.initialize(spec_detection_method="mpi4py"); ' + 'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")' + ) + ] + env = os.environ.copy() + # In case the job was launched via Slurm, + # prevent OpenMPI from detecting Slurm environment + env.pop('SLURM_JOBID', None) + proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=True) + proc = exit_stack.enter_context(proc) + + try: + out, _ = proc.communicate() + self.assertEqual(proc.returncode, 0) + self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}') + finally: + proc.kill() + @unittest.skipIf( os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",