From 1992969da6164e456492fe0f9cd4287f6d8f03cf Mon Sep 17 00:00:00 2001 From: Corey Adams Date: Thu, 7 Mar 2024 12:27:43 -0600 Subject: [PATCH] Enable automatic detection of distrbuted variables with any configuration of MPI, as long as mpi4py is available --- jax/_src/clusters/__init__.py | 1 + jax/_src/clusters/mpi4py_cluster.py | 116 ++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+) create mode 100644 jax/_src/clusters/mpi4py_cluster.py diff --git a/jax/_src/clusters/__init__.py b/jax/_src/clusters/__init__.py index d933af613810..aa2f87867383 100644 --- a/jax/_src/clusters/__init__.py +++ b/jax/_src/clusters/__init__.py @@ -22,6 +22,7 @@ # available one from the list will be picked. from .ompi_cluster import OmpiCluster from .slurm_cluster import SlurmCluster +from .mpi4py_cluster import Mpi4pyCluster from .cloud_tpu_cluster import GkeTpuCluster from .cloud_tpu_cluster import MultisliceGceTpuCluster from .cloud_tpu_cluster import SingleSliceGceTpuCluster diff --git a/jax/_src/clusters/mpi4py_cluster.py b/jax/_src/clusters/mpi4py_cluster.py new file mode 100644 index 000000000000..b966da05f62a --- /dev/null +++ b/jax/_src/clusters/mpi4py_cluster.py @@ -0,0 +1,116 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from jax._src import clusters +import socket + +from numpy import unique, where + +from importlib.util import find_spec + +class Mpi4pyCluster(clusters.ClusterEnv): + + @classmethod + def is_env_present(cls) -> bool: + # in many HPC clusters, the variables `https_proxy` and `http_proxy` + # are set to enable access to normally unreachable network locations. + # For example, `pip install ...` fails on compute nodes without them. + + # Unfortunately, these variables break the jax distributed init. + # The user needs to unset them, but I don't want to modify the global + # python os.environ here for them (bad practice) + + # And I also don't know what the right way to raise a complaint here is + + # Relies on mpi4py: + return find_spec("mpi4py") is not None + + @classmethod + def get_coordinator_address(cls) -> str: + + # Using mpi4py, figure out rank 0 and it's hostname. + # Then broadcast the hostname and port. + + + from mpi4py import MPI + # Get the global communicator: + COMM_WORLD = MPI.COMM_WORLD + + # On rank 0, get the hostname: + + if COMM_WORLD.Get_rank() == 0: + # Order all the hostnames, and find unique ones + hostname = socket.gethostname() + + # Apparently, we want to pick a port in an ephemeral range... + port_id = hash(hostname) % 2**12 + (65535 - 2**12 + 1) + + hostname = f'{hostname}:{port_id}' + + else: + hostname = None + + + + # Broadcast the host_ip to all ranks: + hostname = COMM_WORLD.bcast(hostname, root=0) + # host_ip = COMM_WORLD.bcast(host_ip, root=0) + + + return hostname + + + @classmethod + def get_process_count(cls) -> int: + from mpi4py import MPI + return int(MPI.COMM_WORLD.Get_size()) + + @classmethod + def get_process_id(cls) -> int: + from mpi4py import MPI + return int(MPI.COMM_WORLD.Get_rank()) + + @classmethod + def get_local_process_id(cls) -> int | None: + + # Using mpi4py, split the global communicator into sub communicators + # based on hostname. mpi will assign them ranks and that will allow + # a selection of the local process ID. + from mpi4py import MPI + COMM_WORLD = MPI.COMM_WORLD + + hostname = socket.gethostname() + # host_key = host_key % + all_hostnames = COMM_WORLD.gather(hostname, root=0) + + if COMM_WORLD.Get_rank() == 0: + # Order all the hostnames, and find unique ones + unique_hosts = unique(all_hostnames) + # Numpy automatically sorts them. + else: + unique_hosts = None + + # Broadcast the list of hostnames: + unique_hosts = COMM_WORLD.bcast(unique_hosts, root=0) + + # Find the integer for this host in the list of hosts: + i = int(where(unique_hosts == hostname)[0]) + + new_comm = COMM_WORLD.Split(color=i) + + + # The rank in the new communicator - which is host-local only - IS the local rank: + return int(new_comm.Get_rank())