Skip to content

Commit

Permalink
Unify variable naming and fix function argument ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
coreyjadams committed May 28, 2024
1 parent 72fe093 commit 19e6694
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 25 deletions.
3 changes: 2 additions & 1 deletion jax/_src/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._cluster_types.append(cls)


@classmethod
# pytype: disable=bad-return-type
def auto_detect_unset_distributed_params(cls,
Expand All @@ -61,7 +62,7 @@ def auto_detect_unset_distributed_params(cls,
if env is None:
logger.error(f"Automatic Distributed initialization can not proceed:"
f" {spec_detection_method} is not supported.")
if not env.is_env_present():
elif not env.is_env_present():
logger.error(f"Automatic Distributed initialization can not proceed:"
f" {spec_detection_method} is supported but not functional in this environment.")
else:
Expand Down
33 changes: 19 additions & 14 deletions jax/_src/clusters/mpi4py_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,24 +101,29 @@ def get_local_process_id(cls) -> int | None:
from mpi4py import MPI
COMM_WORLD = MPI.COMM_WORLD

hostname = socket.gethostname()
# host_key = host_key %
all_hostnames = COMM_WORLD.gather(hostname, root=0)
# This is a previous method that is replaced with a different mpi split:

if COMM_WORLD.Get_rank() == 0:
# Order all the hostnames, and find unique ones
unique_hosts = unique(all_hostnames)
# Numpy automatically sorts them.
else:
unique_hosts = None
# hostname = socket.gethostname()
# # host_key = host_key %
# all_hostnames = COMM_WORLD.gather(hostname, root=0)

# if COMM_WORLD.Get_rank() == 0:
# # Order all the hostnames, and find unique ones
# unique_hosts = unique(all_hostnames)
# # Numpy automatically sorts them.
# else:
# unique_hosts = None

# # Broadcast the list of hostnames:
# unique_hosts = COMM_WORLD.bcast(unique_hosts, root=0)

# Broadcast the list of hostnames:
unique_hosts = COMM_WORLD.bcast(unique_hosts, root=0)
# # Find the integer for this host in the list of hosts:
# i = int(where(unique_hosts == hostname)[0])

# Find the integer for this host in the list of hosts:
i = int(where(unique_hosts == hostname)[0])
# new_comm = COMM_WORLD.Split(color=i)

new_comm = COMM_WORLD.Split(color=i)
# This is the alternative method that is simpler:
new_comm = COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)


# The rank in the new communicator - which is host-local only - IS the local rank:
Expand Down
17 changes: 7 additions & 10 deletions jax/_src/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def initialize(self,
num_processes,
process_id,
local_device_ids,
initialization_timeout,
spec_detection_method,
initialization_timeout,
)
)

Expand Down Expand Up @@ -143,7 +143,7 @@ def initialize(coordinator_address: str | None = None,
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
spec_detect_method: str | None = None,
spec_detection_method: str | None = None,
initialization_timeout: int = 300,
coordinator_bind_address: str | None = None):
"""Initializes the JAX distributed system.
Expand All @@ -161,10 +161,10 @@ def initialize(coordinator_address: str | None = None,
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they
will be chosen automatically.
The ``spec_detect_method`` may be used to intentionally, automatically select the values for
The ``spec_detection_method`` may be used to intentionally, automatically select the values for
arguments. You may pass any of the automatic ``spec_detect_methods`` to this argument though
it is not necessary in the TPU, Slurm, or Open MPI cases. For other MPI installations,
if you have a functional ``mpi4py`` installed, you may pass ``spec_detect_method="mpi4py"``
if you have a functional ``mpi4py`` installed, you may pass ``spec_detection_method="mpi4py"``
to bootstrap the required arguments.
Otherwise, you must provide the ``coordinator_address``,
Expand All @@ -191,7 +191,7 @@ def initialize(coordinator_address: str | None = None,
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
If ``None``, defaults to all local devices being visible to the process except when processes
are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
spec_detect_method: An optional string to attempt to autodetect the configuration of the distributed
spec_detection_method: An optional string to attempt to autodetect the configuration of the distributed
run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment,
and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``.
Legacy auto-detect options (OMPI, Slurm) remain enabled.
Expand Down Expand Up @@ -226,11 +226,8 @@ def initialize(coordinator_address: str | None = None,
raise RuntimeError("jax.distributed.initialize() must be called before "
"any JAX computations are executed.")
global_state.initialize(coordinator_address, num_processes, process_id,
<<<<<<< HEAD
local_device_ids, spec_detect_method, initialization_timeout)
=======
local_device_ids, initialization_timeout, coordinator_bind_address)
>>>>>>> 06cd05d1d6722e77744556983e99396d0c208774
local_device_ids, spec_detection_method,
initialization_timeout, coordinator_bind_address)
atexit.register(shutdown)


Expand Down

0 comments on commit 19e6694

Please sign in to comment.