Skip to content

Commit

Permalink
Merge pull request #143 from ZaberKo/dev-fix2
Browse files Browse the repository at this point in the history
Bug Fix & Distributed Training Improvement
  • Loading branch information
BillHuang2001 committed Jul 8, 2024
2 parents c67e704 + 51f5bad commit de6956e
Show file tree
Hide file tree
Showing 99 changed files with 1,121 additions and 1,302 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ src/evox/algorithms/mo/_rm_meda.py
tests/test_.py
tests/log.txt
.ipynb_checkpoints/
/*.py
4 changes: 2 additions & 2 deletions docs/source/api/workflows/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ Workflows
:maxdepth: 1

standard
distributed
non_jit
.. distributed
.. non_jit
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"ray",
"tensorflow_datasets",
"gpjax",
"orbax-checkpoint"
]

# -- Options for HTML output -------------------------------------------------
Expand Down
137 changes: 137 additions & 0 deletions docs/source/guide/user/3-distributed-old.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Distribute the workflow

EvoX provides two distributed workflow implementation, one is based on Ray, and the other one is based on jax.distribute.

## RayDistributedWorkflow

RayDistributedWorkflow is built upon Ray. It can be used on any ray cluster. The Ray cluster should be setup before running the EvoX program.

### Setup Ray cluster

Please refer to [Ray's official documentation](https://docs.ray.io/en/latest/cluster/getting-started.html) for guide on setting up an Ray cluster.

Here is a simple way to setup the cluster locally.

- On the head node
```bash
ray start --head
```
- On worker nodes
```bash
ray start --address="<your head node's ip>:6379"
```

If you only have 1 machine, but multiple devices, then there is nothing needs to be done. Ray will setup itself in this case.

### Setup EvoX

To scale the workflow using multiple machines through Ray, use the {class}`RayDistributedWorkflow <evox.workflows.RayDistributedWorkflow>` instead of StdWorkflow.

First, import `workflows` from evox

```python
from evox import workflows
```

then create your algorithm, problem, monitor object as usual.

```python
algorithm = ...
problem = ...
monitor = ...
```

Now use `RayDistributedWorkflow`
```python
workflow = workflows.RayDistributedWorkflow(
algorithm=algorithm,
problem=problem,
monitors=[monitor],
num_workers=4, # the number of machines
options={ # the options that passes to ray
"num_gpus": 1
}
)
```

The `RayDistributedWorkflow` also uses the `workflow.step` function to execute iterations. However, under the hood, it employs a distinct approach that allows for the utilization of multiple devices across different machines.

```{tip}
It is recommanded that one set the environment variable `XLA_PYTHON_CLIENT_PREALLOCATE=false`.
By default JAX will pre-allocate 80% of the device's memory.
This variable disables the GPU memory preallocation, otherwise running multiple JAX processes may cause OOM.
For more information, please refer to [JAX's documentation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) on this matter.
```

## StdWorkflow

StdWorkflow is short for "Universal Workflow",
which aims to use pure JAX to build a workflow that fits any requirement.
Since `StdWorkflow` is written in pure JAX, it has less overhead and don't need any additional dependencies.
### Setup EvoX
Use `StdWorkflow` to create an workflow,
and use `enable_distributed` and pass in the state to enable this feature.
```python
key = jax.random.PRNGKey(0) # a PRNGKey
workflow = workflows.StdWorkflow(
algorithm,
problem,
monitors=[monitor],
)
state = workflow.init(key) # init as usual
# important: enable this feature
state = workflow.enable_distributed(state)
```
Then, at the start of your program, before any JAX function is called, do this:
```python
jax.distributed.initialize(coordinator_address=..., num_process=...,process_id=...)
```
In this system, the `coordinator` serves as the primary or head node. The total number of participating processes is indicated by `num_process`. The process with `process_id=0` acts as the coordinator.
From more information, please refer to [jax.distributed.initialize](https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html) and [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html).
### Run in a cluster
Unlike Ray, JAX's doesn't have the concept of cluster or scheduler.
Instead, it offers tools for enabling distributed interactions among multiple JAX instances. JAX follows the SPMD (single program multiple data) paradigm. To initiate a distributed program in JAX, you simply need to run the same script on different machines. For instance, if your program is named `main.py`, you should execute the following command on all participating machines with different `process_id` argument in `jax.distributed.initialize`:
```bash
python main.py
```
```{tip}
To have `process_id` in the argument, one can use `argparse` to parse the argument from the commandline.
For example:
```python
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('process_id', type=int)
args = parser.parse_args()
jax.distributed.initialize(
coordinator_address=...,
num_processes=...,
process_id=args.process_id,
)
```
Then call `python main.py 0` on the first machine, `python main 1` on the second machine and so on.
```
### Run on a single machine
In addition to distributed execution across multiple machines, `StdWorkflow` also supports running on a single machine with multiple GPUs. In this scenario, communication between different devices is facilitated by `nccl`, which is considerably more efficient than cross-machine communication.
The setup process remains unchanged from the previous instructions mentioned above. However, since you are working with only a single machine, the subsequent step for multiple machines is no longer necessary:
```python
jax.distributed.initialize(coordinator_address=..., num_process=...,process_id=...)
```
172 changes: 63 additions & 109 deletions docs/source/guide/user/3-distributed.md
Original file line number Diff line number Diff line change
@@ -1,137 +1,91 @@
# Distribute the workflow
# Distributed Training

EvoX provides two distributed workflow implementation, one is based on Ray, and the other one is based on jax.distribute.
## Parallel Model

## RayDistributedWorkflow
All states are replicated across all devices including population. Then, on every device, a sharded candidates are passed to `problem.evaluate()`, and the fitnesses are shared across all device (by `all_gather`). This ensures all devices share the same state data without explicit synchronization. In other word, this parallel model only accelerate the problem's evaluation part, and cannot reduce the memory consumption. We use it as our default distributed strategy, as it offers EC algorithms maximum flexibility.

RayDistributedWorkflow is built upon Ray. It can be used on any ray cluster. The Ray cluster should be setup before running the EvoX program.
## Multiple devices on a single node

### Setup Ray cluster

Please refer to [Ray's official documentation](https://docs.ray.io/en/latest/cluster/getting-started.html) for guide on setting up an Ray cluster.

Here is a simple way to setup the cluster locally.

- On the head node
```bash
ray start --head
```
- On worker nodes
```bash
ray start --address="<your head node's ip>:6379"
```

If you only have 1 machine, but multiple devices, then there is nothing needs to be done. Ray will setup itself in this case.

### Setup EvoX

To scale the workflow using multiple machines through Ray, use the {class}`RayDistributedWorkflow <evox.workflows.RayDistributedWorkflow>` instead of StdWorkflow.

First, import `workflows` from evox
Example:

```python
from evox import workflows
```

then create your algorithm, problem, monitor object as usual.

```python
algorithm = ...
problem = ...
monitor = ...
```

Now use `RayDistributedWorkflow`
```python
workflow = workflows.RayDistributedWorkflow(
algorithm=algorithm,
problem=problem,
monitors=[monitor],
num_workers=4, # the number of machines
options={ # the options that passes to ray
"num_gpus": 1
}
import jax
import jax.tree_util as jtu
from evox import algorithms, problems, workflows
from evox.core.distributed import tree_unpmap

cso = algorithms.CSO(
lb=jnp.full(shape=(2,), fill_value=-32),
ub=jnp.full(shape=(2,), fill_value=32),
pop_size=16*4,
)
```

The `RayDistributedWorkflow` also uses the `workflow.step` function to execute iterations. However, under the hood, it employs a distinct approach that allows for the utilization of multiple devices across different machines.

```{tip}
It is recommanded that one set the environment variable `XLA_PYTHON_CLIENT_PREALLOCATE=false`.
By default JAX will pre-allocate 80% of the device's memory.
This variable disables the GPU memory preallocation, otherwise running multiple JAX processes may cause OOM.
For more information, please refer to [JAX's documentation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) on this matter.
```

## StdWorkflow
ackley = problems.numerical.Ackley()
workflow = workflows.StdWorkflow(cso, ackley)

StdWorkflow is short for "Universal Workflow",
which aims to use pure JAX to build a workflow that fits any requirement.
Since `StdWorkflow` is written in pure JAX, it has less overhead and don't need any additional dependencies.
key = random.PRNGKey(42)
with jax.default_device(devices[0]):
state = workflow.init(key)

### Setup EvoX
state = workflow.enable_multi_devices(state, devices)

Use `StdWorkflow` to create an workflow,
and use `enable_distributed` and pass in the state to enable this feature.
```python
key = jax.random.PRNGKey(0) # a PRNGKey
workflow = workflows.StdWorkflow(
algorithm,
problem,
monitors=[monitor],
)
state = workflow.init(key) # init as usual
# important: enable this feature
state = workflow.enable_distributed(state)
for i in range(100):
train_info, state = workflow.step(state)
train_info = tree_unpmap(train_info, workflow.pmap_axis_name)
print(train_info['transformed_fitness'])
```

Then, at the start of your program, before any JAX function is called, do this:
## Multiple devices on multiple nodes

Example of script `dist_train.py`

```python
jax.distributed.initialize(coordinator_address=..., num_process=...,process_id=...)
```

In this system, the `coordinator` serves as the primary or head node. The total number of participating processes is indicated by `num_process`. The process with `process_id=0` acts as the coordinator.
import argparse
import jax

From more information, please refer to [jax.distributed.initialize](https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html) and [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html).
parser = argparse.ArgumentParser()
parser.add_argument('--addr', type=str, default='127.0.0.1:37233')
parser.add_argument('-n', type=int, required=True)
parser.add_argument('-i', type=int, required=True)
args = parser.parse_args()

### Run in a cluster
jax.distributed.initialize(coordinator_address=args.addr, num_processes=args.n, process_id=args.i, initialization_timeout=30)

Unlike Ray, JAX's doesn't have the concept of cluster or scheduler.
Instead, it offers tools for enabling distributed interactions among multiple JAX instances. JAX follows the SPMD (single program multiple data) paradigm. To initiate a distributed program in JAX, you simply need to run the same script on different machines. For instance, if your program is named `main.py`, you should execute the following command on all participating machines with different `process_id` argument in `jax.distributed.initialize`:
total_devices = jax.devices()
devices = jax.local_devices()

```bash
python main.py
```
print(f'total_devices: {total_devices}')
print(f'devices: {devices}')

```{tip}
To have `process_id` in the argument, one can use `argparse` to parse the argument from the commandline.
For example:
from evox import algorithms, problems, workflows
from evox.core.distributed import tree_unpmap

```python
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('process_id', type=int)
args = parser.parse_args()
jax.distributed.initialize(
coordinator_address=...,
num_processes=...,
process_id=args.process_id,
cso = algorithms.CSO(
lb=jnp.full(shape=(2,), fill_value=-32),
ub=jnp.full(shape=(2,), fill_value=32),
pop_size=16*30,
)
```
ackley = problems.numerical.Ackley()
workflow = workflows.StdWorkflow(cso, ackley)

Then call `python main.py 0` on the first machine, `python main 1` on the second machine and so on.
key = jax.random.PRNGKey(42)
state = workflow.init(key)
state = workflow.enable_multi_devices(state, devices)

```
for i in range(10):
train_info, state = workflow.step(state)
train_info = tree_unpmap(train_info, workflow.pmap_axis_name)
print(train_info['transformed_fitness'])

### Run on a single machine
jax.distributed.shutdown()
```

In addition to distributed execution across multiple machines, `StdWorkflow` also supports running on a single machine with multiple GPUs. In this scenario, communication between different devices is facilitated by `nccl`, which is considerably more efficient than cross-machine communication.
Run script on each node:

The setup process remains unchanged from the previous instructions mentioned above. However, since you are working with only a single machine, the subsequent step for multiple machines is no longer necessary:
```shell
# node1 with ip 10.233.96.181
python dist_train.py --addr 10.233.96.181:35429 -n 2 -i 0

```python
jax.distributed.initialize(coordinator_address=..., num_process=...,process_id=...)
```
# node2
python dist_train.py --addr 10.233.96.181:35429 -n 2 -i 1
```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ classifiers = [
dependencies = [
"jax >= 0.4.16",
"jaxlib >= 0.3.0",
"jax_dataclasses >= 1.6.0",
"optax >= 0.1.0",
"pyarrow >= 10.0.0",
"orbax-checkpoint >= 0.5.0",
]

[project.optional-dependencies]
Expand Down
11 changes: 7 additions & 4 deletions src/evox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from .core.workflow import Workflow
from .core.algorithm import Algorithm
from .core.module import *
from .core.algorithm import Algorithm, has_init_ask, has_init_tell
from .core.module import use_state, jit_class, jit_method, Stateful
from .core.problem import Problem
from .core.state import State
from .core.state import State, get_state_sharding
from .core.monitor import Monitor
from .core.pytree_dataclass import dataclass, pytree_field, PyTreeNode

from . import algorithms, monitors, operators, workflows, problems, utils
# from .core.distributed import POP_AXIS_NAME, ShardingType

# from . import algorithms, monitors, operators, workflows, problems, utils
2 changes: 1 addition & 1 deletion src/evox/algorithms/containers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .clustered_algorithm import ClusterdAlgorithm, RandomMaskAlgorithm
from .tree_algorithm import TreeAlgorithm
from .coevolution import VectorizedCoevolution, Coevolution
from .coevolution import VectorizedCoevolution, Coevolution
Loading

0 comments on commit de6956e

Please sign in to comment.