Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parallelization over multiple devices #39

Open
kazewong opened this issue Sep 19, 2022 · 3 comments
Open

Add parallelization over multiple devices #39

kazewong opened this issue Sep 19, 2022 · 3 comments
Assignees

Comments

@kazewong
Copy link
Owner

Currently the code runs on one device, which doesn't allow scaling to larger computational network such as TPU pods.

Parallelizing over local sampler should be relatively simple, since that does not required communication between devices. Note that if single evaluation of the likelihood demands more RAM than what's available on the chips (TPUv4 has 8GB RAM per core, gradient of functions may cause problem), the computation may need to be shard to multiple device, but that should be taken care separately.

Evaluation of global sampler should be similar to local sampler.

Training the normalizing flow requires collecting data from multiple devices and updating weights in a somewhat sync version. Have a look of pmap to see how to deal with that https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html

@kazewong kazewong self-assigned this Sep 19, 2022
@ahnitz
Copy link

ahnitz commented Mar 21, 2023

@kazewong Is there any though into parallelization over multi-node CPU resources? Such as through MPI?

@ahnitz
Copy link

ahnitz commented Mar 21, 2023

Woops, I see you have an issue for that already #61

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants