Skip to content

Commit

Permalink
Update mpi4py_cluster.py
Browse files Browse the repository at this point in the history
Remove unneeded commented code.
  • Loading branch information
coreyjadams committed Jun 27, 2024
1 parent 6701bd1 commit 6cc07a9
Showing 1 changed file with 1 addition and 22 deletions.
23 changes: 1 addition & 22 deletions jax/_src/clusters/mpi4py_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,28 +100,7 @@ def get_local_process_id(cls) -> int | None:
# a selection of the local process ID.
from mpi4py import MPI
COMM_WORLD = MPI.COMM_WORLD

# This is a previous method that is replaced with a different mpi split:

# hostname = socket.gethostname()
# # host_key = host_key %
# all_hostnames = COMM_WORLD.gather(hostname, root=0)

# if COMM_WORLD.Get_rank() == 0:
# # Order all the hostnames, and find unique ones
# unique_hosts = unique(all_hostnames)
# # Numpy automatically sorts them.
# else:
# unique_hosts = None

# # Broadcast the list of hostnames:
# unique_hosts = COMM_WORLD.bcast(unique_hosts, root=0)

# # Find the integer for this host in the list of hosts:
# i = int(where(unique_hosts == hostname)[0])

# new_comm = COMM_WORLD.Split(color=i)


# This is the alternative method that is simpler:
new_comm = COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)

Expand Down

0 comments on commit 6cc07a9

Please sign in to comment.