Skip to content

Commit

Permalink
Enable automatic detection of distrbuted variables with any configura…
Browse files Browse the repository at this point in the history
…tion of MPI, as long as mpi4py is available
  • Loading branch information
coreyjadams committed Mar 7, 2024
1 parent f0afc1b commit 1992969
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 0 deletions.
1 change: 1 addition & 0 deletions jax/_src/clusters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# available one from the list will be picked.
from .ompi_cluster import OmpiCluster
from .slurm_cluster import SlurmCluster
from .mpi4py_cluster import Mpi4pyCluster
from .cloud_tpu_cluster import GkeTpuCluster
from .cloud_tpu_cluster import MultisliceGceTpuCluster
from .cloud_tpu_cluster import SingleSliceGceTpuCluster
116 changes: 116 additions & 0 deletions jax/_src/clusters/mpi4py_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from jax._src import clusters
import socket

from numpy import unique, where

from importlib.util import find_spec

class Mpi4pyCluster(clusters.ClusterEnv):

@classmethod
def is_env_present(cls) -> bool:
# in many HPC clusters, the variables `https_proxy` and `http_proxy`
# are set to enable access to normally unreachable network locations.
# For example, `pip install ...` fails on compute nodes without them.

# Unfortunately, these variables break the jax distributed init.
# The user needs to unset them, but I don't want to modify the global
# python os.environ here for them (bad practice)

# And I also don't know what the right way to raise a complaint here is

# Relies on mpi4py:
return find_spec("mpi4py") is not None

@classmethod
def get_coordinator_address(cls) -> str:

# Using mpi4py, figure out rank 0 and it's hostname.
# Then broadcast the hostname and port.


from mpi4py import MPI
# Get the global communicator:
COMM_WORLD = MPI.COMM_WORLD

# On rank 0, get the hostname:

if COMM_WORLD.Get_rank() == 0:
# Order all the hostnames, and find unique ones
hostname = socket.gethostname()

# Apparently, we want to pick a port in an ephemeral range...
port_id = hash(hostname) % 2**12 + (65535 - 2**12 + 1)

hostname = f'{hostname}:{port_id}'

else:
hostname = None



# Broadcast the host_ip to all ranks:
hostname = COMM_WORLD.bcast(hostname, root=0)
# host_ip = COMM_WORLD.bcast(host_ip, root=0)


return hostname


@classmethod
def get_process_count(cls) -> int:
from mpi4py import MPI
return int(MPI.COMM_WORLD.Get_size())

@classmethod
def get_process_id(cls) -> int:
from mpi4py import MPI
return int(MPI.COMM_WORLD.Get_rank())

@classmethod
def get_local_process_id(cls) -> int | None:

# Using mpi4py, split the global communicator into sub communicators
# based on hostname. mpi will assign them ranks and that will allow
# a selection of the local process ID.
from mpi4py import MPI
COMM_WORLD = MPI.COMM_WORLD

hostname = socket.gethostname()
# host_key = host_key %
all_hostnames = COMM_WORLD.gather(hostname, root=0)

if COMM_WORLD.Get_rank() == 0:
# Order all the hostnames, and find unique ones
unique_hosts = unique(all_hostnames)
# Numpy automatically sorts them.
else:
unique_hosts = None

# Broadcast the list of hostnames:
unique_hosts = COMM_WORLD.bcast(unique_hosts, root=0)

# Find the integer for this host in the list of hosts:
i = int(where(unique_hosts == hostname)[0])

new_comm = COMM_WORLD.Split(color=i)


# The rank in the new communicator - which is host-local only - IS the local rank:
return int(new_comm.Get_rank())

0 comments on commit 1992969

Please sign in to comment.