Skip to content

Commit

Permalink
Add jax.make_array_from_process_local_data to create a distributed te…
Browse files Browse the repository at this point in the history
…nsor from host data and supporting scaffolding in sharding to be able to figure out dimensions of host data required.

PiperOrigin-RevId: 634205261
  • Loading branch information
marksandler2 authored and jax authors committed May 16, 2024
1 parent cd41b4f commit 8f045ca
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 32 deletions.
1 change: 1 addition & 0 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
from jax._src.array import (
make_array_from_single_device_arrays as make_array_from_single_device_arrays,
make_array_from_callback as make_array_from_callback,
make_array_from_process_local_data as make_array_from_process_local_data,
)

from jax._src.tree_util import (
Expand Down
130 changes: 121 additions & 9 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Sequence
import enum
import functools
import math
import operator as op
import numpy as np
import functools
from typing import Any, Callable, cast, TYPE_CHECKING
from collections.abc import Sequence
from typing import Any, Callable, TYPE_CHECKING, cast

from jax._src import abstract_arrays
from jax._src import api
Expand All @@ -35,18 +34,19 @@
from jax._src import profiler
from jax._src import tree_util
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
device_replica_id_map, hashed_index)
from jax._src.layout import DeviceLocalLayout, Layout, AutoLayout
PmapSharding, SingleDeviceSharding, XLACompatibleSharding,
device_replica_id_map, hashed_index, num_addressable_indices) # pyformat: disable
from jax._src.typing import ArrayLike, DLDeviceType
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
import numpy as np


Shape = tuple[int, ...]
Expand Down Expand Up @@ -627,9 +627,12 @@ def _value(self) -> np.ndarray:
setattr(ArrayImpl, "__hash__", None)
setattr(ArrayImpl, "__array_priority__", 100)

# TODO(yashkatariya): Remove None from callback input type.

def make_array_from_callback(
shape: Shape, sharding: Sharding | Layout,
data_callback: Callable[[Index | None], ArrayLike]) -> ArrayImpl:
# pyformat: disable
"""Returns a ``jax.Array`` via data fetched from ``data_callback``.
``data_callback`` is used to fetch the data for each addressable shard of the
Expand Down Expand Up @@ -667,6 +670,7 @@ def make_array_from_callback(
>>> arr.addressable_data(0).shape
(4, 2)
"""
# pyformat: enable
dll = sharding.device_local_layout if isinstance(sharding, Layout) else None
if isinstance(dll, AutoLayout):
raise TypeError(
Expand Down Expand Up @@ -725,6 +729,114 @@ def make_array_from_callback(
return ArrayImpl(aval, sharding, arrays, committed=True)


def make_array_from_process_local_data(
sharding: Sharding,
local_data: np.ndarray,
global_shape: tuple[int, ...],
) -> ArrayImpl:
# pyformat: disable
"""Creates distributed tensor using the data available in process.
This function is a common special case of `make_array_from_callback`. It
assumes that the data is available in the process and takes care of the
index wrangling.
Note, if the two hosts are replicas, host_local_data should be identical as
well.
Each dimension of the shape of host_local_data should either match
global_shape or the # indices the devices on this process need to
address. For example if dimension $i$ is fully sharded then this size would be
`per_device_shape[i] * jax.local_device_count()`.
If the shape matches global shape, each device slice will just lookup
the slice in the local_data. In the latter case the global slice of each
device will be mapped into local slice of `local_data` array. For example,
if given process only addresses slices (8, 12) and (24, 28), then
these slices will be mapped into (0, 4) and (4, 8) of the `local_data`.
This function can be used to create tensors from dataset feeding pipelines.
The most common case is when the sharding is fully sharded across the batch
dimension and each host just loads its corresponding sub-batch. This function
supports more general case as well, such as multi-host replication
but you would need to compute the size and the contents of process-local data
correctly to satisfy the replication constraints.
Examples:
>>> from jax.sharding import PartitionSpec as P
>>> mesh_rows = 2
>>> mesh_cols = jax.device_count() // 2
...
>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(mesh, P(('x', 'y'),))
>>> rows_per_device = 2
>>> feature_length = 32
>>> per_device_shape = (rows_per_device, feature_length)
>>> per_host_shape = (rows_per_device * len(mesh.local_devices), feature_length)
>>> per_host_generator = lambda : np.arange(np.prod(per_host_shape)).reshape(per_host_shape)
>>> per_host_data = per_host_generator() # replace with your own per-host data pipeline that outputs numpy arrays
>>> global_shape = (rows_per_device * len(sharding.device_set), ) + per_device_shape[1:]
>>> output_global_array = jax.make_array_from_process_local_data(sharding, per_host_data, global_shape)
...
>>> assert output_global_array.addressable_data(0).shape == per_device_shape
>>> assert output_global_array.shape == global_shape
Args:
sharding: sharding of the global tensor.
host_local_data: data on the host to be placed on local devices. Each
dimension should either match global_shape, or match
num_addressable_indices(dim).
global_shape: the target shape of the global tensor. In some cases this
parameter can be inferred from sharding and host_local_data, however it is
useful to catch common sharding errors.
Returns:
Tensor that will have sharding=sharding.
"""
# pyformat: enable
shard_shape = sharding.shard_shape(global_shape)
full_dim = []
for i, (data_dim, global_dim) in enumerate(
zip(local_data.shape, global_shape)
):
full_dim.append(data_dim == global_dim)
if data_dim != global_dim:
process_slice = num_addressable_indices(sharding, i, global_shape)
if process_slice != data_dim:
raise ValueError(
"Invalid host data, each dimension should match either global or "
f"process shape. In dimension {i=}, the process data has {data_dim}"
f"elements. Process addresses {process_slice} elements and "
f"{global_shape=}."
)
addressable_shards = sharding.addressable_devices_indices_map(global_shape)
slices_for_each_dim: list[list[int]] = [[] for _ in global_shape]
for shard_index in addressable_shards.values():
assert shard_index is not None
for i, slc in enumerate(shard_index):
slices_for_each_dim[i].append(slc.start or 0)
for i in range(len(global_shape)):
slices_for_each_dim[i] = sorted(set(slices_for_each_dim[i]))

def local_slice(i, slc):
# Looks up the index of this slice in the list of slices for this dimension.
# This will determine the slice in host_local_data
start = slices_for_each_dim[i].index(slc.start or 0) * shard_shape[i]
end = start + shard_shape[i]
return slice(start, end)

def cb(index: Index | None) -> ArrayLike:
assert index is not None
data_slice = [
slc if full_dim[i] else local_slice(i, slc)
for i, slc in enumerate(index)
]
return local_data[tuple(data_slice)]

return make_array_from_callback(global_shape, sharding, cb)


def make_array_from_single_device_arrays(
shape: Shape, sharding: Sharding, arrays: Sequence[basearray.Array]
) -> ArrayImpl:
Expand Down
60 changes: 55 additions & 5 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,17 @@
from typing import Any, NamedTuple, Union, cast

from jax._src import mesh as mesh_lib
from jax._src.op_shardings import (
is_op_sharding_replicated, are_op_shardings_equal, get_num_ways_dim_sharded,
op_sharding_to_indices)
from jax._src import sharding
from jax._src import sharding_specs
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded,
is_op_sharding_replicated,
op_sharding_to_indices) # pyformat: disable
from jax._src.partition_spec import PartitionSpec

from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
import numpy as np


Expand Down Expand Up @@ -1376,3 +1375,54 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding,
ParsedPartitionSpec('<internally generated spec>', partitions))]
else:
raise AssertionError("Unhandled OpSharding type. Please open a bug report!")


def _slice_as_tuple(s: slice):
assert s.step is None
return (s.start, s.stop)


def num_addressable_indices(
tensor_sharding: sharding.Sharding,
dim: int,
global_shape: Shape,
) -> int:
"""Returns the number of indices for given dimension this host has access to.
Each host can have multiple number of devices that are spanning
possibly discontiguous slices of data. This function computes the
total number of unique indices for dimension `dim` that any of its
addressable devices hold.
In most cases the addressable indices form a sparse grid (and in some
cases a subcube), and thus each host will hold the same of number of
indices for each dimension. However, it is possible to design a mesh that
addressable shards form a complicated pattern. In that case, the returned
value is the number of indices that are addressable by at least one device.
For example, suppose the sharding looks like this: (number indicates
the host index)
1221
1221
0000
Then on host 1 and 2, both dim 0 (rows), and dim=1 (cols) will have size 2,
while on host 0, dim 0 will have size 1, and dim 1 will have size 4.
Args:
tensor_sharding: Sharding of the tensor.
dim: dimension along which to compute the number of addressable indices.
global_shape: global shape of the tensor.
Returns:
The number of indices for dimension `dim` that this host holds.
"""
# TODO(sandler, yashkatariya): Consider making this function public.
addressables = tensor_sharding.addressable_devices_indices_map(global_shape)
addressables = cast(Mapping[sharding.Device, Index], addressables)
num_unique_slices = len({
_slice_as_tuple(addressable[dim]) for addressable in addressables.values()
})
shard_size = tensor_sharding.shard_shape(global_shape)[dim]
return shard_size * num_unique_slices
Loading

0 comments on commit 8f045ca

Please sign in to comment.