diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index a978cf4beff7..c0d64918ef0f 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -83,6 +83,9 @@ def get_gce_worker_endpoints() -> str: return get_metadata('worker-network-endpoints').split(',') class SingleSliceGceTpuCluster(clusters.ClusterEnv): + + name: str = "singleslicegcetpu" + @classmethod def is_env_present(cls) -> bool: return running_in_cloud_tpu_vm and is_gce_env() and not is_multislice_gce_env() @@ -104,6 +107,9 @@ def get_local_process_id(cls) -> int | None: return None class MultisliceGceTpuCluster(clusters.ClusterEnv): + + name: str = "multislicegcetpu" + @classmethod def is_env_present(cls) -> bool: return running_in_cloud_tpu_vm and is_multislice_gce_env() @@ -160,6 +166,10 @@ def _get_process_id_in_slice() -> int: return int(get_metadata('agent-worker-number')) class GkeTpuCluster(MultisliceGceTpuCluster): + + + name: str = "gketpu" + # This class handles both single and multislice GKE as the environment # variables are set the same in both cases. @classmethod diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index 17a28b817205..8a83216be11d 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -31,6 +31,7 @@ class ClusterEnv: """ _cluster_types: list[type[ClusterEnv]] = [] + opt_in_only_method: bool = False # Override this in derived classes if necessary def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) @@ -54,12 +55,16 @@ def auto_detect_unset_distributed_params(cls, # First, we check the spec detection method because it will ignore submitted values # If if succeeds. - if spec_detection_method == "mpi4py": - # We directly select the Mpi4pyCluster environment with an override here to opt in: - from jax._src.clusters.mpi4py_cluster import Mpi4pyCluster - env = Mpi4pyCluster if Mpi4pyCluster.is_env_present(opt_in=True) else None + if spec_detection_method is not None: + env = next( (env for env in cls._cluster_types if env.name == spec_detection_method), None ) + if env is None: + logger.error(f"Automatic Distributed initialization can not proceed:" + f" {spec_detection_method} is not supported.") + if not env.is_env_present(): + logger.error(f"Automatic Distributed initialization can not proceed:" + f" {spec_detection_method} is supported but not functional in this environment.") else: - env = next((env for env in cls._cluster_types if env.is_env_present()), None) + env = next((env for env in cls._cluster_types if env.opt_in_only_method == False and env.is_env_present()), None) # Above: I have wrapped the env selection in a conditional to go through # opt-in methods first (currently only mpi4py) but to check all possible options diff --git a/jax/_src/clusters/mpi4py_cluster.py b/jax/_src/clusters/mpi4py_cluster.py index 8c51abf77835..d57e16162bc7 100644 --- a/jax/_src/clusters/mpi4py_cluster.py +++ b/jax/_src/clusters/mpi4py_cluster.py @@ -23,8 +23,12 @@ class Mpi4pyCluster(clusters.ClusterEnv): + + name: str = "mpi4py" + opt_in_only_method: bool = True + @classmethod - def is_env_present(cls, opt_in=False) -> bool: + def is_env_present(cls) -> bool: # Why include and opt_in? Enables this class to conform to # every other ClusterEnv subclass while always being rejected # as viable, except in the express case where we request to check @@ -41,7 +45,7 @@ def is_env_present(cls, opt_in=False) -> bool: # And I also don't know what the right way to raise a complaint here is # Relies on mpi4py: - return find_spec("mpi4py") is not None and opt_in == True + return find_spec("mpi4py") is not None @classmethod def get_coordinator_address(cls) -> str: diff --git a/jax/_src/clusters/ompi_cluster.py b/jax/_src/clusters/ompi_cluster.py index 908af28a027b..3ae5024bef5b 100644 --- a/jax/_src/clusters/ompi_cluster.py +++ b/jax/_src/clusters/ompi_cluster.py @@ -25,6 +25,9 @@ _LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK' class OmpiCluster(clusters.ClusterEnv): + + name: str = "ompi" + @classmethod def is_env_present(cls) -> bool: return _ORTE_URI in os.environ diff --git a/jax/_src/clusters/slurm_cluster.py b/jax/_src/clusters/slurm_cluster.py index 5edacb4f5d7a..7c9fbb7f1e78 100644 --- a/jax/_src/clusters/slurm_cluster.py +++ b/jax/_src/clusters/slurm_cluster.py @@ -25,6 +25,9 @@ _NUM_NODES = 'SLURM_STEP_NUM_NODES' class SlurmCluster(clusters.ClusterEnv): + + name: str = "slurm" + @classmethod def is_env_present(cls) -> bool: return _JOBID_PARAM in os.environ diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 9649e949359b..86bfde2338b1 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -48,8 +48,6 @@ def initialize(self, if isinstance(local_device_ids, int): local_device_ids = [local_device_ids] - if spec_detection_method is not None and spec_detection_method not in ["mpi4py"]: - raise ValueError("spec_detection method should only be None, or \"mpi4py\".") (coordinator_address, num_processes, process_id, local_device_ids) = ( clusters.ClusterEnv.auto_detect_unset_distributed_params( @@ -75,7 +73,7 @@ def initialize(self, self.process_id = process_id # Emit a warning about PROXY variables if they are in the user's env: - proxy_vars = [ key for key in os.environ.keys() if 'proxy' in key.lower()] + proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()] if len(proxy_vars) > 0: vars = " ".join(proxy_vars) + ". " @@ -152,7 +150,7 @@ def initialize(coordinator_address: str | None = None, Please note: on some systems, particularly HPC clusters that only access external networks through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to - :func`~jax.distributed.initialize` may timeout. You may need to unset these variables + :func:`~jax.distributed.initialize` may timeout. You may need to unset these variables prior to application launch. Args: