Skip to content

Commit

Permalink
Auto-detect of mpi4py-based configuration is now strictly opt-in.
Browse files Browse the repository at this point in the history
  • Loading branch information
coreyjadams committed Mar 11, 2024
1 parent 1992969 commit 900a037
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 6 deletions.
19 changes: 17 additions & 2 deletions jax/_src/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,29 @@ def auto_detect_unset_distributed_params(cls,
coordinator_address: str | None,
num_processes: int | None,
process_id: int | None,
local_device_ids: Sequence[int] | None
local_device_ids: Sequence[int] | None,
spec_detection_method: str | None,
) -> tuple[str | None, int | None, int | None,
Sequence[int] | None]:

if all(p is not None for p in (coordinator_address, num_processes,
process_id, local_device_ids)):
return (coordinator_address, num_processes, process_id,
local_device_ids)
env = next((env for env in cls._cluster_types if env.is_env_present()), None)

# First, we check the spec detection method because it will ignore submitted values
# If if succeeds.
if spec_detection_method == "mpi4py":
# We directly select the Mpi4pyCluster environment with an override here to opt in:
from jax._src.clusters.mpi4py_cluster import Mpi4pyCluster
env = Mpi4pyCluster if Mpi4pyCluster.is_env_present(opt_in=True) else None
else:
env = next((env for env in cls._cluster_types if env.is_env_present()), None)

# Above: I have wrapped the env selection in a conditional to go through
# opt-in methods first (currently only mpi4py) but to check all possible options
# otherwise. Passing no spec_detection_method results in the default, original behavior.

if env:
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
if coordinator_address is None:
Expand Down
9 changes: 7 additions & 2 deletions jax/_src/clusters/mpi4py_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
class Mpi4pyCluster(clusters.ClusterEnv):

@classmethod
def is_env_present(cls) -> bool:
def is_env_present(cls, opt_in=False) -> bool:
# Why include and opt_in? Enables this class to conform to
# every other ClusterEnv subclass while always being rejected
# as viable, except in the express case where we request to check
# it explicitly.

# in many HPC clusters, the variables `https_proxy` and `http_proxy`
# are set to enable access to normally unreachable network locations.
# For example, `pip install ...` fails on compute nodes without them.
Expand All @@ -36,7 +41,7 @@ def is_env_present(cls) -> bool:
# And I also don't know what the right way to raise a complaint here is

# Relies on mpi4py:
return find_spec("mpi4py") is not None
return find_spec("mpi4py") is not None and opt_in == True

@classmethod
def get_coordinator_address(cls) -> str:
Expand Down
31 changes: 29 additions & 2 deletions jax/_src/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,19 @@ def initialize(self,
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
spec_detection_method: str | None = None,
initialization_timeout: int = 300):
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
if isinstance(local_device_ids, int):
local_device_ids = [local_device_ids]

if spec_detection_method is not None and spec_detection_method not in ["mpi4py"]:
raise ValueError("spec_detection method should only be None, or \"mpi4py\".")

(coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address, num_processes, process_id, local_device_ids
coordinator_address, num_processes, process_id, local_device_ids, spec_detection_method
)
)

Expand All @@ -70,6 +74,18 @@ def initialize(self,

self.process_id = process_id

# Emit a warning about PROXY variables if they are in the user's env:
proxy_vars = [ key for key in os.environ.keys() if 'proxy' in key.lower()]

if len(proxy_vars) > 0:
vars = " ".join(proxy_vars) + ". "
warning = (
f'JAX detected proxy variable(s) in the environment as distributed setup: {vars}'
'On some systems, this may cause a hang of distributed.initialize and '
'you may need to unset these ENV variable(s)'
)
logger.warning(warning)

if process_id == 0:
if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
Expand Down Expand Up @@ -114,6 +130,7 @@ def initialize(coordinator_address: str | None = None,
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
spec_detect_method: str | None = None,
initialization_timeout: int = 300):
"""Initializes the JAX distributed system.
Expand All @@ -133,6 +150,11 @@ def initialize(coordinator_address: str | None = None,
Otherwise, you must provide the ``coordinator_address``,
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
Please note: on some systems, particularly HPC clusters that only access external networks
through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to
:func`~jax.distributed.initialize` may timeout. You may need to unset these variables
prior to application launch.
Args:
coordinator_address: the IP address of process `0` and a port on which that
process should launch a coordinator service. The choice of
Expand All @@ -149,6 +171,11 @@ def initialize(coordinator_address: str | None = None,
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
If ``None``, defaults to all local devices being visible to the process except when processes
are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
spec_detect_method: An optional string to attempt to autodetect the configuration of the distributed
run. Available options are "mpipy" only at the moment, though more options may be available in the future.
Note that "mpi4py" method requires you to have a working `mpi4py` install in your environment,
and launch the applicatoin with an MPI-compatible job launcher such as `mpiexec`.
Legacy auto-detect options (OMPI, Slurm) remain enabled.
initialization_timeout: Time period (in seconds) for which connection will
be retried. If the initialization takes more than the timeout specified,
the initialization will error. Defaults to 300 secs i.e. 5 mins.
Expand All @@ -174,7 +201,7 @@ def initialize(coordinator_address: str | None = None,
raise RuntimeError("jax.distributed.initialize() must be called before "
"any JAX computations are executed.")
global_state.initialize(coordinator_address, num_processes, process_id,
local_device_ids, initialization_timeout)
local_device_ids, spec_detect_method, initialization_timeout)
atexit.register(shutdown)


Expand Down

0 comments on commit 900a037

Please sign in to comment.