-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proposal for generic mpi4py initialization of jax distributed module #19409
Comments
I fully support this effort! This would remarkably simplify setup for several academic users. I've had to fight with setting up jax.distributed in weird clusters several times and this seems a great idea. |
For what it's worth, I have done this and tested it on our Polaris Supercomputer at Argonne National Lab, the changes are pretty small, one additional file (Mpi4pyCluster.py) and one modified file. https://github.com/coreyjadams/jax Some open questions that I'm not clear on an answer to, exactly. Related to #9582 , the http proxy probably variables have to be unset to use the distributed.initialize functionality. It'd be nice to include a warning message in the timeout notifications, at least, to try again unsetting them? More aggressive options (unsetting the variables, raising warnings or exceptions preemptively) all have problems. But, I don't know where exactly the timeout error is emanating from, if it happens. Would this PR be welcomed by the JAX team? |
Ultimately, we (the JAX maintainers) aren't MPI users. So the MPI-using community will be the best judge of whether this approach works well! It looks plausible to me, and you should send a PR. @nvcastet originally contributed that file, I think. Perhaps they have comments. Would the mpi4py approach be superior in all cases? Could we just have the mpi4py version? To improve the error message, I'd probably just stick a Python-level |
By the way, we'd love to hear more about your workload, always great to hear about people using JAX at scale. Do you use JAX's distributed |
OK, I'll put a PR together and send it in. I thought of another question - is it likely to call these cluster functions more than once or twice at startup? At large scales, doing MPI broadcast and COMM_Split will incur some overhead that, if it's just at startup, is worth it. If it's happening often, it would be inefficient and I should cache the output of those functions. The mpi4py approach is more generic but introduces an additional dependency ( Perhaps a workable solution - since the MPI users are mostly from the big supercomputer facilities where this HTTP error will show up - is to emit a one-off warning from the Mpi4pyCluster file. This has the advantage of doing it after MPI init and we can control it to emit only from rank 0. |
Happy to share! We actually just release a public version this week: https://github.com/Nuclear-Physics-with-Machine-Learning/JAX_QMC_Public The workload is doing variational monte carlo with quantum many body systems. We've actually tried this in torch, libtorch, julia, and tensorflow too! JAX is the winner for multiple reasons:
For scale out, I actually use mpi4jax in that repository above - it gives good performance and I've scaled this out to 2000+ A100s on our supercomputer. Our bottleneck at scale is an allreduce inside of a conjugate gradient solve, we're calling it a few hundred times per algorithm iteration and it starts to dominate. I actually presented this last fall, you can see the scaling plots around slide 24 here: https://www.alcf.anl.gov/sites/default/files/2023-10/ALCF-HandsOnWorkshop-VariationalQMC-Nuclei.pdf My only real complaint about scaling out JAX was the inability to do MPI reductions in place - before the C.G. implementation for part of our workload, we used a Cholesky solve that needed a big (15+ GB) allreduce and having a separate buffer doubled our GPU VRAM usage. It's a big part of why we switched to C.G. I'm working on a Thanks! |
From what i remembered, slurm_cluster.py should work with all Slurm jobs independent of MPI/PMI* usage since slurm will set those env variables. Is it because you are using On the other hand, Do you know which MPI distribution you running on your cluster? We originally did not go the
Here, the detection could be an issue since having the package installed in the environment does not mean you are launching a MPI application, therefore the job could have been launched with a non a mpi-compatible launcher (e.g. I like the simplicity of
|
Another potential solution for |
On our systems, we're using MPICH from HPE but in my experience when you are dealing with the vendor optimized mpi implementations, from HPE/Cray/Intel/IBM/etc - the env variables they set are different. So the OMPI variables are fine but not necessarily generic. And yes - we're not launching with Slurm, the job scheduler on this cluster is PBSPro. I suspect we're running on very different clusters! Do you think there is much overlap in the use case of What do you think about an optional argument to For example, a user on a slurm system with mpi4py installed could call Anyways, maybe the real answer is to just update the documentation with "If you're launching JAX on a cluster with MPI, here is a technique to pick suitable initialization parameters via mpi4py ... " |
Hi Corey, I agree I really like your The original ideal goal of For |
Just so we can have something clear to discuss: I opened a PR #20174 based on what we've talked about here, using an exclusively opt-in method. Hopefully it proves useful! |
#20174 is merged. Should we close this issue? |
I think so. The documentation could perhaps be improved (how will users discover this?) But that can be addressed separately. |
jax.distributed.initiallize()
works, without arguments, on several but not all common MPI / Slurm parallel job launchers. Unfortunately, the environment variables used in theompi_cluster.py
class are not standardized. I'd like to submit a PR that usesmpi4py
in a generic way to initialize the distributed mode automatically. A proposed file injax/_src/clusters/
would look like this:Assuming you'd welcome this pull request, I had a question or two:
get_coordinator_address
andget_local_process_id
called? Each of them does some MPI ops which, at very large scales, might be detrimental to performance. Caching the results of the functions would help.ompi_clsuter.py
implementation using env variables? Or as a fallback?from mpi4py import MPI
calls inside the functions of the class. This line will raise an error if mpi4py is not installed. I'm assuming those functions will never get called ifis_env_present
returns False?For some context, I work at Argonne National Laboratory and run JAX on our supercomputers. Currently we're running jobs with O(1000) Jax processes, on A100 gpus, but in the (near) future this will hopefully become O(100000) processes on Intel GPUs. I have been using mpi4jax for scaling (it's also great, probably you know about it already) but there are use cases for JAX's distributed package as well.
The text was updated successfully, but these errors were encountered: