Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal for generic mpi4py initialization of jax distributed module #19409

Closed
coreyjadams opened this issue Jan 18, 2024 · 13 comments
Closed

Proposal for generic mpi4py initialization of jax distributed module #19409

coreyjadams opened this issue Jan 18, 2024 · 13 comments
Labels
enhancement New feature or request

Comments

@coreyjadams
Copy link

jax.distributed.initiallize() works, without arguments, on several but not all common MPI / Slurm parallel job launchers. Unfortunately, the environment variables used in the ompi_cluster.py class are not standardized. I'd like to submit a PR that uses mpi4py in a generic way to initialize the distributed mode automatically. A proposed file in jax/_src/clusters/ would look like this:

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from jax._src import clusters
import socket

from numpy import unique, where

from importlib.util import find_spec

class Mpi4pyCluster(clusters.ClusterEnv):
  
  @classmethod
  def is_env_present(cls) -> bool:
    # Relies on mpi4py:
    return find_spec("mpi4py") is not None

  @classmethod
  def get_coordinator_address(cls) -> str:

    # Using mpi4py, figure out rank 0 and it's hostname.
    # Then broadcast the hostname and port.

    from mpi4py import MPI
    # Get the global communicator:
    COMM_WORLD = MPI.COMM_WORLD

    # On rank 0, get the hostname:

    if COMM_WORLD.Get_rank() == 0:
        # Order all the hostnames, and find unique ones
        hostname = socket.gethostname()
        host_ip  = socket.gethostbyname(hostname)

    else:
        host_ip = None

    # Broadcast the host_ip to all ranks:
    host_ip = COMM_WORLD.bcast(host_ip, root=0)


    # Apparently, we want to pick a port in an ephemeral range...
    port_id = hash(host_ip) % 2**12 + (65535 - 2**12 + 1)

    return f'{host_ip}:{port_id}'


  @classmethod
  def get_process_count(cls) -> int:
    from mpi4py import MPI
    return MPI.COMM_WORLD.Get_size()

  @classmethod
  def get_process_id(cls) -> int:
    from mpi4py import MPI
    return MPI.COMM_WORLD.Get_rank()
  
  @classmethod
  def get_local_process_id(cls) -> int | None:
    # Using mpi4py, split the global communicator into sub communicators
    # based on hostname.  mpi will assign them ranks and that will allow
    # a selection of the local process ID.
    from mpi4py import MPI
    COMM_WORLD = MPI.COMM_WORLD
    
    hostname = socket.gethostname()
    # host_key = host_key %
    all_hostnames = COMM_WORLD.gather(hostname, root=0)

    if COMM_WORLD.Get_rank() == 0:
        # Order all the hostnames, and find unique ones
        unique_hosts = unique(all_hostnames)
        # Numpy automatically sorts them.
    else:
        unique_hosts = None

    # Broadcast the list of hostnames:
    unique_hosts = COMM_WORLD.bcast(unique_hosts, root=0)

    # Find the integer for this host in the list of hosts:
    i = int(where(unique_hosts == hostname)[0])

    new_comm = COMM_WORLD.Split(color=i)

    # The rank in the new communicator - which is host-local only - IS the local rank:
    return int(new_comm.Get_rank())

Assuming you'd welcome this pull request, I had a question or two:

  1. How often are the functions get_coordinator_address and get_local_process_id called? Each of them does some MPI ops which, at very large scales, might be detrimental to performance. Caching the results of the functions would help.
  2. Assuming this proceeds, would it be preferable to have this as a separate class to the ompi_clsuter.py implementation using env variables? Or as a fallback?
  3. Does the distributed initialization happen over TCP? If so, is it possible to bypass any of that? I expect that at the largest scales, this may break down performance.
  4. I wrap all of the from mpi4py import MPI calls inside the functions of the class. This line will raise an error if mpi4py is not installed. I'm assuming those functions will never get called if is_env_present returns False?

For some context, I work at Argonne National Laboratory and run JAX on our supercomputers. Currently we're running jobs with O(1000) Jax processes, on A100 gpus, but in the (near) future this will hopefully become O(100000) processes on Intel GPUs. I have been using mpi4jax for scaling (it's also great, probably you know about it already) but there are use cases for JAX's distributed package as well.

@coreyjadams coreyjadams added the enhancement New feature or request label Jan 18, 2024
@PhilipVinc
Copy link
Contributor

I fully support this effort!

This would remarkably simplify setup for several academic users. I've had to fight with setting up jax.distributed in weird clusters several times and this seems a great idea.

@coreyjadams
Copy link
Author

For what it's worth, I have done this and tested it on our Polaris Supercomputer at Argonne National Lab, the changes are pretty small, one additional file (Mpi4pyCluster.py) and one modified file.

https://github.com/coreyjadams/jax

Some open questions that I'm not clear on an answer to, exactly. Related to #9582 , the http proxy probably variables have to be unset to use the distributed.initialize functionality. It'd be nice to include a warning message in the timeout notifications, at least, to try again unsetting them? More aggressive options (unsetting the variables, raising warnings or exceptions preemptively) all have problems. But, I don't know where exactly the timeout error is emanating from, if it happens.

Would this PR be welcomed by the JAX team?

@hawkinsp
Copy link
Member

hawkinsp commented Mar 7, 2024

Ultimately, we (the JAX maintainers) aren't MPI users. So the MPI-using community will be the best judge of whether this approach works well! It looks plausible to me, and you should send a PR.

@nvcastet originally contributed that file, I think. Perhaps they have comments.

Would the mpi4py approach be superior in all cases? Could we just have the mpi4py version?

To improve the error message, I'd probably just stick a Python-level try block around the .connect() call, that perhaps looks for an HTTP proxy environment variable and warns. Or we could just warn right before the connect call, if we think that's likely always an error. The error itself originates from deep in the GRPC stack, so it's probably easiest to provide more information in the Python caller. Send a PR?

@hawkinsp
Copy link
Member

hawkinsp commented Mar 7, 2024

For some context, I work at Argonne National Laboratory and run JAX on our supercomputers. Currently we're running jobs with O(1000) Jax processes, on A100 gpus, but in the (near) future this will hopefully become O(100000) processes on Intel GPUs.

By the way, we'd love to hear more about your workload, always great to hear about people using JAX at scale. Do you use JAX's distributed jit or shard_map, or pmap?

@coreyjadams
Copy link
Author

OK, I'll put a PR together and send it in. I thought of another question - is it likely to call these cluster functions more than once or twice at startup? At large scales, doing MPI broadcast and COMM_Split will incur some overhead that, if it's just at startup, is worth it. If it's happening often, it would be inefficient and I should cache the output of those functions.

The mpi4py approach is more generic but introduces an additional dependency (mpi4py of course). I think it's worthwhile to leave the cluster implementations for the most common job managers (slurm, etc) to not force mpi4py on the users who are already happy. Since those approaches are just checking for the existence of ENV variables, it's not a significant overhead.

Perhaps a workable solution - since the MPI users are mostly from the big supercomputer facilities where this HTTP error will show up - is to emit a one-off warning from the Mpi4pyCluster file. This has the advantage of doing it after MPI init and we can control it to emit only from rank 0.

@coreyjadams
Copy link
Author

By the way, we'd love to hear more about your workload, always great to hear about people using JAX at scale. Do you use JAX's distributed jit or shard_map, or pmap?

Happy to share! We actually just release a public version this week: https://github.com/Nuclear-Physics-with-Machine-Learning/JAX_QMC_Public

The workload is doing variational monte carlo with quantum many body systems. We've actually tried this in torch, libtorch, julia, and tensorflow too! JAX is the winner for multiple reasons:

  • We need 2nd derivatives and jacobians that were horribly inefficiency in torch when we started this. The situation might be different now that torch has vmap, but I haven't checked.
  • Tensorflow was really good until we needed more complicated functions traced with xla and @tf.function - then I was seeing compile times longer than an hour and couldn't actually run jobs!
  • libtorch was better than torch but even with some work optimizing it I couldn't improve the concurrency of the jacobian calculation that we need.
  • Julia was just really frustrating to get performance out of. Static arrays was just ... not working how I expected it to. Certainly a user error :)

For scale out, I actually use mpi4jax in that repository above - it gives good performance and I've scaled this out to 2000+ A100s on our supercomputer. Our bottleneck at scale is an allreduce inside of a conjugate gradient solve, we're calling it a few hundred times per algorithm iteration and it starts to dominate. I actually presented this last fall, you can see the scaling plots around slide 24 here: https://www.alcf.anl.gov/sites/default/files/2023-10/ALCF-HandsOnWorkshop-VariationalQMC-Nuclei.pdf

My only real complaint about scaling out JAX was the inability to do MPI reductions in place - before the C.G. implementation for part of our workload, we used a Cholesky solve that needed a big (15+ GB) allreduce and having a separate buffer doubled our GPU VRAM usage. It's a big part of why we switched to C.G.

I'm working on a shard_map alternative to mpi4jax too. I'm not sure which will be more efficient, it's worth testing both, and at the moment we're blocked in mpi4jax with an unexpected bug: mpi4jax/mpi4jax#229

Thanks!
Corey

@nvcastet
Copy link
Collaborator

nvcastet commented Mar 7, 2024

jax.distributed.initiallize() works, without arguments, on several but not all common MPI / Slurm parallel job launchers.

From what i remembered, slurm_cluster.py should work with all Slurm jobs independent of MPI/PMI* usage since slurm will set those env variables. Is it because you are using mpirun/mpiexec instead of srun to launch your MPI application?

On the other hand, ompi_cluster.py, was made for detecting OpenMPI (orte runtime, default one up to v5) applications launched via mpirun/mpiexec?
See #14576 for v5.

Do you know which MPI distribution you running on your cluster?

We originally did not go the mpi4py path because we did not want to initialize MPI under-the-cover without the user permission.
Applications can initialize MPI differently than mpi4py default init.

  def is_env_present(cls) -> bool:
    # Relies on mpi4py:
    return find_spec("mpi4py") is not None

Here, the detection could be an issue since having the package installed in the environment does not mean you are launching a MPI application, therefore the job could have been launched with a non a mpi-compatible launcher (e.g. srun --mpi=none ...).

I like the simplicity of mpi4py but I am scared of enabling it fully-transparently with JAX for the reasons above.
A compromise would be to put Mpi4pyCluster in a utils python file and the user can opt-in for it by just importing it.
Importing it before calling initialize() will register it automatically as a known env to jax.

from myutils import Mpi4pyCluster
....
jax.distributed.initialize()

@nvcastet
Copy link
Collaborator

nvcastet commented Mar 7, 2024

Another potential solution for mpi4py users is to have mpi4jax defines Mpi4pyCluster at init time of the mpi4jax module since mpi4jax already has that hard dependency on mpi4py anyway.

@coreyjadams
Copy link
Author

On our systems, we're using MPICH from HPE but in my experience when you are dealing with the vendor optimized mpi implementations, from HPE/Cray/Intel/IBM/etc - the env variables they set are different. So the OMPI variables are fine but not necessarily generic.

And yes - we're not launching with Slurm, the job scheduler on this cluster is PBSPro. I suspect we're running on very different clusters!

Do you think there is much overlap in the use case of mpi4jax vs. jax.distributed? If users want to use mpi4jax, they are likely to just use that; users of pmap/shard_map etc are not likely to call mpi4jax reductions on sharded tensors - unless there is some use case I'm missing? I agree it would be bad to limit users to one OR the other though, better to maintain both as viable options separate or together.

What do you think about an optional argument to jax.distributed.initialize that allows the user to select an auto-init method? Legacy methods with just reading env variables can leave it blank, but it could also allow the user to prioritize one init method over another.

For example, a user on a slurm system with mpi4py installed could call jax.distributed.initialize(auto_init_method="slurm") and know that mpi4py will not be initialized. Or, they could call jax.distributed.initialize(auto_init_method="mpi") and force the use of mpi4py. Perhaps jax.distributed.initialize(auto_init_method="any") could let JAX take the wheel and figure out the initialization parameters in whatever way works first?

Anyways, maybe the real answer is to just update the documentation with "If you're launching JAX on a cluster with MPI, here is a technique to pick suitable initialization parameters via mpi4py ... "

@nvcastet
Copy link
Collaborator

nvcastet commented Mar 8, 2024

Hi Corey,

I agree mpi4jax and jax.distributed may be used together or separate.

I really like your mpi4py approach to catch all the vendor-specific MPI implementations. Figuring out all the potential env variables that the different vendors are setting or not setting is a nightmare even if it would not add the mentioned extra downsides (extra dependency, initializing MPI, and extra communications).

The original ideal goal of jax.distributed.initialize() [with no args] was to be able to run the same script on different environments without any code change (not even argument change) by doing auto-detection.
This implies that when implementing the def is_env_present(cls) method of the environment class, if True is returned, we know for sure that the application is running in that environment and we can get the job parameters from the environment to initialize jax.distributed.

For mpi4py, I think we agree the user will need to opt-in and says "I launched my job with a mpi-compatible launcher and i am good leveraging mpi4py to collect the info needed for jax.distributed".
For this new "opt-in" scenario, the user would need to be able to specify the method as you mentioned something like jax.distributed.initialize(spec_detection_method="mpi4py").
@hawkinsp and @skye what are your thoughts?

@coreyjadams
Copy link
Author

coreyjadams commented Mar 11, 2024

Just so we can have something clear to discuss: I opened a PR #20174 based on what we've talked about here, using an exclusively opt-in method. Hopefully it proves useful!

@nouiz
Copy link
Collaborator

nouiz commented Jul 18, 2024

#20174 is merged. Should we close this issue?

@hawkinsp
Copy link
Member

I think so. The documentation could perhaps be improved (how will users discover this?) But that can be addressed separately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants