Skip to content

Commit

Permalink
Address feedback and comments on PR 20174; fix typo in documentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
coreyjadams committed Apr 16, 2024
1 parent 4f22d86 commit 01f4709
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 11 deletions.
10 changes: 10 additions & 0 deletions jax/_src/clusters/cloud_tpu_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def get_gce_worker_endpoints() -> str:
return get_metadata('worker-network-endpoints').split(',')

class SingleSliceGceTpuCluster(clusters.ClusterEnv):

name: str = "singleslicegcetpu"

@classmethod
def is_env_present(cls) -> bool:
return running_in_cloud_tpu_vm and is_gce_env() and not is_multislice_gce_env()
Expand All @@ -104,6 +107,9 @@ def get_local_process_id(cls) -> int | None:
return None

class MultisliceGceTpuCluster(clusters.ClusterEnv):

name: str = "multislicegcetpu"

@classmethod
def is_env_present(cls) -> bool:
return running_in_cloud_tpu_vm and is_multislice_gce_env()
Expand Down Expand Up @@ -160,6 +166,10 @@ def _get_process_id_in_slice() -> int:
return int(get_metadata('agent-worker-number'))

class GkeTpuCluster(MultisliceGceTpuCluster):


name: str = "gketpu"

# This class handles both single and multislice GKE as the environment
# variables are set the same in both cases.
@classmethod
Expand Down
15 changes: 10 additions & 5 deletions jax/_src/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ClusterEnv:
"""

_cluster_types: list[type[ClusterEnv]] = []
opt_in_only_method: bool = False # Override this in derived classes if necessary

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
Expand All @@ -54,12 +55,16 @@ def auto_detect_unset_distributed_params(cls,

# First, we check the spec detection method because it will ignore submitted values
# If if succeeds.
if spec_detection_method == "mpi4py":
# We directly select the Mpi4pyCluster environment with an override here to opt in:
from jax._src.clusters.mpi4py_cluster import Mpi4pyCluster
env = Mpi4pyCluster if Mpi4pyCluster.is_env_present(opt_in=True) else None
if spec_detection_method is not None:
env = next( (env for env in cls._cluster_types if env.name == spec_detection_method), None )
if env is None:
logger.error(f"Automatic Distributed initialization can not proceed:"
f" {spec_detection_method} is not supported.")
if not env.is_env_present():
logger.error(f"Automatic Distributed initialization can not proceed:"
f" {spec_detection_method} is supported but not functional in this environment.")
else:
env = next((env for env in cls._cluster_types if env.is_env_present()), None)
env = next((env for env in cls._cluster_types if env.opt_in_only_method == False and env.is_env_present()), None)

# Above: I have wrapped the env selection in a conditional to go through
# opt-in methods first (currently only mpi4py) but to check all possible options
Expand Down
8 changes: 6 additions & 2 deletions jax/_src/clusters/mpi4py_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@

class Mpi4pyCluster(clusters.ClusterEnv):


name: str = "mpi4py"
opt_in_only_method: bool = True

@classmethod
def is_env_present(cls, opt_in=False) -> bool:
def is_env_present(cls) -> bool:
# Why include and opt_in? Enables this class to conform to
# every other ClusterEnv subclass while always being rejected
# as viable, except in the express case where we request to check
Expand All @@ -41,7 +45,7 @@ def is_env_present(cls, opt_in=False) -> bool:
# And I also don't know what the right way to raise a complaint here is

# Relies on mpi4py:
return find_spec("mpi4py") is not None and opt_in == True
return find_spec("mpi4py") is not None

@classmethod
def get_coordinator_address(cls) -> str:
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/clusters/ompi_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK'

class OmpiCluster(clusters.ClusterEnv):

name: str = "ompi"

@classmethod
def is_env_present(cls) -> bool:
return _ORTE_URI in os.environ
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/clusters/slurm_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
_NUM_NODES = 'SLURM_STEP_NUM_NODES'

class SlurmCluster(clusters.ClusterEnv):

name: str = "slurm"

@classmethod
def is_env_present(cls) -> bool:
return _JOBID_PARAM in os.environ
Expand Down
6 changes: 2 additions & 4 deletions jax/_src/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def initialize(self,
if isinstance(local_device_ids, int):
local_device_ids = [local_device_ids]

if spec_detection_method is not None and spec_detection_method not in ["mpi4py"]:
raise ValueError("spec_detection method should only be None, or \"mpi4py\".")

(coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params(
Expand All @@ -75,7 +73,7 @@ def initialize(self,
self.process_id = process_id

# Emit a warning about PROXY variables if they are in the user's env:
proxy_vars = [ key for key in os.environ.keys() if 'proxy' in key.lower()]
proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()]

if len(proxy_vars) > 0:
vars = " ".join(proxy_vars) + ". "
Expand Down Expand Up @@ -152,7 +150,7 @@ def initialize(coordinator_address: str | None = None,
Please note: on some systems, particularly HPC clusters that only access external networks
through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to
:func`~jax.distributed.initialize` may timeout. You may need to unset these variables
:func:`~jax.distributed.initialize` may timeout. You may need to unset these variables
prior to application launch.
Args:
Expand Down

0 comments on commit 01f4709

Please sign in to comment.