diff --git a/.bazelrc b/.bazelrc index f6d60d968848..a08df1017f91 100644 --- a/.bazelrc +++ b/.bazelrc @@ -90,6 +90,9 @@ build:win_clang --compiler=clang-cl build:cuda_plugin --@xla//xla/python:enable_gpu=false build:cuda_plugin --define=xla_python_enable_gpu=false +build:rocm_plugin --@xla//xla/python:enable_gpu=false +build:rocm_plugin --define=xla_python_enable_gpu=false + # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to # point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA @@ -116,10 +119,6 @@ build:cuda_clang --copt=-Wno-gnu-offsetof-extensions # Disable clang extention that rejects unknown arguments. build:cuda_clang --copt=-Qunused-arguments -build:mosaic_gpu --@llvm-project//mlir:enable_cuda=true -build:mosaic_gpu --copt=-DLLVM_HAS_NVPTX_TARGET=1 -build:mosaic_gpu --//jax:build_mosaic_gpu=true - build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true build:rocm --@xla//xla/python:enable_gpu=true @@ -224,8 +223,6 @@ build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylin build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_cpu_linux_py3.9 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.9" -build:rbe_cpu_linux_py3.9 --repo_env HERMETIC_PYTHON_VERSION="3.9" build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10" build:rbe_cpu_linux_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11" @@ -254,8 +251,6 @@ build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04- build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_nccl" # RBE machines have an older CUDA driver version, so we have to enable driver forward compatibility build:rbe_linux_cuda12.3_nvcc_base --test_env=LD_LIBRARY_PATH=/usr/local/cuda/compat -build:rbe_linux_cuda12.3_nvcc_py3.9 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.9" -build:rbe_linux_cuda12.3_nvcc_py3.9 --repo_env HERMETIC_PYTHON_VERSION="3.9" build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10" build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11" diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 3ece6976c9e4..9c5dee48fae9 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -1,8 +1,8 @@ name: CI # We test all supported Python versions as follows: -# - 3.9 : Documentation build -# - 3.9 : Part of Matrix with NumPy dispatch +# - 3.10 : Documentation build +# - 3.10 : Part of Matrix with NumPy dispatch # - 3.10 : Part of Matrix # - 3.11 : Part of Matrix @@ -45,8 +45,8 @@ jobs: matrix: # Test the oldest and newest supported Python versions here. include: - - name-prefix: "with 3.9" - python-version: "3.9" + - name-prefix: "with 3.10" + python-version: "3.10" os: ubuntu-20.04-16core enable-x64: 1 prng-upgrade: 1 @@ -108,7 +108,7 @@ jobs: timeout-minutes: 10 strategy: matrix: - python-version: [3.9] + python-version: ['3.10'] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 @@ -136,10 +136,11 @@ jobs: - name: Test documentation env: XLA_FLAGS: "--xla_force_host_platform_device_count=8" + JAX_TRACEBACK_FILTERING: "off" JAX_ARRAY: 1 PY_COLORS: 1 run: | - pytest -n auto --tb=short docs + pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/maps.py @@ -149,7 +150,7 @@ jobs: timeout-minutes: 10 strategy: matrix: - python-version: [3.9] + python-version: ['3.10'] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 01942e524ba5..3a1dd863c2b0 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -25,7 +25,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'e38ce3466e596ed2b8fa4638f161f5563ded81a8' # Latest commit as of 2024-04-15 + ref: '33f2d2ea2f3dd2b3ceeeb4519d55e08096184149' # Latest commit as of 2024-05-28 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 41f17fd31ee2..6433fb66039e 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -13,7 +13,7 @@ jobs: matrix: os: [windows-2019-32core] arch: [AMD64] - pyver: ['3.9', '3.10', '3.11', '3.12'] + pyver: ['3.10', '3.11', '3.12'] name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build runs-on: ${{ matrix.os }} @@ -24,7 +24,7 @@ jobs: access_token: ${{ github.token }} - name: Install LLVM/Clang - run: choco install llvm --version=18.1.4 --yes + run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 @@ -39,8 +39,7 @@ jobs: JAXLIB_RELEASE: true run: | python -m pip install -r build/test-requirements.txt - python -m pip uninstall -y matplotlib - python -m pip install --pre --upgrade numpy==2.0.0rc2 scipy==1.13.0 + python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH python.exe build\build.py ` --bazel_options=--color=yes ` diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 2f04264ef9c8..92f9355ae200 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -3,6 +3,8 @@ on: schedule: - cron: "0 12 * * *" # Daily at 12:00 UTC workflow_dispatch: # allows triggering the workflow run manually + pull_request: + types: [ labeled ] # allow force-windows-run label env: DISTUTILS_USE_SDK: 1 @@ -10,13 +12,14 @@ env: jobs: win-wheels: + if: ${{ (github.event.action != 'labeled') || (github.event.label.name == 'windows:force-run')}} strategy: fail-fast: true matrix: os: [windows-2019-32core] arch: [AMD64] - pyver: ['3.9'] - name: ${{ matrix.os }} CI build + pyver: ['3.10'] + name: Windows CI build runs-on: ${{ matrix.os }} steps: @@ -26,17 +29,12 @@ jobs: access_token: ${{ github.token }} - name: Install LLVM/Clang - run: choco install llvm --version=18.1.4 --yes + run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 with: path: jax - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - with: - repository: openxla/xla - path: xla - - uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.pyver }} @@ -49,11 +47,9 @@ jobs: run: | cd jax python -m pip install -r build/test-requirements.txt - python -m pip uninstall -y matplotlib - python -m pip install --pre --upgrade numpy==2.0.0rc2 scipy==1.13.0 + python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH python.exe build\build.py ` - ('--bazel_options=--override_repository=xla=${{ github.workspace }}\xla' -replace '\\','\\') ` --bazel_options=--color=yes ` --bazel_options=--config=win_clang diff --git a/CHANGELOG.md b/CHANGELOG.md index d9f937d7ee17..be8f6f8d18c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,12 +6,88 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html). Remember to align the itemized text with the first line of an item within a list. --> -## jax 0.4.29 +## jax 0.4.31 -* Breaking changes +* Changes + * The minimum Python version is now 3.10. 3.10 will remain the minimum + supported version until July 2025. + * The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum + supported version until December 2024. + * {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output + of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point. + * `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be + installed either as a part of local CUDA installation, or via NVIDIA's CUDA + pip wheels. + +## jaxlib 0.4.31 + +* Bug fixes + * Fixed a bug that meant that negative static_argnums to a jit were mishandled + by the jit dispatch fast path. + +## jax 0.4.30 (June 18, 2024) + +* Changes + * JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was + bumped to 0.4.0 but this has been rolled back in this release to give users + of both TensorFlow and JAX more time to migrate to a newer TensorFlow + release. + * `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e. + * jax now depends on jaxlib directly. This change was enabled by the CUDA + plugin switch: there are no longer multiple jaxlib variants. You can install + a CPU-only jax with `pip install jax`, no extras required. + * Added an API for exporting and serializing JAX functions. This used + to exist in `jax.experimental.export` (which is being deprecated), + and will now live in `jax.export`. + See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html). + +* Deprecations + * Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed + in a future release. + * Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX + release. This previously was the case, but there was an inadvertent regression in + the last several JAX releases. + * `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead. + See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export). + * Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays + `x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`. + * `jax.xla_computation` is deprecated and will be removed in a future release. + Please use the AOT APIs to get the same functionality as `jax.xla_computation`. + * `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with + `jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`. + * You can also use `.out_info` property of `jax.stages.Lowered` to get the + output information (like tree structure, shape and dtype). + * For cross-backend lowering, you can replace + `jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with + `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`. + + +## jaxlib 0.4.30 (June 18, 2024) + + * Support for monolithic CUDA jaxlibs has been dropped. You must use the + plugin-based installation (`pip install jax[cuda12]` or + `pip install jax[cuda12_local]`). + +## jax 0.4.29 (June 10, 2024) + +* Changes + * We anticipate that this will be the last release of JAX and jaxlib + supporting a monolithic CUDA jaxlib. Future releases will use the CUDA + plugin jaxlib (e.g. `pip install jax[cuda12]`). * JAX now requires ml_dtypes version 0.4.0 or newer. + * Removed backwards-compatibility support for old usage of the + `jax.experimental.export` API. It is not possible anymore to use + `from jax.experimental.export import export`, and instead you should use + `from jax.experimental import export`. + The removed functionality has been deprecated since 0.4.24. + * Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`. * Deprecations + * `jax.sharding.XLACompatibleSharding` is deprecated. Please use + `jax.sharding.Sharding`. + * `jax.experimental.Exported.in_shardings` has been renamed as + `jax.experimental.Exported.in_shardings_hlo`. Same for `out_shardings`. + The old names will be removed after 3 months. * Removed a number of previously-deprecated APIs: * from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape` * from {mod}`jax.lax`: `tie_in` @@ -28,12 +104,29 @@ Remember to align the itemized text with the first line of an item within a list * {mod}`jax.random` APIs no longer accept batched keys, where previously some did unintentionally. Going forward, we recommend explicit use of {func}`jax.vmap` in such cases. + * In {func}`jax.scipy.special.beta`, the `x` and `y` parameters have been + renamed to `a` and `b` for consistency with other `beta` APIs. -## jaxlib 0.4.29 +* New Functionality + * Added {func}`jax.experimental.Exported.in_shardings_jax` to construct + shardings that can be used with the JAX APIs from the HloShardings + that are stored in the `Exported` objects. + +## jaxlib 0.4.29 (June 10, 2024) * Bug fixes - * Fixes a bug where XLA sharded some concatenation operations incorrectly, + * Fixed a bug where XLA sharded some concatenation operations incorrectly, which manifested as an incorrect output for cumulative reductions (#21403). + * Fixed a bug where XLA:CPU miscompiled certain matmul fusions + (https://github.com/openxla/xla/pull/13301). + * Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396). + +* Deprecations + * `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will + raise an error in a future version of jax. `None` is only a tree-prefix of + itself. To preserve the current behavior, you can ask `jax.tree.map` to + treat `None` as a leaf value by writing: + `jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`. ## jax 0.4.28 (May 9, 2024) diff --git a/README.md b/README.md index 2f9b30cca8bf..b19d7b9ff128 100644 --- a/README.md +++ b/README.md @@ -396,8 +396,8 @@ Some standouts: | Hardware | Instructions | |------------|-----------------------------------------------------------------------------------------------------------------| -| CPU | `pip install -U "jax[cpu]"` | -| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12]"` | +| CPU | `pip install -U jax` | +| NVIDIA GPU | `pip install -U "jax[cuda12]"` | | Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` | | AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). | | Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | @@ -414,7 +414,8 @@ community-supported conda build, and answers to some frequently-asked questions. Multiple Google research groups develop and share libraries for training neural networks in JAX. If you want a fully featured library for neural network training with examples and how-to guides, try -[Flax](https://github.com/google/flax). +[Flax](https://github.com/google/flax). Check out the new [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) API for a +simplified development experience. Google X maintains the neural network library [Equinox](https://github.com/patrick-kidger/equinox). This is used as the diff --git a/WORKSPACE b/WORKSPACE index a4915763d761..e574bd9f9611 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -9,12 +9,14 @@ python_init_rules() load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") python_init_repositories( requirements = { - "3.9": "//build:requirements_lock_3_9.txt", "3.10": "//build:requirements_lock_3_10.txt", "3.11": "//build:requirements_lock_3_11.txt", "3.12": "//build:requirements_lock_3_12.txt", "3.13": "//build:requirements_lock_3_13.txt", }, + local_wheel_workspaces = ["//jaxlib:jax.bzl"], + local_wheel_dist_folder = "../dist", + default_python_version = "system", ) load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 75cd38d10c10..c68dab85dc8e 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -677,11 +677,31 @@ def host_local_array_to_global_array(state): multihost_utils.host_local_array_to_global_array( (input_data, input_data), global_mesh, (in_pspec, in_pspec)) + @google_benchmark.register -def device_put(state): - x = np.array(1, np.int32) +@google_benchmark.option.arg_names(['num_args']) +@google_benchmark.option.args([1]) +@google_benchmark.option.args([10]) +@google_benchmark.option.args([100]) +@google_benchmark.option.args([1000]) +def device_put_from_numpy_array(state): + x = [np.array(1, np.int32)] * state.range(0) while state: - _ = jax.device_put(x).block_until_ready() + _ = jax.block_until_ready(jax.device_put(x)) + + +@google_benchmark.register +@google_benchmark.option.arg_names(['num_args']) +@google_benchmark.option.args([1]) +@google_benchmark.option.args([10]) +@google_benchmark.option.args([100]) +@google_benchmark.option.args([1000]) +def device_put_from_jax_array(state): + x = [np.array(1, np.int32)] * state.range(0) + x = jax.block_until_ready(jax.device_put(x, device=jax.devices()[0])) + d = jax.devices()[1] + while state: + _ = jax.block_until_ready(jax.device_put(x, device=d)) @google_benchmark.register @@ -854,6 +874,21 @@ def bench_make_array_from_callback_fully_replicated_sharding(state): while state: jax.make_array_from_callback(shape, s, np_arr.__getitem__) + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def bench_make_array_from_callback_sharded(state): + global_mesh = create_mesh((4, 2), ('x', 'y'), state) + input_shape = (8, 2) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) + + def callback(index): + return input_data[index] + + s = jax.NamedSharding(global_mesh, jax.sharding.PartitionSpec('x', 'y')) + while state: + jax.make_array_from_callback((8, 2), s, callback) + @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def benchmark_lorentz63_cache_hits(state): @@ -886,5 +921,23 @@ def loss(x0): jax.make_jaxpr(lambda x: training_step(x, 100, unroll=True))(x) +@google_benchmark.register +def jit_add_chain(state): + SIZE = 100 + + @jax.jit + def g(x, y): + return lax.add(x, y) + + x = jax.random.normal(jax.random.PRNGKey(0), (2, 2)) + while state: + @jax.jit + def f(x): + for i in range(SIZE): + x = g(x, x) + return x + f(x).block_until_ready() + + if __name__ == "__main__": google_benchmark.main() diff --git a/benchmarks/mosaic/BUILD b/benchmarks/mosaic/BUILD new file mode 100644 index 000000000000..027da12ce6d3 --- /dev/null +++ b/benchmarks/mosaic/BUILD @@ -0,0 +1,56 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "jax_generate_backend_suites", + "jax_test", + "py_deps", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +jax_generate_backend_suites() + +DISABLED_BACKENDS = [ + "cpu", + "tpu", +] + +DISABLED_CONFIGS = [ + "gpu", + "gpu_a100", + "gpu_p100", + "gpu_p100_x32", + "gpu_x32", + "gpu_pjrt_c_api", +] + +jax_test( + name = "matmul_bench", + srcs = ["matmul_bench.py"], + disable_backends = DISABLED_BACKENDS, + disable_configs = DISABLED_CONFIGS, + tags = ["notap"], + deps = [ + "//third_party/py/google_benchmark", + "//third_party/py/jax:mosaic_gpu", + "//third_party/py/jax/experimental/mosaic/gpu/examples:matmul", + ] + py_deps("absl/testing") + py_deps("numpy"), +) diff --git a/benchmarks/mosaic/matmul_bench.py b/benchmarks/mosaic/matmul_bench.py new file mode 100644 index 000000000000..32c147916407 --- /dev/null +++ b/benchmarks/mosaic/matmul_bench.py @@ -0,0 +1,110 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Microbenchmarks for mosaic gpu matrix mutliplication.""" + +import functools +import sys + +from absl import app +import google_benchmark as benchmark +from jax._src import config +from jax.experimental.mosaic.gpu.examples import matmul +from jax._src import test_util as jtu +import jax.numpy as jnp + +config.update("jax_traceback_filtering", "off") +config.parse_flags_with_absl() + +def _params_name(params): + return ",".join(f"{k}={v}" for k, v in params.items()) + +def matmul_benchmark(*args): + def decorator(get_runtimes): + for test_case in args: + + @benchmark.register(name=f"{get_runtimes.__name__}_{_params_name(test_case)}") + @benchmark.option.unit(benchmark.kMillisecond) + @benchmark.option.use_manual_time() + @benchmark.option.iterations(1) + @functools.wraps(get_runtimes) + def wrapper(state, test_case=test_case): + m, n, k = test_case["m"], test_case["n"], test_case["k"] + runtime, ref_runtime = get_runtimes(**test_case) + state.counters["TFlops"] = ( + float(2 * k * m * n) / (runtime / 1e3) / 1e12 + ) + state.counters["jax_TFlops"] = ( + float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12 + ) + state.counters["speedup"] = ref_runtime / runtime + state.set_iteration_time(runtime / 1e3) + + return decorator + + +@matmul_benchmark( + dict(m=55 * 128, n=95 * 128, k=48 * 128, stages=4, tile_m=128), + dict(m=55 * 128, n=45 * 128, k=48 * 128, stages=4, tile_m=128), + dict(m=64, n=95 * 128, k=48 * 128, stages=4, tile_m=64), + dict(m=64, n=45 * 128, k=48 * 128, stages=4, tile_m=64), +) +def bf16_i8_matmul(m, k, n, stages, tile_m): + # RHS.element_size==1b so k_tile=128 + if stages * 128 > k: + raise ValueError(f"Too many stages {(stages, k)=}.") + + return matmul.verify( + m, + k, + n, + stages, + tile_m=tile_m, + rhs_transpose=False, + lhs_dtype=jnp.bfloat16, + rhs_dtype=jnp.int8, + ) + +@matmul_benchmark( + dict(m=1024, n=1024, k=1024, stages=4, tile_m=128, tile_n=256), + dict(m=1024, n=1024, k=1024, stages=4, tile_m=128, tile_n=128), + dict(m=1024, n=1024, k=1024, stages=4, tile_m=64, tile_n=128), +) +def f32_matmul(m, n, k, stages, tile_m, tile_n): + if stages * 32 > k: + raise ValueError(f"Too many stages {(stages, k)=}.") + + return matmul.verify( + m=m, + k=k, + n=n, + stages=stages, + tile_m=tile_m, + tile_n=tile_n, + rhs_transpose=True, + lhs_dtype=jnp.float32, + rhs_dtype=jnp.float32, + ) + + +def main(_): + device = jtu.device_under_test() + if device != "gpu": + raise ValueError(f"Mosaic only work with gpu (got {device})") + + benchmark.run_benchmarks() + + +if __name__ == "__main__": + sys.argv = benchmark.initialize(sys.argv) + app.run(main) diff --git a/benchmarks/shape_poly_benchmark.py b/benchmarks/shape_poly_benchmark.py index b1b6b625ccca..d26801d8dfe5 100644 --- a/benchmarks/shape_poly_benchmark.py +++ b/benchmarks/shape_poly_benchmark.py @@ -18,7 +18,7 @@ import jax from jax import core from jax._src.numpy import lax_numpy -from jax.experimental import export +from jax import export jax.config.parse_flags_with_absl() diff --git a/build/build.py b/build/build.py index cc95977a5123..2f68222816ac 100755 --- a/build/build.py +++ b/build/build.py @@ -70,8 +70,8 @@ def get_python_version(python_bin_path): return major, minor def check_python_version(python_version): - if python_version < (3, 9): - print("ERROR: JAX requires Python 3.9 or newer, found ", python_version) + if python_version < (3, 10): + print("ERROR: JAX requires Python 3.10 or newer, found ", python_version) sys.exit(-1) @@ -244,7 +244,7 @@ def write_bazelrc(*, remote_build, rocm_amdgpu_targets, target_cpu_features, wheel_cpu, enable_mkl_dnn, use_clang, clang_path, clang_major_version, enable_cuda, enable_nccl, enable_rocm, - build_gpu_plugin, enable_mosaic_gpu, python_version): + build_gpu_plugin, python_version): tf_cuda_paths = [] with open("../.jax_configure.bazelrc", "w") as f: @@ -313,14 +313,15 @@ def write_bazelrc(*, remote_build, if use_clang: f.write("build --config=nvcc_clang\n") f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") - if enable_mosaic_gpu: - f.write("build --config=mosaic_gpu") if enable_rocm: f.write("build --config=rocm\n") if not enable_nccl: f.write("build --config=nonccl\n") if build_gpu_plugin: - f.write("build --config=cuda_plugin\n") + if enable_cuda: + f.write("build --config=cuda_plugin\n") + elif enable_rocm: + f.write("build --config=rocm_plugin\n") if python_version: f.write( "build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format( @@ -433,21 +434,21 @@ def main(): "plugin is still experimental and is not ready for use yet." ), ) - add_boolean_argument( - parser, - "build_cuda_kernel_plugin", - default=False, - help_str=( - "Are we building the cuda kernel plugin? jaxlib will not be built " - "when this flag is True." + parser.add_argument( + "--build_gpu_kernel_plugin", + choices=["cuda", "rocm"], + default="", + help=( + "Specify 'cuda' or 'rocm' to build the respective kernel plugin." + " When this flag is set, jaxlib will not be built." ), ) add_boolean_argument( parser, - "build_cuda_pjrt_plugin", + "build_gpu_pjrt_plugin", default=False, help_str=( - "Are we building the cuda pjrt plugin? jaxlib will not be built " + "Are we building the cuda/rocm pjrt plugin? jaxlib will not be built " "when this flag is True." ), ) @@ -456,6 +457,11 @@ def main(): choices=["11", "12"], default="12", help="Which CUDA major version the gpu plugin is for.") + parser.add_argument( + "--gpu_plugin_rocm_version", + choices=["60"], + default="60", + help="Which ROCM major version the gpu plugin is for.") add_boolean_argument( parser, "enable_rocm", @@ -527,10 +533,6 @@ def main(): "--python_version", default=None, help="hermetic python version, e.g., 3.10") - add_boolean_argument( - parser, - "enable_mosaic_gpu", - help_str="Should we build with Mosaic GPU? VERY EXPERIMENTAL.") add_boolean_argument( parser, "configure_only", @@ -652,7 +654,6 @@ def main(): enable_nccl=args.enable_nccl, enable_rocm=args.enable_rocm, build_gpu_plugin=args.build_gpu_plugin, - enable_mosaic_gpu=args.enable_mosaic_gpu, python_version=python_version, ) @@ -683,7 +684,7 @@ def main(): *args.bazel_options, ) - if not args.build_cuda_kernel_plugin and not args.build_cuda_pjrt_plugin: + if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin: build_cpu_wheel_command = [ *command_base, "//jaxlib/tools:build_wheel", "--", @@ -698,29 +699,44 @@ def main(): print(" ".join(build_cpu_wheel_command)) shell(build_cpu_wheel_command) - if args.build_gpu_plugin or args.build_cuda_kernel_plugin: - build_cuda_kernels_command = [ + if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \ + (args.build_gpu_kernel_plugin == "rocm"): + build_gpu_kernels_command = [ *command_base, - "//jaxlib/tools:build_cuda_kernels_wheel", "--", + "//jaxlib/tools:build_gpu_kernels_wheel", "--", f"--output_path={output_path}", f"--jaxlib_git_hash={get_githash()}", f"--cpu={wheel_cpu}", - f"--cuda_version={args.gpu_plugin_cuda_version}" ] + if args.enable_cuda: + build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}") + build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") + elif args.enable_rocm: + build_gpu_kernels_command.append(f"--enable-rocm={args.enable_rocm}") + build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + else: + raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") if args.editable: - build_cuda_kernels_command.append("--editable") - print(" ".join(build_cuda_kernels_command)) - shell(build_cuda_kernels_command) + build_gpu_kernels_command.append("--editable") + print(" ".join(build_gpu_kernels_command)) + shell(build_gpu_kernels_command) - if args.build_gpu_plugin or args.build_cuda_pjrt_plugin: + if args.build_gpu_plugin or args.build_gpu_pjrt_plugin: build_pjrt_plugin_command = [ *command_base, "//jaxlib/tools:build_gpu_plugin_wheel", "--", f"--output_path={output_path}", f"--jaxlib_git_hash={get_githash()}", f"--cpu={wheel_cpu}", - f"--cuda_version={args.gpu_plugin_cuda_version}" ] + if args.enable_cuda: + build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}") + build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") + elif args.enable_rocm: + build_pjrt_plugin_command.append(f"--enable-rocm={args.enable_rocm}") + build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + else: + raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") if args.editable: build_pjrt_plugin_command.append("--editable") print(" ".join(build_pjrt_plugin_command)) diff --git a/build/requirements.in b/build/requirements.in index cf9750e082c5..add6b8577350 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -11,18 +11,13 @@ matplotlib; python_version>="3.11" # # build deps # -numpy~=1.22.0; python_version<="3.10" -numpy~=1.23.2; python_version=="3.11" -numpy>=1.26.0; python_version>="3.12" -numpy>=1.26.0; python_version=="3.13" +numpy~=2.0.0 # # runtime deps # -scipy~=1.9.0; python_version<"3.12" -scipy>=1.11.1; python_version>="3.12" +scipy~=1.13.1 -importlib_metadata; python_version<"3.10" ml_dtypes>=0.4.0 opt_einsum zstandard diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index c6d73524d221..adabb0dd2e70 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -88,6 +88,10 @@ execnet==2.1.1 \ --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 # via pytest-xdist +filelock==3.14.0 \ + --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ + --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a + # via -r build/test-requirements.txt flatbuffers==24.3.25 \ --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 @@ -262,36 +266,35 @@ markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.9.0 \ - --hash=sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38 \ - --hash=sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321 \ - --hash=sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db \ - --hash=sha256:0fc51eaa5262553868461c083d9adadb11a6017315f3a757fc45ec6ec5f02888 \ - --hash=sha256:13beb4840317d45ffd4183a778685e215939be7b08616f431c7795276e067463 \ - --hash=sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03 \ - --hash=sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56 \ - --hash=sha256:2e7f03e5cbbfacdd48c8ea394d365d91ee8f3cae7e6ec611409927b5ed997ee4 \ - --hash=sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b \ - --hash=sha256:52146fc3bd7813cc784562cb93a15788be0b2875c4655e2cc6ea646bfa30344b \ - --hash=sha256:550cdda3adbd596078cca7d13ed50b77879104e2e46392dcd7c75259d8f00e85 \ - --hash=sha256:616fabf4981a3b3c5a15cd95eba359c8489c4e20e03717aea42866d8d0465956 \ - --hash=sha256:76cce0f31b351e3551d1f3779420cf8f6ec0d4a8cf9c0237a3b549fd28eb4abb \ - --hash=sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd \ - --hash=sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7 \ - --hash=sha256:81c40af649d19c85f8073e25e5806926986806fa6d54be506fbf02aef47d5a89 \ - --hash=sha256:9a2fa6d899e17ddca6d6526cf6e7ba677738bf2a6a9590d702c277204a7c6152 \ - --hash=sha256:a5be985db2596d761cdf0c2eaf52396f26e6a64ab46bd8cd810c48972349d1be \ - --hash=sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e \ - --hash=sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0 \ - --hash=sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84 \ - --hash=sha256:c53aeb514ccbbcbab55a27f912d79ea30ab21ee0531ee2c09f13800efb272674 \ - --hash=sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382 \ - --hash=sha256:cd53c79fd02f1c1808d2cfc87dd3cf4dbc63c5244a58ee7944497107469c8d8a \ - --hash=sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5 \ - --hash=sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf \ - --hash=sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a \ - --hash=sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d \ - --hash=sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241 +matplotlib==3.8.4 ; python_version <= "3.10" \ + --hash=sha256:1c13f041a7178f9780fb61cc3a2b10423d5e125480e4be51beaf62b172413b67 \ + --hash=sha256:232ce322bfd020a434caaffbd9a95333f7c2491e59cfc014041d95e38ab90d1c \ + --hash=sha256:493e9f6aa5819156b58fce42b296ea31969f2aab71c5b680b4ea7a3cb5c07d94 \ + --hash=sha256:50bac6e4d77e4262c4340d7a985c30912054745ec99756ce213bfbc3cb3808eb \ + --hash=sha256:606e3b90897554c989b1e38a258c626d46c873523de432b1462f295db13de6f9 \ + --hash=sha256:6209e5c9aaccc056e63b547a8152661324404dd92340a6e479b3a7f24b42a5d0 \ + --hash=sha256:6485ac1f2e84676cff22e693eaa4fbed50ef5dc37173ce1f023daef4687df616 \ + --hash=sha256:6addbd5b488aedb7f9bc19f91cd87ea476206f45d7116fcfe3d31416702a82fa \ + --hash=sha256:72f9322712e4562e792b2961971891b9fbbb0e525011e09ea0d1f416c4645661 \ + --hash=sha256:7a6769f58ce51791b4cb8b4d7642489df347697cd3e23d88266aaaee93b41d9a \ + --hash=sha256:8080d5081a86e690d7688ffa542532e87f224c38a6ed71f8fbed34dd1d9fedae \ + --hash=sha256:843cbde2f0946dadd8c5c11c6d91847abd18ec76859dc319362a0964493f0ba6 \ + --hash=sha256:8aac397d5e9ec158960e31c381c5ffc52ddd52bd9a47717e2a694038167dffea \ + --hash=sha256:8f65c9f002d281a6e904976007b2d46a1ee2bcea3a68a8c12dda24709ddc9106 \ + --hash=sha256:90df07db7b599fe7035d2f74ab7e438b656528c68ba6bb59b7dc46af39ee48ef \ + --hash=sha256:9bb0189011785ea794ee827b68777db3ca3f93f3e339ea4d920315a0e5a78d54 \ + --hash=sha256:a0e47eda4eb2614300fc7bb4657fced3e83d6334d03da2173b09e447418d499f \ + --hash=sha256:abc9d838f93583650c35eca41cfcec65b2e7cb50fd486da6f0c49b5e1ed23014 \ + --hash=sha256:ac24233e8f2939ac4fd2919eed1e9c0871eac8057666070e94cbf0b33dd9c338 \ + --hash=sha256:b12ba985837e4899b762b81f5b2845bd1a28f4fdd1a126d9ace64e9c4eb2fb25 \ + --hash=sha256:b7a2a253d3b36d90c8993b4620183b55665a429da8357a4f621e78cd48b2b30b \ + --hash=sha256:c7064120a59ce6f64103c9cefba8ffe6fba87f2c61d67c401186423c9a20fd35 \ + --hash=sha256:c89ee9314ef48c72fe92ce55c4e95f2f39d70208f9f1d9db4e64079420d8d732 \ + --hash=sha256:cc4ccdc64e3039fc303defd119658148f2349239871db72cd74e2eeaa9b80b71 \ + --hash=sha256:ce1edd9f5383b504dbc26eeea404ed0a00656c526638129028b758fd43fc5f10 \ + --hash=sha256:ecd79298550cba13a43c340581a3ec9c707bd895a6a061a78fa2524660482fc0 \ + --hash=sha256:f51c4c869d4b60d769f7b4406eec39596648d9d70246428745a681c327a8ad30 \ + --hash=sha256:fb44f53af0a62dc80bba4443d9b27f2fde6acfdac281d95bc872dc148a6509cc # via -r build/requirements.in mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ @@ -320,52 +323,52 @@ mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 # via -r build/test-requirements.txt -numpy==2.0.0rc2 \ - --hash=sha256:01ac116e2f053f283ac5292fcd146f8f223d4b6cd343beab341748818692a2de \ - --hash=sha256:0a2cf839a7d6cc0b854ba81cdfee96aad2c7e4d558c7e23ca82d08e4f7d7daa7 \ - --hash=sha256:0a49e876be11b4409eb3120841c7d2dba1f63549224f85fa4ab7ee83288c3b41 \ - --hash=sha256:0d5cfbf693408cf1ee72d79d36d51f7b63f5e46a5e9cf12f63d4ed07c0f876e0 \ - --hash=sha256:0e146557fdede5a7434a788648e62a9e87db8c6e05136a92483e2c2180ad4bab \ - --hash=sha256:12d3bf0cac2aec23e10b6927ee063aa6cf7ca8deba1d3c5702faa0ea5cfb8049 \ - --hash=sha256:159d9c21a2989afdfebb638f60268becbc3da07eb224d9221a7c37255216feb6 \ - --hash=sha256:1691e64c838d33fdba59ac7043144194f8f847b5fec6f47ecd9e117418cc9bdc \ - --hash=sha256:201c0e05854d25f16b15851380c07d61aab34eef76a2acf1c3fcc4bda0879b0b \ - --hash=sha256:2202abe3e8afb2b88102a75f1beb888f380c09d40898db0f1df5d847623701d5 \ - --hash=sha256:225c2b3303eb2ebf745ab954ef8723cd60f64d926edd73dc963141538ddc48ed \ - --hash=sha256:24bcf0cdd31debdcb80e1f3bb7dba224c9a93a66f48ff1b1df2cb9a53eede944 \ - --hash=sha256:2a9a5ee4b090af548a1019bb76b53b02cb37f09dc002386349ee5e79ff54c40e \ - --hash=sha256:2bc615498fce8e15b99c1b4d7e018ffebf7bd1a288665b3b916357bdf6725d6a \ - --hash=sha256:32207294f21331ae0d7fd33dc9324447a8117d5af15a0895f39af3441d4af70e \ - --hash=sha256:32725b717f902e7243d270e50ff9487a499820233b57c3e71b33f65a84707e38 \ - --hash=sha256:4f3a4c676ab4ce211e5886cb16cc282e9e18b352b2b1427bbb4c104f9d80f12a \ - --hash=sha256:5262d69981502ded9b397c3fd5a20a1f2c91a66b21325ddff5e6d88486eee6fa \ - --hash=sha256:53286933bf3be7a13459c7a7885ce0935aff56fe0baf280f0e6d80e75cc3ee3c \ - --hash=sha256:6aba1c147f69ee1fb8afb44e93178e92d2aa9a3bf5374b6f1cb53ee1cae1376d \ - --hash=sha256:6b93d6b401db67948a4997e2c45e958df65b98b1a9183e96e96e491f9fb3c2fe \ - --hash=sha256:6d23b0db1fd4ad8225fd32f39036b07a5052398929a5af5291379bceac49d95a \ - --hash=sha256:6fe254c271f8ce4c2e60250f8ee80684abd2be748af84312a05b7614c3ae3b8d \ - --hash=sha256:7288d8ac70be23ff29df8da51840aad8f7acd9120d27cd7a61488b96bc5ad68b \ - --hash=sha256:74dcc392725837896532ec7d65506cbeaecee237871b36ae813521bc3e2c40ed \ - --hash=sha256:800ff28d0da25fca3f843c19035005b73c76350be7c6fa6061c8fcdd248aced9 \ - --hash=sha256:83c76a11c5e5a343fb1cb87afec147d6bebac91758c9c9f01d2c692ae4750e27 \ - --hash=sha256:868e9edbee689d6fdb7957c0b790de2b2123e6feff5d66045d10760c521f2c00 \ - --hash=sha256:87172a69d7eafb00ea1b734dba9ffebb474505082078ec2d95b99918f14a0a0e \ - --hash=sha256:951164e9919664a3e5e605715809173b47f14329b586e24ec05e66ae516ce11b \ - --hash=sha256:9b07a5c460941ae5ef8cde51c04b635af58abbbd55387ad6257dbdfda043290a \ - --hash=sha256:9dd61b79856aed44f818fffe1555fa7ef8f6ffa5b5211cde473e2e33f7a5bd92 \ - --hash=sha256:9e00367261ee0347208a8bcc355b6470b084cb777c45141e098328b67b02c98b \ - --hash=sha256:9ea90fb601d5ac32ff7f9f0a3bf7ccab5971a0196364b9429734bd270cd2fa67 \ - --hash=sha256:a0202e282ec9d45fc6ddb85777fddeea1107fe4555be50dd22d044e7fe01860c \ - --hash=sha256:a44b0ebf7ef61c289a33c76247874177c446083c5236c7e7e0595350883e0424 \ - --hash=sha256:a666cc3d55f301b86edc7f1eaef10ffa1f79206c4b196a1f2649f91c8a1b49b6 \ - --hash=sha256:a99ac361ddb0ef14894c3e7405aa98ffdfe6d0101b9f4a2e931f3912f3b43085 \ - --hash=sha256:b3ba5f436c6de9b8829f231e9eb9e394aa819efce9eab697cd4e558b0b8c6cc8 \ - --hash=sha256:c58bc6aac83175dcfa02a0ef92b7a7fff5a0420014202f052a9af6214684e6ac \ - --hash=sha256:d5211fd4e126699b16b8573eef007f25afb9459d966b35430908798b24298e3b \ - --hash=sha256:da6ab9dab471668155e0b208ab710417a7407397794a88b3ccbece5bcf10091d \ - --hash=sha256:e13a1fa60a471b79a53de8abb87e1e0ad53e6899edee8a29b4db3edccee53d65 \ - --hash=sha256:f8c7012dd6779f078e3f42e19a2204275abe4d68a80dc807a97caf42e825d9c3 \ - --hash=sha256:fa5485c565ca222ba69c5fe04ebd8a89f884615466d74e0856e03fff873bcc43 +numpy==2.0.0 \ + --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ + --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ + --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ + --hash=sha256:0ec84b9ba0654f3b962802edc91424331f423dcf5d5f926676e0150789cb3d95 \ + --hash=sha256:17067d097ed036636fa79f6a869ac26df7db1ba22039d962422506640314933a \ + --hash=sha256:1cde1753efe513705a0c6d28f5884e22bdc30438bf0085c5c486cdaff40cd67a \ + --hash=sha256:1e72728e7501a450288fc8e1f9ebc73d90cfd4671ebbd631f3e7857c39bd16f2 \ + --hash=sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2 \ + --hash=sha256:2ce46fd0b8a0c947ae047d222f7136fc4d55538741373107574271bc00e20e8f \ + --hash=sha256:34f003cb88b1ba38cb9a9a4a3161c1604973d7f9d5552c38bc2f04f829536609 \ + --hash=sha256:354f373279768fa5a584bac997de6a6c9bc535c482592d7a813bb0c09be6c76f \ + --hash=sha256:38ecb5b0582cd125f67a629072fed6f83562d9dd04d7e03256c9829bdec027ad \ + --hash=sha256:3e8e01233d57639b2e30966c63d36fcea099d17c53bf424d77f088b0f4babd86 \ + --hash=sha256:3f6bed7f840d44c08ebdb73b1825282b801799e325bcbdfa6bc5c370e5aecc65 \ + --hash=sha256:4554eb96f0fd263041baf16cf0881b3f5dafae7a59b1049acb9540c4d57bc8cb \ + --hash=sha256:46e161722e0f619749d1cd892167039015b2c2817296104487cd03ed4a955995 \ + --hash=sha256:49d9f7d256fbc804391a7f72d4a617302b1afac1112fac19b6c6cec63fe7fe8a \ + --hash=sha256:4d2f62e55a4cd9c58c1d9a1c9edaedcd857a73cb6fda875bf79093f9d9086f85 \ + --hash=sha256:5f64641b42b2429f56ee08b4f427a4d2daf916ec59686061de751a55aafa22e4 \ + --hash=sha256:63b92c512d9dbcc37f9d81b123dec99fdb318ba38c8059afc78086fe73820275 \ + --hash=sha256:6d7696c615765091cc5093f76fd1fa069870304beaccfd58b5dcc69e55ef49c1 \ + --hash=sha256:79e843d186c8fb1b102bef3e2bc35ef81160ffef3194646a7fdd6a73c6b97196 \ + --hash=sha256:821eedb7165ead9eebdb569986968b541f9908979c2da8a4967ecac4439bae3d \ + --hash=sha256:84554fc53daa8f6abf8e8a66e076aff6ece62de68523d9f665f32d2fc50fd66e \ + --hash=sha256:8d83bb187fb647643bd56e1ae43f273c7f4dbcdf94550d7938cfc32566756514 \ + --hash=sha256:903703372d46bce88b6920a0cd86c3ad82dae2dbef157b5fc01b70ea1cfc430f \ + --hash=sha256:9416a5c2e92ace094e9f0082c5fd473502c91651fb896bc17690d6fc475128d6 \ + --hash=sha256:9a1712c015831da583b21c5bfe15e8684137097969c6d22e8316ba66b5baabe4 \ + --hash=sha256:9c27f0946a3536403efb0e1c28def1ae6730a72cd0d5878db38824855e3afc44 \ + --hash=sha256:a356364941fb0593bb899a1076b92dfa2029f6f5b8ba88a14fd0984aaf76d0df \ + --hash=sha256:a7039a136017eaa92c1848152827e1424701532ca8e8967fe480fe1569dae581 \ + --hash=sha256:acd3a644e4807e73b4e1867b769fbf1ce8c5d80e7caaef0d90dcdc640dfc9787 \ + --hash=sha256:ad0c86f3455fbd0de6c31a3056eb822fc939f81b1618f10ff3406971893b62a5 \ + --hash=sha256:b4c76e3d4c56f145d41b7b6751255feefae92edbc9a61e1758a98204200f30fc \ + --hash=sha256:b6f6a8f45d0313db07d6d1d37bd0b112f887e1369758a5419c0370ba915b3871 \ + --hash=sha256:c5a59996dc61835133b56a32ebe4ef3740ea5bc19b3983ac60cc32be5a665d54 \ + --hash=sha256:c73aafd1afca80afecb22718f8700b40ac7cab927b8abab3c3e337d70e10e5a2 \ + --hash=sha256:cee6cc0584f71adefe2c908856ccc98702baf95ff80092e4ca46061538a2ba98 \ + --hash=sha256:cef04d068f5fb0518a77857953193b6bb94809a806bd0a14983a8f12ada060c9 \ + --hash=sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864 \ + --hash=sha256:e61155fae27570692ad1d327e81c6cf27d535a5d7ef97648a17d922224b216de \ + --hash=sha256:e7f387600d424f91576af20518334df3d97bc76a300a755f9a8d6e4f5cadd289 \ + --hash=sha256:ed08d2703b5972ec736451b818c2eb9da80d66c3e84aed1deeb0c345fefe461b \ + --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ + --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via # -r build/requirements.in # -r build/test-requirements.txt @@ -512,32 +515,32 @@ rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 # via -r build/test-requirements.txt -scipy==1.13.0 \ - --hash=sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922 \ - --hash=sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5 \ - --hash=sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa \ - --hash=sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820 \ - --hash=sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd \ - --hash=sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42 \ - --hash=sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e \ - --hash=sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d \ - --hash=sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86 \ - --hash=sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e \ - --hash=sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c \ - --hash=sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602 \ - --hash=sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e \ - --hash=sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5 \ - --hash=sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a \ - --hash=sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21 \ - --hash=sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d \ - --hash=sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6 \ - --hash=sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78 \ - --hash=sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551 \ - --hash=sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7 \ - --hash=sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4 \ - --hash=sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d \ - --hash=sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b \ - --hash=sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9 +scipy==1.13.1 \ + --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ + --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ + --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ + --hash=sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9 \ + --hash=sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54 \ + --hash=sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16 \ + --hash=sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2 \ + --hash=sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5 \ + --hash=sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59 \ + --hash=sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326 \ + --hash=sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b \ + --hash=sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1 \ + --hash=sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d \ + --hash=sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24 \ + --hash=sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627 \ + --hash=sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c \ + --hash=sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa \ + --hash=sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949 \ + --hash=sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989 \ + --hash=sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004 \ + --hash=sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f \ + --hash=sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884 \ + --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ + --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ + --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f # via -r build/requirements.in six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index e9649f45d32e..053e996cefad 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -82,6 +82,10 @@ execnet==2.1.1 \ --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 # via pytest-xdist +filelock==3.14.0 \ + --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ + --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a + # via -r build/test-requirements.txt flatbuffers==24.3.25 \ --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 @@ -256,7 +260,7 @@ markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.9.0 \ +matplotlib==3.9.0 ; python_version >= "3.11" \ --hash=sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38 \ --hash=sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321 \ --hash=sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db \ @@ -314,52 +318,52 @@ mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 # via -r build/test-requirements.txt -numpy==2.0.0rc2 \ - --hash=sha256:01ac116e2f053f283ac5292fcd146f8f223d4b6cd343beab341748818692a2de \ - --hash=sha256:0a2cf839a7d6cc0b854ba81cdfee96aad2c7e4d558c7e23ca82d08e4f7d7daa7 \ - --hash=sha256:0a49e876be11b4409eb3120841c7d2dba1f63549224f85fa4ab7ee83288c3b41 \ - --hash=sha256:0d5cfbf693408cf1ee72d79d36d51f7b63f5e46a5e9cf12f63d4ed07c0f876e0 \ - --hash=sha256:0e146557fdede5a7434a788648e62a9e87db8c6e05136a92483e2c2180ad4bab \ - --hash=sha256:12d3bf0cac2aec23e10b6927ee063aa6cf7ca8deba1d3c5702faa0ea5cfb8049 \ - --hash=sha256:159d9c21a2989afdfebb638f60268becbc3da07eb224d9221a7c37255216feb6 \ - --hash=sha256:1691e64c838d33fdba59ac7043144194f8f847b5fec6f47ecd9e117418cc9bdc \ - --hash=sha256:201c0e05854d25f16b15851380c07d61aab34eef76a2acf1c3fcc4bda0879b0b \ - --hash=sha256:2202abe3e8afb2b88102a75f1beb888f380c09d40898db0f1df5d847623701d5 \ - --hash=sha256:225c2b3303eb2ebf745ab954ef8723cd60f64d926edd73dc963141538ddc48ed \ - --hash=sha256:24bcf0cdd31debdcb80e1f3bb7dba224c9a93a66f48ff1b1df2cb9a53eede944 \ - --hash=sha256:2a9a5ee4b090af548a1019bb76b53b02cb37f09dc002386349ee5e79ff54c40e \ - --hash=sha256:2bc615498fce8e15b99c1b4d7e018ffebf7bd1a288665b3b916357bdf6725d6a \ - --hash=sha256:32207294f21331ae0d7fd33dc9324447a8117d5af15a0895f39af3441d4af70e \ - --hash=sha256:32725b717f902e7243d270e50ff9487a499820233b57c3e71b33f65a84707e38 \ - --hash=sha256:4f3a4c676ab4ce211e5886cb16cc282e9e18b352b2b1427bbb4c104f9d80f12a \ - --hash=sha256:5262d69981502ded9b397c3fd5a20a1f2c91a66b21325ddff5e6d88486eee6fa \ - --hash=sha256:53286933bf3be7a13459c7a7885ce0935aff56fe0baf280f0e6d80e75cc3ee3c \ - --hash=sha256:6aba1c147f69ee1fb8afb44e93178e92d2aa9a3bf5374b6f1cb53ee1cae1376d \ - --hash=sha256:6b93d6b401db67948a4997e2c45e958df65b98b1a9183e96e96e491f9fb3c2fe \ - --hash=sha256:6d23b0db1fd4ad8225fd32f39036b07a5052398929a5af5291379bceac49d95a \ - --hash=sha256:6fe254c271f8ce4c2e60250f8ee80684abd2be748af84312a05b7614c3ae3b8d \ - --hash=sha256:7288d8ac70be23ff29df8da51840aad8f7acd9120d27cd7a61488b96bc5ad68b \ - --hash=sha256:74dcc392725837896532ec7d65506cbeaecee237871b36ae813521bc3e2c40ed \ - --hash=sha256:800ff28d0da25fca3f843c19035005b73c76350be7c6fa6061c8fcdd248aced9 \ - --hash=sha256:83c76a11c5e5a343fb1cb87afec147d6bebac91758c9c9f01d2c692ae4750e27 \ - --hash=sha256:868e9edbee689d6fdb7957c0b790de2b2123e6feff5d66045d10760c521f2c00 \ - --hash=sha256:87172a69d7eafb00ea1b734dba9ffebb474505082078ec2d95b99918f14a0a0e \ - --hash=sha256:951164e9919664a3e5e605715809173b47f14329b586e24ec05e66ae516ce11b \ - --hash=sha256:9b07a5c460941ae5ef8cde51c04b635af58abbbd55387ad6257dbdfda043290a \ - --hash=sha256:9dd61b79856aed44f818fffe1555fa7ef8f6ffa5b5211cde473e2e33f7a5bd92 \ - --hash=sha256:9e00367261ee0347208a8bcc355b6470b084cb777c45141e098328b67b02c98b \ - --hash=sha256:9ea90fb601d5ac32ff7f9f0a3bf7ccab5971a0196364b9429734bd270cd2fa67 \ - --hash=sha256:a0202e282ec9d45fc6ddb85777fddeea1107fe4555be50dd22d044e7fe01860c \ - --hash=sha256:a44b0ebf7ef61c289a33c76247874177c446083c5236c7e7e0595350883e0424 \ - --hash=sha256:a666cc3d55f301b86edc7f1eaef10ffa1f79206c4b196a1f2649f91c8a1b49b6 \ - --hash=sha256:a99ac361ddb0ef14894c3e7405aa98ffdfe6d0101b9f4a2e931f3912f3b43085 \ - --hash=sha256:b3ba5f436c6de9b8829f231e9eb9e394aa819efce9eab697cd4e558b0b8c6cc8 \ - --hash=sha256:c58bc6aac83175dcfa02a0ef92b7a7fff5a0420014202f052a9af6214684e6ac \ - --hash=sha256:d5211fd4e126699b16b8573eef007f25afb9459d966b35430908798b24298e3b \ - --hash=sha256:da6ab9dab471668155e0b208ab710417a7407397794a88b3ccbece5bcf10091d \ - --hash=sha256:e13a1fa60a471b79a53de8abb87e1e0ad53e6899edee8a29b4db3edccee53d65 \ - --hash=sha256:f8c7012dd6779f078e3f42e19a2204275abe4d68a80dc807a97caf42e825d9c3 \ - --hash=sha256:fa5485c565ca222ba69c5fe04ebd8a89f884615466d74e0856e03fff873bcc43 +numpy==2.0.0 \ + --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ + --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ + --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ + --hash=sha256:0ec84b9ba0654f3b962802edc91424331f423dcf5d5f926676e0150789cb3d95 \ + --hash=sha256:17067d097ed036636fa79f6a869ac26df7db1ba22039d962422506640314933a \ + --hash=sha256:1cde1753efe513705a0c6d28f5884e22bdc30438bf0085c5c486cdaff40cd67a \ + --hash=sha256:1e72728e7501a450288fc8e1f9ebc73d90cfd4671ebbd631f3e7857c39bd16f2 \ + --hash=sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2 \ + --hash=sha256:2ce46fd0b8a0c947ae047d222f7136fc4d55538741373107574271bc00e20e8f \ + --hash=sha256:34f003cb88b1ba38cb9a9a4a3161c1604973d7f9d5552c38bc2f04f829536609 \ + --hash=sha256:354f373279768fa5a584bac997de6a6c9bc535c482592d7a813bb0c09be6c76f \ + --hash=sha256:38ecb5b0582cd125f67a629072fed6f83562d9dd04d7e03256c9829bdec027ad \ + --hash=sha256:3e8e01233d57639b2e30966c63d36fcea099d17c53bf424d77f088b0f4babd86 \ + --hash=sha256:3f6bed7f840d44c08ebdb73b1825282b801799e325bcbdfa6bc5c370e5aecc65 \ + --hash=sha256:4554eb96f0fd263041baf16cf0881b3f5dafae7a59b1049acb9540c4d57bc8cb \ + --hash=sha256:46e161722e0f619749d1cd892167039015b2c2817296104487cd03ed4a955995 \ + --hash=sha256:49d9f7d256fbc804391a7f72d4a617302b1afac1112fac19b6c6cec63fe7fe8a \ + --hash=sha256:4d2f62e55a4cd9c58c1d9a1c9edaedcd857a73cb6fda875bf79093f9d9086f85 \ + --hash=sha256:5f64641b42b2429f56ee08b4f427a4d2daf916ec59686061de751a55aafa22e4 \ + --hash=sha256:63b92c512d9dbcc37f9d81b123dec99fdb318ba38c8059afc78086fe73820275 \ + --hash=sha256:6d7696c615765091cc5093f76fd1fa069870304beaccfd58b5dcc69e55ef49c1 \ + --hash=sha256:79e843d186c8fb1b102bef3e2bc35ef81160ffef3194646a7fdd6a73c6b97196 \ + --hash=sha256:821eedb7165ead9eebdb569986968b541f9908979c2da8a4967ecac4439bae3d \ + --hash=sha256:84554fc53daa8f6abf8e8a66e076aff6ece62de68523d9f665f32d2fc50fd66e \ + --hash=sha256:8d83bb187fb647643bd56e1ae43f273c7f4dbcdf94550d7938cfc32566756514 \ + --hash=sha256:903703372d46bce88b6920a0cd86c3ad82dae2dbef157b5fc01b70ea1cfc430f \ + --hash=sha256:9416a5c2e92ace094e9f0082c5fd473502c91651fb896bc17690d6fc475128d6 \ + --hash=sha256:9a1712c015831da583b21c5bfe15e8684137097969c6d22e8316ba66b5baabe4 \ + --hash=sha256:9c27f0946a3536403efb0e1c28def1ae6730a72cd0d5878db38824855e3afc44 \ + --hash=sha256:a356364941fb0593bb899a1076b92dfa2029f6f5b8ba88a14fd0984aaf76d0df \ + --hash=sha256:a7039a136017eaa92c1848152827e1424701532ca8e8967fe480fe1569dae581 \ + --hash=sha256:acd3a644e4807e73b4e1867b769fbf1ce8c5d80e7caaef0d90dcdc640dfc9787 \ + --hash=sha256:ad0c86f3455fbd0de6c31a3056eb822fc939f81b1618f10ff3406971893b62a5 \ + --hash=sha256:b4c76e3d4c56f145d41b7b6751255feefae92edbc9a61e1758a98204200f30fc \ + --hash=sha256:b6f6a8f45d0313db07d6d1d37bd0b112f887e1369758a5419c0370ba915b3871 \ + --hash=sha256:c5a59996dc61835133b56a32ebe4ef3740ea5bc19b3983ac60cc32be5a665d54 \ + --hash=sha256:c73aafd1afca80afecb22718f8700b40ac7cab927b8abab3c3e337d70e10e5a2 \ + --hash=sha256:cee6cc0584f71adefe2c908856ccc98702baf95ff80092e4ca46061538a2ba98 \ + --hash=sha256:cef04d068f5fb0518a77857953193b6bb94809a806bd0a14983a8f12ada060c9 \ + --hash=sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864 \ + --hash=sha256:e61155fae27570692ad1d327e81c6cf27d535a5d7ef97648a17d922224b216de \ + --hash=sha256:e7f387600d424f91576af20518334df3d97bc76a300a755f9a8d6e4f5cadd289 \ + --hash=sha256:ed08d2703b5972ec736451b818c2eb9da80d66c3e84aed1deeb0c345fefe461b \ + --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ + --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via # -r build/requirements.in # -r build/test-requirements.txt @@ -506,32 +510,32 @@ rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 # via -r build/test-requirements.txt -scipy==1.13.0 \ - --hash=sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922 \ - --hash=sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5 \ - --hash=sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa \ - --hash=sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820 \ - --hash=sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd \ - --hash=sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42 \ - --hash=sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e \ - --hash=sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d \ - --hash=sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86 \ - --hash=sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e \ - --hash=sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c \ - --hash=sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602 \ - --hash=sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e \ - --hash=sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5 \ - --hash=sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a \ - --hash=sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21 \ - --hash=sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d \ - --hash=sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6 \ - --hash=sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78 \ - --hash=sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551 \ - --hash=sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7 \ - --hash=sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4 \ - --hash=sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d \ - --hash=sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b \ - --hash=sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9 +scipy==1.13.1 \ + --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ + --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ + --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ + --hash=sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9 \ + --hash=sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54 \ + --hash=sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16 \ + --hash=sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2 \ + --hash=sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5 \ + --hash=sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59 \ + --hash=sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326 \ + --hash=sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b \ + --hash=sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1 \ + --hash=sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d \ + --hash=sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24 \ + --hash=sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627 \ + --hash=sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c \ + --hash=sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa \ + --hash=sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949 \ + --hash=sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989 \ + --hash=sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004 \ + --hash=sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f \ + --hash=sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884 \ + --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ + --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ + --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f # via -r build/requirements.in six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 606b042eb4aa..1468e64c29cd 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -82,6 +82,10 @@ execnet==2.1.1 \ --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 # via pytest-xdist +filelock==3.14.0 \ + --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ + --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a + # via -r build/test-requirements.txt flatbuffers==24.3.25 \ --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 @@ -256,7 +260,7 @@ markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.9.0 \ +matplotlib==3.9.0 ; python_version >= "3.11" \ --hash=sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38 \ --hash=sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321 \ --hash=sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db \ @@ -314,52 +318,52 @@ mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 # via -r build/test-requirements.txt -numpy==2.0.0rc2 \ - --hash=sha256:01ac116e2f053f283ac5292fcd146f8f223d4b6cd343beab341748818692a2de \ - --hash=sha256:0a2cf839a7d6cc0b854ba81cdfee96aad2c7e4d558c7e23ca82d08e4f7d7daa7 \ - --hash=sha256:0a49e876be11b4409eb3120841c7d2dba1f63549224f85fa4ab7ee83288c3b41 \ - --hash=sha256:0d5cfbf693408cf1ee72d79d36d51f7b63f5e46a5e9cf12f63d4ed07c0f876e0 \ - --hash=sha256:0e146557fdede5a7434a788648e62a9e87db8c6e05136a92483e2c2180ad4bab \ - --hash=sha256:12d3bf0cac2aec23e10b6927ee063aa6cf7ca8deba1d3c5702faa0ea5cfb8049 \ - --hash=sha256:159d9c21a2989afdfebb638f60268becbc3da07eb224d9221a7c37255216feb6 \ - --hash=sha256:1691e64c838d33fdba59ac7043144194f8f847b5fec6f47ecd9e117418cc9bdc \ - --hash=sha256:201c0e05854d25f16b15851380c07d61aab34eef76a2acf1c3fcc4bda0879b0b \ - --hash=sha256:2202abe3e8afb2b88102a75f1beb888f380c09d40898db0f1df5d847623701d5 \ - --hash=sha256:225c2b3303eb2ebf745ab954ef8723cd60f64d926edd73dc963141538ddc48ed \ - --hash=sha256:24bcf0cdd31debdcb80e1f3bb7dba224c9a93a66f48ff1b1df2cb9a53eede944 \ - --hash=sha256:2a9a5ee4b090af548a1019bb76b53b02cb37f09dc002386349ee5e79ff54c40e \ - --hash=sha256:2bc615498fce8e15b99c1b4d7e018ffebf7bd1a288665b3b916357bdf6725d6a \ - --hash=sha256:32207294f21331ae0d7fd33dc9324447a8117d5af15a0895f39af3441d4af70e \ - --hash=sha256:32725b717f902e7243d270e50ff9487a499820233b57c3e71b33f65a84707e38 \ - --hash=sha256:4f3a4c676ab4ce211e5886cb16cc282e9e18b352b2b1427bbb4c104f9d80f12a \ - --hash=sha256:5262d69981502ded9b397c3fd5a20a1f2c91a66b21325ddff5e6d88486eee6fa \ - --hash=sha256:53286933bf3be7a13459c7a7885ce0935aff56fe0baf280f0e6d80e75cc3ee3c \ - --hash=sha256:6aba1c147f69ee1fb8afb44e93178e92d2aa9a3bf5374b6f1cb53ee1cae1376d \ - --hash=sha256:6b93d6b401db67948a4997e2c45e958df65b98b1a9183e96e96e491f9fb3c2fe \ - --hash=sha256:6d23b0db1fd4ad8225fd32f39036b07a5052398929a5af5291379bceac49d95a \ - --hash=sha256:6fe254c271f8ce4c2e60250f8ee80684abd2be748af84312a05b7614c3ae3b8d \ - --hash=sha256:7288d8ac70be23ff29df8da51840aad8f7acd9120d27cd7a61488b96bc5ad68b \ - --hash=sha256:74dcc392725837896532ec7d65506cbeaecee237871b36ae813521bc3e2c40ed \ - --hash=sha256:800ff28d0da25fca3f843c19035005b73c76350be7c6fa6061c8fcdd248aced9 \ - --hash=sha256:83c76a11c5e5a343fb1cb87afec147d6bebac91758c9c9f01d2c692ae4750e27 \ - --hash=sha256:868e9edbee689d6fdb7957c0b790de2b2123e6feff5d66045d10760c521f2c00 \ - --hash=sha256:87172a69d7eafb00ea1b734dba9ffebb474505082078ec2d95b99918f14a0a0e \ - --hash=sha256:951164e9919664a3e5e605715809173b47f14329b586e24ec05e66ae516ce11b \ - --hash=sha256:9b07a5c460941ae5ef8cde51c04b635af58abbbd55387ad6257dbdfda043290a \ - --hash=sha256:9dd61b79856aed44f818fffe1555fa7ef8f6ffa5b5211cde473e2e33f7a5bd92 \ - --hash=sha256:9e00367261ee0347208a8bcc355b6470b084cb777c45141e098328b67b02c98b \ - --hash=sha256:9ea90fb601d5ac32ff7f9f0a3bf7ccab5971a0196364b9429734bd270cd2fa67 \ - --hash=sha256:a0202e282ec9d45fc6ddb85777fddeea1107fe4555be50dd22d044e7fe01860c \ - --hash=sha256:a44b0ebf7ef61c289a33c76247874177c446083c5236c7e7e0595350883e0424 \ - --hash=sha256:a666cc3d55f301b86edc7f1eaef10ffa1f79206c4b196a1f2649f91c8a1b49b6 \ - --hash=sha256:a99ac361ddb0ef14894c3e7405aa98ffdfe6d0101b9f4a2e931f3912f3b43085 \ - --hash=sha256:b3ba5f436c6de9b8829f231e9eb9e394aa819efce9eab697cd4e558b0b8c6cc8 \ - --hash=sha256:c58bc6aac83175dcfa02a0ef92b7a7fff5a0420014202f052a9af6214684e6ac \ - --hash=sha256:d5211fd4e126699b16b8573eef007f25afb9459d966b35430908798b24298e3b \ - --hash=sha256:da6ab9dab471668155e0b208ab710417a7407397794a88b3ccbece5bcf10091d \ - --hash=sha256:e13a1fa60a471b79a53de8abb87e1e0ad53e6899edee8a29b4db3edccee53d65 \ - --hash=sha256:f8c7012dd6779f078e3f42e19a2204275abe4d68a80dc807a97caf42e825d9c3 \ - --hash=sha256:fa5485c565ca222ba69c5fe04ebd8a89f884615466d74e0856e03fff873bcc43 +numpy==2.0.0 \ + --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ + --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ + --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ + --hash=sha256:0ec84b9ba0654f3b962802edc91424331f423dcf5d5f926676e0150789cb3d95 \ + --hash=sha256:17067d097ed036636fa79f6a869ac26df7db1ba22039d962422506640314933a \ + --hash=sha256:1cde1753efe513705a0c6d28f5884e22bdc30438bf0085c5c486cdaff40cd67a \ + --hash=sha256:1e72728e7501a450288fc8e1f9ebc73d90cfd4671ebbd631f3e7857c39bd16f2 \ + --hash=sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2 \ + --hash=sha256:2ce46fd0b8a0c947ae047d222f7136fc4d55538741373107574271bc00e20e8f \ + --hash=sha256:34f003cb88b1ba38cb9a9a4a3161c1604973d7f9d5552c38bc2f04f829536609 \ + --hash=sha256:354f373279768fa5a584bac997de6a6c9bc535c482592d7a813bb0c09be6c76f \ + --hash=sha256:38ecb5b0582cd125f67a629072fed6f83562d9dd04d7e03256c9829bdec027ad \ + --hash=sha256:3e8e01233d57639b2e30966c63d36fcea099d17c53bf424d77f088b0f4babd86 \ + --hash=sha256:3f6bed7f840d44c08ebdb73b1825282b801799e325bcbdfa6bc5c370e5aecc65 \ + --hash=sha256:4554eb96f0fd263041baf16cf0881b3f5dafae7a59b1049acb9540c4d57bc8cb \ + --hash=sha256:46e161722e0f619749d1cd892167039015b2c2817296104487cd03ed4a955995 \ + --hash=sha256:49d9f7d256fbc804391a7f72d4a617302b1afac1112fac19b6c6cec63fe7fe8a \ + --hash=sha256:4d2f62e55a4cd9c58c1d9a1c9edaedcd857a73cb6fda875bf79093f9d9086f85 \ + --hash=sha256:5f64641b42b2429f56ee08b4f427a4d2daf916ec59686061de751a55aafa22e4 \ + --hash=sha256:63b92c512d9dbcc37f9d81b123dec99fdb318ba38c8059afc78086fe73820275 \ + --hash=sha256:6d7696c615765091cc5093f76fd1fa069870304beaccfd58b5dcc69e55ef49c1 \ + --hash=sha256:79e843d186c8fb1b102bef3e2bc35ef81160ffef3194646a7fdd6a73c6b97196 \ + --hash=sha256:821eedb7165ead9eebdb569986968b541f9908979c2da8a4967ecac4439bae3d \ + --hash=sha256:84554fc53daa8f6abf8e8a66e076aff6ece62de68523d9f665f32d2fc50fd66e \ + --hash=sha256:8d83bb187fb647643bd56e1ae43f273c7f4dbcdf94550d7938cfc32566756514 \ + --hash=sha256:903703372d46bce88b6920a0cd86c3ad82dae2dbef157b5fc01b70ea1cfc430f \ + --hash=sha256:9416a5c2e92ace094e9f0082c5fd473502c91651fb896bc17690d6fc475128d6 \ + --hash=sha256:9a1712c015831da583b21c5bfe15e8684137097969c6d22e8316ba66b5baabe4 \ + --hash=sha256:9c27f0946a3536403efb0e1c28def1ae6730a72cd0d5878db38824855e3afc44 \ + --hash=sha256:a356364941fb0593bb899a1076b92dfa2029f6f5b8ba88a14fd0984aaf76d0df \ + --hash=sha256:a7039a136017eaa92c1848152827e1424701532ca8e8967fe480fe1569dae581 \ + --hash=sha256:acd3a644e4807e73b4e1867b769fbf1ce8c5d80e7caaef0d90dcdc640dfc9787 \ + --hash=sha256:ad0c86f3455fbd0de6c31a3056eb822fc939f81b1618f10ff3406971893b62a5 \ + --hash=sha256:b4c76e3d4c56f145d41b7b6751255feefae92edbc9a61e1758a98204200f30fc \ + --hash=sha256:b6f6a8f45d0313db07d6d1d37bd0b112f887e1369758a5419c0370ba915b3871 \ + --hash=sha256:c5a59996dc61835133b56a32ebe4ef3740ea5bc19b3983ac60cc32be5a665d54 \ + --hash=sha256:c73aafd1afca80afecb22718f8700b40ac7cab927b8abab3c3e337d70e10e5a2 \ + --hash=sha256:cee6cc0584f71adefe2c908856ccc98702baf95ff80092e4ca46061538a2ba98 \ + --hash=sha256:cef04d068f5fb0518a77857953193b6bb94809a806bd0a14983a8f12ada060c9 \ + --hash=sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864 \ + --hash=sha256:e61155fae27570692ad1d327e81c6cf27d535a5d7ef97648a17d922224b216de \ + --hash=sha256:e7f387600d424f91576af20518334df3d97bc76a300a755f9a8d6e4f5cadd289 \ + --hash=sha256:ed08d2703b5972ec736451b818c2eb9da80d66c3e84aed1deeb0c345fefe461b \ + --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ + --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via # -r build/requirements.in # -r build/test-requirements.txt @@ -506,32 +510,32 @@ rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 # via -r build/test-requirements.txt -scipy==1.13.0 \ - --hash=sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922 \ - --hash=sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5 \ - --hash=sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa \ - --hash=sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820 \ - --hash=sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd \ - --hash=sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42 \ - --hash=sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e \ - --hash=sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d \ - --hash=sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86 \ - --hash=sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e \ - --hash=sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c \ - --hash=sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602 \ - --hash=sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e \ - --hash=sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5 \ - --hash=sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a \ - --hash=sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21 \ - --hash=sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d \ - --hash=sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6 \ - --hash=sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78 \ - --hash=sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551 \ - --hash=sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7 \ - --hash=sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4 \ - --hash=sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d \ - --hash=sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b \ - --hash=sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9 +scipy==1.13.1 \ + --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ + --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ + --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ + --hash=sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9 \ + --hash=sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54 \ + --hash=sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16 \ + --hash=sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2 \ + --hash=sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5 \ + --hash=sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59 \ + --hash=sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326 \ + --hash=sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b \ + --hash=sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1 \ + --hash=sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d \ + --hash=sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24 \ + --hash=sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627 \ + --hash=sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c \ + --hash=sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa \ + --hash=sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949 \ + --hash=sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989 \ + --hash=sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004 \ + --hash=sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f \ + --hash=sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884 \ + --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ + --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ + --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f # via -r build/requirements.in six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 79b96452ebf4..62b5e14e65b4 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -86,7 +86,7 @@ python-dateutil==2.9.0.post0 # via matplotlib rich==13.7.1 # via -r build/test-requirements.txt -scipy==1.13.0 +scipy==1.13.1 # via -r build/requirements.in six==1.16.0 # via python-dateutil diff --git a/build/requirements_lock_3_9.txt b/build/requirements_lock_3_9.txt deleted file mode 100644 index 105554c64633..000000000000 --- a/build/requirements_lock_3_9.txt +++ /dev/null @@ -1,632 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.9 -# by the following command: -# -# bazel run //build:requirements.update -# -absl-py==2.1.0 \ - --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ - --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff - # via -r build/test-requirements.txt -attrs==23.2.0 \ - --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ - --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 - # via hypothesis -build==1.2.1 \ - --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ - --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 - # via -r build/test-requirements.txt -cloudpickle==3.0.0 \ - --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ - --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 - # via -r build/test-requirements.txt -colorama==0.4.6 \ - --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ - --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt -contourpy==1.2.1 \ - --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ - --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ - --hash=sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9 \ - --hash=sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4 \ - --hash=sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce \ - --hash=sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7 \ - --hash=sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f \ - --hash=sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922 \ - --hash=sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4 \ - --hash=sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e \ - --hash=sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b \ - --hash=sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619 \ - --hash=sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205 \ - --hash=sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480 \ - --hash=sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965 \ - --hash=sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c \ - --hash=sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd \ - --hash=sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5 \ - --hash=sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f \ - --hash=sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc \ - --hash=sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec \ - --hash=sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd \ - --hash=sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b \ - --hash=sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9 \ - --hash=sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe \ - --hash=sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce \ - --hash=sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609 \ - --hash=sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8 \ - --hash=sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0 \ - --hash=sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f \ - --hash=sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8 \ - --hash=sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b \ - --hash=sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364 \ - --hash=sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040 \ - --hash=sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f \ - --hash=sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083 \ - --hash=sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df \ - --hash=sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba \ - --hash=sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445 \ - --hash=sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da \ - --hash=sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3 \ - --hash=sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72 \ - --hash=sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02 \ - --hash=sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985 - # via matplotlib -cycler==0.12.1 \ - --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ - --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c - # via matplotlib -etils[epath,epy]==1.5.2 \ - --hash=sha256:6dc882d355e1e98a5d1a148d6323679dc47c9a5792939b9de72615aa4737eb0b \ - --hash=sha256:ba6a3e1aff95c769130776aa176c11540637f5dd881f3b79172a5149b6b1c446 - # via -r build/requirements.in -exceptiongroup==1.2.1 \ - --hash=sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad \ - --hash=sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16 - # via - # hypothesis - # pytest -execnet==2.1.1 \ - --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ - --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 - # via pytest-xdist -flatbuffers==24.3.25 \ - --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ - --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 - # via -r build/test-requirements.txt -fonttools==4.51.0 \ - --hash=sha256:0118ef998a0699a96c7b28457f15546815015a2710a1b23a7bf6c1be60c01636 \ - --hash=sha256:0d145976194a5242fdd22df18a1b451481a88071feadf251221af110ca8f00ce \ - --hash=sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f \ - --hash=sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1 \ - --hash=sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc \ - --hash=sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f \ - --hash=sha256:1621ee57da887c17312acc4b0e7ac30d3a4fb0fec6174b2e3754a74c26bbed1e \ - --hash=sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716 \ - --hash=sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15 \ - --hash=sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77 \ - --hash=sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034 \ - --hash=sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba \ - --hash=sha256:54dcf21a2f2d06ded676e3c3f9f74b2bafded3a8ff12f0983160b13e9f2fb4a7 \ - --hash=sha256:56fc244f2585d6c00b9bcc59e6593e646cf095a96fe68d62cd4da53dd1287b55 \ - --hash=sha256:599bdb75e220241cedc6faebfafedd7670335d2e29620d207dd0378a4e9ccc5a \ - --hash=sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0 \ - --hash=sha256:60a3409c9112aec02d5fb546f557bca6efa773dcb32ac147c6baf5f742e6258b \ - --hash=sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671 \ - --hash=sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a \ - --hash=sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039 \ - --hash=sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74 \ - --hash=sha256:865a58b6e60b0938874af0968cd0553bcd88e0b2cb6e588727117bd099eef836 \ - --hash=sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2 \ - --hash=sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308 \ - --hash=sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2 \ - --hash=sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5 \ - --hash=sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1 \ - --hash=sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438 \ - --hash=sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74 \ - --hash=sha256:aefa011207ed36cd280babfaa8510b8176f1a77261833e895a9d96e57e44802f \ - --hash=sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097 \ - --hash=sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e \ - --hash=sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037 \ - --hash=sha256:c18b49adc721a7d0b8dfe7c3130c89b8704baf599fb396396d07d4aa69b824a1 \ - --hash=sha256:c5b8cab0c137ca229433570151b5c1fc6af212680b58b15abd797dcdd9dd5051 \ - --hash=sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b \ - --hash=sha256:cadf4e12a608ef1d13e039864f484c8a968840afa0258b0b843a0556497ea9ed \ - --hash=sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68 \ - --hash=sha256:de7c29bdbdd35811f14493ffd2534b88f0ce1b9065316433b22d63ca1cd21f14 \ - --hash=sha256:e9d9298be7a05bb4801f558522adbe2feea1b0b103d5294ebf24a92dd49b78e5 \ - --hash=sha256:ee1af4be1c5afe4c96ca23badd368d8dc75f611887fb0c0dac9f71ee5d6f110e \ - --hash=sha256:f7e89853d8bea103c8e3514b9f9dc86b5b4120afb4583b57eb10dfa5afbe0936 - # via matplotlib -fsspec==2024.5.0 \ - --hash=sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a \ - --hash=sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c - # via etils -hypothesis==6.102.4 \ - --hash=sha256:013df31b04a4daede13756f497e60e451963d86f426395a79f99c5d692919bbd \ - --hash=sha256:59b4d144346d5cffb482cc1bafbd21b13ff31608e8c4b3e4630339aee3e87763 - # via -r build/test-requirements.txt -importlib-metadata==7.1.0 ; python_version < "3.10" \ - --hash=sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570 \ - --hash=sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2 - # via - # -r build/requirements.in - # build -importlib-resources==6.4.0 \ - --hash=sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c \ - --hash=sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145 - # via - # etils - # matplotlib -iniconfig==2.0.0 \ - --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ - --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 - # via pytest -kiwisolver==1.4.5 \ - --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ - --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ - --hash=sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af \ - --hash=sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f \ - --hash=sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046 \ - --hash=sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3 \ - --hash=sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5 \ - --hash=sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71 \ - --hash=sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee \ - --hash=sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3 \ - --hash=sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9 \ - --hash=sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b \ - --hash=sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985 \ - --hash=sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea \ - --hash=sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16 \ - --hash=sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89 \ - --hash=sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c \ - --hash=sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9 \ - --hash=sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712 \ - --hash=sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342 \ - --hash=sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a \ - --hash=sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958 \ - --hash=sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d \ - --hash=sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a \ - --hash=sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130 \ - --hash=sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff \ - --hash=sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898 \ - --hash=sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b \ - --hash=sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f \ - --hash=sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265 \ - --hash=sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93 \ - --hash=sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929 \ - --hash=sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635 \ - --hash=sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709 \ - --hash=sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b \ - --hash=sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb \ - --hash=sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a \ - --hash=sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920 \ - --hash=sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e \ - --hash=sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544 \ - --hash=sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45 \ - --hash=sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390 \ - --hash=sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77 \ - --hash=sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355 \ - --hash=sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff \ - --hash=sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4 \ - --hash=sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7 \ - --hash=sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20 \ - --hash=sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c \ - --hash=sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162 \ - --hash=sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228 \ - --hash=sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437 \ - --hash=sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc \ - --hash=sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a \ - --hash=sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901 \ - --hash=sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4 \ - --hash=sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770 \ - --hash=sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525 \ - --hash=sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad \ - --hash=sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a \ - --hash=sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29 \ - --hash=sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90 \ - --hash=sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250 \ - --hash=sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d \ - --hash=sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3 \ - --hash=sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54 \ - --hash=sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f \ - --hash=sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1 \ - --hash=sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da \ - --hash=sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238 \ - --hash=sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa \ - --hash=sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523 \ - --hash=sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0 \ - --hash=sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205 \ - --hash=sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3 \ - --hash=sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4 \ - --hash=sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac \ - --hash=sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9 \ - --hash=sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb \ - --hash=sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced \ - --hash=sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd \ - --hash=sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0 \ - --hash=sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da \ - --hash=sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18 \ - --hash=sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9 \ - --hash=sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276 \ - --hash=sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333 \ - --hash=sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b \ - --hash=sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db \ - --hash=sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126 \ - --hash=sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9 \ - --hash=sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09 \ - --hash=sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0 \ - --hash=sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec \ - --hash=sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7 \ - --hash=sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff \ - --hash=sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9 \ - --hash=sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192 \ - --hash=sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8 \ - --hash=sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d \ - --hash=sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6 \ - --hash=sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797 \ - --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ - --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f - # via matplotlib -markdown-it-py==3.0.0 \ - --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ - --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb - # via rich -matplotlib==3.9.0 \ - --hash=sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38 \ - --hash=sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321 \ - --hash=sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db \ - --hash=sha256:0fc51eaa5262553868461c083d9adadb11a6017315f3a757fc45ec6ec5f02888 \ - --hash=sha256:13beb4840317d45ffd4183a778685e215939be7b08616f431c7795276e067463 \ - --hash=sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03 \ - --hash=sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56 \ - --hash=sha256:2e7f03e5cbbfacdd48c8ea394d365d91ee8f3cae7e6ec611409927b5ed997ee4 \ - --hash=sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b \ - --hash=sha256:52146fc3bd7813cc784562cb93a15788be0b2875c4655e2cc6ea646bfa30344b \ - --hash=sha256:550cdda3adbd596078cca7d13ed50b77879104e2e46392dcd7c75259d8f00e85 \ - --hash=sha256:616fabf4981a3b3c5a15cd95eba359c8489c4e20e03717aea42866d8d0465956 \ - --hash=sha256:76cce0f31b351e3551d1f3779420cf8f6ec0d4a8cf9c0237a3b549fd28eb4abb \ - --hash=sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd \ - --hash=sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7 \ - --hash=sha256:81c40af649d19c85f8073e25e5806926986806fa6d54be506fbf02aef47d5a89 \ - --hash=sha256:9a2fa6d899e17ddca6d6526cf6e7ba677738bf2a6a9590d702c277204a7c6152 \ - --hash=sha256:a5be985db2596d761cdf0c2eaf52396f26e6a64ab46bd8cd810c48972349d1be \ - --hash=sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e \ - --hash=sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0 \ - --hash=sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84 \ - --hash=sha256:c53aeb514ccbbcbab55a27f912d79ea30ab21ee0531ee2c09f13800efb272674 \ - --hash=sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382 \ - --hash=sha256:cd53c79fd02f1c1808d2cfc87dd3cf4dbc63c5244a58ee7944497107469c8d8a \ - --hash=sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5 \ - --hash=sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf \ - --hash=sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a \ - --hash=sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d \ - --hash=sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241 - # via -r build/requirements.in -mdurl==0.1.2 \ - --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ - --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba - # via markdown-it-py -ml-dtypes==0.4.0 \ - --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ - --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ - --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ - --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ - --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ - --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ - --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ - --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ - --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ - --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ - --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ - --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ - --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ - --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ - --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ - --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ - --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 - # via -r build/requirements.in -mpmath==1.4.0a1 \ - --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ - --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 - # via -r build/test-requirements.txt -numpy==2.0.0rc2 \ - --hash=sha256:01ac116e2f053f283ac5292fcd146f8f223d4b6cd343beab341748818692a2de \ - --hash=sha256:0a2cf839a7d6cc0b854ba81cdfee96aad2c7e4d558c7e23ca82d08e4f7d7daa7 \ - --hash=sha256:0a49e876be11b4409eb3120841c7d2dba1f63549224f85fa4ab7ee83288c3b41 \ - --hash=sha256:0d5cfbf693408cf1ee72d79d36d51f7b63f5e46a5e9cf12f63d4ed07c0f876e0 \ - --hash=sha256:0e146557fdede5a7434a788648e62a9e87db8c6e05136a92483e2c2180ad4bab \ - --hash=sha256:12d3bf0cac2aec23e10b6927ee063aa6cf7ca8deba1d3c5702faa0ea5cfb8049 \ - --hash=sha256:159d9c21a2989afdfebb638f60268becbc3da07eb224d9221a7c37255216feb6 \ - --hash=sha256:1691e64c838d33fdba59ac7043144194f8f847b5fec6f47ecd9e117418cc9bdc \ - --hash=sha256:201c0e05854d25f16b15851380c07d61aab34eef76a2acf1c3fcc4bda0879b0b \ - --hash=sha256:2202abe3e8afb2b88102a75f1beb888f380c09d40898db0f1df5d847623701d5 \ - --hash=sha256:225c2b3303eb2ebf745ab954ef8723cd60f64d926edd73dc963141538ddc48ed \ - --hash=sha256:24bcf0cdd31debdcb80e1f3bb7dba224c9a93a66f48ff1b1df2cb9a53eede944 \ - --hash=sha256:2a9a5ee4b090af548a1019bb76b53b02cb37f09dc002386349ee5e79ff54c40e \ - --hash=sha256:2bc615498fce8e15b99c1b4d7e018ffebf7bd1a288665b3b916357bdf6725d6a \ - --hash=sha256:32207294f21331ae0d7fd33dc9324447a8117d5af15a0895f39af3441d4af70e \ - --hash=sha256:32725b717f902e7243d270e50ff9487a499820233b57c3e71b33f65a84707e38 \ - --hash=sha256:4f3a4c676ab4ce211e5886cb16cc282e9e18b352b2b1427bbb4c104f9d80f12a \ - --hash=sha256:5262d69981502ded9b397c3fd5a20a1f2c91a66b21325ddff5e6d88486eee6fa \ - --hash=sha256:53286933bf3be7a13459c7a7885ce0935aff56fe0baf280f0e6d80e75cc3ee3c \ - --hash=sha256:6aba1c147f69ee1fb8afb44e93178e92d2aa9a3bf5374b6f1cb53ee1cae1376d \ - --hash=sha256:6b93d6b401db67948a4997e2c45e958df65b98b1a9183e96e96e491f9fb3c2fe \ - --hash=sha256:6d23b0db1fd4ad8225fd32f39036b07a5052398929a5af5291379bceac49d95a \ - --hash=sha256:6fe254c271f8ce4c2e60250f8ee80684abd2be748af84312a05b7614c3ae3b8d \ - --hash=sha256:7288d8ac70be23ff29df8da51840aad8f7acd9120d27cd7a61488b96bc5ad68b \ - --hash=sha256:74dcc392725837896532ec7d65506cbeaecee237871b36ae813521bc3e2c40ed \ - --hash=sha256:800ff28d0da25fca3f843c19035005b73c76350be7c6fa6061c8fcdd248aced9 \ - --hash=sha256:83c76a11c5e5a343fb1cb87afec147d6bebac91758c9c9f01d2c692ae4750e27 \ - --hash=sha256:868e9edbee689d6fdb7957c0b790de2b2123e6feff5d66045d10760c521f2c00 \ - --hash=sha256:87172a69d7eafb00ea1b734dba9ffebb474505082078ec2d95b99918f14a0a0e \ - --hash=sha256:951164e9919664a3e5e605715809173b47f14329b586e24ec05e66ae516ce11b \ - --hash=sha256:9b07a5c460941ae5ef8cde51c04b635af58abbbd55387ad6257dbdfda043290a \ - --hash=sha256:9dd61b79856aed44f818fffe1555fa7ef8f6ffa5b5211cde473e2e33f7a5bd92 \ - --hash=sha256:9e00367261ee0347208a8bcc355b6470b084cb777c45141e098328b67b02c98b \ - --hash=sha256:9ea90fb601d5ac32ff7f9f0a3bf7ccab5971a0196364b9429734bd270cd2fa67 \ - --hash=sha256:a0202e282ec9d45fc6ddb85777fddeea1107fe4555be50dd22d044e7fe01860c \ - --hash=sha256:a44b0ebf7ef61c289a33c76247874177c446083c5236c7e7e0595350883e0424 \ - --hash=sha256:a666cc3d55f301b86edc7f1eaef10ffa1f79206c4b196a1f2649f91c8a1b49b6 \ - --hash=sha256:a99ac361ddb0ef14894c3e7405aa98ffdfe6d0101b9f4a2e931f3912f3b43085 \ - --hash=sha256:b3ba5f436c6de9b8829f231e9eb9e394aa819efce9eab697cd4e558b0b8c6cc8 \ - --hash=sha256:c58bc6aac83175dcfa02a0ef92b7a7fff5a0420014202f052a9af6214684e6ac \ - --hash=sha256:d5211fd4e126699b16b8573eef007f25afb9459d966b35430908798b24298e3b \ - --hash=sha256:da6ab9dab471668155e0b208ab710417a7407397794a88b3ccbece5bcf10091d \ - --hash=sha256:e13a1fa60a471b79a53de8abb87e1e0ad53e6899edee8a29b4db3edccee53d65 \ - --hash=sha256:f8c7012dd6779f078e3f42e19a2204275abe4d68a80dc807a97caf42e825d9c3 \ - --hash=sha256:fa5485c565ca222ba69c5fe04ebd8a89f884615466d74e0856e03fff873bcc43 - # via - # -r build/requirements.in - # -r build/test-requirements.txt - # contourpy - # matplotlib - # ml-dtypes - # opt-einsum - # scipy -opt-einsum==3.3.0 \ - --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ - --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via -r build/requirements.in -packaging==24.0 \ - --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ - --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 - # via - # build - # matplotlib - # pytest -pillow==10.3.0 \ - --hash=sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c \ - --hash=sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2 \ - --hash=sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb \ - --hash=sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d \ - --hash=sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa \ - --hash=sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3 \ - --hash=sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1 \ - --hash=sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a \ - --hash=sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd \ - --hash=sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8 \ - --hash=sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999 \ - --hash=sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599 \ - --hash=sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936 \ - --hash=sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375 \ - --hash=sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d \ - --hash=sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b \ - --hash=sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60 \ - --hash=sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572 \ - --hash=sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3 \ - --hash=sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced \ - --hash=sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f \ - --hash=sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b \ - --hash=sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19 \ - --hash=sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f \ - --hash=sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d \ - --hash=sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383 \ - --hash=sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795 \ - --hash=sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355 \ - --hash=sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57 \ - --hash=sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09 \ - --hash=sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b \ - --hash=sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462 \ - --hash=sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf \ - --hash=sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f \ - --hash=sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a \ - --hash=sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad \ - --hash=sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9 \ - --hash=sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d \ - --hash=sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45 \ - --hash=sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994 \ - --hash=sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d \ - --hash=sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338 \ - --hash=sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463 \ - --hash=sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451 \ - --hash=sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591 \ - --hash=sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c \ - --hash=sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd \ - --hash=sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32 \ - --hash=sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9 \ - --hash=sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf \ - --hash=sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5 \ - --hash=sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828 \ - --hash=sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3 \ - --hash=sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5 \ - --hash=sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2 \ - --hash=sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b \ - --hash=sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2 \ - --hash=sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475 \ - --hash=sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3 \ - --hash=sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb \ - --hash=sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef \ - --hash=sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015 \ - --hash=sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002 \ - --hash=sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170 \ - --hash=sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84 \ - --hash=sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57 \ - --hash=sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f \ - --hash=sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27 \ - --hash=sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a - # via - # -r build/test-requirements.txt - # matplotlib -pluggy==1.5.0 \ - --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ - --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 - # via pytest -portpicker==1.6.0 \ - --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ - --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt -psutil==5.9.8 \ - --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ - --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ - --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ - --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ - --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ - --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ - --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ - --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ - --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ - --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ - --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ - --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ - --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ - --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ - --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ - --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 - # via portpicker -pygments==2.18.0 \ - --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ - --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a - # via rich -pyparsing==3.1.2 \ - --hash=sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad \ - --hash=sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742 - # via matplotlib -pyproject-hooks==1.1.0 \ - --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ - --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 - # via build -pytest==8.2.0 \ - --hash=sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233 \ - --hash=sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f - # via pytest-xdist -pytest-xdist==3.6.1 \ - --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ - --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d - # via -r build/test-requirements.txt -python-dateutil==2.9.0.post0 \ - --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ - --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 - # via matplotlib -rich==13.7.1 \ - --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ - --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 - # via -r build/test-requirements.txt -scipy==1.13.0 \ - --hash=sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922 \ - --hash=sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5 \ - --hash=sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa \ - --hash=sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820 \ - --hash=sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd \ - --hash=sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42 \ - --hash=sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e \ - --hash=sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d \ - --hash=sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86 \ - --hash=sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e \ - --hash=sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c \ - --hash=sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602 \ - --hash=sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e \ - --hash=sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5 \ - --hash=sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a \ - --hash=sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21 \ - --hash=sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d \ - --hash=sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6 \ - --hash=sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78 \ - --hash=sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551 \ - --hash=sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7 \ - --hash=sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4 \ - --hash=sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d \ - --hash=sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b \ - --hash=sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9 - # via -r build/requirements.in -six==1.16.0 \ - --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ - --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 - # via python-dateutil -sortedcontainers==2.4.0 \ - --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ - --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 - # via hypothesis -tomli==2.0.1 \ - --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \ - --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f - # via - # build - # pytest -typing-extensions==4.12.0rc1 \ - --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ - --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe - # via etils -wheel==0.43.0 \ - --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ - --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 - # via -r build/test-requirements.txt -zipp==3.18.2 \ - --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ - --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e - # via - # etils - # importlib-metadata - # importlib-resources -zstandard==0.22.0 \ - --hash=sha256:11f0d1aab9516a497137b41e3d3ed4bbf7b2ee2abc79e5c8b010ad286d7464bd \ - --hash=sha256:1958100b8a1cc3f27fa21071a55cb2ed32e9e5df4c3c6e661c193437f171cba2 \ - --hash=sha256:1a90ba9a4c9c884bb876a14be2b1d216609385efb180393df40e5172e7ecf356 \ - --hash=sha256:1d43501f5f31e22baf822720d82b5547f8a08f5386a883b32584a185675c8fbf \ - --hash=sha256:23d2b3c2b8e7e5a6cb7922f7c27d73a9a615f0a5ab5d0e03dd533c477de23004 \ - --hash=sha256:2612e9bb4977381184bb2463150336d0f7e014d6bb5d4a370f9a372d21916f69 \ - --hash=sha256:275df437ab03f8c033b8a2c181e51716c32d831082d93ce48002a5227ec93019 \ - --hash=sha256:2ac9957bc6d2403c4772c890916bf181b2653640da98f32e04b96e4d6fb3252a \ - --hash=sha256:2b11ea433db22e720758cba584c9d661077121fcf60ab43351950ded20283440 \ - --hash=sha256:2fdd53b806786bd6112d97c1f1e7841e5e4daa06810ab4b284026a1a0e484c0b \ - --hash=sha256:33591d59f4956c9812f8063eff2e2c0065bc02050837f152574069f5f9f17775 \ - --hash=sha256:36a47636c3de227cd765e25a21dc5dace00539b82ddd99ee36abae38178eff9e \ - --hash=sha256:39b2853efc9403927f9065cc48c9980649462acbdf81cd4f0cb773af2fd734bc \ - --hash=sha256:3db41c5e49ef73641d5111554e1d1d3af106410a6c1fb52cf68912ba7a343a0d \ - --hash=sha256:445b47bc32de69d990ad0f34da0e20f535914623d1e506e74d6bc5c9dc40bb09 \ - --hash=sha256:466e6ad8caefb589ed281c076deb6f0cd330e8bc13c5035854ffb9c2014b118c \ - --hash=sha256:48f260e4c7294ef275744210a4010f116048e0c95857befb7462e033f09442fe \ - --hash=sha256:4ac59d5d6910b220141c1737b79d4a5aa9e57466e7469a012ed42ce2d3995e88 \ - --hash=sha256:53866a9d8ab363271c9e80c7c2e9441814961d47f88c9bc3b248142c32141d94 \ - --hash=sha256:589402548251056878d2e7c8859286eb91bd841af117dbe4ab000e6450987e08 \ - --hash=sha256:68953dc84b244b053c0d5f137a21ae8287ecf51b20872eccf8eaac0302d3e3b0 \ - --hash=sha256:6c25b8eb733d4e741246151d895dd0308137532737f337411160ff69ca24f93a \ - --hash=sha256:7034d381789f45576ec3f1fa0e15d741828146439228dc3f7c59856c5bcd3292 \ - --hash=sha256:73a1d6bd01961e9fd447162e137ed949c01bdb830dfca487c4a14e9742dccc93 \ - --hash=sha256:8226a33c542bcb54cd6bd0a366067b610b41713b64c9abec1bc4533d69f51e70 \ - --hash=sha256:888196c9c8893a1e8ff5e89b8f894e7f4f0e64a5af4d8f3c410f0319128bb2f8 \ - --hash=sha256:88c5b4b47a8a138338a07fc94e2ba3b1535f69247670abfe422de4e0b344aae2 \ - --hash=sha256:8a1b2effa96a5f019e72874969394edd393e2fbd6414a8208fea363a22803b45 \ - --hash=sha256:93e1856c8313bc688d5df069e106a4bc962eef3d13372020cc6e3ebf5e045202 \ - --hash=sha256:9501f36fac6b875c124243a379267d879262480bf85b1dbda61f5ad4d01b75a3 \ - --hash=sha256:959665072bd60f45c5b6b5d711f15bdefc9849dd5da9fb6c873e35f5d34d8cfb \ - --hash=sha256:a1d67d0d53d2a138f9e29d8acdabe11310c185e36f0a848efa104d4e40b808e4 \ - --hash=sha256:a493d470183ee620a3df1e6e55b3e4de8143c0ba1b16f3ded83208ea8ddfd91d \ - --hash=sha256:a7ccf5825fd71d4542c8ab28d4d482aace885f5ebe4b40faaa290eed8e095a4c \ - --hash=sha256:a88b7df61a292603e7cd662d92565d915796b094ffb3d206579aaebac6b85d5f \ - --hash=sha256:a97079b955b00b732c6f280d5023e0eefe359045e8b83b08cf0333af9ec78f26 \ - --hash=sha256:d22fdef58976457c65e2796e6730a3ea4a254f3ba83777ecfc8592ff8d77d303 \ - --hash=sha256:d75f693bb4e92c335e0645e8845e553cd09dc91616412d1d4650da835b5449df \ - --hash=sha256:d8593f8464fb64d58e8cb0b905b272d40184eac9a18d83cf8c10749c3eafcd7e \ - --hash=sha256:d8fff0f0c1d8bc5d866762ae95bd99d53282337af1be9dc0d88506b340e74b73 \ - --hash=sha256:de20a212ef3d00d609d0b22eb7cc798d5a69035e81839f549b538eff4105d01c \ - --hash=sha256:e9e9d4e2e336c529d4c435baad846a181e39a982f823f7e4495ec0b0ec8538d2 \ - --hash=sha256:f058a77ef0ece4e210bb0450e68408d4223f728b109764676e1a13537d056bb0 \ - --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ - --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ - --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e - # via -r build/requirements.in - -# The following packages are considered to be unsafe in a requirements file: -setuptools==69.5.1 \ - --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ - --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 - # via -r build/test-requirements.txt diff --git a/build/rocm/build_rocm.sh b/build/rocm/build_rocm.sh index 64e166239fc8..6374a2a18929 100755 --- a/build/rocm/build_rocm.sh +++ b/build/rocm/build_rocm.sh @@ -57,7 +57,7 @@ rocm_version=$(cat /opt/rocm/.info/version | cut -d "-" -f 1) export JAX_ROCM_VERSION=${rocm_version//./} #Build and install wheel -python3 ./build/build.py --enable_rocm --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR} +python3 ./build/build.py --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR} JAX_RELEASE=1 python -m build pip3 install --force-reinstall dist/*.whl # installs jaxlib (includes XLA) diff --git a/build/rocm/run_single_gpu.py b/build/rocm/run_single_gpu.py index bf37a49ee61e..add7ee3d86b5 100755 --- a/build/rocm/run_single_gpu.py +++ b/build/rocm/run_single_gpu.py @@ -22,7 +22,7 @@ GPU_LOCK = threading.Lock() LAST_CODE = 0 -base_dir="./logs" +base_dir = "./logs" def extract_filename(path): base_name = os.path.basename(path) @@ -32,7 +32,7 @@ def extract_filename(path): def generate_final_report(shell=False, env_vars={}): env = os.environ env = {**env, **env_vars} - cmd = ["pytest_html_merger", "-i", '{}'.format(base_dir), "-o", '{}/final_compiled_report.html'.format(base_dir)] + cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html'] result = subprocess.run(cmd, shell=shell, capture_output=True, @@ -90,7 +90,7 @@ def run_test(testmodule, gpu_tokens): "XLA_PYTHON_CLIENT_ALLOCATOR": "default", } testfile = extract_filename(testmodule) - cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-x", testmodule] + cmd = ["python3", "-m", "pytest", f'--html={base_dir}/{testfile}_log.html', "--reruns", "3", "-x", testmodule] return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars) with GPU_LOCK: gpu_tokens.append(target_gpu) diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 800bc735daf7..4f9d19e76ba2 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -2,6 +2,7 @@ absl-py build cloudpickle colorama>=0.4.4 +filelock flatbuffers hypothesis mpmath>=1.3 diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index bb5a16c8b4b9..2d9c43831e4d 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -1,5 +1,7 @@ # Custom operations for GPUs with C++ and CUDA + + JAX ships with a large number of built-in operations, but users occasionally run into a situation where they need a new operation that is not supported by JAX. To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments. @@ -752,7 +754,7 @@ class RmsNormBwdClass: return mesh, impl, output_shardings, arg_shardings register_primitive(RmsNormBwdClass) ``` -Plumbing to establish the forward and backward primtives with a custom_vjp rule as before: +Plumbing to establish the forward and backward primitives with a custom_vjp rule as before: ```python @partial(jax.custom_vjp, nondiff_argnums=(2,)) diff --git a/docs/Custom_Operation_for_GPUs.py b/docs/Custom_Operation_for_GPUs.py index 09aec08be4d4..4c0b4b6f7b38 100644 --- a/docs/Custom_Operation_for_GPUs.py +++ b/docs/Custom_Operation_for_GPUs.py @@ -14,7 +14,6 @@ from functools import partial, reduce import math -from typing import Tuple import jax import jax.numpy as jnp @@ -325,9 +324,9 @@ def batcher(batched_args, batch_dims, *, eps): return RmsNormFwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims @staticmethod - def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, - arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], - result_infos : Tuple[jax._src.core.ShapedArray]): + def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh, + arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...], + result_infos: tuple[jax._src.core.ShapedArray, ...]): del eps, result_infos # Not needed for this example. x_info, weight_info = arg_infos assert len(x_info.shape) == 3 @@ -340,9 +339,9 @@ def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, return (output_sharding, invvar_sharding) @staticmethod - def partition(eps : float, mesh : jax.sharding.Mesh, - arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], - result_infos : Tuple[jax._src.api.ShapeDtypeStruct]): + def partition(eps: float, mesh : jax.sharding.Mesh, + arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...], + result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]): del result_infos # Not needed for this example. x_info, weight_info = arg_infos assert len(x_info.shape) == 3 @@ -395,9 +394,9 @@ def batcher(batched_args, batch_dims, *, eps): return RmsNormBwdClass.outer_primitive.bind(x, gamma, eps=eps), out_bdims @staticmethod - def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, - arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], - result_infos : Tuple[jax._src.core.ShapedArray]): + def infer_sharding_from_operands(eps: float, mesh : jax.sharding.Mesh, + arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...], + result_infos: tuple[jax._src.core.ShapedArray, ...]): del eps, result_infos # Not needed for this example. g_info, invvar_info, x_info, weight_info = arg_infos assert len(g_info.shape) == 3 @@ -411,9 +410,9 @@ def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, return (output_sharding, invvar_sharding, output_sharding, ) @staticmethod - def partition(eps : float, mesh : jax.sharding.Mesh, - arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], - result_infos : Tuple[jax._src.api.ShapeDtypeStruct]): + def partition(eps: float, mesh : jax.sharding.Mesh, + arg_infos: tuple[jax._src.api.ShapeDtypeStruct, ...], + result_infos: tuple[jax._src.api.ShapeDtypeStruct, ...]): del result_infos # Not needed for this example. g_info, invvar_info, x_info, weight_info = arg_infos assert len(g_info.shape) == 3 diff --git a/docs/_static/style.css b/docs/_static/style.css index f2b855838064..7a5c647052f0 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -27,3 +27,17 @@ div.red-background pre { div.green-background pre { background-color: rgba(204, 244, 204, var(--block-bg-opacity)); } + +/* Python code block comments */ +html[data-theme="light"] .highlight span.c1 { + color: #fa8d59; +} + +/* Python code traceback and exception */ +html[data-theme="light"] .highlight span.gt { + color: #ff0000; +} + +html[data-theme="light"] .highlight span.gr { + color: #ff0000; +} diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md index 56ac53baf814..da95f96d8b25 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -15,6 +15,8 @@ kernelspec: (advanced-autodiff)= # Advanced automatic differentiation + + In this tutorial, you will learn about complex applications of automatic differentiation (autodiff) in JAX and gain a better understanding of how taking derivatives in JAX can be both easy and powerful. Make sure to check out the {ref}`automatic-differentiation` tutorial to go over the JAX autodiff basics, if you haven't already. diff --git a/docs/_tutorials/advanced-compilation.md b/docs/_tutorials/advanced-compilation.md index a3aeeaf3c31e..09535f2fce96 100644 --- a/docs/_tutorials/advanced-compilation.md +++ b/docs/_tutorials/advanced-compilation.md @@ -1,5 +1,7 @@ # Advanced compilation + + ```{note} This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. diff --git a/docs/_tutorials/advanced-debugging.md b/docs/_tutorials/advanced-debugging.md index 34d15e30bc62..56188e0958fa 100644 --- a/docs/_tutorials/advanced-debugging.md +++ b/docs/_tutorials/advanced-debugging.md @@ -14,6 +14,9 @@ kernelspec: (advanced-debugging)= # Advanced debugging + + + ```{note} This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. diff --git a/docs/_tutorials/external-callbacks.md b/docs/_tutorials/external-callbacks.md index 0420afaaab4a..a46927e6a8b4 100644 --- a/docs/_tutorials/external-callbacks.md +++ b/docs/_tutorials/external-callbacks.md @@ -22,6 +22,8 @@ kernelspec: (external-callbacks)= # External callbacks + + This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` and {func}`jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`. ## Why callbacks? diff --git a/docs/_tutorials/gradient-checkpointing.md b/docs/_tutorials/gradient-checkpointing.md index a15ee1941837..b768514e4bb0 100644 --- a/docs/_tutorials/gradient-checkpointing.md +++ b/docs/_tutorials/gradient-checkpointing.md @@ -15,6 +15,8 @@ kernelspec: (gradient-checkpointing)= ## Gradient checkpointing with `jax.checkpoint` (`jax.remat`) + + In this tutorial, you will learn how to control JAX automatic differentiation's saved values using {func}`jax.checkpoint` (also known as {func}`jax.remat`), which can be particularly helpful in machine learning. If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has {ref}`automatic-differentiation` and {ref}`advanced-autodiff` tutorials. diff --git a/docs/_tutorials/jax-primitives.md b/docs/_tutorials/jax-primitives.md index e5fab275a106..51abe2916693 100644 --- a/docs/_tutorials/jax-primitives.md +++ b/docs/_tutorials/jax-primitives.md @@ -15,6 +15,8 @@ kernelspec: (jax-internals-jax-primitives)= # JAX Internals: primitives + + ## Introduction to JAX primitives A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide). diff --git a/docs/_tutorials/jaxpr.md b/docs/_tutorials/jaxpr.md index 03c0bef0849b..9fe990c0a8ba 100644 --- a/docs/_tutorials/jaxpr.md +++ b/docs/_tutorials/jaxpr.md @@ -15,6 +15,8 @@ kernelspec: (jax-internals-jaxpr)= # JAX internals: The jaxpr language + + Jaxprs are JAX’s internal intermediate representation (IR) of programs. They are explicitly typed, functional, first-order, and in algebraic normal form (ANF). Conceptually, one can think of JAX transformations, such as {func}`jax.jit` or {func}`jax.grad`, as first trace-specializing the Python function to be transformed into a small and well-behaved intermediate form that is then interpreted with transformation-specific interpretation rules. diff --git a/docs/_tutorials/parallelism.md b/docs/_tutorials/parallelism.md index 8bf6957432f0..9b840357e8aa 100644 --- a/docs/_tutorials/parallelism.md +++ b/docs/_tutorials/parallelism.md @@ -1,5 +1,7 @@ # Parallel computation + + ```{note} This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. diff --git a/docs/_tutorials/profiling-and-performance.md b/docs/_tutorials/profiling-and-performance.md index e540c920e962..d9a13b213f70 100644 --- a/docs/_tutorials/profiling-and-performance.md +++ b/docs/_tutorials/profiling-and-performance.md @@ -1,5 +1,7 @@ # Profiling and performance + + ```{note} This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. diff --git a/docs/_tutorials/simple-neural-network.md b/docs/_tutorials/simple-neural-network.md index b5c91ffd0e22..76e98db88d82 100644 --- a/docs/_tutorials/simple-neural-network.md +++ b/docs/_tutorials/simple-neural-network.md @@ -1,5 +1,7 @@ # Example: Writing a simple neural network + + ```{note} This is a placeholder for a section in the new {ref}`jax-tutorials-draft`. ``` diff --git a/docs/aot.md b/docs/aot.md index 3304f4081b6a..ed7f4574900b 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -2,6 +2,8 @@ # Ahead-of-time lowering and compilation + + JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning a function that is compiled and runs on accelerators or the CPU. As the JIT acronym indicates, all compilation happens _just-in-time_ for execution. @@ -35,8 +37,6 @@ way. An example: ```python >>> import jax ->>> import jax.numpy as jnp ->>> import numpy as np >>> def f(x, y): return 2 * x + y >>> x, y = 3, 4 @@ -45,12 +45,12 @@ way. An example: >>> # Print lowered HLO >>> print(lowered.as_text()) -module @jit_f.0 { - func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = stablehlo.constant dense<2> : tensor - %1 = stablehlo.multiply %0, %arg0 : tensor - %2 = stablehlo.add %1, %arg1 : tensor - return %2 : tensor +module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"}, %arg1: tensor {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<2> : tensor + %0 = stablehlo.multiply %c, %arg0 : tensor + %1 = stablehlo.add %0, %arg1 : tensor + return %1 : tensor } } @@ -62,9 +62,14 @@ module @jit_f.0 { >>> # Execute the compiled function! >>> compiled(x, y) -DeviceArray(10, dtype=int32) +Array(10, dtype=int32, weak_type=True) + ``` +Note that the lowered objects can be used only in the same process +in which they were lowered. For exporting use cases, +see the {ref}`export` APIs. + See the {mod}`jax.stages` documentation for more details on what functionality the lowering and compiled functions provide. @@ -83,7 +88,8 @@ that have `shape` and `dtype` attributes: ```python >>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32')) >>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y) -DeviceArray(10, dtype=int32) +Array(10, dtype=int32) + ``` More generally, `lower` only needs its arguments to structurally supply what JAX @@ -97,18 +103,21 @@ lowering raises an error: ```python >>> x_1d = y_1d = jnp.arange(3) ->>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) +>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL ... +Traceback (most recent call last): TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are: Argument 'x' compiled with int32[] and called with int32[3] Argument 'y' compiled with int32[] and called with int32[3] >>> x_f = y_f = jnp.float32(72.) ->>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) +>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL ... +Traceback (most recent call last): TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are: Argument 'x' compiled with int32[] and called with float32[] Argument 'y' compiled with int32[] and called with float32[] + ``` Relatedly, AOT-compiled functions [cannot be transformed by JAX's just-in-time @@ -127,17 +136,23 @@ to invoke the resulting compiled function. Continuing with our example above: >>> # Lowered HLO, specialized to the *value* of the first argument (7) >>> print(lowered_with_x.as_text()) -module @jit_f.1 { - func.func public @main(%arg0: tensor) -> tensor { - %0 = stablehlo.constant dense<14> : tensor - %1 = stablehlo.add %0, %arg0 : tensor - return %1 : tensor +module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<14> : tensor + %0 = stablehlo.add %c, %arg0 : tensor + return %0 : tensor } } + >>> lowered_with_x.compile()(5) -DeviceArray(19, dtype=int32) +Array(19, dtype=int32, weak_type=True) + ``` +The result of `lower` is not safe to serialize directly for use +in a different process. +See {ref}`export` for additional APIs for this purpose. + Note that `lower` here takes two arguments as usual, but the subsequent compiled function accepts only the remaining non-static second argument. The static first argument (value 7) is taken as a constant at lowering time and built into the @@ -149,11 +164,13 @@ shape/dtype structure, it is necessary that the static first argument be a concrete value. Otherwise, lowering would err: ```python ->>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) +>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) # doctest: +SKIP +Traceback (most recent call last): TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct' >>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5) -DeviceArray(25, dtype=int32) +Array(25, dtype=int32) + ``` ## AOT-compiled functions cannot be transformed @@ -179,13 +196,15 @@ in transformations. Example: >>> g_aot = jax.jit(g).lower(z).compile() >>> jax.vmap(g_jit)(zs) -DeviceArray([[ 1., 5., 9.], - [13., 17., 21.], - [25., 29., 33.], - [37., 41., 45.]], dtype=float32) +Array([[ 1., 5., 9.], + [13., 17., 21.], + [25., 29., 33.], + [37., 41., 45.]], dtype=float32) + +>>> jax.vmap(g_aot)(zs) # doctest: +SKIP +Traceback (most recent call last): +TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type ->>> jax.vmap(g_aot)(zs) -TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type . ``` A similar error is raised when `g_aot` is involved in autodiff diff --git a/docs/api_compatibility.md b/docs/api_compatibility.md index 39c4386f447e..b3019bfc11f1 100644 --- a/docs/api_compatibility.md +++ b/docs/api_compatibility.md @@ -2,6 +2,8 @@ # API compatibility + + JAX is constantly evolving, and we want to be able to make improvements to its APIs. That said, we want to minimize churn for the JAX user community, and we try to make breaking changes rarely. diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 24980cf30d04..ed242ecc5710 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -36,6 +36,8 @@ "source": [ "# Autodidax: JAX core from scratch\n", "\n", + "\n", + "\n", "Ever want to learn how JAX works, but the implementation seemed impenetrable?\n", "Well, you're in luck! By reading this tutorial, you'll learn every big idea in\n", "JAX's core system. You'll even get clued into our weird jargon!\n", @@ -165,15 +167,15 @@ "source": [ "from collections.abc import Sequence\n", "from contextlib import contextmanager\n", - "from typing import Optional, Any\n", + "from typing import Any\n", "\n", "class MainTrace(NamedTuple):\n", " level: int\n", " trace_type: type['Trace']\n", - " global_data: Optional[Any]\n", + " global_data: Any | None\n", "\n", "trace_stack: list[MainTrace] = []\n", - "dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3\n", + "dynamic_trace: MainTrace | None = None # to be employed in Part 3\n", "\n", "@contextmanager\n", "def new_main(trace_type: type['Trace'], global_data=None):\n", @@ -910,7 +912,7 @@ "source": [ "from collections.abc import Hashable, Iterable, Iterator\n", "import itertools as it\n", - "from typing import Callable\n", + "from collections.abc import Callable\n", "\n", "class NodeType(NamedTuple):\n", " name: str\n", @@ -1649,7 +1651,7 @@ "source": [ "from functools import lru_cache\n", "\n", - "@lru_cache() # ShapedArrays are hashable\n", + "@lru_cache # ShapedArrays are hashable\n", "def make_jaxpr_v1(f, *avals_in):\n", " avals_in, in_tree = tree_flatten(avals_in)\n", " f, out_tree = flatten_fun(f, in_tree)\n", @@ -1801,7 +1803,7 @@ " finally:\n", " dynamic_trace = prev_dynamic_trace\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def make_jaxpr(f: Callable, *avals_in: ShapedArray,\n", " ) -> tuple[Jaxpr, list[Any], PyTreeDef]:\n", " avals_in, in_tree = tree_flatten(avals_in)\n", @@ -1992,7 +1994,7 @@ " return execute(*args)\n", "impl_rules[xla_call_p] = xla_call_impl\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def xla_callable(hashable_jaxpr: IDHashable,\n", " hashable_consts: tuple[IDHashable, ...]):\n", " jaxpr: Jaxpr = hashable_jaxpr.val\n", @@ -2225,7 +2227,7 @@ " return primals_out, tangents_out\n", "jvp_rules[xla_call_p] = xla_call_jvp_rule\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:\n", " def jvp_traceable(*primals_and_tangents):\n", " n = len(primals_and_tangents) // 2\n", @@ -2251,7 +2253,7 @@ " return outs, [0] * len(outs)\n", "vmap_rules[xla_call_p] = xla_call_vmap_rule\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]\n", " ) -> tuple[Jaxpr, list[Any]]:\n", " vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n", @@ -2636,7 +2638,7 @@ "source": [ "class PartialVal(NamedTuple):\n", " aval: ShapedArray\n", - " const: Optional[Any]\n", + " const: Any | None\n", "\n", " @classmethod\n", " def known(cls, val: Any):\n", @@ -2725,7 +2727,7 @@ "source": [ "class PartialEvalTracer(Tracer):\n", " pval: PartialVal\n", - " recipe: Optional[JaxprRecipe]\n", + " recipe: JaxprRecipe | None\n", "\n", " def __init__(self, trace, pval, recipe):\n", " self._trace = trace\n", @@ -2972,7 +2974,7 @@ "partial_eval_rules[xla_call_p] = xla_call_partial_eval\n", "\n", "def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],\n", - " instantiate: Optional[list[bool]] = None,\n", + " instantiate: list[bool] | None = None,\n", " ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:\n", " env: dict[Var, bool] = {}\n", " residuals: set[Var] = set()\n", @@ -3269,7 +3271,7 @@ " return [next(outs) if undef else None for undef in undef_primals]\n", "transpose_rules[xla_call_p] = xla_call_transpose_rule\n", "\n", - "@lru_cache()\n", + "@lru_cache\n", "def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]\n", " ) -> tuple[Jaxpr, list[Any]]:\n", " avals_in, avals_out = typecheck_jaxpr(jaxpr)\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index b5f82ec4ffa3..0551b9905db3 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -39,6 +39,8 @@ Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab. # Autodidax: JAX core from scratch + + Ever want to learn how JAX works, but the implementation seemed impenetrable? Well, you're in luck! By reading this tutorial, you'll learn every big idea in JAX's core system. You'll even get clued into our weird jargon! @@ -146,15 +148,15 @@ more descriptive. ```{code-cell} from collections.abc import Sequence from contextlib import contextmanager -from typing import Optional, Any +from typing import Any class MainTrace(NamedTuple): level: int trace_type: type['Trace'] - global_data: Optional[Any] + global_data: Any | None trace_stack: list[MainTrace] = [] -dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3 +dynamic_trace: MainTrace | None = None # to be employed in Part 3 @contextmanager def new_main(trace_type: type['Trace'], global_data=None): @@ -703,7 +705,7 @@ class Store: from collections.abc import Hashable, Iterable, Iterator import itertools as it -from typing import Callable +from collections.abc import Callable class NodeType(NamedTuple): name: str @@ -1293,7 +1295,7 @@ transformation and a pretty-printer: ```{code-cell} from functools import lru_cache -@lru_cache() # ShapedArrays are hashable +@lru_cache # ShapedArrays are hashable def make_jaxpr_v1(f, *avals_in): avals_in, in_tree = tree_flatten(avals_in) f, out_tree = flatten_fun(f, in_tree) @@ -1413,7 +1415,7 @@ def new_dynamic(main: MainTrace): finally: dynamic_trace = prev_dynamic_trace -@lru_cache() +@lru_cache def make_jaxpr(f: Callable, *avals_in: ShapedArray, ) -> tuple[Jaxpr, list[Any], PyTreeDef]: avals_in, in_tree = tree_flatten(avals_in) @@ -1562,7 +1564,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int): return execute(*args) impl_rules[xla_call_p] = xla_call_impl -@lru_cache() +@lru_cache def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: tuple[IDHashable, ...]): jaxpr: Jaxpr = hashable_jaxpr.val @@ -1732,7 +1734,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts): return primals_out, tangents_out jvp_rules[xla_call_p] = xla_call_jvp_rule -@lru_cache() +@lru_cache def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]: def jvp_traceable(*primals_and_tangents): n = len(primals_and_tangents) // 2 @@ -1753,7 +1755,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts): return outs, [0] * len(outs) vmap_rules[xla_call_p] = xla_call_vmap_rule -@lru_cache() +@lru_cache def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...] ) -> tuple[Jaxpr, list[Any]]: vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) @@ -2063,7 +2065,7 @@ be either known or unknown: ```{code-cell} class PartialVal(NamedTuple): aval: ShapedArray - const: Optional[Any] + const: Any | None @classmethod def known(cls, val: Any): @@ -2127,7 +2129,7 @@ JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe] ```{code-cell} class PartialEvalTracer(Tracer): pval: PartialVal - recipe: Optional[JaxprRecipe] + recipe: JaxprRecipe | None def __init__(self, trace, pval, recipe): self._trace = trace @@ -2327,7 +2329,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts): partial_eval_rules[xla_call_p] = xla_call_partial_eval def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool], - instantiate: Optional[list[bool]] = None, + instantiate: list[bool] | None = None, ) -> tuple[Jaxpr, Jaxpr, list[bool], int]: env: dict[Var, bool] = {} residuals: set[Var] = set() @@ -2584,7 +2586,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts): return [next(outs) if undef else None for undef in undef_primals] transpose_rules[xla_call_p] = xla_call_transpose_rule -@lru_cache() +@lru_cache def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...] ) -> tuple[Jaxpr, list[Any]]: avals_in, avals_out = typecheck_jaxpr(jaxpr) diff --git a/docs/autodidax.py b/docs/autodidax.py index 3a9f1f4156ef..b09534381c69 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -31,6 +31,8 @@ # # Autodidax: JAX core from scratch # +# +# # Ever want to learn how JAX works, but the implementation seemed impenetrable? # Well, you're in luck! By reading this tutorial, you'll learn every big idea in # JAX's core system. You'll even get clued into our weird jargon! @@ -136,15 +138,15 @@ def bind1(prim, *args, **params): # + from collections.abc import Sequence from contextlib import contextmanager -from typing import Optional, Any +from typing import Any class MainTrace(NamedTuple): level: int trace_type: type['Trace'] - global_data: Optional[Any] + global_data: Any | None trace_stack: list[MainTrace] = [] -dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3 +dynamic_trace: MainTrace | None = None # to be employed in Part 3 @contextmanager def new_main(trace_type: type['Trace'], global_data=None): @@ -695,7 +697,7 @@ def __call__(self): # + tags=["hide-input"] from collections.abc import Hashable, Iterable, Iterator import itertools as it -from typing import Callable +from collections.abc import Callable class NodeType(NamedTuple): name: str @@ -1295,7 +1297,7 @@ def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int], # + from functools import lru_cache -@lru_cache() # ShapedArrays are hashable +@lru_cache # ShapedArrays are hashable def make_jaxpr_v1(f, *avals_in): avals_in, in_tree = tree_flatten(avals_in) f, out_tree = flatten_fun(f, in_tree) @@ -1410,7 +1412,7 @@ def new_dynamic(main: MainTrace): finally: dynamic_trace = prev_dynamic_trace -@lru_cache() +@lru_cache def make_jaxpr(f: Callable, *avals_in: ShapedArray, ) -> tuple[Jaxpr, list[Any], PyTreeDef]: avals_in, in_tree = tree_flatten(avals_in) @@ -1554,7 +1556,7 @@ def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int): return execute(*args) impl_rules[xla_call_p] = xla_call_impl -@lru_cache() +@lru_cache def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: tuple[IDHashable, ...]): jaxpr: Jaxpr = hashable_jaxpr.val @@ -1726,7 +1728,7 @@ def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts): return primals_out, tangents_out jvp_rules[xla_call_p] = xla_call_jvp_rule -@lru_cache() +@lru_cache def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]: def jvp_traceable(*primals_and_tangents): n = len(primals_and_tangents) // 2 @@ -1747,7 +1749,7 @@ def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts): return outs, [0] * len(outs) vmap_rules[xla_call_p] = xla_call_vmap_rule -@lru_cache() +@lru_cache def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...] ) -> tuple[Jaxpr, list[Any]]: vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) @@ -2055,7 +2057,7 @@ def vspace(aval: ShapedArray) -> ShapedArray: class PartialVal(NamedTuple): aval: ShapedArray - const: Optional[Any] + const: Any | None @classmethod def known(cls, val: Any): @@ -2119,7 +2121,7 @@ class JaxprEqnRecipe(NamedTuple): class PartialEvalTracer(Tracer): pval: PartialVal - recipe: Optional[JaxprRecipe] + recipe: JaxprRecipe | None def __init__(self, trace, pval, recipe): self._trace = trace @@ -2320,7 +2322,7 @@ def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts): partial_eval_rules[xla_call_p] = xla_call_partial_eval def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool], - instantiate: Optional[list[bool]] = None, + instantiate: list[bool] | None = None, ) -> tuple[Jaxpr, Jaxpr, list[bool], int]: env: dict[Var, bool] = {} residuals: set[Var] = set() @@ -2583,7 +2585,7 @@ def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts): return [next(outs) if undef else None for undef in undef_primals] transpose_rules[xla_call_p] = xla_call_transpose_rule -@lru_cache() +@lru_cache def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...] ) -> tuple[Jaxpr, list[Any]]: avals_in, avals_out = typecheck_jaxpr(jaxpr) diff --git a/docs/automatic-differentiation.md b/docs/automatic-differentiation.md index 4a8922dab900..cc4a19aaba64 100644 --- a/docs/automatic-differentiation.md +++ b/docs/automatic-differentiation.md @@ -15,6 +15,8 @@ kernelspec: (automatic-differentiation)= # Automatic differentiation + + In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general autodiff system. Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as: diff --git a/docs/automatic-vectorization.md b/docs/automatic-vectorization.md index 794a9f11307b..7559155e2e9e 100644 --- a/docs/automatic-vectorization.md +++ b/docs/automatic-vectorization.md @@ -15,6 +15,8 @@ kernelspec: (automatic-vectorization)= # Automatic vectorization + + In the previous section we discussed JIT compilation via the {func}`jax.jit` function. This notebook discusses another of JAX's transforms: vectorization via {func}`jax.vmap`. diff --git a/docs/building_on_jax.md b/docs/building_on_jax.md index f6b8e84cded1..e0a4404911a7 100644 --- a/docs/building_on_jax.md +++ b/docs/building_on_jax.md @@ -1,5 +1,7 @@ # Building on JAX + + A great way to learn advanced JAX usage is to see how other libraries are using JAX, both how they integrate the library into their API, what functionality it adds mathematically, diff --git a/docs/contributing.md b/docs/contributing.md index 5040fbd9f17e..cad7cfc1ea64 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -1,5 +1,7 @@ # Contributing to JAX + + Everyone can contribute to JAX, and we value everyone's contributions. There are several ways to contribute, including: @@ -34,7 +36,7 @@ Follow these steps to contribute code: [repository page](http://www.github.com/google/jax). This creates a copy of the JAX repository in your own account. -3. Install Python >= 3.9 locally in order to run tests. +3. Install Python >= 3.10 locally in order to run tests. 4. `pip` installing your fork from source. This allows you to modify the code and immediately test it out: diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD new file mode 100644 index 000000000000..93715bdac171 --- /dev/null +++ b/docs/cuda_custom_call/BUILD @@ -0,0 +1,63 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "cuda_library", + "jax_generate_backend_suites", + "jax_test", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +jax_generate_backend_suites() + +jax_test( + name = "cuda_custom_call_test", + srcs = ["cuda_custom_call_test.py"], + data = [":foo"], + disable_backends = [ + "cpu", + "tpu", + ], + tags = ["notap"], + deps = [ + "//jax:extend", + ], +) + +# this second target is needed to properly link in CUDA runtime symbols +# such as cudaLaunchKernel, even though we are only building one library. +cc_shared_library( + name = "foo", + deps = [ + ":foo_", + "@xla//xla/tsl/cuda:cudart", + ], +) + +cuda_library( + name = "foo_", + srcs = ["foo.cu.cc"], + deps = [ + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@local_config_cuda//cuda:cuda_headers", + ], +) diff --git a/docs/cuda_custom_call/Makefile b/docs/cuda_custom_call/Makefile new file mode 100644 index 000000000000..ca51b63b5eaf --- /dev/null +++ b/docs/cuda_custom_call/Makefile @@ -0,0 +1,35 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This Makefile is not used by Bazel for this test, it is intended to serve as +# documentation of build instructions for JAX users that are not using Bazel to +# build their custom call code. For that reason, this Makefile is likely subject +# to bitrot over time. Please file a JAX issue on GitHub if typing "make" in +# this directory no longer runs the test to completion. +NVCC = nvcc +NVCCFLAGS += -I$(shell python -c 'from jax.extend import ffi; print(ffi.include_dir())') +NVCCFLAGS += -arch native +# since the file extension is .cu.cc, tell NVCC explicitly to treat it as .cu +NVCCFLAGS += -x cu + +# depends on libfoo.so being in the same directory as cuda_custom_call_test.py +check: libfoo.so + python cuda_custom_call_test.py + +lib%.so: %.cu.cc + $(NVCC) $(NVCCFLAGS) --compiler-options=-shared,-fPIC -o $@ $< + +clean: + rm -rf *.so diff --git a/docs/cuda_custom_call/cuda_custom_call_test.py b/docs/cuda_custom_call/cuda_custom_call_test.py new file mode 100644 index 000000000000..563462feb472 --- /dev/null +++ b/docs/cuda_custom_call/cuda_custom_call_test.py @@ -0,0 +1,216 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This test is intentionally structured to stay close to what a standalone JAX +# custom call integration might look like. JAX test harness is in a separate +# section towards the end of this file. The test can be run standalone by typing +# "make" in the directory containing this file. + +import os +import ctypes +import unittest + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.extend import ffi +from jax.lib import xla_client +from jax.interpreters import mlir + +# start test boilerplate +from absl.testing import absltest +from jax._src import config +from jax._src import test_util as jtu + +config.parse_flags_with_absl() +# end test boilerplate + +# XLA needs uppercase, "cuda" isn't recognized +XLA_PLATFORM = "CUDA" + +# JAX needs lowercase, "CUDA" isn't recognized +JAX_PLATFORM = "cuda" + +# 0 = original ("opaque"), 1 = FFI +XLA_CUSTOM_CALL_API_VERSION = 1 + +# these strings are how we identify kernels to XLA: +# - first we register a pointer to the kernel with XLA under this name +# - then we "tell" JAX to emit StableHLO specifying this name to XLA +XLA_CUSTOM_CALL_TARGET_FWD = "foo-fwd" +XLA_CUSTOM_CALL_TARGET_BWD = "foo-bwd" + +# independently, corresponding JAX primitives must also be named, +# names can be different from XLA targets, here they are the same +JAX_PRIMITIVE_FWD = "foo-fwd" +JAX_PRIMITIVE_BWD = "foo-bwd" + +if jtu.is_running_under_pytest(): + raise unittest.SkipTest("libfoo.so hasn't been built") +SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so") + +library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY) + +#-----------------------------------------------------------------------------# +# Forward pass # +#-----------------------------------------------------------------------------# + +# register the XLA FFI binding pointer with XLA +xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_FWD, + fn=ffi.pycapsule(library.FooFwd), + platform=XLA_PLATFORM, + api_version=XLA_CUSTOM_CALL_API_VERSION) + + +# our forward primitive will also return the intermediate output b+1 +# so it can be reused in the backward pass computation +def _foo_fwd_abstract_eval(a, b): + assert a.shape == b.shape + assert a.dtype == b.dtype + shaped_array = jax.core.ShapedArray(a.shape, a.dtype) + return ( + shaped_array, # output c + shaped_array, # intermediate output b+1 + ) + + +def _foo_fwd_lowering(ctx, a, b): + # ffi.ffi_lowering does most of the heavy lifting building a lowering. + # Keyword arguments passed to the lowering constructed by ffi_lowering are + # turned into custom call backend_config entries, which we take advantage of + # here for the dynamically computed n. + n = np.prod(a.type.shape).astype(np.uint64) + return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_FWD)(ctx, a, b, n=n) + + +# construct a new JAX primitive +foo_fwd_p = jax.core.Primitive(JAX_PRIMITIVE_FWD) +# register the abstract evaluation rule for the forward primitive +foo_fwd_p.def_abstract_eval(_foo_fwd_abstract_eval) +foo_fwd_p.multiple_results = True +mlir.register_lowering(foo_fwd_p, _foo_fwd_lowering, platform=JAX_PLATFORM) + +#-----------------------------------------------------------------------------# +# Backward pass # +#-----------------------------------------------------------------------------# + +# register the XLA FFI binding pointer with XLA +xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_BWD, + fn=ffi.pycapsule(library.FooBwd), + platform=XLA_PLATFORM, + api_version=XLA_CUSTOM_CALL_API_VERSION) + + +def _foo_bwd_abstract_eval(c_grad, a, b_plus_1): + assert c_grad.shape == a.shape + assert a.shape == b_plus_1.shape + assert c_grad.dtype == a.dtype + assert a.dtype == b_plus_1.dtype + + shaped_array = jax.core.ShapedArray(a.shape, a.dtype) + return ( + shaped_array, # a_grad + shaped_array, # b_grad + ) + + +def _foo_bwd_lowering(ctx, c_grad, a, b_plus_1): + n = np.prod(a.type.shape).astype(np.uint64) + return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_BWD)(ctx, + c_grad, + a, + b_plus_1, + n=n) + + +# construct a new JAX primitive +foo_bwd_p = jax.core.Primitive(JAX_PRIMITIVE_BWD) +# register the abstract evaluation rule for the backward primitive +foo_bwd_p.def_abstract_eval(_foo_bwd_abstract_eval) +foo_bwd_p.multiple_results = True +mlir.register_lowering(foo_bwd_p, _foo_bwd_lowering, platform=JAX_PLATFORM) + +#-----------------------------------------------------------------------------# +# User facing API # +#-----------------------------------------------------------------------------# + + +def foo_fwd(a, b): + c, b_plus_1 = foo_fwd_p.bind(a, b) + return c, (a, b_plus_1) + + +def foo_bwd(res, c_grad): + a, b_plus_1 = res + return foo_bwd_p.bind(c_grad, a, b_plus_1) + + +@jax.custom_vjp +def foo(a, b): + c, _ = foo_fwd(a, b) + return c + + +foo.defvjp(foo_fwd, foo_bwd) + +#-----------------------------------------------------------------------------# +# Test # +#-----------------------------------------------------------------------------# + + +class CustomCallTest(jtu.JaxTestCase): + + def test_fwd_interpretable(self): + shape = (2, 3) + a = 2. * jnp.ones(shape) + b = 3. * jnp.ones(shape) + observed = jax.jit(foo)(a, b) + expected = (2. * (3. + 1.)) + self.assertArraysEqual(observed, expected) + + def test_bwd_interpretable(self): + shape = (2, 3) + a = 2. * jnp.ones(shape) + b = 3. * jnp.ones(shape) + + def loss(a, b): + return jnp.sum(foo(a, b)) + + da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b) + da_expected = b + 1 + db_expected = a + self.assertArraysEqual(da_observed, da_expected) + self.assertArraysEqual(db_observed, db_expected) + + def test_fwd_random(self): + shape = (2, 3) + akey, bkey = jax.random.split(jax.random.key(0)) + a = jax.random.normal(key=akey, shape=shape) + b = jax.random.normal(key=bkey, shape=shape) + observed = jax.jit(foo)(a, b) + expected = a * (b + 1) + self.assertAllClose(observed, expected) + + def test_bwd_random(self): + shape = (2, 3) + akey, bkey = jax.random.split(jax.random.key(0)) + a = jax.random.normal(key=akey, shape=shape) + b = jax.random.normal(key=bkey, shape=shape) + jtu.check_grads(f=jax.jit(foo), args=(a, b), order=1, modes=("rev",)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/docs/cuda_custom_call/foo.cu.cc b/docs/cuda_custom_call/foo.cu.cc new file mode 100644 index 000000000000..7072a822f929 --- /dev/null +++ b/docs/cuda_custom_call/foo.cu.cc @@ -0,0 +1,136 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; + +//----------------------------------------------------------------------------// +// Forward pass // +//----------------------------------------------------------------------------// + +// c = a * (b+1) +// This strawman operation works well for demo purposes because: +// 1. it's simple enough to be quickly understood, +// 2. it's complex enough to require intermediate outputs in grad computation, +// like many operations in practice do, and +// 3. it does not have a built-in implementation in JAX. +__global__ void FooFwdKernel(const float *a, const float *b, float *c, + float *b_plus_1, // intermediate output b+1 + size_t n) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t grid_stride = blockDim.x * gridDim.x; + for (size_t i = tid; i < n; i += grid_stride) { + b_plus_1[i] = b[i] + 1.0f; + c[i] = a[i] * b_plus_1[i]; + } +} + +// Host function wrapper that launches the kernel with hardcoded grid/block +// size. Note, it uses types from XLA FFI. The return type must be ffi::Error. +// Buffer type provides buffer dimensions, so the "n" argument here is not +// strictly necessary, but it allows us to demonstrate the use of attributes +// (.Attr in the FFI handler definition above). +ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer a, + ffi::Buffer b, + ffi::Result> c, + ffi::Result> b_plus_1, + size_t n) { + const int block_dim = 128; + const int grid_dim = 1; + // Note how we access regular Buffer data vs Result Buffer data: + FooFwdKernel<<>>( + a.data, b.data, c->data, b_plus_1->data, n); + // Check for launch time errors. Note that this function may also + // return error codes from previous, asynchronous launches. This + // means that an error status returned here could have been caused + // by a different kernel previously launched by XLA. + cudaError_t last_error = cudaGetLastError(); + if (last_error != cudaSuccess) { + return ffi::Error( + XLA_FFI_Error_Code_INTERNAL, + std::string("CUDA error: ") + cudaGetErrorString(last_error)); + } + return ffi::Error::Success(); +} + +// Creates symbol FooFwd with C linkage that can be loaded using Python ctypes +XLA_FFI_DEFINE_HANDLER_SYMBOL( + FooFwd, FooFwdHost, + ffi::Ffi::Bind() + .Ctx>() // stream + .Arg>() // a + .Arg>() // b + .Ret>() // c + .Ret>() // b_plus_1 + .Attr("n")); + +//----------------------------------------------------------------------------// +// Backward pass // +//----------------------------------------------------------------------------// + +// compute da = dc * (b+1), and +// db = dc * a +__global__ void FooBwdKernel(const float *c_grad, // incoming gradient wrt c + const float *a, // original input a + const float *b_plus_1, // intermediate output b+1 + float *a_grad, // outgoing gradient wrt a + float *b_grad, // outgoing gradient wrt b + size_t n) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t grid_stride = blockDim.x * gridDim.x; + for (size_t i = tid; i < n; i += grid_stride) { + // In practice on GPUs b_plus_1 can be recomputed for practically free + // instead of storing it out and reusing, so the reuse here is a bit + // contrived. We do it to demonstrate residual/intermediate output passing + // between the forward and the backward pass which becomes useful when + // recomputation is more expensive than reuse. + a_grad[i] = c_grad[i] * b_plus_1[i]; + b_grad[i] = c_grad[i] * a[i]; + } +} + +ffi::Error FooBwdHost(cudaStream_t stream, + ffi::Buffer c_grad, + ffi::Buffer a, + ffi::Result> b_plus_1, + ffi::Result> a_grad, + ffi::Result> b_grad, + size_t n) { + const int block_dim = 128; + const int grid_dim = 1; + FooBwdKernel<<>>( + c_grad.data, a.data, b_plus_1->data, a_grad->data, b_grad->data, n); + cudaError_t last_error = cudaGetLastError(); + if (last_error != cudaSuccess) { + return ffi::Error( + XLA_FFI_Error_Code_INTERNAL, + std::string("CUDA error: ") + cudaGetErrorString(last_error)); + } + return ffi::Error::Success(); +} + +// Creates symbol FooBwd with C linkage that can be loaded using Python ctypes +XLA_FFI_DEFINE_HANDLER_SYMBOL( + FooBwd, FooBwdHost, + ffi::Ffi::Bind() + .Ctx>() // stream + .Arg>() // c_grad + .Arg>() // a + .Arg>() // b_plus_1 + .Ret>() // a_grad + .Ret>() // b_grad + .Attr("n")); diff --git a/docs/debugging.md b/docs/debugging.md index b53f08139fd2..7ee36f19f5bf 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -15,6 +15,8 @@ kernelspec: (debugging)= # Introduction to debugging + + This section introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations. Let's begin with {func}`jax.debug.print`. diff --git a/docs/debugging/checkify_guide.md b/docs/debugging/checkify_guide.md index a804d36038a1..2dad9b863b06 100644 --- a/docs/debugging/checkify_guide.md +++ b/docs/debugging/checkify_guide.md @@ -1,5 +1,7 @@ # The `checkify` transformation + + **TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: ```python diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md index 90a6cb3bbfbd..1cf1829e5152 100644 --- a/docs/debugging/flags.md +++ b/docs/debugging/flags.md @@ -1,5 +1,7 @@ # JAX debugging flags + + JAX offers flags and context managers that enable catching errors more easily. ## `jax_debug_nans` configuration option and context manager diff --git a/docs/debugging/index.md b/docs/debugging/index.md index 35e0f68950c4..b00fcc13d0a0 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -1,5 +1,7 @@ # Runtime value debugging in JAX + + Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more. Table of contents: diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index f29f68c4dbc4..440cc38d99f0 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -1,5 +1,7 @@ # `jax.debug.print` and `jax.debug.breakpoint` + + The {mod}`jax.debug` package offers some useful tools for inspecting values inside of JIT-ted functions. diff --git a/docs/deprecation.md b/docs/deprecation.md index 5ee58882acc3..7a8b867b6f2e 100644 --- a/docs/deprecation.md +++ b/docs/deprecation.md @@ -1,6 +1,8 @@ (version-support-policy)= # Python and NumPy version support policy + + For NumPy and SciPy version support, JAX follows the Python scientific community's [SPEC 0](https://scientific-python.org/specs/spec-0000/). diff --git a/docs/developer.md b/docs/developer.md index c936d1ba20ad..018982f4c00d 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -1,6 +1,8 @@ (building-from-source)= # Building from source + + First, obtain the JAX source code: ``` diff --git a/docs/device_memory_profiling.md b/docs/device_memory_profiling.md index a6f27e9e9710..e4d871b780f3 100644 --- a/docs/device_memory_profiling.md +++ b/docs/device_memory_profiling.md @@ -1,5 +1,6 @@ # Device Memory Profiling + ```{note} May 2023 update: we recommend using [Tensorboard diff --git a/docs/distributed_data_loading.md b/docs/distributed_data_loading.md index 5f2ea8e610e3..70cbd26baa5c 100644 --- a/docs/distributed_data_loading.md +++ b/docs/distributed_data_loading.md @@ -14,6 +14,8 @@ kernelspec: # Distributed data loading in a multi-host/multi-process environment + + This high-level guide demonstrates how you can perform distributed data loading — when you run JAX in a {doc}`multi-host or multi-process environment <./multi_process>`, and the data required for the JAX computations is split across the multiple processes. This document covers the overall approach for how to think about distributed data loading, and then how to apply it to *data-parallel* (simpler) and *model-parallel* (more complicated) workloads. Distributed data loading is usually more efficient (the data is split across processes) but also *more complex* compared with its alternatives, such as: 1) loading the *full global data in a single process*, splitting it up and sending the needed parts to the other processes via RPC; and 2) loading the *full global data in all processes* and only using the needed parts in each process. Loading the full global data is often simpler but more expensive. For example, in machine learning the training loop can get blocked while waiting for data, and additional network bandwidth gets used per each process. diff --git a/docs/export/export.md b/docs/export/export.md new file mode 100644 index 000000000000..d8ddd91b91d5 --- /dev/null +++ b/docs/export/export.md @@ -0,0 +1,660 @@ +# Exporting and serializing staged-out computations + +The {ref}`ahead-of-time-lowering` APIs produce +objects that can be used for debugging or for compilation and +execution in the same process. +Sometimes you want to serialize a lowered JAX function for +compilation and execution in a separate process, perhaps +at a later time. This would allow you to: + + * compile and execute the function in another process or machine + without requiring access to the JAX program, + and without having to repeat the staging-out and lowering, e.g., + in an inference system. + * trace and lower a function on a machine that does not have access + to the accelerator for which you want to later compile and execute + the function. + * archive a snapshot of a JAX function, e.g., to be able to + reproduce later your results. **Note:** check out the [compatibility + guarantees](#compatibility-guarantees) for this use case. + +Here is an example: + +```python +>>> import re +>>> import numpy as np +>>> import jax +>>> from jax import export + +>>> def f(x): return 2 * x * x + + +>>> exported: export.Exported = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct((), np.float32)) + +>>> # You can inspect the Exported object +>>> exported.fun_name +'f' + +>>> exported.in_avals +(ShapedArray(float32[]),) + +>>> print(re.search(r".*@main.*", exported.mlir_module()).group(0)) + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"} loc("x")) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + +>>> # And you can serialize the Exported to a bytearray. +>>> serialized: bytearray = exported.serialize() + +>>> # The serialized function can later be rehydrated and called from +>>> # another JAX computation, possibly in another process. +>>> rehydrated_exp: export.Exported = export.deserialize(serialized) +>>> rehydrated_exp.in_avals +(ShapedArray(float32[]),) + +>>> def callee(y): +... return 3. * rehydrated_exp.call(y * 4.) + +>>> callee(1.) +Array(96., dtype=float32) + +``` + +Serialization is broken down into two stages: + 1. exporting to produce an {class}`jax.export.Exported` object that contains + the StableHLO for the lowered function along with the metadata necessary to + call it from another JAX function. We have plans to add code to generate + `Exported` objects from TensorFlow, and to use `Exported` objects from + TensorFlow and PyTorch. + 2. the actual serialization to a byte array using the flatbuffers format. + See {ref}`jax2tf` for + an alternative serialization to TensorFlow graph that can be used + for interoperation with TensorFlow. + +## Support for reverse-mode AD + +Serialization can optionally support higher-order reverse-mode AD. This is done +by serializing the {func}`jax.vjp` of the primal function along with the primal function, +up to a user-specified order (default is 0, meaning that the rehydrated +function cannot be differentiated): + +```python +>>> import jax +>>> from jax import export +>>> from typing import Callable + +>>> def f(x): return 7 * x * x * x + +>>> # Serialize 3 levels of VJP along with the primal function +>>> blob: bytearray = export.export(jax.jit(f))(1.).serialize(vjp_order=3) +>>> rehydrated_f: Callable = export.deserialize(blob).call + +>>> rehydrated_f(0.1) # 7 * 0.1^3 +Array(0.007, dtype=float32) + +>>> jax.grad(rehydrated_f)(0.1) # 7*3 * 0.1^2 +Array(0.21000001, dtype=float32) + +>>> jax.grad(jax.grad(rehydrated_f))(0.1) # 7*3*2 * 0.1 +Array(4.2, dtype=float32) + +>>> jax.grad(jax.grad(jax.grad(rehydrated_f)))(0.1) # 7*3*2 +Array(42., dtype=float32) + +>>> jax.grad(jax.grad(jax.grad(jax.grad(rehydrated_f))))(0.1) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: No VJP is available + +``` + +Note that the VJP function is computed lazily while serializing, +when the JAX program is still available. +This means that it respects all features of JAX VJP, +e.g., {func}`jax.custom_vjp` and {func}`jax.remat`. + +Note that the rehydrated function does not support any other +transformations, e.g., forward-mode AD (jvp), or {func}`jax.vmap`. + +## Compatibility guarantees + +You should not use the raw StableHLO that is obtained from just lowering +(`jax.jit(f).lower(1.).compiler_ir()`) +for archival and for compilation in another process, for several reasons. + +First, the compilation may use a different version of the compiler, supporting a +different version of StableHLO. The {class}`jax.export` module takes +care of this by using the +[portable-artifact feature of StableHLO](https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md) +to deal with the possible evolution of the StableHLO opset. + +### Compatibility guarantees for custom calls + +Second, the raw StableHLO may contain custom calls referencing C++ +functions. +JAX uses custom calls for lowering of a small number of primitives, +e.g., linear algebra primitives, sharding annotations, or Pallas kernels. +These do not fall under the compatibility guarantees for StableHLO. +The C++ implementations of these functions change rarely, but they can change. + +`jax.export` makes the following export compatibility guarantees: +A JAX exported artifact can be compiled and executed by a compiler and +JAX runtime system that are: + + * **up to 6 months newer** than the version of JAX used for exporting + (we say that JAX export offers **6 months backward compatibility**). + This is useful if we want to archive the exported artifact to be compiled and executed later. + * **up to 3 weeks older** than the version of JAX used for exporting + (we say that JAX export offers **3 weeks forward compatibility**). + This is useful if we want to compile and run an exported artifact with a + consumer that was built and deployed before the export, e.g., + an inference system that is already deployed when the exporting is done. + +(The particular compatibility window lengths are the same that JAX +[promised for jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model), +and are based on [TensorFlow Compatibility](https://www.tensorflow.org/guide/versions#graph_and_checkpoint_compatibility_when_extending_tensorflow). +The terminology “backward compatibility” is from the perspective of the consumer, +e.g., the inference system.) + +What **matters is when the exporting and consuming components were built**, +not the time when the exporting and the compilation happen. +For external JAX users, it is +[possible to run JAX and jaxlib at different versions](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); +what matters is when the jaxlib release was built. + +To reduce chances of incompatibility, internal JAX users should: + * **rebuild and redeploy consumer systems as frequently as possible**. + +and external users should: + * run the exporting and consumer systems with the same version of jaxlib, whenever possible, and + * export for archival **with the latest released version of jaxlib**. + +The compatibility guarantees do not apply if you bypass the `jax.export` APIs +to obtain the StableHLO code. + +Only a subset of custom calls are guaranteed stable and have +compatibility guarantees ([see list](https://github.com/search?q=repo%3Agoogle%2Fjax%20_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE&type=code)). +We continuously +add more custom call targets to the allowed list along with backwards +compatibility tests. If you try to serialize +code that invokes other custom call targets you will get an error +during exporting. + +If you want to disable this safety check for a specific custom call, +e.g., with target `my_target`, you can add +`export.DisabledSafetyCheck.custom_call("my_target")` to the +`disabled_checks` parameter of the `export` method, +as in the following example: + +```python +>>> import jax +>>> from jax import export +>>> from jax import lax +>>> from jax._src import core +>>> from jax._src.interpreters import mlir +>>> # Define a new primitive backed by a custom call +>>> new_prim = core.Primitive("new_prim") +>>> _ = new_prim.def_abstract_eval(lambda x: x) +>>> _ = mlir.register_lowering(new_prim, lambda ctx, o: mlir.custom_call("my_new_prim", operands=[o], result_types=[o.type]).results) +>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir()) +module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32} : (tensor) -> tensor + return %0 : tensor + } +} + +>>> # If we try to export, we get an error +>>> export.export(jax.jit(new_prim.bind))(1.) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Cannot serialize code with custom calls whose targets have no compatibility guarantees: my_new_bind + +>>> # We can avoid the error if we pass a `DisabledSafetyCheck.custom_call` +>>> exp = export.export( +... jax.jit(new_prim.bind), +... disabled_checks=[export.DisabledSafetyCheck.custom_call("my_new_prim")])(1.) + +``` + +## Cross-platform and multi-platform export + +JAX lowering is platform specific for a small number of JAX primitives. +By default, the code is lowered and exported for the accelerator +present on the exporting machine: + +```python +>>> from jax import export +>>> export.default_export_platform() +'cpu' + +``` + +There is a safety check that will be raise an error when trying to compile +an `Exported` object on a machine that does not have the accelerator +for which the code was exported. + +You can specify explicitly for what platforms the code should be exported. +This allows you to specify a different accelerator than you have +available at export time, +and it even allows you to specify multi-platform lexport to +obtain an `Exported` object that can be compiled and executed +on multiple platforms. + + +```python +>>> import jax +>>> from jax import export +>>> from jax import lax + +>>> # You can specify the export platform, e.g., `tpu`, `cpu`, `cuda`, `rocm` +>>> # even if the current machine does not have that accelerator. +>>> exp = export.export(jax.jit(lax.cos), platforms=['tpu'])(1.) + +>>> # But you will get an error if you try to compile `exp` +>>> # on a machine that does not have TPUs. +>>> exp.call(1.) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'. + +>>> # We can avoid the error if we pass a `DisabledSafetyCheck.platform` +>>> # parameter to `export`, e.g., because you have reasons to believe +>>> # that the code lowered will run adequately on the current +>>> # compilation platform (which is the case for `cos` in this +>>> # example): +>>> exp_unsafe = export.export(jax.jit(lax.cos), +... lowering_platforms=['tpu'], +... disabled_checks=[export.DisabledSafetyCheck.platform()])(1.) + +>>> exp_unsafe.call(1.) +Array(0.5403023, dtype=float32, weak_type=True) + +# and similarly with multi-platform lowering +>>> exp_multi = export.export(jax.jit(lax.cos), +... lowering_platforms=['tpu', 'cpu', 'cuda'])(1.) +>>> exp_multi.call(1.) +Array(0.5403023, dtype=float32, weak_type=True) + +``` + +For multi-platform export, the StableHLO will contain multiple +lowerings but only for those primitives that require it, so the +resulting module size should be only marginally larger than the +size of a module with default export. +As an extreme case, when serializing a module without any +primitives with platform-specific lowering, you will get +the same StableHLO as for the single-plaform export. + +```python +>>> import jax +>>> from jax import export +>>> from jax import lax +>>> # A largish function +>>> def f(x): +... for i in range(1000): +... x = jnp.cos(x) +... return x + +>>> exp_single = export.export(jax.jit(f))(1.) +>>> len(exp_single.mlir_module_serialized) # doctest: +SKIP +9220 + +>>> exp_multi = export.export(jax.jit(f), +... lowering_platforms=["cpu", "tpu", "cuda"])(1.) +>>> len(exp_multi.mlir_module_serialized) # doctest: +SKIP +9282 + +``` + +## Shape polymorphic export + +When used in JIT mode, JAX will trace and lower a function separately +for each combination of input shapes. When exporting, it is possible +in some cases to use dimension variables for some input dimensions +in order to obtain an exported artifact that can be used with multiple +combinations of input shapes. + +See the {ref}`shape_poly` documentation. + +## Device polymorphic export + +An exported artifact may contain sharding annotations for inputs, +outputs and for some intermediates, but these annotations do not refer +directly to the actual physical devices that existed at exporting time. +Instead, the sharding annotations refer to logical devices. This +means that you can compile and run the exported artifacts on different +physical devices that were used for exporting. + +```python +>>> import jax +>>> from jax import export +>>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import PartitionSpec as P + +>>> # Use the first 4 devices for exporting. +>>> export_devices = jax.local_devices()[:4] +>>> export_mesh = Mesh(export_devices, ("a",)) +>>> def f(x): +... return x.T + +>>> arg = jnp.arange(8 * len(export_devices)) +>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg) + +>>> # `exp` knows for how many devices it was exported. +>>> exp.nr_devices +4 + +>>> # and it knows the shardings for the inputs. These will be applied +>>> # when the exported is called. +>>> exp.in_shardings_hlo +({devices=[4]<=[4]},) + +>>> res1 = exp.call(jax.device_put(arg, +... NamedSharding(export_mesh, P("a")))) + +>>> # Check out the first 2 shards of the result +>>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]] +['device=TFRT_CPU_0 index=(slice(0, 8, None),)', + 'device=TFRT_CPU_1 index=(slice(8, 16, None),)'] + +>>> # We can call `exp` with some other 4 devices and another +>>> # mesh with a different shape, as long as the number of devices is +>>> # the same. +>>> other_mesh = Mesh(np.array(jax.local_devices()[2:6]).reshape((2, 2)), ("b", "c")) +>>> res2 = exp.call(jax.device_put(arg, +... NamedSharding(other_mesh, P("b")))) + +>>> # Check out the first 2 shards of the result. Notice that the output is +>>> # sharded similarly; this means that the input was resharded according to the +>>> # exp.in_shardings. +>>> [f"device={s.device} index={s.index}" for s in res2.addressable_shards[:2]] +['device=TFRT_CPU_2 index=(slice(0, 8, None),)', + 'device=TFRT_CPU_3 index=(slice(8, 16, None),)'] + +``` + +It is an error to try to invoke an exported artifact with a different number +of devices than it was exported for: + +```python +>>> import jax +>>> from jax import export +>>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import PartitionSpec as P + +>>> export_devices = jax.local_devices() +>>> export_mesh = Mesh(np.array(export_devices), ("a",)) +>>> def f(x): +... return x.T + +>>> arg = jnp.arange(4 * len(export_devices)) +>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg) + +>>> exp.call(arg) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device. + +``` + +There are helper functions to shard the inputs for calling an exported +artifacts using a new mesh constructed at the call site: + +```python +>>> import jax +>>> from jax import export +>>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import PartitionSpec as P + +>>> export_devices = jax.local_devices() +>>> export_mesh = Mesh(np.array(export_devices), ("a",)) +>>> def f(x): +... return x.T + +>>> arg = jnp.arange(4 * len(export_devices)) +>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg) + +>>> # Prepare the mesh for calling `exp`. +>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",)) + +>>> # Shard the arg according to what `exp` expects. +>>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0]) +>>> res = exp.call(sharded_arg) + +``` + +As a special facility, if a function was exported for 1 device and if it contains no +sharding annotations, then it can be invoked on an argument of the same shape but sharded +on multiple devices, and the compiler will shard the function appropriately: + +```python +```python +>>> import jax +>>> from jax import export +>>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import PartitionSpec as P + +>>> def f(x): +... return jnp.cos(x) + +>>> arg = jnp.arange(4) +>>> exp = export.export(jax.jit(f))(arg) +>>> exp.in_avals +(ShapedArray(int32[4]),) + +>>> exp.nr_devices +1 + +>>> # Prepare the mesh for calling `exp`. +>>> calling_mesh = Mesh(jax.local_devices()[:4], ("b",)) + +>>> # Shard the arg according to what `exp` expects. +>>> sharded_arg = jax.device_put(arg, +... NamedSharding(calling_mesh, P("b"))) +>>> res = exp.call(sharded_arg) + +``` + +## Calling convention versions + +The JAX export support has evolved over time, e.g., to support +effects. In order to support compatibility (see [compatibility guarantees](#compatibility-guarantees)) +we maintain a calling convention version for each `Exported`. +As of June 2024, all function exported with version 9 +(the latest, see [all calling convention versions](#calling-convention-versions)): + +```python +>>> from jax import export +>>> exp: export.Exported = export.export(jnp.cos)(1.) +>>> exp.calling_convention_version +9 + +``` + +At any given time, the export APIs may support a range +of calling convention versions. You can control which calling convention +version to use using the `--jax-export-calling-convention-version` flag +or the `JAX_EXPORT_CALLING_CONVENTION_VERSION` environment variable: + +```python +>>> from jax import export +>>> (export.minimum_supported_calling_convention_version, export.maximum_supported_calling_convention_version) +(9, 9) + +>>> from jax._src import config +>>> with config.jax_export_calling_convention_version(9): +... exp = export.export(jnp.cos)(1.) +... exp.calling_convention_version +9 + +``` + +We reserve the right to remove support for +generating or consuming calling convention versions older than 6 months. + +### Module calling convention + +The `Exported.mlir_module` has a `main` function that takes an optional first +platform index argument if the module supports multiple platforms +(`len(platforms) > 1`), followed by the token arguments corresponding +to the ordered effects, followed by the kept array +arguments (corresponding to `module_kept_var_idx` and `in_avals`). +The platform index is a i32 or i64 scalar encoding the index of the current +compilation platform into the `platforms` sequence. + +Inner functions use a different calling convention: an optional +platform index argument, optional dimension variable arguments +(scalar tensors of type i32 or i64), +followed by optional token arguments (in presence of ordered effects), +followed by the regular array arguments. +The dimension arguments correspond to the dimension variables appearing in +the `args_avals`, in sorted order of their names. + +Consider the lowering of a function with one array argument of type +`f32[w, 2 * h]`, where `w` and `h` are two dimension variables. +Assume that we use multi-platform lowering, and we have +one ordered effect. The `main` function will be as follows: + +``` + func public main( + platform_index: i32 {jax.global_constant="_platform_index"}, + token_in: token, + arg: f32[?, ?]) { + arg_w = hlo.get_dimension_size(arg, 0) + dim1 = hlo.get_dimension_size(arg, 1) + arg_h = hlo.floordiv(dim1, 2) + call _check_shape_assertions(arg) # See below + token = new_token() + token_out, res = call _wrapped_jax_export_main(platform_index, + arg_h, + arg_w, + token_in, + arg) + return token_out, res + } +``` + +The actual computation is in `_wrapped_jax_export_main`, taking also +the values of `h` and `w` dimension variables. + +The signature of the `_wrapped_jax_export_main` is: + +``` + func private _wrapped_jax_export_main( + platform_index: i32 {jax.global_constant="_platform_index"}, + arg_h: i32 {jax.global_constant="h"}, + arg_w: i32 {jax.global_constant="w"}, + arg_token: stablehlo.token {jax.token=True}, + arg: f32[?, ?]) -> (stablehlo.token, ...) +``` + +Prior to calling convention version 9 the calling convention for effects was +different: the `main` function does not take or return a token. Instead +the function creates dummy tokens of type `i1[0]` and passes them to the +`_wrapped_jax_export_main`. The `_wrapped_jax_export_main` +takes dummy tokens of type `i1[0]` and will create internally real +tokens to pass to the inner functions. The inner functions use real +tokens (both before and after calling convention version 9) + +Also starting with calling convention version 9, function arguments that contain +the platform index or the dimension variable values have a +`jax.global_constant` string attribute whose value is the name of the +global constant, either `_platform_index` or a dimension variable name. +The global constant name may be empty if it is not known. +Some global constant computations use inner functions, e.g., for +`floor_divide`. The arguments of such functions have a `jax.global_constant` +attribute for all attributes, meaning that the result of the function is +also a global constant. + +Note that `main` contains a call to `_check_shape_assertions`. +JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h` +have values >= 1. We must check these constraints when we invoke the +module. We use a special custom call `@shape_assertion` that takes +a boolean first operand, a string `error_message` attribute that may contain +format specifiers `{0}`, `{1}`, ..., and a variadic number of integer +scalar operands corresponding to the format specifiers. + +``` + func private _check_shape_assertions(arg: f32[?, ?]) { + # Check that w is >= 1 + arg_w = hlo.get_dimension_size(arg, 0) + custom_call @shape_assertion(arg_w >= 1, arg_w, + error_message="Dimension variable 'w' must have integer value >= 1. Found {0}") + # Check that dim1 is even + dim1 = hlo.get_dimension_size(arg, 1) + custom_call @shape_assertion(dim1 % 2 == 0, dim1 % 2, + error_message="Division had remainder {0} when computing the value of 'h') + # Check that h >= 1 + arg_h = hlo.floordiv(dim1, 2) + custom_call @shape_assertion(arg_h >= 1, arg_h, + error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}") +``` + +(export-calling-convention-version)= + +### Calling convention versions + +We list here a history of the calling convention version numbers: + + * Version 1 used MHLO & CHLO to serialize the code, not supported anymore. + * Version 2 supports StableHLO & CHLO. Used from October 2022. Not supported + anymore. + * Version 3 supports platform checking and multiple platforms. + Used from February 2023. Not supported anymore. + * Version 4 supports StableHLO with compatibility guarantees. + This is the earliest version at the time of the JAX native serialization + launch. + Used in JAX from March 15, 2023 (cl/516885716). Starting with + March 28th, 2023 we stopped using `dim_args_spec` (cl/520033493). + The support for this version was dropped on + October 17th, 2023 (cl/573858283). + * Version 5 adds support for `call_tf_graph`. This is currently used + for some specialized use cases. Used in JAX from May 3rd, 2023 + (cl/529106145). + * Version 6 adds support for the `disabled_checks` attribute. This version + mandates a non-empty `platforms` attribute. Supported by XlaCallModule + since June 7th, 2023 and available in JAX since + June 13th, 2023 (JAX 0.4.13). + * Version 7 adds support for `stablehlo.shape_assertion` operations and + for `shape_assertions` specified in `disabled_checks`. + See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule + since July 12th, 2023 (cl/547482522), + available in JAX serialization since July 20th, 2023 (JAX 0.4.14), + and the default since August 12th, 2023 (JAX 0.4.15). + * Version 8 adds support for the `jax.uses_shape_polymorphism` module + attribute and enables the shape refinement pass only when the + attribute is present. Supported by XlaCallModule since July 21st, 2023 + (cl/549973693), available in JAX since July 26th, 2023 (JAX 0.4.14), + and the default since October 21st, 2023 (JAX 0.4.20). + * Version 9 adds support for effects. + See the docstring for `export.Exported` for the precise calling convention. + In this calling convention version we also tag the platform index and the + dimension variables arguments with `jax.global_constant` attributes. + Supported by XlaCallModule since October 27th, 2023, + available in JAX since October 20th, 2023 (JAX 0.4.20), + and the default since February 1st, 2024 (JAX 0.4.24). + This is the only supported version as of 27th of March, 2024. + + +## Migration guide from jax.experimental.export + +On June 14, 2024 we deprecated the `jax.experimental.export` APIs +in favor of `jax.export` APIs. There have been some minor changes: + + * `jax.experimental.export.export`: + * The old function used to allow any Python callable, or the result of + `jax.jit`. Now only the latter is accepted. You have to manually apply + `jax.jit` to the function to export before calling `export`. + * The old `lowering_parameters` kwarg is now named `platforms` + * `jax.experimental.export.default_lowering_platform()` is now + at {func}`jax.export.default_export_platform`. + * `jax.experimental.export.call` is now a method of the {class}`jax.export.Exported` object. + Instead of `export.call(exp)` you should use `exp.call`. + * `jax.experimental.export.serialize` is now a method of the {class}`jax.export.Exported` + object. Instead of `export.serialize(exp)` you should use `exp.serialize()`. + * The configuration flag `--jax-serialization-version` is deprecated. + Use `--jax-export-calling-convention-version`. + * The value `jax.experimental.export.minimum_supported_serialization_version` + is now at `jax.export.minimum_supported_calling_convention_version`. + * The following fields of {class}`jax.export.Exported` have been renamed + * `uses_shape_polymorphism` is now `uses_global_constants` + * `mlir_module_serialization_version` is now `calling_convention_version` + * `lowering_platforms` is now `platforms`. + + diff --git a/docs/export/index.rst b/docs/export/index.rst new file mode 100644 index 000000000000..24cf2716cafe --- /dev/null +++ b/docs/export/index.rst @@ -0,0 +1,13 @@ +.. _export: + +Exporting and serialization +============================= + +.. toctree:: + :caption: Guides + :maxdepth: 2 + + export + shape_poly + + jax2tf diff --git a/docs/export/jax2tf.md b/docs/export/jax2tf.md new file mode 100644 index 000000000000..498a0418f232 --- /dev/null +++ b/docs/export/jax2tf.md @@ -0,0 +1,5 @@ +(jax2tf)= + +## Interoperation with TensorFlow + +See the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md new file mode 100644 index 000000000000..8b07a3666c87 --- /dev/null +++ b/docs/export/shape_poly.md @@ -0,0 +1,642 @@ +(shape_poly)= + +# Shape polymorphism + +When JAX is used in JIT mode, a function will be traced, lowered to StableHLO, and compiled for each +combination of input types and shapes. After exporting a function and +deserializing it on another system we don't have the Python sources available anymore, +so we cannot re-trace and re-lower it. **Shape polymorphism** is a feature of JAX export +to allow some exported functions to be used for a whole family of input shapes. +These functions are traced and lowered once, during exporting, and `Exported` +object contains the information needed to be able to compile and execute the function +on many concrete input shapes. We do this by specifying shapes that contain +dimension variables (symbolic shapes) when exporting, as in the +following example: + +```python +>>> import jax +>>> from jax import export +>>> from jax import numpy as jnp +>>> def f(x): # f: f32[a, b] +... return jnp.concatenate([x, x], axis=1) + +>>> # We construct symbolic dimension variables. +>>> a, b = export.symbolic_shape("a, b") + +>>> # We can use the symbolic dimensions to construct shapes. +>>> x_shape = (a, b) +>>> x_shape +(a, b) + +>>> # Then we export with symbolic shapes: +>>> exp: export.Exported = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct(x_shape, jnp.int32)) +>>> exp.in_avals +(ShapedArray(int32[a,b]),) +>>> exp.out_avals +(ShapedArray(int32[a,2*b]),) + +>>> # We can later call with concrete shapes (with a=3 and b=4), without re-tracing `f`. +>>> res = exp.call(np.ones((3, 4), dtype=np.int32)) +>>> res.shape +(3, 8) + +``` + +Note that such functions are still re-compiled on demand for +each concrete input shapes they are invoked on. Only the +tracing and the lowering are saved. + +The {func}`jax.export.symbolic_shape` is used in the above +example to parse a string representation of a symbolic shape +into dimension expressions objects (of type `_DimExpr`) that are usable in place of integer +constants to construct shapes. The dimension expression objects +overload most integer operators, so you can use them as +you'd use integer constants in most cases. +See {ref}`computing-with-dimension-variables` for more details. + +Additionally, we provide the {func}`jax.export.symbolic_args_specs` that +can be used to construct pytrees of `jax.ShapeDtypeStruct` objects based +on a polymorphic shape specification: + +```python +>>> def f1(x, y): # x: f32[a, 1], y : f32[a, 4] +... return x + y + +>>> # Assuming you have some actual args with concrete shapes +>>> x = np.ones((3, 1), dtype=np.int32) +>>> y = np.ones((3, 4), dtype=np.int32) +>>> args_specs = export.symbolic_args_specs((x, y), "a, ...") +>>> exp = export.export(jax.jit(f1))(* args_specs) +>>> exp.in_avals +(ShapedArray(int32[a,1]), ShapedArray(int32[a,4])) + +``` + +Note how the polymorphic shape specification `"a, ..."` contains +the placeholder `...` to be filled from the concrete shapes of +the concrete shapes of the arguments `(x, y)`. +The placeholder `...` stands for 0 or more dimensions, while the +placeholder `_` stands for one dimension. +The {func}`jax.export.symbolic_args_specs` supports pytrees of arguments, +which are used to fill-in the dtypes and any placeholders. +The function will construct a pytree of +argument specifications ({class}`jax.ShapeDtypeStruct`) +matching the structure of the arguments passed to it. +The polymorphic shapes specification can be a +pytree prefix in cases where one specification should apply +to multiple arguments, as in the above example. +See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). + + +A few examples of shape specifications: + + * `("(b, _, _)", None)` can be used for a function with two arguments, the first + being a 3D array with a batch leading dimension that should be symbolic. + The other dimensions for the + first argument and the shape of the second argument are specialized based on the actual + arguments. Note that the same specification would work if the first + argument is a pytree of 3D arrays, all with the same leading dimension + but possibly with different trailing dimensions. + The value `None` for the second arugment means that the argument + is not symbolic. Equivalently, one can use `...`. + + * `("(batch, ...)", "(batch,)")` specifies that the two arguments + have matching leading dimensions, the first argument has rank at + least 1, and the second has rank 1. + +## Correctness of shape polymorphism + +We want to trust that the exported program produces the same results as the +original JAX program when compiled and executed for any applicable concrete shapes. +More precisely: + +For any JAX function `f` and any argument specification `arg_spec` containing a +symbolic shape, and any concrete argument `arg` whose shape matches `arg_spec`: + + * If the JAX native execution succeeds on the concrete argument: `res = f(arg)`, + * and if the exporting succeeds with symbolic shapes: `exp = export.export(f)(arg_spec)`, + * then compiling and running the export will succeed with the same result: `res == exp.call(arg)` + +It is crucial to understand that `f(arg)` has the freedom to re-invoke +the JAX tracing machinery, +and in fact it does so for each distinct concrete `arg` shape, +while the execution of `exp.call(arg)` cannot use JAX tracing anymore +(this execution may happen in an environment where the source code +of `f` is not available). + +Ensuring this form of correctness is hard, and in the hardest cases +exporting fails. The rest of this chapter describes how to handle these failures. + +(computing-with-dimension-variables)= + +## Computing with dimension variables + +JAX keeps track of the shapes of all intermediate results. When those shapes depend +on dimension variables JAX computes them as symbolic dimension expressions +involving dimension variables. +Dimension variables stand for integer values greater or equal to 1. +The symbolic expressions can represent the result +of applying arithmetic operators (add, sub, mul, floordiv, mod, +including the NumPy variants `np.sum`, `np.prod`, etc.) **on dimension +expressions and integers** (`int`, `np.int`, or anything convertible by `operator.index`). +These symbolic dimensions can then be used in shape-parameters of JAX primitives +and APIs, e.g., in `jnp.reshape`, `jnp.arange`, slicing indices, etc. + +For example, in the following code to flatten a 2D array, the computation +`x.shape[0] * x.shape[1]` computes the symbolic dimension `4 * b` as the +new shape: + +```python +>>> f = lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],)) +>>> arg_spec = jax.ShapeDtypeStruct(export.symbolic_shape("b, 4"), jnp.int32) +>>> exp = export.export(jax.jit(f))(arg_spec) +>>> exp.out_avals +(ShapedArray(int32[4*b]),) + +``` + +It is possible to convert dimension expressions explicitly +to JAX arrays, with `jnp.array(x.shape[0])` or even `jnp.array(x.shape)`. +The result of these operations can be used as regular JAX arrays, +bug cannot be used anymore as dimensions in shapes. + +```python +>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))( +... jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32)) +>>> exp.call(jnp.arange(3, dtype=np.int32)) +Array([3, 4, 5], dtype=int32) + +>>> exp = export.export(jax.jit(lambda x: x.reshape(jnp.array(x.shape[0]) + 2)))( +... jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Tracedwith]. + +``` + +When a symbolic dimension is used in arithmetic operations with **non-integers**, +e.g., `float`, `np.float`, `np.ndarray`, or JAX arrays, it is automatically +converted to a JAX array using `jnp.array`. +For example, in the function below all occurrences of `x.shape[0]` +are converted implicitly to `jnp.array(x.shape[0])` because +they are involved in operations with non-integer scalars or with +JAX arrays: + +```python +>>> exp = export.export(jax.jit( +... lambda x: (5. + x.shape[0], +... x.shape[0] - np.arange(5, dtype=jnp.int32), +... x + x.shape[0] + jnp.sin(x.shape[0]))))( +... jax.ShapeDtypeStruct(export.symbolic_shape("b"), jnp.int32)) +>>> exp.out_avals +(ShapedArray(float32[], weak_type=True), + ShapedArray(int32[5]), + ShapedArray(float32[b], weak_type=True)) + +>>> exp.call(jnp.ones((3,), jnp.int32)) + (Array(8., dtype=float32, weak_type=True), + Array([ 3, 2, 1, 0, -1], dtype=int32), + Array([4.14112, 4.14112, 4.14112], dtype=float32, weak_type=True)) + +``` + +Another typical example is when computing averages +(observe how `x.shape[0]` is automatically turned into a JAX array): + +```python +>>> exp = export.export(jax.jit( +... lambda x: jnp.sum(x, axis=0) / x.shape[0]))( +... jax.ShapeDtypeStruct(export.symbolic_shape("b, c"), jnp.int32)) +>>> exp.call(jnp.arange(12, dtype=jnp.int32).reshape((3, 4))) +Array([4., 5., 6., 7.], dtype=float32) + +``` + +### Errors in presence of shape polymorphism + +Most JAX code assumes that the shapes of JAX arrays are tuples of integers, +but with shape polymorphism some dimensions may be symbolic expressions. +This can lead to a number of errors. For example, we can have the usual +JAX shape check errors: + +```python +>>> v, = export.symbolic_shape("v,") +>>> export.export(jax.jit(lambda x, y: x + y))( +... jax.ShapeDtypeStruct((v,), dtype=np.int32), +... jax.ShapeDtypeStruct((4,), dtype=np.int32)) +Traceback (most recent call last): +TypeError: add got incompatible shapes for broadcasting: (v,), (4,). + +>>> export.export(jax.jit(lambda x: jnp.matmul(x, x)))( +... jax.ShapeDtypeStruct((v, 4), dtype=np.int32)) +Traceback (most recent call last): +TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (v,). + +``` + +We can fix the above matmul example by specifying that the +argument has shape `(v, v)`. + +### Comparison of symbolic dimensions is partially supported + +Inside JAX there are a number of equality and inequality comparisons +involving shapes, e.g., for doing shape checking or even for choosing +the implementation for some primitives. Comparisons are supported +as follows: + + * equality is supported with a caveat: if the two symbolic dimensions denote the same + value under all valuations for dimension variables, then equality evaluates to `True`, + e.g., for `b + b == 2*b`; otherwise the equality evaluates to `False`. + See [below](#caveat-for-equality-comparisons) + for a discussion of important consequences of this behavior. + * disequality is always the negation of equality. + * inequality is partially supported, in a similar way as partial equality. + However, in this + case we take into consideration that dimension variables range over strictly positive + integers. E.g., `b >= 1`, `b >= 0`, `2 * a + b >= 3` are `True`, while `b >= 2`, + `a >= b`, `a - b >= 0` are inconclusive and result in an exception. + +In cases where a comparison operation cannot be resolve to a boolean, +we raise {class}`InconclusiveDimensionOperation`. E.g., + +```python +import jax +>>> export.export(jax.jit(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1))( +... jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'a + 1' >= 'b' is inconclusive. +This error arises for comparison operations with shapes that +are non-constant, and the result of the operation cannot be represented as +a boolean value for all values of the symbolic dimensions involved. + +``` + +If you do get a `InconclusiveDimensionOperation`, you can try +several strategies: + + * If your code uses the built-in `max` or `min`, or the + `np.max` or `np.min` then you can replace those with + `core.max_dim` and `core.min_dim`, which have the effect + of delaying the inequality comparison to the compilation + time, when shapes become known. + * Try to rewrite conditionals using `core.max_dim` and + `core.min_dim`, e.g., instead of `d if d > 0 else 0` + you can write `core.max_dim(d, 0)`. + * Try to rewrite the code to be less dependent on the fact + that dimensions should be integers, and rely on the fact + that symbolic dimensions duck-type as integers for most + arithmetic operations. E.g., instead of `int(d) + 5` write + `d + 5`. + * Specify symbolic constraints, as explained below. + +#### User-specified symbolic constraints + +By default, JAX assumes that all dimension variables range +over values greater-or-equal to 1, and it tries to derive +other simple inequalities from that, e.g.: + + * `a + 2 >= 3`, + * `a * 2 >= 1`, + * `a + b + c >= 3`, + * `a // 4 >= 0`, `a**2 >= 1`, and so on. + +You can avoid some inequality comparison failures if you +change the symbolic shape specifications to add **implicit** constraints +for dimension sizes. E.g., + + * You can use `2*b` for a dimension to constrain it to be even and greater or equal + to 2. + * You can use `b + 15` for a dimension to constrain it to + be at least 16. E.g., the following code would fail without + the `+ 15` part, because JAX will want to verify that slice sizes + are at most as large as the axis size. + +```python +>>> _ = export.export(jax.jit(lambda x: x[0:16]))( +... jax.ShapeDtypeStruct(export.symbolic_shape("b + 15"), dtype=np.int32)) + +``` + +Such implicit symbolic constraints are used for deciding comparisons and are +checked at compile time, as explained [below](#shape-assertion-errors). + +You can also specify **explicit** symbolic constraints: + +```python +>>> # Introduce dimension variable with constraints. +>>> a, b = export.symbolic_shape("a, b", +... constraints=("a >= b", "b >= 16")) +>>> _ = export.export(jax.jit(lambda x: x[:x.shape[1], :16]))( +... jax.ShapeDtypeStruct((a, b), dtype=np.int32)) + +``` + +The constraints form a conjunction together with the implicit +constraints. You can specify `>=`, `<=`, and `==` constraints. +At the moment, JAX has limited support for reasoning with +symbolic constraints: + + * You get the most from constraints of the form + of a variable being greater-or-equal or + less-or-equal to a constant. + For example, from the constraints that + `a >= 16` and `b >= 8` we can infer + that `a + 2*b >= 32`. + * You get limited power when the constraint involves + more complex expressions, e.g., from `a >= b + 8` we + can infer that `a - b >= 8` but not that `a >= 9`. + We may improve somewhat this area in the future. + * Equality constraints are treated as normalization rules. + E.g., `floordiv(a, b) = c` works by replacing all + occurences of the left-hand-side with the right-hand-side. + You can only have equality constraints where the left-hand-side + is a multiplication of factors, e.g, `a * b`, or `4 * a`, or + `floordiv(a, b)`. Thus, the left-hand-side cannot contain + addition or subtraction at the top-level. + +The symbolic constraints can also help to work around the +limitations in the JAX reasoning mechanisms. +For example, in the code below JAX will attempt to prove that +the slice size `x.shape[0] % 3`, which is the symbolic expression +`mod(b, 3)`, is less or equal to the axis size, which is `b`. +This happens to be true for all strictly positive values of +`b`, but it is not something JAX's symbolic comparison rules +can prove. Hence the following code raises an error: + +```python +from jax import lax +>>> b, = export.symbolic_shape("b") +>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3) +>>> export.export(jax.jit(f))( +... jax.ShapeDtypeStruct((b,), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'b' >= 'mod(b, 3)' is inconclusive. +This error arises for comparison operations with shapes that +are non-constant, and the result of the operation cannot be represented as +a boolean value for all values of the symbolic dimensions involved. + +``` + +One option here would be to restrict the code to work only on +axis sizes that are multiple of `3` (by replacing +`b` with `3*b` in the shape). Then, JAX would be able +to simplify the modulo operation `mod(3*b, 3)` to `0`. +Another option is to add a symbolic constraint +with the exact inconclusive inequality that JAX +is attempting to prove: + +```python +>>> b, = export.symbolic_shape("b", +... constraints=["b >= mod(b, 3)"]) +>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3) +>>> _ = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct((b,), dtype=np.int32)) + +``` + +Just like the implicit constraints, the explicit +symbolic constraints are checked at compile time, +using the same mechanism as explained [below](#shape-assertion-errors). + +#### Symbolic dimension scopes + +The symbolic constraints are stored in αn +{class}`jax.export.SymbolicScope` object, which is created implicitly +for each call to {func}`jax.export.symbolic_shapes`. You must be careful +to not mix symbolic expressions that use different scopes. +For example, +the following code will fail because `a1` and `a2` +use different scopes (created by different invocations of +{func}`jax.export.symbolic_shape`): + +```python +>>> a1, = export.symbolic_shape("a,") +>>> a2, = export.symbolic_shape("a,", constraints=("a >= 8",)) + +>>> a1 + a2 # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Invalid mixing of symbolic scopes for linear combination. +Expected scope 4776451856 created at :1:6 () +and found for 'a' (unknown) scope 4776979920 created at :1:6 () with constraints: + a >= 8 +``` + +The symbolic expressions that originate from a single call +to {func}`jax.export.symbolic_shape` share a scope and +can be mixed up in arithmetic operations. The result would +also share the same scope. + +You can re-use scopes: + +```python +>>> a, = export.symbolic_shape("a,", constraints=("a >= 8",)) +>>> b, = export.symbolic_shape("b,", scope=a.scope) # Reuse the scope of `a` + +>>> a + b # Allowed +b + a + +``` + +You can also create scopes explicitly: + +```python +>>> my_scope = export.SymbolicScope() +>>> c, = export.symbolic_shape("c", scope=my_scope) +>>> d, = export.symbolic_shape("d", scope=my_scope) +>>> c + d # Allowed +d + c + +``` + +JAX tracing uses caches keyed partially by shapes, and +symbolic shapes that are printed identically will be considered +distinct if they use different scopes. + +### Caveat for equality comparisons + +The equality comparison returns `False` for `b + 1 == b` or `b == 0` +(in which case it is certain that the dimensions are different for all values +of the dimension variables), +but also for `b == 1` and for `a == b`. This is unsound, and we +ought to raise `core.InconclusiveDimensionOperation` because under +some valuations the result should be `True` and under other +valuations it should be `False`. We choose to make equality total +thus allowing unsoundness because otherwise we may get spurious errors +in presence of hash collisions +when hashing dimension expressions or objects that include +them (shapes, `core.AbstractValue`, `core.Jaxpr`). +Besides the hashing errors, a partial semantics of equality +leads to errors for the following expressions `b == a or b == b` or `b in [a, b]` +even though the error is avoided if we change the order of the comparisons. + +Code of the form `if x.shape[0] != 1: raise NiceErrorMessage` is sound even +with this treatment of equality, but code of the form `if x.shape[0] != 1: return 1` +is unsound. + +### Dimension variables must be solvable from the input shapes + +Currently, the only way to pass the values of dimension variables +when an exported object is invoked is indirectly through the shapes +of the array arguments. E.g., the value of `b` can be inferred at the +call site from the shape of the first argument of type `f32[b]`. +This works well for most use cases, and +it mirrors the calling convention of JIT functions. + +Sometimes you may want to export a function parameterized +by an integer values that determines some shapes in the program. +For example, we may +want to export the function `my_top_k` defined below, +parameterized by the +value of `k`, which determined the shape of the result. +The following attempt will lead to an error since the dimension +variable `k` cannot be derived from the shape of the input `x: i32[4, 10]`: + +```python +>>> def my_top_k(k, x): # x: i32[4, 10], k <= 10 +... return lax.top_k(x, k)[0] # : i32[4, 3] +>>> x = np.arange(40, dtype=np.int32).reshape((4, 10)) + +>>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`. +>>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x) +>>> exp_static_k.in_avals[0] +ShapedArray(int32[4,10]) + +>>> exp_static_k.out_avals[0] +ShapedArray(int32[4,3]) + +>>> # When calling the exported function we pass only the non-static arguments +>>> exp_static_k.call(x) +Array([[ 9, 8, 7], + [19, 18, 17], + [29, 28, 27], + [39, 38, 37]], dtype=int32) + +>>> # Now attempt to export with symbolic `k` so that we choose `k` after export. +>>> k, = export.symbolic_shape("k", constraints=["k <= 10"]) +>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +KeyError: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments + +``` + +In the future, we may add an additional mechanism to pass the values of +dimension variables, besides implicitly through the input shapes. +Meanwhile, the workaround for the above use case is to replace the +function parameter `k` with an array of shape `(0, k)`, so that +`k` can be derived from the input shape of an array. +The first dimension is 0 to ensure that the whole array is empty +and there is no performance penalty when we call the exported function. + +```python +>>> def my_top_k_with_dimensions(dimensions, x): # dimensions: i32[0, k], x: i32[4, 10] +... return my_top_k(dimensions.shape[1], x) +>>> exp = export.export(jax.jit(my_top_k_with_dimensions))( +... jax.ShapeDtypeStruct((0, k), dtype=np.int32), +... x) +>>> exp.in_avals +(ShapedArray(int32[0,k]), ShapedArray(int32[4,10])) + +>>> exp.out_avals[0] +ShapedArray(int32[4,k]) + +>>> # When we invoke `exp` we must construct and pass an array of shape (0, k) +>>> exp.call(np.zeros((0, 3), dtype=np.int32), x) +Array([[ 9, 8, 7], + [19, 18, 17], + [29, 28, 27], + [39, 38, 37]], dtype=int32) + +``` + +Another situation when you may get an error is when some dimension +variables do appear in the input shapes, but in a non-linear +expression that JAX cannot currently solve: + +```python +>>> a, = export.symbolic_shape("a") +>>> export.export(jax.jit(lambda x: x.shape[0]))( +... jax.ShapeDtypeStruct((a * a,), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Cannot solve for values of dimension variables {'a'}. +We can only solve linear uni-variate constraints. +Using the following polymorphic shapes specifications: args[0].shape = (a^2,). +Unprocessed specifications: 'a^2' for dimension size args[0].shape[0]. + +``` + +### Shape assertion errors + +JAX assumes that dimension variables range over strictly positive integers, +and this assumption is checked when the code is compiled for concrete +input shapes. + +For example, given the symbolic input shape `(b, b, 2*d)`, +JAX will generate code to check the following assertions when +invoked with actual argument `arg`: + + * `arg.shape[0] >= 1` + * `arg.shape[1] == arg.shape[0]` + * `arg.shape[2] % 2 == 0` + * `arg.shape[2] // 2 >= 1` + +For example, here is the error we get when we call the exported +on an argument of shape `(3, 3, 5)`: + +```python +>>> def f(x): # x: f32[b, b, 2*d] +... return x +>>> exp = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct(export.symbolic_shape("b, b, 2*d"), dtype=np.int32)) +>>> exp.call(np.ones((3, 3, 5), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Input shapes do not match the polymorphic shapes specification. +Division had remainder 1 when computing the value of 'd'. +Using the following polymorphic shapes specifications: + args[0].shape = (b, b, 2*d). +Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), . +Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details. + +``` + +These errors arise in a pre-processing step before the +compilation. + +### Division of symbolic dimensions is partially supported + +JAX will attempt to simplify division and modulo operations, +e.g., `(a * b + a) // (b + 1) == a` and `6*a + 4 % 3 == 1`. +In particular, JAX will handle the cases when either (a) there +is no remainder, or (b) the divisor is a constant +in which case there may be a constant remainder. + +For example, the code below results in a division error when trying to +compute the inferred dimension for a `reshape` operation: + +```python +>>> b, = export.symbolic_shape("b") +>>> export.export(jax.jit(lambda x: x.reshape((2, -1))))( +... jax.ShapeDtypeStruct((b,), dtype=np.int32)) +Traceback (most recent call last): +jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2, -1). +The remainder mod(b, - 2) should be 0. + +``` + +Note that the following will succeed: + +```python +>>> b, = export.symbolic_shape("b") +>>> # We specify that the first dimension is a multiple of 4 +>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( +... jax.ShapeDtypeStruct((4*b,), dtype=np.int32)) +>>> exp.out_avals +(ShapedArray(int32[2,2*b]),) + +>>> # We specify that some other dimension is even +>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( +... jax.ShapeDtypeStruct((b, 5, 6), dtype=np.int32)) +>>> exp.out_avals +(ShapedArray(int32[2,15*b]),) + +``` + diff --git a/docs/faq.rst b/docs/faq.rst index d654d67b9d5e..3b63128d2c28 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -686,7 +686,7 @@ The inner ``jnp.where`` may be needed in addition to the original one, e.g.:: def my_log_or_y(x, y): """Return log(x) if x > 0 or y""" - return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.), y) + return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y) Additional reading: @@ -849,4 +849,4 @@ see the page on `JAX GPU memory allocation`_. .. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function .. _algebraic_simplifier.cc: https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/compiler/xla/service/algebraic_simplifier.cc#L3266 .. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html -.. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html \ No newline at end of file +.. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index ee9f63b924ad..1f5cc0727605 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -1,5 +1,7 @@ # GPU performance tips + + This document focuses on performance tips for neural network workloads ## Matmul precision @@ -23,6 +25,10 @@ code examples: ## XLA performance flags +```{note} + JAX-Toolbox also has a page on [NVIDIA XLA performance FLAGS](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/docs/GPU_performance.md). +``` + The existence and exact behavior of XLA flags may be `jaxlib`-version dependent. As of `jaxlib==0.4.18` (released [Oct 6 @@ -60,10 +66,6 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta ### Communication flags -* **--xla_gpu_enable_async_collectives** This flag enables the collective ops - such as `AllReduce`, `AllGather`, `ReduceScatter` and `CollectivePermute` to - be asynchronous. Asynchronous communication can overlap cross-core - communication with computation. The default value is False. * **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False. @@ -107,3 +109,12 @@ os.environ.update({ These NCCL flags could improve single-host communication speed. These flags don't seem useful for multi-host communication yet. + +## Multi-Process + +We recommand using one process per GPU and not one per node. In some +cases, this can speed up jitted computation. The +{func}`jax.distributed.initialize` API will automatically understand +that configuration when run under SLURM. However, this only a rule of +thumb and it may be useful to test both one process per GPU and one +process per node on your use case. diff --git a/docs/installation.md b/docs/installation.md index ed3639a6086c..fa77d1fc29f6 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,6 +1,8 @@ (installation)= # Installing JAX + + Using JAX requires installing two packages: `jax`, which is pure Python and cross-platform, and `jaxlib` which contains compiled binaries, and requires different builds for different operating systems and accelerators. @@ -9,20 +11,16 @@ different builds for different operating systems and accelerators. * **CPU-only (Linux/macOS/Windows)** ``` - pip install -U "jax[cpu]" + pip install -U jax ``` -* **GPU (NVIDIA, CUDA 12, x86_64)** +* **GPU (NVIDIA, CUDA 12)** ``` pip install -U "jax[cuda12]" ``` -* **GPU (NVIDIA, CUDA 12, x86_64) legacy** - -You should prefer `jax[cuda12]`, which uses the common CPU jaxlib and adds GPU -support as a plugin. The monolithic `jax[cuda12_pip]` option will be removed in -a future JAX release. +* **TPU (Google Cloud TPU VM)** ``` - pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` (install-supported-platforms)= @@ -48,6 +46,7 @@ Currently, the JAX team releases `jaxlib` wheels for the following operating systems and architectures: - Linux, x86_64 +- Linux, aarch64 - macOS, Intel - macOS, Apple ARM-based - Windows, x86_64 (*experimental*) @@ -57,7 +56,7 @@ development on a laptop, you can run: ```bash pip install --upgrade pip -pip install --upgrade "jax[cpu]" +pip install --upgrade jax ``` On Windows, you may also need to install the @@ -97,8 +96,8 @@ There are two ways to install JAX with NVIDIA GPU support: The JAX team strongly recommends installing CUDA and cuDNN using the pip wheels, since it is much easier! -This method is only supported on x86_64, because NVIDIA has not released aarch64 -CUDA pip packages. +NVIDIA has released CUDA pip packages only for x86_64 and aarch64; on other +platforms you must use a local installation of CUDA. ```bash pip install --upgrade pip @@ -106,11 +105,6 @@ pip install --upgrade pip # NVIDIA CUDA 12 installation # Note: wheels only available on linux. pip install --upgrade "jax[cuda12]" - -# Legacy way of NVIDIA CUDA 12 installation. You should prefer `jax[cuda12]`, -# which uses the common CPU jaxlib and adds GPU support as a plugin. The -# monolithic `jax[cuda12_pip]` option will be removed in a future JAX release. -pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things @@ -127,7 +121,7 @@ If you prefer to use a preinstalled copy of NVIDIA CUDA, you must first install NVIDIA [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/CUDNN). -JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other +JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 and Linux aarch64 only**. Other combinations of operating system and architecture are possible, but require building from source (refer to {ref}`building-from-source` to learn more}. @@ -141,11 +135,11 @@ that NVIDIA provides for this purpose. JAX currently ships one CUDA wheel variant: -| Built with | Compatible with | -|------------|-------------------| -| CUDA 12.3 | CUDA >=12.1 | -| CUDNN 8.9 | CUDNN >=8.9, <9.0 | -| NCCL 2.19 | NCCL >=2.18 | +| Built with | Compatible with | +|------------|--------------------| +| CUDA 12.3 | CUDA >=12.1 | +| CUDNN 9.0 | CUDNN >=9.0, <10.0 | +| NCCL 2.19 | NCCL >=2.18 | JAX checks the versions of your libraries, and will report an error if they are not sufficiently new. @@ -161,9 +155,9 @@ To install, run: ```bash pip install --upgrade pip -# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 8.9 or newer. +# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer. # Note: wheels only available on linux. -pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pip install --upgrade "jax[cuda12_local]" ``` **These `pip` installations do not work with Windows, and may fail silently; refer to the table @@ -179,6 +173,9 @@ JAX uses `LD_LIBRARY_PATH` to find CUDA libraries and `PATH` to find binaries (`ptxas`, `nvlink`). Please make sure that these paths point to the correct CUDA installation. +JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package. +Make sure that it is present in your CUDA installation. + Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues) if you run into any errors or problems with the pre-built wheels. @@ -190,43 +187,6 @@ Toolbox](https://github.com/NVIDIA/JAX-Toolbox) containers, which are bleeding edge containers containing nightly releases of jax and some models/frameworks. -## JAX nightly installation - -Nightly releases reflect the state of the main JAX repository at the time they are -built, and may not pass the full test suite. - -- `jax`: - -```bash -pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -``` - -- `jaxlib` CPU: - -```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -``` - -- `jaxlib` Google Cloud TPU: - -```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -``` - -- `jaxlib` NVIDIA GPU (CUDA 12): - -```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html -``` - -- `jaxlib` NVIDIA GPU (CUDA 12) legacy: - -```bash -pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html -``` - (install-google-tpu)= ## Google Cloud TPU @@ -304,6 +264,39 @@ Go to the `conda-forge` for more details. +## JAX nightly installation + +Nightly releases reflect the state of the main JAX repository at the time they are +built, and may not pass the full test suite. + +- CPU only: + +```bash +pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +``` + +- Google Cloud TPU: + +```bash +pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +``` + +- NVIDIA GPU (CUDA 12): + +```bash +pip install -U --pre jax[cuda12] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +``` + +- NVIDIA GPU (CUDA 12) legacy: + +Use the following for historical nightly releases of monolithic CUDA jaxlibs. +You most likely do not want this; no further monolithic CUDA jaxlibs will be +built and those that exist will expire by Sep 2024. Use the "CUDA 12" option above. + +```bash +pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html +``` + (building-jax-from-source)= ## Building JAX from source diff --git a/docs/investigating_a_regression.md b/docs/investigating_a_regression.md index fb5293a82274..4affae3a65d8 100644 --- a/docs/investigating_a_regression.md +++ b/docs/investigating_a_regression.md @@ -1,6 +1,8 @@ (investigating-a-regression)= # Investigating a regression + + So you updated JAX and you hit a speed regression? You have a little bit of time and are ready to investigate this? Let's first make a JAX issue. diff --git a/docs/jax.export.rst b/docs/jax.export.rst new file mode 100644 index 000000000000..6caecc5a085d --- /dev/null +++ b/docs/jax.export.rst @@ -0,0 +1,47 @@ +``jax.export`` module +===================== + +.. automodule:: jax.export + +Classes +------- + +.. autosummary:: + :toctree: _autosummary + + Exported + DisabledSafetyCheck + +Functions +--------- + +.. autosummary:: + :toctree: _autosummary + + export + deserialize + minimum_supported_calling_convention_version + maximum_supported_calling_convention_version + default_export_platform + +Functions related to shape polymorphism +--------------------------------------- + +.. autosummary:: + :toctree: _autosummary + + symbolic_shape + symbolic_args_specs + is_symbolic_dim + SymbolicScope + +Constants +--------- + +.. data:: jax.export.minimum_supported_serialization_version + + The minimum supported serialization version; see :ref:`export-calling-convention-version`. + +.. data:: jax.export.maximum_supported_serialization_version + + The maximum supported serialization version; see :ref:`export-calling-convention-version`. diff --git a/docs/jax.extend.ffi.rst b/docs/jax.extend.ffi.rst new file mode 100644 index 000000000000..070778b8f065 --- /dev/null +++ b/docs/jax.extend.ffi.rst @@ -0,0 +1,10 @@ +``jax.extend.ffi`` module +========================= + +.. automodule:: jax.extend.ffi + +.. autosummary:: + :toctree: _autosummary + + ffi_lowering + pycapsule diff --git a/docs/jax.extend.rst b/docs/jax.extend.rst index 3b4ec41ea680..9cbee08e8e50 100644 --- a/docs/jax.extend.rst +++ b/docs/jax.extend.rst @@ -11,6 +11,7 @@ Modules .. toctree:: :maxdepth: 1 + jax.extend.ffi jax.extend.linear_util jax.extend.mlir jax.extend.random diff --git a/docs/jax.nn.rst b/docs/jax.nn.rst index 67d80e1605d3..33223ee755e5 100644 --- a/docs/jax.nn.rst +++ b/docs/jax.nn.rst @@ -23,6 +23,7 @@ Activation functions sigmoid softplus sparse_plus + sparse_sigmoid soft_sign silu swish diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 5ccf043d2282..b96dfcdfb208 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -495,6 +495,7 @@ jax.numpy.linalg tensordot tensorinv tensorsolve + trace vector_norm vecdot diff --git a/docs/jax.rst b/docs/jax.rst index 9979fb464435..b112490a0912 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -27,6 +27,7 @@ Subpackages jax.tree jax.tree_util jax.typing + jax.export jax.extend jax.example_libraries jax.experimental diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index a60b7397019d..f6d8a151440b 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -174,7 +174,9 @@ jax.scipy.special i0e i1 i1e + kl_div log_ndtr + log_softmax logit logsumexp lpmn @@ -184,13 +186,13 @@ jax.scipy.special ndtri poch polygamma + rel_entr + softmax spence sph_harm xlog1py xlogy zeta - kl_div - rel_entr jax.scipy.stats diff --git a/docs/jax.sharding.rst b/docs/jax.sharding.rst index 7b1393d8e2c4..954f62b8a52d 100644 --- a/docs/jax.sharding.rst +++ b/docs/jax.sharding.rst @@ -10,9 +10,6 @@ Classes .. autoclass:: Sharding :members: -.. autoclass:: XLACompatibleSharding - :members: - :show-inheritance: .. autoclass:: SingleDeviceSharding :members: :show-inheritance: diff --git a/docs/jax.stages.rst b/docs/jax.stages.rst index f8adce32b7c6..804019ee1cc6 100644 --- a/docs/jax.stages.rst +++ b/docs/jax.stages.rst @@ -9,7 +9,7 @@ Classes .. currentmodule:: jax.stages .. autoclass:: Wrapped - :members: lower + :members: trace, lower :special-members: __call__ .. autoclass:: Lowered diff --git a/docs/jax.tree_util.rst b/docs/jax.tree_util.rst index 8d892f9f4c7f..35bce340d4de 100644 --- a/docs/jax.tree_util.rst +++ b/docs/jax.tree_util.rst @@ -20,19 +20,27 @@ List of Functions register_pytree_with_keys register_pytree_with_keys_class register_static + tree_flatten_with_path + tree_leaves_with_path + tree_map_with_path + treedef_children + treedef_is_leaf + treedef_tuple + keystr + +Legacy APIs +----------- +These APIs are now accessed via :mod:`jax.tree`. + +.. autosummary:: + :toctree: _autosummary + tree_all tree_flatten - tree_flatten_with_path tree_leaves - tree_leaves_with_path tree_map - tree_map_with_path tree_reduce tree_structure tree_transpose tree_unflatten - treedef_children - treedef_is_leaf - treedef_tuple - keystr diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index 860197becf6e..95d4a632a295 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -1,6 +1,8 @@ (jax-array-migration)= # jax.Array migration + + **yashkatariya@** ## TL;DR diff --git a/docs/jep/9263-typed-keys.md b/docs/jep/9263-typed-keys.md index 925fc2c47fa2..828b95e8ce00 100644 --- a/docs/jep/9263-typed-keys.md +++ b/docs/jep/9263-typed-keys.md @@ -21,6 +21,7 @@ Array([0, 0], dtype=uint32) (2,) >>> key.dtype dtype('uint32') + ``` Starting now, new-style RNG keys can be created with {func}`jax.random.key`: @@ -33,6 +34,7 @@ Array((), dtype=key) overlaying: () >>> key.dtype key + ``` This (scalar-shaped) array behaves the same as any other JAX array, except that its element type is a key (and associated metadata). We can make @@ -48,6 +50,7 @@ Array((4,), dtype=key) overlaying: [0 3]] >>> key_arr.shape (4,) + ``` Aside from switching to a new constructor, most PRNG-related code should continue to work as expected. You can continue to use keys in @@ -62,14 +65,17 @@ data = jax.random.uniform(key, shape=(5,)) However, not all numerical operations work on key arrays. They now intentionally raise errors: ```python ->>> key = key + 1 -ValueError: dtype=key is not a valid dtype for JAX type promotion. +>>> key = key + 1 # doctest: +SKIP +Traceback (most recent call last): +TypeError: add does not accept dtypes key, int32. + ``` If for some reason you need to recover the underlying buffer (the old-style key), you can do so with {func}`jax.random.key_data`: ```python >>> jax.random.key_data(key) Array([0, 0], dtype=uint32) + ``` For old-style keys, {func}`~jax.random.key_data` is an identity operation. @@ -108,6 +114,7 @@ True >>> raw_key = jax.random.PRNGKey(0) >>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key) False + ``` ### Type annotations for PRNG Keys @@ -173,6 +180,7 @@ Array((), dtype=key) overlaying: [0 0 0 0] >>> jax.random.uniform(key, shape=(3,)) Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32) + ``` ### Safe PRNG key use @@ -322,6 +330,7 @@ which has the following property: ```python >>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended) True + ``` PRNG key arrays then have a dtype with the following properties: ```python @@ -330,6 +339,7 @@ PRNG key arrays then have a dtype with the following properties: True >>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key) True + ``` And in addition to `key.dtype._rules` as outlined for extended dtypes in general, PRNG dtypes define `key.dtype._impl`, which contains the metadata diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index ddf99db90aab..2d442c8411aa 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -22,6 +22,8 @@ kernelspec: (jit-compilation)= # Just-in-time compilation + + In this section, we will further explore how JAX works, and how we can make it performant. We will discuss the {func}`jax.jit` transformation, which will perform *Just In Time* (JIT) compilation of a JAX Python function so it can be executed efficiently in XLA. diff --git a/docs/key-concepts.md b/docs/key-concepts.md index 90a491b6a0c7..4b114c857460 100644 --- a/docs/key-concepts.md +++ b/docs/key-concepts.md @@ -15,6 +15,8 @@ kernelspec: (key-concepts)= # Key Concepts + + This section briefly introduces some key concepts of the JAX package. (key-concepts-jax-arrays)= diff --git a/docs/multi_process.md b/docs/multi_process.md index 2405a208b60e..7d7083bde10f 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -1,5 +1,7 @@ # Using JAX in multi-host and multi-process environments + + ## Introduction This guide explains how to use JAX in environments such as @@ -28,12 +30,15 @@ Key concepts: * Each process has a distinct set of *local* devices it can address. The *global* devices are the set of all devices across all processes. - * Use standard JAX parallelism APIs like {func}`~jax.pmap` and - {func}`~jax.experimental.maps.xmap`. Each process “sees” *local* input and - output to parallelized functions, but communication inside the computations - is *global*. + * Use standard JAX parallelism APIs like {func}`~jax.jit` (see + {doc}`/sharded-computation` tutorial) and + {func}`~jax.experimental.shard_map.shard_map`. jax.jit only accepts + globally shaped arrays. shard_map allows you to drop to per-device + shape. * Make sure all processes run the same parallel computations in the same order. + * Make sure all processes has the same number of local devices. + * Make sure all devices are the same (e.g., all V100, or all H100). ### Launching JAX processes @@ -123,18 +128,13 @@ global devices. So how do you actually run a computation involving cross-process communication? **Use the same parallel evaluation APIs that you would in a single process!** -For example, {func}`~jax.experimental.shard_map.shard_map` can be used to -run a parallel computation across -multiple processes. (If you’re not already familiar with how to use -`shard_map` to run across multiple devices within a single process, check -out the {doc}`/sharded-computation` tutorial.) Each process should call the -same pmapped function and pass in arguments to be mapped across its *local* -devices (i.e., the pmapped axis size is equal to the number of local devices). -Similarly, the function will return outputs sharded across *local* devices only. -Inside the function, however, collective communication operations are run across -all *global* devices, across all processes. Conceptually, this can be thought of -as running a pmap over a single array sharded across hosts, where each host -“sees” only its local shard of the input and output. +For example, {func}`~jax.experimental.shard_map.shard_map` can be used +to run a parallel computation across multiple processes. (If you’re +not already familiar with how to use `shard_map` to run across +multiple devices within a single process, check out the +{doc}`/sharded-computation` tutorial.) Conceptually, this can be +thought of as running a pmap over a single array sharded across hosts, +where each host “sees” only its local shard of the input and output. Here’s an example of multi-process pmap in action: @@ -152,12 +152,6 @@ Here’s an example of multi-process pmap in action: ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32) ``` -{func}`~jax.experimental.maps.xmap` works similarly when using a physical -hardware mesh (see the {doc}`xmap tutorial` if you’re -not familiar with the single-process version). Like {func}`~jax.pmap` , the -inputs and outputs are local and any parallel communication inside the xmapped -function is global. The mesh is also global. - **It’s very important that all processes run the same cross-process computations in the same order.** Running the same JAX Python program in each process is usually sufficient. Some common pitfalls to look out for that may cause diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index d8dffdb8a2f1..2665e25fdd43 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -8,6 +8,8 @@ "source": [ "# 🔪 JAX - The Sharp Bits 🔪\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)" ] }, @@ -2223,6 +2225,7 @@ "\n", " >>> jnp.arange(254.0, 258.0).astype('uint8')\n", " Array([254, 255, 255, 255], dtype=uint8)\n", + "\n", " ```\n", " This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n", "\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index e63d64d94e77..58fcb4310bc7 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -16,6 +16,8 @@ kernelspec: # 🔪 JAX - The Sharp Bits 🔪 + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) +++ {"id": "4k5PVzEo2uJO"} @@ -1143,6 +1145,7 @@ Many such cases are discussed in detail in the sections above; here we list seve >>> jnp.arange(254.0, 258.0).astype('uint8') Array([254, 255, 255, 255], dtype=uint8) + ``` This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa. diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index 70f499a89a5d..3abb6d9cbaec 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -8,6 +8,8 @@ "source": [ "# Custom derivative rules for JAX-transformable Python functions\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "\n", "*mattjj@ Mar 19 2020, last updated Oct 14 2020*\n", diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 02b9e08274f4..ad577d55cd0d 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -15,6 +15,8 @@ kernelspec: # Custom derivative rules for JAX-transformable Python functions + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) *mattjj@ Mar 19 2020, last updated Oct 14 2020* diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 54eb00d780ab..2face1d4a0b2 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -6,7 +6,9 @@ "id": "PxHrg4Cjuapm" }, "source": [ - "# Distributed arrays and automatic parallelization" + "# Distributed arrays and automatic parallelization\n", + "\n", + "" ] }, { diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index dad695d121be..b9ec9dc694d2 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -15,6 +15,8 @@ kernelspec: # Distributed arrays and automatic parallelization + + +++ {"id": "pFtQjv4SzHRj"} [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb index 89e774d84321..f42e3f74b4e3 100644 --- a/docs/notebooks/How_JAX_primitives_work.ipynb +++ b/docs/notebooks/How_JAX_primitives_work.ipynb @@ -8,6 +8,8 @@ "source": [ "# How JAX primitives work\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n", "\n", "*necula@google.com*, October 2019.\n", diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md index 17a7379dd677..0ebf202f2258 100644 --- a/docs/notebooks/How_JAX_primitives_work.md +++ b/docs/notebooks/How_JAX_primitives_work.md @@ -15,6 +15,8 @@ kernelspec: # How JAX primitives work + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) *necula@google.com*, October 2019. diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index b48ac353c505..f0c157655790 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -8,6 +8,8 @@ "source": [ "# Training a Simple Neural Network, with PyTorch Data Loading\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n", "\n", "**Copyright 2018 The JAX Authors.**\n", diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index 8fb2d4f06a45..2c53bb1e4ab5 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -16,6 +16,8 @@ kernelspec: # Training a Simple Neural Network, with PyTorch Data Loading + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) **Copyright 2018 The JAX Authors.** diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 1a1a77eb9ee3..7e65aefe359c 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -8,6 +8,8 @@ "source": [ "# Writing custom Jaxpr interpreters in JAX\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)" ] }, diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 866eeffe1eb1..e52c6a5f8742 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -16,6 +16,8 @@ kernelspec: # Writing custom Jaxpr interpreters in JAX + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) +++ {"id": "r-3vMiKRYXPJ"} diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index d0b8fe0c0c23..edfd0d4535f8 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -8,6 +8,8 @@ "source": [ "# The Autodiff Cookbook\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n", "\n", "*alexbw@, mattjj@* \n", diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index 0d08e5061f13..c24d05c0e7c9 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -16,6 +16,8 @@ kernelspec: # The Autodiff Cookbook + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) *alexbw@, mattjj@* diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index 9aec8b1a23df..f0552e52688f 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -6,7 +6,9 @@ "id": "29WqUVkCXjDD" }, "source": [ - "## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)" + "## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)\n", + "\n", + "" ] }, { diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index 3b8d6218a56a..b31e093b6f91 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -15,6 +15,8 @@ kernelspec: ## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`) + + ```{code-cell} import jax import jax.numpy as jnp diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb index c4ef1961bbd2..0a823353068b 100644 --- a/docs/notebooks/convolutions.ipynb +++ b/docs/notebooks/convolutions.ipynb @@ -8,6 +8,8 @@ "source": [ "# Generalized Convolutions in JAX\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)\n", "\n", "JAX provides a number of interfaces to compute convolutions across data, including:\n", diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 4216e0ffc744..3de8f261aa5b 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -16,6 +16,8 @@ kernelspec: # Generalized Convolutions in JAX + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb) JAX provides a number of interfaces to compute convolutions across data, including: diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb index 5cda80620961..bdf71004c01b 100644 --- a/docs/notebooks/external_callbacks.ipynb +++ b/docs/notebooks/external_callbacks.ipynb @@ -6,7 +6,9 @@ "id": "7XNMxdTwURqI" }, "source": [ - "# External Callbacks in JAX" + "# External Callbacks in JAX\n", + "\n", + "" ] }, { diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md index 582b3536e78c..857eef42e2b3 100644 --- a/docs/notebooks/external_callbacks.md +++ b/docs/notebooks/external_callbacks.md @@ -15,6 +15,8 @@ kernelspec: # External Callbacks in JAX + + +++ {"id": "h6lXo6bSUYGq"} This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation. diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index 8368fc3aae29..95c00bf1e689 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -38,6 +38,8 @@ "source": [ "# Training a Simple Neural Network, with tensorflow/datasets Data Loading\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n", "\n", "_Forked from_ `neural_network_and_data_loading.ipynb`\n", diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index e16c5ce25cd4..8f795484d5b9 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -36,6 +36,8 @@ limitations under the License. # Training a Simple Neural Network, with tensorflow/datasets Data Loading + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) _Forked from_ `neural_network_and_data_loading.ipynb` diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index e9fafa2c1f42..ed0a13d8702a 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -7,6 +7,8 @@ "source": [ "# SPMD multi-device parallelism with `shard_map`\n", "\n", + "\n", + "\n", "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n", "\n", "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 403463812bb5..67494cfd4a02 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -16,6 +16,8 @@ kernelspec: # SPMD multi-device parallelism with `shard_map` + + `shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations. `shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index bcaa1f42b9a7..1c1c9729b654 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -8,6 +8,8 @@ "source": [ "# How to Think in JAX\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n", "\n", "JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively." diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 1f25bdc4e305..14089fa36e32 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -15,6 +15,8 @@ kernelspec: # How to Think in JAX + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively. diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb index 833c3a40d145..96b334296667 100644 --- a/docs/notebooks/vmapped_log_probs.ipynb +++ b/docs/notebooks/vmapped_log_probs.ipynb @@ -8,6 +8,8 @@ "source": [ "# Autobatching for Bayesian Inference\n", "\n", + "\n", + "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n", "\n", "This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n", diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md index ac0864fb1c15..ea8b4fce2f70 100644 --- a/docs/notebooks/vmapped_log_probs.md +++ b/docs/notebooks/vmapped_log_probs.md @@ -16,6 +16,8 @@ kernelspec: # Autobatching for Bayesian Inference + + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs. diff --git a/docs/notebooks/xmap_tutorial.ipynb b/docs/notebooks/xmap_tutorial.ipynb index 4e216e58813e..a8eb76c353ed 100644 --- a/docs/notebooks/xmap_tutorial.ipynb +++ b/docs/notebooks/xmap_tutorial.ipynb @@ -8,6 +8,8 @@ "source": [ "# Named axes and easy-to-revise parallelism with `xmap`\n", "\n", + "\n", + "\n", "**_UPDATE:_** `xmap` is deprecated and will be removed in a future release. The recommended ways to do multi-device programming in JAX are using: 1) [`jit` (automatic partitioning of computation and `jax.Array` sharding)](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html); and/or 2) [`shard_map` (manual data sharding)](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Learn more in [Why don’t `pmap` or `xmap` already solve this?](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this) in the [`shard_map` JEP document](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html).\n", "\n", "This tutorial introduces JAX `xmap` (`jax.experimental.maps.xmap`) and the named-axis programming model that comes with it. By reading this, you'll learn how to write error-avoiding, self-documenting functions using named axes, then control how they're executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer.\n", diff --git a/docs/notebooks/xmap_tutorial.md b/docs/notebooks/xmap_tutorial.md index d4557078811b..c4b511dbe711 100644 --- a/docs/notebooks/xmap_tutorial.md +++ b/docs/notebooks/xmap_tutorial.md @@ -15,6 +15,8 @@ kernelspec: # Named axes and easy-to-revise parallelism with `xmap` + + **_UPDATE:_** `xmap` is deprecated and will be removed in a future release. The recommended ways to do multi-device programming in JAX are using: 1) [`jit` (automatic partitioning of computation and `jax.Array` sharding)](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html); and/or 2) [`shard_map` (manual data sharding)](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). Learn more in [Why don’t `pmap` or `xmap` already solve this?](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this) in the [`shard_map` JEP document](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html). This tutorial introduces JAX `xmap` (`jax.experimental.maps.xmap`) and the named-axis programming model that comes with it. By reading this, you'll learn how to write error-avoiding, self-documenting functions using named axes, then control how they're executed on hardware at any scale, from your laptop CPU to the largest TPU supercomputer. diff --git a/docs/pallas/design.md b/docs/pallas/design.md index 4fae9a11dfc7..caa1f1eb2b04 100644 --- a/docs/pallas/design.md +++ b/docs/pallas/design.md @@ -1,23 +1,96 @@ # Pallas Design -In this document, we explain the initial Pallas design. This is a snapshot of some of the earlier design decisions made and Pallas's specific APIs might have changed since. + -## Introduction - -JAX is being used for a diverse set of workloads, from large scale machine learning to scientific computing. JAX’s success story is as much a success story for XLA, the primary compiler that JAX targets – XLA compiles JAX programs for accelerators and has enabled JAX to scale to the largest ML models. JAX describes logical computations in XLA’s representation, HLO. HLO describes how computations happen logically but not physically. Given a logical HLO computation, XLA decides how that computation is to be executed physically. For a wide variety of ML applications, XLA does a good job of compiling user programs but inevitably some users hit XLA's limitations. In these cases, we need to provide an “escape hatch” to allow experts to write hand-tuned kernels that outperform XLA at that point in time. Furthermore, advances in ML systems research take some time to be incorporated into XLA and users often want to run ahead with them. Over time, the compiler can incorporate the optimizations that were proven out experimentally through hand-tuned kernels. - -XLA does offer the `CustomCall` mechanism as an escape hatch, but it requires users to write C++ and on GPU it requires users to learn the CUDA programming model. The CUDA programming model is arguably too low-level for many machine learning GPU kernels, like matrix multiplication, and even expert users will have trouble using CUDA to implement efficient matrix multiplication or multi-headed attention. Not only this, JAX users are usually familiar with Python and NumPy-style array programming which doesn’t involve writing any C++ or thinking about GPU parallelism. All popular machine learning frameworks share this idea: manipulating (usually) arrays with high level operations like `matmul` or `convolution`. Unfortunately, this means implementing a custom operation via `CustomCall` is a big investment, involving potentially learning C++ and/or GPU programming. - -[Triton](https://triton-lang.org/main/index.html), a GPU compiler built and maintained by OpenAI, has taken the ML compiler world by storm. Triton offers the best of both worlds: an array-based programming model for GPU kernels. Triton is the primary code generation route for `torch.compile` in PyTorch 2.0, via the Torch Inductor library. Triton actively hides some aspects of GPU programming in the name of a more accessible programming model that can be used from Python and to generate optimized code from a higher-level representation. While GPUs are more flexible than what Triton offers, in the ML domain, Triton seems to be expressive enough for many applications. +In this document, we explain the initial Pallas design. +This is a snapshot of some of the earlier design decisions made +and Pallas's specific APIs might have changed since. -In this document, we describe Pallas, an extension to JAX that enables kernel programming for both GPUs and TPUs using a Triton-like model. A JAX-based kernel language offers several advantages: -* Although Triton exposes a TPU-like programming model to users, i.e. writing programs for tiles of arrays in L1-cache, it is specialized enough to GPU that we cannot directly compile Triton for TPU. For example, Triton offers atomic operations specifically meant to handle parallel writes that don’t necessarily make sense on TPU. A higher level front end can abstract away details of the platform while surfacing just that tile-based programming model. The kernels will thus be portable across different hardware platforms. -* JAX as a tracing-based frontend for numerical computing is both mature and well-used. By embedding the kernel programming language in JAX itself, we can re-use JAX’s tracing infrastructure and provide a NumPy-like frontend that’s already familiar to users. -* JAX transformations are key to its success, allowing users to express simple programs but transform them to achieve complex functionality. We can leverage the same transformations (vmap, jvp, etc.) to transform user-written kernels. - -The open question is: is JAX a good fit for a kernel language at all? We think so. Triton demonstrates that an array programming language can be practical for writing GPU kernels and JAX is just that. JAX has also proven to be a flexible front-end for compilers and for program transformations. +## Introduction -We describe Pallas as follows: we first describe the ways in which we extend JAX to support writing custom kernels. We then show how we can lower Pallas to both Triton and Mosaic. We conclude by describing existing and potential ways to transform Pallas kernels via JAX transformations. +JAX is being used for a diverse set of workloads, from large scale machine +learning to scientific computing. +JAX’s success story is as much a success story for XLA, +the primary compiler that JAX targets – XLA compiles JAX +programs for accelerators and has enabled JAX to scale to the largest ML +models. +JAX describes logical computations in XLA’s representation, HLO. +HLO describes how computations happen logically but not physically. +Given a logical HLO computation, XLA decides how that computation is to be +executed physically. +For a wide variety of ML applications, XLA does a good +job of compiling user programs but inevitably some users hit XLA's +limitations. +In these cases, we need to provide an “escape hatch” to allow +experts to write hand-tuned kernels that outperform XLA at that +point in time. +Furthermore, advances in ML systems research take some time to be +incorporated into XLA and users often want to run ahead with them. +Over time, the compiler can incorporate the optimizations that were proven +out experimentally through hand-tuned kernels. + +XLA does offer the `CustomCall` mechanism as an escape hatch, but it +requires users to write C++ and on GPU it requires users to learn the +CUDA programming model. +The CUDA programming model is arguably too low-level for many machine +learning GPU kernels, like matrix multiplication, +and even expert users will have trouble using CUDA to implement efficient +matrix multiplication or multi-headed attention. +Not only this, JAX users are usually familiar with Python and NumPy-style +array programming which doesn’t involve writing any C++ or thinking about +GPU parallelism. +All popular machine learning frameworks share this +idea: manipulating (usually) arrays with high level operations +like `matmul` or `convolution`. +Unfortunately, this means implementing a custom operation via `CustomCall` +is a big investment, involving potentially learning C++ and/or GPU +programming. + +[Triton](https://triton-lang.org/main/index.html), a GPU compiler built +and maintained by OpenAI, has taken the ML compiler world by storm. +Triton offers the best of both worlds: an array-based programming model +for GPU kernels. Triton is the primary code generation route +for `torch.compile` in PyTorch 2.0, via the Torch Inductor library. +Triton actively hides some aspects of GPU programming in the name of a +more accessible programming model that can be used from Python and to +generate optimized code from a higher-level representation. +While GPUs are more flexible than what Triton offers, in the ML domain, +Triton seems to be expressive enough for many applications. + +In this document, we describe Pallas, an extension to JAX that enables +kernel programming for both GPUs and TPUs using a Triton-like model. +A JAX-based kernel language offers several advantages: +* Although Triton exposes a TPU-like programming model to users, + i.e. writing programs for tiles of arrays in L1-cache, it is specialized + enough to GPU that we cannot directly compile Triton for TPU. + For example, Triton offers atomic operations specifically meant to + handle parallel writes that don’t necessarily make sense on TPU. + A higher level front end can abstract away details of the platform + while surfacing just that tile-based programming model. + The kernels will thus be portable across different hardware platforms. +* JAX as a tracing-based frontend for numerical computing is both + mature and well-used. + By embedding the kernel programming language in JAX itself, + we can re-use JAX’s tracing infrastructure and provide a + NumPy-like frontend that’s already familiar to users. +* JAX transformations are key to its success, allowing users to + express simple programs but transform them to achieve complex + functionality. + We can leverage the same transformations (vmap, jvp, etc.) to + transform user-written kernels. + +The open question is: is JAX a good fit for a kernel language at all? +We think so. +Triton demonstrates that an array programming language can be +practical for writing GPU kernels and JAX is just that. +JAX has also proven to be a flexible front-end for compilers and +for program transformations. + +We describe Pallas as follows: we first describe the ways in which +we extend JAX to support writing custom kernels. +We then show how we can lower Pallas to both Triton and Mosaic. +We conclude by describing existing and potential ways to transform +Pallas kernels via JAX transformations.
@@ -28,10 +101,17 @@ Visualization of Pallas lowering paths ## Pallas: Extending JAX for kernels -The key point we’d like to make is that Pallas is just JAX, with some extensions: -1. Users now use reference types called `Ref`s in their JAX code. This gives users more precise control over memory access and layout in JAX will more closely resemble physical layout. -2. Users write their JAX programs using a subset of JAX primitives, along with a set of Pallas-specific primitives. -3. Users embed their Pallas kernels in an outer JAX program via a special `pallas_call` higher-order function, that executes the kernel in a map. It is analogous to `pmap` or `shard_map`, except with references to shared memory. +The key point we’d like to make is that Pallas is just JAX, with some +extensions: +1. Users now use reference types called `Ref`s in their JAX code. + This gives users more precise control over memory access and + layout in JAX will more closely resemble physical layout. +2. Users write their JAX programs using a subset of JAX primitives, + along with a set of Pallas-specific primitives. +3. Users embed their Pallas kernels in an outer JAX program via a + special `pallas_call` higher-order function, that executes the + kernel in a map. It is analogous to `pmap` or `shard_map`, + except with references to shared memory. We’ll go over these three extensions one at a time, by example. @@ -56,13 +136,28 @@ add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32) add(x, y) ``` -Unlike a regular JAX program, `add_kernel` does not receive immutable array arguments. Instead, it’s provided with references that can be read from and updated in-place using NumPy-like syntax. `Ref`s are not a Pallas-specific concept – they were introduced to JAX to represent stateful computations. However, we can leverage them when writing kernels that operate on mutable memory too. - -Pallas kernels not only receive `Ref`s corresponding to the inputs to the kernel, but also receive `Ref`s for the outputs as well (specified in `pallas_call` via `out_shape`). `Ref`s are special types that cannot be passed into the usual set of JAX primitives without being read from first. When you read from a `Ref` you get a JAX `Array` type out, and you must write an `Array` into a `Ref`. +Unlike a regular JAX program, `add_kernel` does not receive immutable +array arguments. +Instead, it’s provided with references that can be read from and +updated in-place using NumPy-like syntax. +`Ref`s are not a Pallas-specific concept – they were introduced to +JAX to represent stateful computations. +However, we can leverage them when writing kernels that operate on +mutable memory too. + +Pallas kernels not only receive `Ref`s corresponding to the inputs +to the kernel, but also receive `Ref`s for the outputs as well +(specified in `pallas_call` via `out_shape`). +`Ref`s are special types that cannot be passed into the usual set of +JAX primitives without being read from first. +When you read from a `Ref` you get a JAX `Array` type out, and you +must write an `Array` into a `Ref`. #### Reading from/writing into Refs -Reading from a `Ref` corresponds to loading an array into the lowest level of the memory hierarchy (L1-cache on GPU and vector registers on TPU). Writing into a `Ref` is analogous. +Reading from a `Ref` corresponds to loading an array into the +lowest level of the memory hierarchy (L1-cache on GPU and vector +registers on TPU). Writing into a `Ref` is analogous. ```python def f(x_ref, o_ref): @@ -77,19 +172,37 @@ def f(x_ref): x = x_ref[jnp.arange(2)[..., None], jnp.arange(3)[None, ...]] ``` -Writing to `Ref`s can be done via analogous `__setitem__` style indexing. +Writing to `Ref`s can be done via analogous `__setitem__` style +indexing. -Other forms of indexing (for example, dynamic slicing) can be done via `pallas.load` and `pallas.store`, new JAX primitives designed to make loading from/storing into memory easier. We’ll discuss these new primitives later. +Other forms of indexing (for example, dynamic slicing) can be done +via `pallas.load` and `pallas.store`, new JAX primitives designed to +make loading from/storing into memory easier. +We’ll discuss these new primitives later. ### Extending JAX with new Pallas primitives -Because JAX was designed with HLO in mind, the set of JAX primitives closely mirrors the set of HLO operations. Targeting a new compiler (e.g. Triton or Mosaic) means we might need to supplement JAX’s primitives with new ones specific to the new compiler. At the same time, we may not be able to lower all JAX primitives, so we need to restrict it to a subset. +Because JAX was designed with HLO in mind, the set of JAX primitives +closely mirrors the set of HLO operations. +Targeting a new compiler (e.g. Triton or Mosaic) means we might need +to supplement JAX’s primitives with new ones specific to the new +compiler. +At the same time, we may not be able to lower all JAX primitives, +so we need to restrict it to a subset. -Because Pallas was initially designed with Triton in mind, we offer a set of new primitives targeting the Triton programming model. As we’ll show later, we can lower these primitives to Mosaic as well. +Because Pallas was initially designed with Triton in mind, +we offer a set of new primitives targeting the Triton programming model. +As we’ll show later, we can lower these primitives to Mosaic as well. #### `pallas.load` and `pallas.store` -`pallas.load` and `pallas.store` are primitives that allow loading from memory and storing into memory. Unlike `__getitem__` and `__setitem__` they are more flexible at the cost of being more verbose. Specifically, you can use the `pallas.dynamic_slice` (`pallas.ds` for short) construct (which should maybe be upstreamed into JAX to be used with Ref `__getitem__` and `__setitem__`). +`pallas.load` and `pallas.store` are primitives that allow loading +from memory and storing into memory. +Unlike `__getitem__` and `__setitem__` they are more flexible at the +cost of being more verbose. +Specifically, you can use the `pallas.dynamic_slice` (`pallas.ds` for +short) construct (which should maybe be upstreamed into JAX to be +used with Ref `__getitem__` and `__setitem__`). ```python def f(x_ref, o_ref): @@ -102,7 +215,8 @@ def f(x_ref, o_ref): ``` -`pallas.load` and `pallas.store` also support masking via the mask argument. +`pallas.load` and `pallas.store` also support masking via the mask +argument. ```python def f(x_ref, o_ref): @@ -112,13 +226,25 @@ def f(x_ref, o_ref): x = pl.load(x_ref, (idx,), mask=mask, other=float('-inf')) ``` -Masking is important when doing out-of-bounds loads/stores. The operational semantics of masking can be compiler-determined (if we understand the documentation properly, Triton avoids the read from/write to memory if it’s masked). +Masking is important when doing out-of-bounds loads/stores. +The operational semantics of masking can be compiler-determined +(if we understand the documentation properly, Triton avoids the read +from/write to memory if it’s masked). #### `pallas.program_id` and `pallas.num_programs` -As we’ll soon see, we’ll be executing the same Pallas kernels many times (either in parallel or in a pipeline depending on the backend). These new primitives tell us “where” we are in the execution of the kernel. +As we’ll soon see, we’ll be executing the same Pallas kernels many +times (either in parallel or in a pipeline depending on the backend). +These new primitives tell us “where” we are in the execution of the +kernel. -`pallas.program_id` takes in an axis argument, which tells us which index in an axis of a multidimensional grid this kernel is currently executing in (analogous to `threadId` from CUDA programming or `lax.axis_index` in `jax.pmap`). Note that we are currently borrowing the “program” terminology from Triton and in the future we might want to change it to something more familiar to JAX users. +`pallas.program_id` takes in an axis argument, which tells us which +index in an axis of a multidimensional grid this kernel is currently +executing in (analogous to `threadId` from CUDA programming or +`lax.axis_index` in `jax.pmap`). +Note that we are currently borrowing the “program” terminology from +Triton and in the future we might want to change it to something more +familiar to JAX users. ```python def f(x_ref, o_ref): @@ -126,21 +252,34 @@ def f(x_ref, o_ref): o_ref[i] = jnp.exp(x_ref[i]) ``` -`pallas.num_programs` also takes in an axis and returns the grid size for that axis. +`pallas.num_programs` also takes in an axis and returns the grid size +for that axis. -Note that while `program_id` and `num_programs` are Triton-specific terminology they are easily generalized to make sense on TPU as well. +Note that while `program_id` and `num_programs` are Triton-specific +terminology they are easily generalized to make sense on TPU as well. #### Using a subset of JAX primitives in Pallas -Because we’re writing kernels, not high-level HLO programs, some JAX primitives may not be able to be represented in our underlying substrate efficiently. However, we know we can support most elementwise operations, simple dot products, and JAX control flow. +Because we’re writing kernels, not high-level HLO programs, some JAX +primitives may not be able to be represented in our underlying +substrate efficiently. +However, we know we can support most elementwise operations, +simple dot products, and JAX control flow. -While we haven’t yet mapped out exactly all the JAX primitives that we can support in Pallas kernels, we can certainly identify some that are not easy to lower or are unlikely to be useful: -* `conv_general` - convolution usually isn’t offered as primitive in the underlying hardware. -* `gather/scatter` - the underlying compiler may not support noncontiguous memory reads and writes +While we haven’t yet mapped out exactly all the JAX primitives that +we can support in Pallas kernels, we can certainly identify some that +are not easy to lower or are unlikely to be useful: +* `conv_general` - convolution usually isn’t offered as primitive in + the underlying hardware. +* `gather/scatter` - the underlying compiler may not support + noncontiguous memory reads and writes ### Executing Pallas kernels with `pallas_call` -Now that we’ve written our Pallas kernels (a.k.a. JAX with `Ref`s and the extra Pallas primitives), how do we execute them on a GPU or TPU? We use `pallas_call`, a higher order function (akin to `jax.jit` and `jax.pmap`) that executes the kernel. +Now that we’ve written our Pallas kernels (a.k.a. JAX with `Ref`s and +the extra Pallas primitives), how do we execute them on a GPU or TPU? +We use `pallas_call`, a higher order function (akin to `jax.jit` and +`jax.pmap`) that executes the kernel. The signature of `pallas_call` is as follows: @@ -154,7 +293,12 @@ def pallas_call( ... ``` -When we provide a kernel to `pallas_call` we provide additional information. The first is `out_shape` which tells the kernel what the outputs look like (`pallas_call` will pass a `Ref` corresponding to these into the kernel to be written to). The rest of the information (`in_specs`, `out_specs`, and `grid`) are information about how the kernel will be scheduled on the accelerator. +When we provide a kernel to `pallas_call` we provide additional +information. The first is `out_shape` which tells the kernel what the +outputs look like (`pallas_call` will pass a `Ref` corresponding to +these into the kernel to be written to). +The rest of the information (`in_specs`, `out_specs`, and `grid`) are +information about how the kernel will be scheduled on the accelerator. The (rough) semantics for `pallas_call` are as follows: @@ -172,13 +316,37 @@ def pallas_call(kernel, in_specs, out_specs, out_shapes, grid): return execute ``` -Specifically, `pallas_call` will “loop” over grid iteration space, applying a transformation to the inputs and outputs specified via the `in_specs` and `out_specs`. In each iteration, the kernel will be called on the transformed inputs and outputs. Note that the “loop” over the iteration space could be executed in parallel (e.g. on GPU). `pallas_call` also provides no guarantees on the order of loop iterations over the iteration space, just that every member of the iteration space will be looped over. Compilers like Triton and Mosaic will have more specific operational semantics associated with the grid. +Specifically, `pallas_call` will “loop” over grid iteration space, +applying a transformation to the inputs and outputs specified via +the `in_specs` and `out_specs`. +In each iteration, the kernel will be called on the transformed +inputs and outputs. Note that the “loop” over the iteration space +could be executed in parallel (e.g. on GPU). +`pallas_call` also provides no guarantees on the order of loop +iterations over the iteration space, just that every member of the +iteration space will be looped over. +Compilers like Triton and Mosaic will have more specific operational +semantics associated with the grid. #### Transformation functions -The `in_specs` and `out_specs` arguments to `pallas_call` allow inputs and outputs to be transformed in some way. The two options that Pallas offers right now are an identity transformation (where inputs and outputs are left unchanged), and `BlockSpec`s, take fixed-size slices of `Ref`s determined by the loop index. - -A `BlockSpec` takes an `index_map` function and a `block_shape`. Logically, it takes an array and slices it along each axis into `block_shape` sizes blocks. The `index_map` function takes loop indices (from the grid index set) and maps them to block indices. The transform function converts `Ref`s into logical views of the `Ref` at the corresponding block. When we specify `None` in an entry in block_shape, that corresponds to “mapping” over that dimension, removing it from the block within the kernel. +The `in_specs` and `out_specs` arguments to `pallas_call` allow +inputs and outputs to be transformed in some way. +The two options that Pallas offers right now are an identity +transformation (where inputs and outputs are left unchanged), +and `BlockSpec`s, take fixed-size slices of `Ref`s determined by the +loop index. + +A `BlockSpec` takes an `index_map` function and a `block_shape`. +Logically, it takes an array and slices it along each axis into +`block_shape` sizes blocks. +The `index_map` function takes loop indices (from the grid index set) +and maps them to block indices. +The transform function converts `Ref`s into logical views of the +`Ref` at the corresponding block. +When we specify `None` in an entry in block_shape, +that corresponds to “mapping” over that dimension, +removing it from the block within the kernel. ```python class BlockSpec: @@ -191,16 +359,28 @@ class BlockSpec: ... ``` -We could also imagine other `Spec`s that are used with `pallas_call`, for example a `Spec` that corresponds to overlapping windows to, say, implement convolutions. +We could also imagine other `Spec`s that are used with `pallas_call`, +for example a `Spec` that corresponds to overlapping windows to, say, +implement convolutions. ### Immediate benefits of Pallas as a front-end -By offering a JAX front-end for kernel writing, we can immediately reap some benefits. +By offering a JAX front-end for kernel writing, we can immediately +reap some benefits. #### More flexible front end -The first is that JAX users are already accustomed to the benefits (and limitations) of programming with JAX and its tracing-based transformations. This means users can use closures and other familiar Python constructs when writing Pallas kernels. This is unlike the existing AST-parsing-based Triton front end or the MLIR builders for Mosaic. For example, this makes Pallas far more amenable to templating than Triton. +The first is that JAX users are already accustomed to the benefits +(and limitations) of programming with JAX and its tracing-based +transformations. +This means users can use closures and other familiar Python constructs +when writing Pallas kernels. +This is unlike the existing AST-parsing-based Triton front end or the +MLIR builders for Mosaic. +For example, this makes Pallas far more amenable to templating than +Triton. -See this example of how we can use higher-order functions in Python to template a kernel. +See this example of how we can use higher-order functions in Python +to template a kernel. ```python def make_kernel(eltwise_kernel): @@ -219,13 +399,25 @@ pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.) #### Emulation mode -By representing kernels as programs with JAX primitives and some new Pallas primitives, we can also lower Pallas programs to StableHLO directly and compile/execute them with XLA. Specifically, a `pallas_call` can be implemented as a `lax.scan` over the grid. This enables us to develop GPU or TPU kernels on any XLA-supported platform (even CPU!) and debug them using JAX/XLA debugging tools (like `jax.debug.print`). We can also use the more reliable and better tested XLA numerics to verify the correctness of the Triton and Mosaic compilers. One could also imagine perturbing the `scan` ordering to simulate the parallel reads and writes that happen on GPU. +By representing kernels as programs with JAX primitives and some new +Pallas primitives, we can also lower Pallas programs to StableHLO +directly and compile/execute them with XLA. +Specifically, a `pallas_call` can be implemented as a `lax.scan` over +the grid. +This enables us to develop GPU or TPU kernels on any XLA-supported +platform (even CPU!) and debug them using JAX/XLA debugging tools +(like `jax.debug.print`). +We can also use the more reliable and better tested XLA numerics to +verify the correctness of the Triton and Mosaic compilers. +One could also imagine perturbing the `scan` ordering to simulate the +parallel reads and writes that happen on GPU. ### Examples #### `add` -We modify our `add_kernel` example to operate over (2,)-sized blocks using `BlockSpec`s. +We modify our `add_kernel` example to operate over (2,)-sized blocks +using `BlockSpec`s. ```python def add_kernel(x_ref, y_ref, o_ref): @@ -248,7 +440,10 @@ add(x, y) #### Templated matmul -In this example, we compute tiles of the output by doing an unrolled accumulation over blocks of rows and columns from our input arrays. We inline an activation function into the body of the kernel using a higher order function so we can emit a fused kernel. +In this example, we compute tiles of the output by doing an unrolled +accumulation over blocks of rows and columns from our input arrays. +We inline an activation function into the body of the kernel using a +higher order function so we can emit a fused kernel. ```python def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k): @@ -281,42 +476,112 @@ z = matmul(x, y, block_shape=block_shape, activation=jax.nn.gelu) ``` ### Lowering Pallas -After users express their Pallas kernels, we lower them to different representations depending on the target backend. On GPUs, we lower Pallas to Triton IR, and on TPU we lower Pallas to Mosaic. +After users express their Pallas kernels, we lower them to different +representations depending on the target backend. +On GPUs, we lower Pallas to Triton IR, and on TPU we lower Pallas to +Mosaic. #### Lowering Pallas to Triton for GPU -Lowering Pallas to Triton is easy because Pallas was designed with Triton as a target language in mind. The main differences between Pallas and Triton is that Triton doesn’t have a notion of `BlockSpec`s and also uses pointers when doing memory loads and stores as opposed to indices. - -Triton supports pointers as an array element type in its language and in Triton you can load from and store to arrays of pointers. In Pallas, when given a `(4, 5)`-shaped `Ref`, `x_ref`, and then do like `x_ref[3, 2]`, we need to lower this to computing a Triton pointer to the appropriate row-major position in `x_ref` (that is, doing 5 * 3 + 2 * 1). Similarly, when we lower slices to Triton, e.g. `x_ref[4, :]` we need to produce an array of pointers `5 * 4 + jnp.arange(3)`. - -Other than that, lowering to Triton is fairly straightforward. JAX dot products can be lowered to Triton dot products and JAX unary primitives are lowered to their Triton equivalents. Triton’s atomic operations are lowered via new Pallas atomic primitives. +Lowering Pallas to Triton is easy because Pallas was designed with +Triton as a target language in mind. +The main differences between Pallas and Triton is that Triton doesn’t +have a notion of `BlockSpec`s and also uses pointers when doing +memory loads and stores as opposed to indices. + +Triton supports pointers as an array element type in its language +and in Triton you can load from and store to arrays of pointers. +In Pallas, when given a `(4, 5)`-shaped `Ref`, `x_ref`, and then do +like `x_ref[3, 2]`, we need to lower this to computing a Triton +pointer to the appropriate row-major position in `x_ref` (that is, +doing 5 * 3 + 2 * 1). +Similarly, when we lower slices to Triton, e.g. `x_ref[4, :]` we need +to produce an array of pointers `5 * 4 + jnp.arange(3)`. + +Other than that, lowering to Triton is fairly straightforward. +JAX dot products can be lowered to Triton dot products and JAX unary +primitives are lowered to their Triton equivalents. +Triton’s atomic operations are lowered via new Pallas atomic +primitives. #### Lowering Pallas to Mosaic for TPU -Mosaic consumes (mostly) standard dialect MLIR and emits LLO to be compiled for TPU. Pallas can be lowered to Mosaic via translating JAX primitives to MLIR (mostly the `vector` and `arith` dialects). The `BlockSpec`s can be converted into pipeline schedules (i.e. the `transform_func`s in Mosaic). +Mosaic consumes (mostly) standard dialect MLIR and emits LLO to be +compiled for TPU. +Pallas can be lowered to Mosaic via translating JAX primitives to +MLIR (mostly the `vector` and `arith` dialects). +The `BlockSpec`s can be converted into pipeline schedules +(i.e. the `transform_func`s in Mosaic). ### Transforming Pallas -A natural question is how do JAX transformations interact with Pallas kernels? There are two main ways: transformations inside Pallas kernels and transformations outside Pallas kernels. +A natural question is how do JAX transformations interact with Pallas +kernels? +There are two main ways: transformations inside Pallas kernels and +transformations outside Pallas kernels. -Transformation inside Pallas kernels should actually “just work”, so long as we are able to lower the transformed code. For example, we could use `jax.grad(jnp.sin)(...)` inside of a JAX kernel because we can lower a `cos` to both Triton and Mosaic. However, we might not be able to lower a `jax.vmap(lax.dynamic_slice)` because it could turn into a gather that we cannot lower. +Transformation inside Pallas kernels should actually “just work”, +so long as we are able to lower the transformed code. +For example, we could use `jax.grad(jnp.sin)(...)` inside of a JAX +kernel because we can lower a `cos` to both Triton and Mosaic. +However, we might not be able to lower a `jax.vmap(lax.dynamic_slice)` +because it could turn into a gather that we cannot lower. -Transformations of Pallas kernels from the outer JAX programs is perhaps the more interesting case. How do we handle things like `vmap(pallas_call)` and `grad(pallas_call)`? +Transformations of Pallas kernels from the outer JAX programs is +perhaps the more interesting case. How do we handle things like +`vmap(pallas_call)` and `grad(pallas_call)`? #### `vmap-of-pallas_call` -vmap automatically vectorizes JAX programs. While kernel writers might want precise control over how a batched kernel will behave differently from its unbatched variant, we can offer a reasonable default `vmap` rule for `pallas_call` while offering the `jax.custom_vmap` customization mechanism. When `pallas_call` is `vmap`-ed, we augment the `pallas_call` to have an extra grid dimension corresponding to the new batch dimension and transform the `BlockSpec`s to handle indexing along that dimension. +vmap automatically vectorizes JAX programs. While kernel writers might +want precise control over how a batched kernel will behave differently +from its unbatched variant, we can offer a reasonable default `vmap` +rule for `pallas_call` while offering the `jax.custom_vmap` +customization mechanism. When `pallas_call` is `vmap`-ed, we augment +the `pallas_call` to have an extra grid dimension corresponding to the +new batch dimension and transform the `BlockSpec`s to handle indexing +along that dimension. #### `grad-of-pallas_call` -`grad` of `pallas_call` enables automatic differentiation of kernels. `jax.grad` breaks down into applications of three distinct transforms: `jvp`, `partial_eval` and `transpose`. In principle, we can re-use most of JAX’s infrastructure when implementing these rules for `pallas_call` (since it behaves much like existing JAX higher order primitives). - -However, automatic differentiation of kernels can result in a performance hit due to how memory access is transposed. If we write a GPU kernel with overlapping-and-parallel reads and disjoint-but-parallel writes, we automatically transpose it into a kernel that has overlapping-but-parallel writes (which are slow when done atomically) and disjoint-and-parallel reads. To emit a kernel that better uses parallelism with shared memory, we would need to reorder loops and change how the kernel is vectorized. Unfortunately, we do not have a program representation amenable to that in Pallas. A potential direction to automatically differentiating kernels efficiently is to explore a different representation, perhaps one like that in Dex. We could also look at how Enzyme approaches this problem. However, AD of Pallas kernels may still be useful for a class of kernels that does transpose efficiently (for example elementwise kernels). - -In general, though, `jax.custom_vjp` is a viable escape hatch to express Pallas kernels that work with `jax.grad`. +`grad` of `pallas_call` enables automatic differentiation of kernels. +`jax.grad` breaks down into applications of three distinct transforms: +`jvp`, `partial_eval` and `transpose`. +In principle, we can re-use most of JAX’s infrastructure when +implementing these rules for `pallas_call` (since it behaves much like +existing JAX higher order primitives). + +However, automatic differentiation of kernels can result in a +performance hit due to how memory access is transposed. +If we write a GPU kernel with overlapping-and-parallel reads and +disjoint-but-parallel writes, we automatically transpose it into a +kernel that has overlapping-but-parallel writes (which are slow when +done atomically) and disjoint-and-parallel reads. +To emit a kernel that better uses parallelism with shared memory, +we would need to reorder loops and change how the kernel is vectorized. +Unfortunately, we do not have a program representation amenable to +that in Pallas. +A potential direction to automatically differentiating kernels +efficiently is to explore a different representation, perhaps one +like that in Dex. +We could also look at how Enzyme approaches this problem. +However, AD of Pallas kernels may still be useful for a class of +kernels that does transpose efficiently (for example elementwise +kernels). + +In general, though, `jax.custom_vjp` is a viable escape hatch to +express Pallas kernels that work with `jax.grad`. #### Other transformations -We could imagine other JAX transformations applying to Pallas kernels that we haven’t explicitly explored yet. For example, `checkify` is a JAX transformation that does functional error handling. We could imagine using `checkify` with pallas_call to allow plumbing out error codes from GPU kernels that indicate if OOB access or NaNs were produced. - -Another potential transformation to integrate with is custom_partitioning to enable automatically partitionable kernels to be used with pjit. +We could imagine other JAX transformations applying to Pallas kernels +that we haven’t explicitly explored yet. +For example, `checkify` is a JAX transformation that does functional +error handling. +We could imagine using `checkify` with pallas_call to allow plumbing +out error codes from GPU kernels that indicate if OOB access or NaNs +were produced. + +Another potential transformation to integrate with is +custom_partitioning to enable automatically partitionable kernels to +be used with pjit. diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 94673e753d22..47d1a1409b52 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -7,9 +7,15 @@ "source": [ "# Pallas Quickstart\n", "\n", - "Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. Pallas allows you to use the same JAX functions and APIs but operates at a *lower* level of abstraction.\n", + "\n", "\n", - "Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.\n", + "Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU.\n", + "Pallas allows you to use the same JAX functions and APIs but operates at a\n", + "*lower* level of abstraction.\n", + "\n", + "Specifically, Pallas requires users to think about memory access and how to\n", + "divide up computations across multiple compute units in a hardware accelerator.\n", + "On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic.\n", "\n", "Let's dive into some examples.\n", "\n", @@ -64,15 +70,24 @@ "source": [ "**`Ref` types**\n", "\n", - "Let's dissect this function a bit. Unlike most JAX functions you've probably written, it does not take in `jax.Array`s as inputs and doesn't return any values. Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs but we are given an `o_ref`, which corresponds to the desired output.\n", + "Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n", + "it does not take in `jax.Array`s as inputs and doesn't return any values.\n", + "Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n", + "but we are given an `o_ref`, which corresponds to the desired output.\n", "\n", "**Reading from `Ref`s**\n", "\n", - "In the body, we are first reading from `x_ref` and `y_ref`, indicated by the `[...]` (the ellipsis means we are reading the whole `Ref`; alternatively we also could have used `x_ref[:]`). Reading from a `Ref` like this returns a `jax.Array`.\n", + "In the body, we are first reading from `x_ref` and `y_ref`, indicated by the `[...]`\n", + "(the ellipsis means we are reading the whole `Ref`;\n", + "alternatively we also could have used `x_ref[:]`).\n", + "Reading from a `Ref` like this returns a `jax.Array`.\n", "\n", "**Writing to `Ref`s**\n", "\n", - "We then write `x + y` to `o_ref`. Mutation has not historically been supported in JAX -- `jax.Array`s are immutable! `Ref`s are new (experimental) types that allow mutation under certain circumstances. We can interpret writing to a `Ref` as mutating its underlying buffer." + "We then write `x + y` to `o_ref`.\n", + "Mutation has not historically been supported in JAX -- `jax.Array`s are immutable!\n", + "`Ref`s are new (experimental) types that allow mutation under certain circumstances.\n", + "We can interpret writing to a `Ref` as mutating its underlying buffer." ] }, { @@ -80,7 +95,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "So we've written what we call a \"kernel\", which we define as a program that will run as an atomic unit of execution on an accelerator, without any interaction with the host. How do we invoke it from a JAX computation? We use the `pallas_call` higher-order function." + "So we've written what we call a \"kernel\", which we define as a program that will\n", + "run as an atomic unit of execution on an accelerator,\n", + "without any interaction with the host.\n", + "How do we invoke it from a JAX computation?\n", + "We use the `pallas_call` higher-order function." ] }, { @@ -113,7 +132,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`pallas_call` lifts the Pallas kernel function into an operation that can be called as part of a larger JAX program. But, to do so, it needs a few more details. Here we specify `out_shape`, an object that has a `.shape` and `.dtype` (or a list thereof).\n", + "`pallas_call` lifts the Pallas kernel function into an operation that can be called\n", + "as part of a larger JAX program. But, to do so, it needs a few more details.\n", + "Here we specify `out_shape`, an object that has a `.shape` and `.dtype` (or a list\n", + "thereof).\n", "`out_shape` determines the shape/dtype of `o_ref` in our `add_vector_kernel`.\n", "\n", "`pallas_call` returns a function that takes in and returns `jax.Array`s." @@ -126,11 +148,20 @@ "source": [ "**What's actually happening here?**\n", "\n", - "Thus far we've described how to think about Pallas kernels but what we've actually accomplished is we're writing a function that's executed very close to the compute units.\n", + "Thus far we've described how to think about Pallas kernels but what we've actually\n", + "accomplished is we're writing a function that's executed very close to the compute units.\n", "\n", - "On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM) (this is a costly operation generally speaking!). We then use GPU vector compute to execute the addition, then copy the resulting value in SRAM back to HBM.\n", + "On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when\n", + "we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)\n", + "(this is a costly operation generally speaking!).\n", + "We then use GPU vector compute to execute the addition, then copy the resulting value\n", + "in SRAM back to HBM.\n", "\n", - "On TPU, we do something slightly different. Before the kernel is ever executed, we fetch the value from HBM into SRAM. `x_ref` therefore corresponds to a value in SRAM and when we do `x_ref[...]` we are copying the value from SRAM into a register. We then use TPU vector compute to execute the addition, then copy the resulting value back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM.\n", + "On TPU, we do something slightly different. Before the kernel is ever executed,\n", + "we fetch the value from HBM into SRAM. `x_ref` therefore corresponds to a value in\n", + "SRAM and when we do `x_ref[...]` we are copying the value from SRAM into a register.\n", + "We then use TPU vector compute to execute the addition, then copy the resulting\n", + "value back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM.\n", "\n", "We are in the process of writing backend-specific Pallas guides. Coming soon!" ] @@ -148,7 +179,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In our \"hello world\" example, we wrote a very simple kernel. It takes advantage of the fact that our 8-sized arrays can comfortably fit inside the SRAM of hardware accelerators. In most real-world applications, this will not be the case!" + "In our \"hello world\" example, we wrote a very simple kernel.\n", + "It takes advantage of the fact that our 8-sized arrays can comfortably fit inside\n", + "the SRAM of hardware accelerators.\n", + "In most real-world applications, this will not be the case!" ] }, { @@ -156,15 +190,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Part of writing Pallas kernels is thinking about how to take big arrays that live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations that operate on \"blocks\" of those arrays that can fit in SRAM.\n", + "Part of writing Pallas kernels is thinking about how to take big arrays that\n", + "live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n", + "that operate on \"blocks\" of those arrays that can fit in SRAM.\n", "\n", "### Grids\n", "\n", - "To automatically \"carve\" up the inputs and outputs, you provide a `grid` and `BlockSpec`s to `pallas_call`.\n", + "To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n", + "`BlockSpec`s to `pallas_call`.\n", "\n", - "A `grid` is a tuple of integers (e.g. `()`, `(2, 3, 4)`, or `(8,)`) that specifies an iteration space.\n", - "For example, a grid `(4, 5)` would have 20 elements: `(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)`.\n", - "We run the kernel function once for each element, a style of single-program multiple-data (SPMD) programming.\n", + "A `grid` is a tuple of integers (e.g. `()`, `(2, 3, 4)`, or `(8,)`) that specifies\n", + "an iteration space.\n", + "For example, a grid `(4, 5)` would have 20 elements:\n", + "`(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)`.\n", + "We run the kernel function once for each element, a style of single-program\n", + "multiple-data (SPMD) programming.\n", "\n", "
\n", "\n", @@ -173,7 +213,12 @@ "A 2D grid\n", "
\n", "\n", - "When we provide a `grid` to `pallas_call`, the kernel is executed as many times as `prod(grid)`. Each of these invocations is referred to as a \"program\", To access which program (i.e. which element of the grid) the kernel is currently executing, we use `program_id(axis=...)`. For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.\n", + "When we provide a `grid` to `pallas_call`, the kernel is executed as many times\n", + "as `prod(grid)`. Each of these invocations is referred to as a \"program\".\n", + "To access which program (i.e. which element of the grid) the kernel is currently\n", + "executing, we use `program_id(axis=...)`.\n", + "For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and\n", + "`program_id(axis=1)` returns `2`.\n", "\n", "Here's an example kernel that uses a `grid` and `program_id`." ] @@ -226,9 +271,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "On GPUs, each program is executed in parallel on separate threads. Thus, we need to think about race conditions on writes to HBM. A reasonable approach is to write our kernels in such a way that different programs write to disjoint places in HBM to avoid these parallel writes. On the other hand, parallelizing the computation is how we can execute operations like matrix multiplications really quickly.\n", - "\n", - "On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations." + "On GPUs, each program is executed in parallel on separate threads.\n", + "Thus, we need to think about race conditions on writes to HBM.\n", + "A reasonable approach is to write our kernels in such a way that different\n", + "programs write to disjoint places in HBM to avoid these parallel writes.\n", + "On the other hand, parallelizing the computation is how we can execute\n", + "operations like matrix multiplications really quickly.\n", + "\n", + "On TPUs, programs are executed in a combination of parallel and sequential\n", + "(depending on the architecture) so there are slightly different considerations." ] }, { @@ -244,12 +295,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "With `grid` and `program_id` in mind, Pallas provides an abstraction that takes care of some common indexing patterns seen in a lot of kernels.\n", + "With `grid` and `program_id` in mind, Pallas provides an abstraction that\n", + "takes care of some common indexing patterns seen in a lot of kernels.\n", "To build intuition, let's try to implement a matrix multiplication.\n", "\n", - "A simple strategy for implementing a matrix multiplication in Pallas is to implement it recursively. We know our underlying hardware has support for small matrix multiplications (using GPU and TPU tensorcores), so we just express a big matrix multiplication in terms of smaller ones.\n", + "A simple strategy for implementing a matrix multiplication in Pallas is to\n", + "implement it recursively.\n", + "We know our underlying hardware has support for small matrix multiplications\n", + "(using GPU and TPU tensorcores), so we just express a big matrix multiplication\n", + "in terms of smaller ones.\n", "\n", - "Suppose we have input matrices $X$ and $Y$ and are computing $Z = XY$. We first express $X$ and $Y$ as block matrices. $X$ will have \"row\" blocks and $Y$ will have \"column\" blocks.\n", + "Suppose we have input matrices $X$ and $Y$ and are computing $Z = XY$.\n", + "We first express $X$ and $Y$ as block matrices. $X$ will have \"row\" blocks\n", + "and $Y$ will have \"column\" blocks.\n", "\n", "$$\n", "\\begin{align*}\n", @@ -289,7 +347,10 @@ "\\end{align*}\n", "$$\n", "\n", - "Our strategy is that because $Z$ is also a block matrix, we can assign each of the programs in our Pallas kernel one of the output blocks. Computing each output block corresponds to doing a smaller matrix multiply between a \"row\" block of $X$ and a \"column\" block of $Y$." + "Our strategy is that because $Z$ is also a block matrix, we can assign each of\n", + "the programs in our Pallas kernel one of the output blocks.\n", + "Computing each output block corresponds to doing a smaller matrix multiply\n", + "between a \"row\" block of $X$ and a \"column\" block of $Y$." ] }, { @@ -297,7 +358,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To express this pattern, we use `BlockSpec`s. A `BlockSpec` specifies a block shape for each input and output, and an \"index map\" function, that maps a set of program indices to a block index.\n", + "To express this pattern, we use `BlockSpec`s. A `BlockSpec` specifies a block\n", + "shape for each input and output, and an \"index map\" function, that maps a\n", + "set of program indices to a block index.\n", "\n", "
\n", "\n", @@ -307,13 +370,23 @@ "\n", "
\n", "\n", - "For a concrete example, let's say we'd like to multiply two `(1024, 1024)` matrices `x` and `y` together to produce `z`, and would like to parallelize the computation 4 ways. We split up `z` into 4 `(512, 512)` blocks where each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication. To express this, we'd first use a `(2, 2)` grid (one block for each program).\n", + "For a concrete example, let's say we'd like to multiply two `(1024, 1024)`\n", + "matrices `x` and `y` together to produce `z`, and would like to parallelize\n", + "the computation 4 ways. We split up `z` into 4 `(512, 512)` blocks where\n", + "each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication.\n", + "To express this, we'd first use a `(2, 2)` grid (one block for each program).\n", "\n", - "For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this carves `x` up into \"row\" blocks. To see this see how both program instances `(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`. For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`. Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`.\n", + "For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this\n", + "carves `x` up into \"row\" blocks.\n", + "To see this see how both program instances\n", + "`(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`.\n", + "For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`.\n", + "Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`.\n", "\n", "These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.\n", "\n", - "Underneath the hood, `pallas_call` will automatically carve up your inputs and outputs into `Ref`s for each block that will be passed into the kernel." + "Underneath the hood, `pallas_call` will automatically carve up your inputs and\n", + "outputs into `Ref`s for each block that will be passed into the kernel." ] }, { @@ -350,8 +423,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note that this is a very naive implementation of a matrix multiplication but consider it a starting point for various types of optimizations.\n", - "Let's add an additional feature to our matrix multiply: fused activation. It's actually really easy! Just pass a higher-order activation function into the kernel." + "Note that this is a very naive implementation of a matrix multiplication but\n", + "consider it a starting point for various types of optimizations.\n", + "Let's add an additional feature to our matrix multiply: fused activation.\n", + "It's actually really easy! Just pass a higher-order activation function into the kernel." ] }, { @@ -388,7 +463,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To conclude, let's highlight a cool feature of Pallas: it composes with `jax.vmap`! To turn this matrix multiplication into a batched version, we just need to `vmap` it." + "To conclude, let's highlight a cool feature of Pallas: it composes with `jax.vmap`!\n", + "To turn this matrix multiplication into a batched version, we just need to `vmap` it." ] }, { diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index 931d1e96f11d..d42459f891ca 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -14,9 +14,15 @@ kernelspec: # Pallas Quickstart -Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. Pallas allows you to use the same JAX functions and APIs but operates at a *lower* level of abstraction. + -Specifically, Pallas requires users to think about memory access and how to divide up computations across multiple compute units in a hardware accelerator. On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic. +Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. +Pallas allows you to use the same JAX functions and APIs but operates at a +*lower* level of abstraction. + +Specifically, Pallas requires users to think about memory access and how to +divide up computations across multiple compute units in a hardware accelerator. +On GPUs, Pallas lowers to Triton and on TPUs, Pallas lowers to Mosaic. Let's dive into some examples. @@ -45,19 +51,32 @@ def add_vectors_kernel(x_ref, y_ref, o_ref): **`Ref` types** -Let's dissect this function a bit. Unlike most JAX functions you've probably written, it does not take in `jax.Array`s as inputs and doesn't return any values. Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs but we are given an `o_ref`, which corresponds to the desired output. +Let's dissect this function a bit. Unlike most JAX functions you've probably written, +it does not take in `jax.Array`s as inputs and doesn't return any values. +Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs +but we are given an `o_ref`, which corresponds to the desired output. **Reading from `Ref`s** -In the body, we are first reading from `x_ref` and `y_ref`, indicated by the `[...]` (the ellipsis means we are reading the whole `Ref`; alternatively we also could have used `x_ref[:]`). Reading from a `Ref` like this returns a `jax.Array`. +In the body, we are first reading from `x_ref` and `y_ref`, indicated by the `[...]` +(the ellipsis means we are reading the whole `Ref`; +alternatively we also could have used `x_ref[:]`). +Reading from a `Ref` like this returns a `jax.Array`. **Writing to `Ref`s** -We then write `x + y` to `o_ref`. Mutation has not historically been supported in JAX -- `jax.Array`s are immutable! `Ref`s are new (experimental) types that allow mutation under certain circumstances. We can interpret writing to a `Ref` as mutating its underlying buffer. +We then write `x + y` to `o_ref`. +Mutation has not historically been supported in JAX -- `jax.Array`s are immutable! +`Ref`s are new (experimental) types that allow mutation under certain circumstances. +We can interpret writing to a `Ref` as mutating its underlying buffer. +++ -So we've written what we call a "kernel", which we define as a program that will run as an atomic unit of execution on an accelerator, without any interaction with the host. How do we invoke it from a JAX computation? We use the `pallas_call` higher-order function. +So we've written what we call a "kernel", which we define as a program that will +run as an atomic unit of execution on an accelerator, +without any interaction with the host. +How do we invoke it from a JAX computation? +We use the `pallas_call` higher-order function. ```{code-cell} ipython3 @jax.jit @@ -68,7 +87,10 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: add_vectors(jnp.arange(8), jnp.arange(8)) ``` -`pallas_call` lifts the Pallas kernel function into an operation that can be called as part of a larger JAX program. But, to do so, it needs a few more details. Here we specify `out_shape`, an object that has a `.shape` and `.dtype` (or a list thereof). +`pallas_call` lifts the Pallas kernel function into an operation that can be called +as part of a larger JAX program. But, to do so, it needs a few more details. +Here we specify `out_shape`, an object that has a `.shape` and `.dtype` (or a list +thereof). `out_shape` determines the shape/dtype of `o_ref` in our `add_vector_kernel`. `pallas_call` returns a function that takes in and returns `jax.Array`s. @@ -77,11 +99,20 @@ add_vectors(jnp.arange(8), jnp.arange(8)) **What's actually happening here?** -Thus far we've described how to think about Pallas kernels but what we've actually accomplished is we're writing a function that's executed very close to the compute units. +Thus far we've described how to think about Pallas kernels but what we've actually +accomplished is we're writing a function that's executed very close to the compute units. -On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM) (this is a costly operation generally speaking!). We then use GPU vector compute to execute the addition, then copy the resulting value in SRAM back to HBM. +On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when +we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM) +(this is a costly operation generally speaking!). +We then use GPU vector compute to execute the addition, then copy the resulting value +in SRAM back to HBM. -On TPU, we do something slightly different. Before the kernel is ever executed, we fetch the value from HBM into SRAM. `x_ref` therefore corresponds to a value in SRAM and when we do `x_ref[...]` we are copying the value from SRAM into a register. We then use TPU vector compute to execute the addition, then copy the resulting value back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM. +On TPU, we do something slightly different. Before the kernel is ever executed, +we fetch the value from HBM into SRAM. `x_ref` therefore corresponds to a value in +SRAM and when we do `x_ref[...]` we are copying the value from SRAM into a register. +We then use TPU vector compute to execute the addition, then copy the resulting +value back into SRAM. After the kernel is executed, the SRAM value is copied back into HBM. We are in the process of writing backend-specific Pallas guides. Coming soon! @@ -91,19 +122,28 @@ We are in the process of writing backend-specific Pallas guides. Coming soon! +++ -In our "hello world" example, we wrote a very simple kernel. It takes advantage of the fact that our 8-sized arrays can comfortably fit inside the SRAM of hardware accelerators. In most real-world applications, this will not be the case! +In our "hello world" example, we wrote a very simple kernel. +It takes advantage of the fact that our 8-sized arrays can comfortably fit inside +the SRAM of hardware accelerators. +In most real-world applications, this will not be the case! +++ -Part of writing Pallas kernels is thinking about how to take big arrays that live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations that operate on "blocks" of those arrays that can fit in SRAM. +Part of writing Pallas kernels is thinking about how to take big arrays that +live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations +that operate on "blocks" of those arrays that can fit in SRAM. ### Grids -To automatically "carve" up the inputs and outputs, you provide a `grid` and `BlockSpec`s to `pallas_call`. +To automatically "carve" up the inputs and outputs, you provide a `grid` and +`BlockSpec`s to `pallas_call`. -A `grid` is a tuple of integers (e.g. `()`, `(2, 3, 4)`, or `(8,)`) that specifies an iteration space. -For example, a grid `(4, 5)` would have 20 elements: `(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)`. -We run the kernel function once for each element, a style of single-program multiple-data (SPMD) programming. +A `grid` is a tuple of integers (e.g. `()`, `(2, 3, 4)`, or `(8,)`) that specifies +an iteration space. +For example, a grid `(4, 5)` would have 20 elements: +`(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)`. +We run the kernel function once for each element, a style of single-program +multiple-data (SPMD) programming.
@@ -112,7 +152,12 @@ We run the kernel function once for each element, a style of single-program mult A 2D grid
-When we provide a `grid` to `pallas_call`, the kernel is executed as many times as `prod(grid)`. Each of these invocations is referred to as a "program", To access which program (i.e. which element of the grid) the kernel is currently executing, we use `program_id(axis=...)`. For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`. +When we provide a `grid` to `pallas_call`, the kernel is executed as many times +as `prod(grid)`. Each of these invocations is referred to as a "program". +To access which program (i.e. which element of the grid) the kernel is currently +executing, we use `program_id(axis=...)`. +For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and +`program_id(axis=1)` returns `2`. Here's an example kernel that uses a `grid` and `program_id`. @@ -132,9 +177,15 @@ def iota(len: int): iota(8) ``` -On GPUs, each program is executed in parallel on separate threads. Thus, we need to think about race conditions on writes to HBM. A reasonable approach is to write our kernels in such a way that different programs write to disjoint places in HBM to avoid these parallel writes. On the other hand, parallelizing the computation is how we can execute operations like matrix multiplications really quickly. +On GPUs, each program is executed in parallel on separate threads. +Thus, we need to think about race conditions on writes to HBM. +A reasonable approach is to write our kernels in such a way that different +programs write to disjoint places in HBM to avoid these parallel writes. +On the other hand, parallelizing the computation is how we can execute +operations like matrix multiplications really quickly. -On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations. +On TPUs, programs are executed in a combination of parallel and sequential +(depending on the architecture) so there are slightly different considerations. +++ @@ -142,12 +193,19 @@ On TPUs, programs are executed in a combination of parallel and sequential (depe +++ -With `grid` and `program_id` in mind, Pallas provides an abstraction that takes care of some common indexing patterns seen in a lot of kernels. +With `grid` and `program_id` in mind, Pallas provides an abstraction that +takes care of some common indexing patterns seen in a lot of kernels. To build intuition, let's try to implement a matrix multiplication. -A simple strategy for implementing a matrix multiplication in Pallas is to implement it recursively. We know our underlying hardware has support for small matrix multiplications (using GPU and TPU tensorcores), so we just express a big matrix multiplication in terms of smaller ones. +A simple strategy for implementing a matrix multiplication in Pallas is to +implement it recursively. +We know our underlying hardware has support for small matrix multiplications +(using GPU and TPU tensorcores), so we just express a big matrix multiplication +in terms of smaller ones. -Suppose we have input matrices $X$ and $Y$ and are computing $Z = XY$. We first express $X$ and $Y$ as block matrices. $X$ will have "row" blocks and $Y$ will have "column" blocks. +Suppose we have input matrices $X$ and $Y$ and are computing $Z = XY$. +We first express $X$ and $Y$ as block matrices. $X$ will have "row" blocks +and $Y$ will have "column" blocks. $$ \begin{align*} @@ -187,11 +245,16 @@ X_1 Y_0 & X_1 Y_1 \end{align*} $$ -Our strategy is that because $Z$ is also a block matrix, we can assign each of the programs in our Pallas kernel one of the output blocks. Computing each output block corresponds to doing a smaller matrix multiply between a "row" block of $X$ and a "column" block of $Y$. +Our strategy is that because $Z$ is also a block matrix, we can assign each of +the programs in our Pallas kernel one of the output blocks. +Computing each output block corresponds to doing a smaller matrix multiply +between a "row" block of $X$ and a "column" block of $Y$. +++ -To express this pattern, we use `BlockSpec`s. A `BlockSpec` specifies a block shape for each input and output, and an "index map" function, that maps a set of program indices to a block index. +To express this pattern, we use `BlockSpec`s. A `BlockSpec` specifies a block +shape for each input and output, and an "index map" function, that maps a +set of program indices to a block index.
@@ -201,13 +264,23 @@ A visualization of a `BlockSpec`
-For a concrete example, let's say we'd like to multiply two `(1024, 1024)` matrices `x` and `y` together to produce `z`, and would like to parallelize the computation 4 ways. We split up `z` into 4 `(512, 512)` blocks where each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication. To express this, we'd first use a `(2, 2)` grid (one block for each program). +For a concrete example, let's say we'd like to multiply two `(1024, 1024)` +matrices `x` and `y` together to produce `z`, and would like to parallelize +the computation 4 ways. We split up `z` into 4 `(512, 512)` blocks where +each block is computed with a `(512, 1024) x (1024, 512)` matrix multiplication. +To express this, we'd first use a `(2, 2)` grid (one block for each program). -For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this carves `x` up into "row" blocks. To see this see how both program instances `(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`. For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`. Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`. +For `x`, we use `BlockSpec(lambda i, j: (i, 0), (512, 1024))` -- this +carves `x` up into "row" blocks. +To see this see how both program instances +`(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`. +For `y`, we use a transposed version `BlockSpec(lambda i, j: (0, j), (1024, 512))`. +Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`. These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`. -Underneath the hood, `pallas_call` will automatically carve up your inputs and outputs into `Ref`s for each block that will be passed into the kernel. +Underneath the hood, `pallas_call` will automatically carve up your inputs and +outputs into `Ref`s for each block that will be passed into the kernel. ```{code-cell} ipython3 def matmul_kernel(x_ref, y_ref, z_ref): @@ -233,8 +306,10 @@ z = matmul(x, y) np.testing.assert_allclose(z, x @ y) ``` -Note that this is a very naive implementation of a matrix multiplication but consider it a starting point for various types of optimizations. -Let's add an additional feature to our matrix multiply: fused activation. It's actually really easy! Just pass a higher-order activation function into the kernel. +Note that this is a very naive implementation of a matrix multiplication but +consider it a starting point for various types of optimizations. +Let's add an additional feature to our matrix multiply: fused activation. +It's actually really easy! Just pass a higher-order activation function into the kernel. ```{code-cell} ipython3 def matmul_kernel(x_ref, y_ref, z_ref, *, activation): @@ -260,7 +335,8 @@ z = matmul(x, y, activation=jax.nn.relu) np.testing.assert_allclose(z, jax.nn.relu(x @ y)) ``` -To conclude, let's highlight a cool feature of Pallas: it composes with `jax.vmap`! To turn this matrix multiplication into a batched version, we just need to `vmap` it. +To conclude, let's highlight a cool feature of Pallas: it composes with `jax.vmap`! +To turn this matrix multiplication into a batched version, we just need to `vmap` it. ```{code-cell} ipython3 k1, k2 = jax.random.split(jax.random.key(0)) diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 87fa02ec1bbf..06bbe9135c9e 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -6,7 +6,9 @@ "id": "teoJ_fUwlu0l" }, "source": [ - "# Pipelining and `BlockSpec`s" + "# Pipelining and `BlockSpec`s\n", + "\n", + "" ] }, { @@ -15,7 +17,8 @@ "id": "gAJDZh1gBh-h" }, "source": [ - "In this guide we'll cover how memory spaces in TPU work and how to write pipelines in Pallas that overlap memory I/O with compute." + "In this guide we'll cover how memory spaces in TPU work and how to write\n", + "pipelines in Pallas that overlap memory I/O with compute." ] }, { @@ -42,17 +45,33 @@ "source": [ "## TPU and its memory spaces\n", "\n", - "A TPU and its TensorCore consist of memory spaces (where arrays can reside), registers (which temporarily store scalar and array values) and compute units (that do computation with values in registers). Below is a diagram of a TPU in which `x` and `y` are arrays that live in high-bandwidth memory (HBM):\n", + "A TPU and its TensorCore consist of memory spaces (where arrays can reside),\n", + "registers (which temporarily store scalar and array values) and compute units\n", + "(that do computation with values in registers).\n", + "Below is a diagram of a TPU in which `x` and `y` are arrays that live in\n", + "high-bandwidth memory (HBM):\n", "\n", "![TPU Memory Space Cartoon.png]()\n", "\n", "Let's talk about the components of this diagram in more detail:\n", "\n", - "* **Memory spaces**: A TPU has high-bandwidth memory (HBM) which is what we often think of as \"device memory\". There is also vector memory (VMEM), a cache meant for storing vector and array values, and scalar memory (SMEM), a cache designed to store scalar values.\n", - "* **Registers**: A TensorCore has two main types of registers: vector registers (VREGs) store array values, and scalar registers (SREGs) store scalar values. Values can be loaded into memory from their respective caches (VMEM for VREGs and SMEM for SREGs).\n", - "* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and matrix unit (MXU) that can do numerical computation. Compute units operate on values that live in SREGs and VREGs and output values into those registers as well.\n", - "\n", - "In order to do a vectorized computation on our values `x` and `y` that live in HBM, we need to:\n", + "* **Memory spaces**: A TPU has high-bandwidth memory (HBM) which is what we\n", + " often think of as \"device memory\".\n", + " There is also vector memory (VMEM),\n", + " a cache meant for storing vector and array values, and scalar memory (SMEM),\n", + " a cache designed to store scalar values.\n", + "* **Registers**: A TensorCore has two main types of registers: vector\n", + " registers (VREGs) store array values, and scalar registers (SREGs) store\n", + " scalar values.\n", + " Values can be loaded into memory from their respective caches (VMEM for\n", + " VREGs and SMEM for SREGs).\n", + "* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and\n", + " matrix unit (MXU) that can do numerical computation.\n", + " Compute units operate on values that live in SREGs and VREGs and output\n", + " values into those registers as well.\n", + "\n", + "In order to do a vectorized computation on our values `x` and `y` that live \n", + "in HBM, we need to:\n", "\n", "1. Copy the values `x` and `y` into VMEM.\n", "2. Load the values from VMEM into VREGs.\n", @@ -128,9 +147,20 @@ "source": [ "We've written two functions: `add_matrices_kernel` and `add_matrices`.\n", "\n", - "`add_matrices_kernel` operates using `Ref`s that live in VMEM. Loading from a VMEM `Ref` produces a value that lives in VREGs. Values in VREGs behave like `jax.Array`s in that we can use `jnp` and `jax.lax` operations on them to produce new values that live in VREGs. When we produce the values we'd like to return, we store them in the output VMEM `Ref`.\n", + "`add_matrices_kernel` operates using `Ref`s that live in VMEM.\n", + "Loading from a VMEM `Ref` produces a value that lives in VREGs.\n", + "Values in VREGs behave like `jax.Array`s in that we can use `jnp` and\n", + "`jax.lax` operations on them to produce new values that live in VREGs.\n", + "When we produce the values we'd like to return, we store them in the output\n", + "VMEM `Ref`.\n", "\n", - "The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into `pallas_call`. `pallas_call` is responsible for copying `x` and `y` into VMEM and for allocating the VMEM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output VMEM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`." + "The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`.\n", + "Inside it, we pass `x` and `y` into `pallas_call`.\n", + "`pallas_call` is responsible for copying `x` and `y` into VMEM and for\n", + "allocating the VMEM buffers that the kernel operates on (including allocating\n", + "`z_vmem_ref`, the output VMEM buffer).\n", + "After the kernel function is finished running, `pallas_call` will also copy\n", + "the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`." ] }, { @@ -141,13 +171,22 @@ "source": [ "## Constraints of using VMEM/SMEM\n", "\n", - "Pallas exposes access to lower level memory spaces like VMEM and SMEM but writing kernels utilizing them adds some considerations.\n", + "Pallas exposes access to lower level memory spaces like VMEM and SMEM but\n", + "writing kernels utilizing them adds some considerations.\n", "\n", - "1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB and SMEM ranges in the tens to hundreds of KiB. If our arrays are too big, we won't even be able to fit them into VMEM at all. For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't scale beyond moderately sized arrays.\n", + "1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB \n", + " and SMEM ranges in the tens to hundreds of KiB.\n", + " If our arrays are too big, we won't even be able to fit them into VMEM at all.\n", + " For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't\n", + " scale beyond moderately sized arrays.\n", "\n", - "2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least compared to most compute instructions. The `add_matrices` function above will likely spend more time copying between HBM and VMEM than actually performing the addition itself.\n", + "2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least\n", + " compared to most compute instructions.\n", + " The `add_matrices` function above will likely spend more time copying\n", + " between HBM and VMEM than actually performing the addition itself.\n", "\n", - "With these two constraints in mind, we'll have to rethink our strategy for getting performance out of our TPUs." + "With these two constraints in mind, we'll have to rethink our strategy for\n", + "getting performance out of our TPUs." ] }, { @@ -158,13 +197,26 @@ "source": [ "## Primer: Pipelining\n", "\n", - "Pipelining our computation offers a way of dealing with both the memory capacity and bandwidth constraints in one fell swoop. What do we mean by pipelining?\n", + "Pipelining our computation offers a way of dealing with both the memory\n", + "capacity and bandwidth constraints in one fell swoop.\n", + "What do we mean by pipelining?\n", "\n", - "The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our compute units. Naively this is difficult because in our program above we copy *all* of `x` and `y` before we start doing any compute with them, creating a dependence between the copy and the compute.\n", + "The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our\n", + "compute units.\n", + "Naively this is difficult because in our program above we copy *all* of `x`\n", + "and `y` before we start doing any compute with them, creating a dependence\n", + "between the copy and the compute.\n", "\n", - "However, if we can chunk up our computation into several subcomputations (e.g. when we add two matrices, we can express that as addition of \"blocks\" of the original matrices together), we can now overlap the copies of one of those subcomputations with the compute of the other. Let's walk through a simple example:\n", + "However, if we can chunk up our computation into several subcomputations\n", + "(e.g. when we add two matrices, we can express that as addition of \"blocks\"\n", + "of the original matrices together), we can now overlap the copies of one of\n", + "those subcomputations with the compute of the other. Let's walk through a\n", + "simple example:\n", "\n", - "Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for example, split along the leading axis, resulting in two `(256, 512)` arrays for each input. We can now execute the following pipelined computation.\n", + "Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for\n", + "example, split along the leading axis, resulting in two `(256, 512)` arrays\n", + "for each input.\n", + "We can now execute the following pipelined computation.\n", "\n", "1. Copy `x1` and `y1` into VMEM.\n", "1. Start copying `x2` and `y2` into VMEM\n", @@ -180,9 +232,15 @@ "10. Start copying `z2` from VMEM back into HBM.\n", "10. Wait until `z2` is copied into HBM.\n", "\n", - "Any time we are doing compute here, we are asynchronously copying something. This means that some of the time spent copying is not wasted.\n", + "Any time we are doing compute here, we are asynchronously copying something.\n", + "This means that some of the time spent copying is not wasted.\n", "\n", - "The two most important numbers for determining how efficient a pipelined computation are a) how many floating point operations (FLOPs) we need to execute and b) how many bytes we need to copy to execute that computation. The ratio of these two (FLOPs/memory usage) is called the *arithmetic intensity* of an operation and determines if our pipeline will be compute bound or memory bound." + "The two most important numbers for determining how efficient a pipelined\n", + "computation are a) how many floating point operations (FLOPs) we need to\n", + "execute and b) how many bytes we need to copy to execute that computation.\n", + "The ratio of these two (FLOPs/memory usage) is called the\n", + "*arithmetic intensity* of an operation and determines if our pipeline will\n", + "be compute bound or memory bound." ] }, { @@ -200,7 +258,11 @@ "id": "U-dPTjlBverB" }, "source": [ - "How do we implement a pipeline like the one above in Pallas? It seems like a complex sequence of asynchronous data operations and executing kernels that would be a pain to implement manually. Fear not! Pallas offers an API for expressing pipelines without too much boilerplate, namely through `grid`s and `BlockSpec`s." + "How do we implement a pipeline like the one above in Pallas?\n", + "It seems like a complex sequence of asynchronous data operations and\n", + "executing kernels that would be a pain to implement manually.\n", + "Fear not! Pallas offers an API for expressing pipelines without too much\n", + "boilerplate, namely through `grid`s and `BlockSpec`s." ] }, { @@ -211,9 +273,15 @@ "source": [ "### `grid`, a.k.a. kernels in a loop\n", "\n", - "See how in the above pipelined example, we are executing the same logic multiple times: steps 3-5 and 8-10 both execute the same operations, only on different inputs. The generalized version of this is a loop in which the same kernel is executed multiple times. `pallas_call` provides an option to do exactly that.\n", + "See how in the above pipelined example, we are executing the same logic\n", + "multiple times: steps 3-5 and 8-10 both execute the same operations,\n", + "only on different inputs.\n", + "The generalized version of this is a loop in which the same kernel is\n", + "executed multiple times.\n", + "`pallas_call` provides an option to do exactly that.\n", "\n", - "The number of iterations in the loop is specified via the `grid` argument to `pallas_call`. Conceptually:\n", + "The number of iterations in the loop is specified via the `grid` argument\n", + "to `pallas_call`. Conceptually:\n", "```python\n", "pl.pallas_call(some_kernel, grid=n)(...)\n", "```\n", @@ -224,7 +292,8 @@ " some_kernel(...)\n", " # do VMEM -> HBM copies\n", "```\n", - "Grids can be generalized to be multi-dimensional, corresponding to nested loops. For example,\n", + "Grids can be generalized to be multi-dimensional, corresponding to nested\n", + "loops. For example,\n", "\n", "```python\n", "pl.pallas_call(some_kernel, grid=(n, m))(...)\n", @@ -237,7 +306,8 @@ " some_kernel(...)\n", " # do VMEM -> HBM copies\n", "```\n", - "This generalizes to any tuple of integers (a length `d` grid will correspond to `d` nested loops)." + "This generalizes to any tuple of integers (a length `d` grid will correspond\n", + "to `d` nested loops)." ] }, { @@ -255,11 +325,22 @@ "id": "miWgPkytyIIa" }, "source": [ - "The next piece of information we need to provide Pallas in order to automatically pipeline our computation is information on how to chunk it up. Specifically, we need to provide a mapping between *the iteration of the loop* to *which block of our inputs and outputs to be operated on*. A `BlockSpec` is exactly these two pieces of information.\n", + "The next piece of information we need to provide Pallas in order to\n", + "automatically pipeline our computation is information on how to chunk it up.\n", + "Specifically, we need to provide a mapping between *the iteration of the loop*\n", + "to *which block of our inputs and outputs to be operated on*.\n", + "A `BlockSpec` is exactly these two pieces of information.\n", "\n", - " First we pick a `block_shape` for our inputs. In the pipelining example above, we had `(512, 512)`-shaped arrays and split them along the leading dimension into two `(256, 512)`-shaped arrays. In this pipeline, our `block_shape` would be `(256, 512)`.\n", + "First we pick a `block_shape` for our inputs.\n", + "In the pipelining example above, we had `(512, 512)`-shaped arrays and\n", + "split them along the leading dimension into two `(256, 512)`-shaped arrays.\n", + "In this pipeline, our `block_shape` would be `(256, 512)`.\n", "\n", - "We then provide an `index_map` function that maps the iteration space to the blocks. Specifically, in the aforementioned pipeline, on the 1st iteration we'd like to select `x1` and on the second iteration we'd like to use `x2`. This can be expressed with the following `index_map`:\n", + "We then provide an `index_map` function that maps the iteration space to the\n", + "blocks.\n", + "Specifically, in the aforementioned pipeline, on the 1st iteration we'd\n", + "like to select `x1` and on the second iteration we'd like to use `x2`.\n", + "This can be expressed with the following `index_map`:\n", "\n", "```python\n", "def x_index_map(i):\n", @@ -282,7 +363,9 @@ "source": [ "### Putting it together\n", "\n", - "We provide these arguments to `pallas_call` via `grid`, `in_specs` and `out_specs` (`in_specs` corresponds to the tuple of positional arguments, and `out_specs` corresponds to the output)." + "We provide these arguments to `pallas_call` via `grid`, `in_specs` and\n", + "`out_specs` (`in_specs` corresponds to the tuple of positional arguments,\n", + "and `out_specs` corresponds to the output)." ] }, { @@ -329,9 +412,17 @@ "id": "rkytgIZYzz4t" }, "source": [ - "We've only added a little bit of code to our original function to add automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy lifting!\n", + "We've only added a little bit of code to our original function to add\n", + "automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy\n", + "lifting!\n", "\n", - "How does it work? Well, the `BlockSpec`s provide enough information to start *prefetching* blocks of our input from HBM into VMEM. For example, if we are starting iteration `i` of our `grid`, we can pass `i + 1` into the `index_map` functions to obtain the blocks needed for the next iteration. We can then start an asynchronous copy for those blocks. Similarly for outputs, we can wait for the outputs of the previous iteration to be copied before starting the copy for the current iteration's outputs." + "How does it work? Well, the `BlockSpec`s provide enough information to start\n", + "*prefetching* blocks of our input from HBM into VMEM.\n", + "For example, if we are starting iteration `i` of our `grid`, we can pass\n", + "`i + 1` into the `index_map` functions to obtain the blocks needed for the\n", + "next iteration. We can then start an asynchronous copy for those blocks.\n", + "Similarly for outputs, we can wait for the outputs of the previous iteration\n", + "to be copied before starting the copy for the current iteration's outputs." ] }, { @@ -349,9 +440,15 @@ "id": "esY4GcIB0bqQ" }, "source": [ - "It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do).\n", + "It's common to parameterize the block shapes in our kernel. Block sizes are\n", + "perhaps the most important parameter to tune when optimizing the performance\n", + "of Pallas kernels! They give us control over the pipeline (for example,\n", + "picking smaller blocks adds more iterations to our pipelined loop where each\n", + "iteration has less work to do).\n", "\n", - "Furthermore, we could also carve up the inputs and outputs along the 2nd dimension (we are only splitting along the first right now). Let's write a more general kernel that handles both of these features." + "Furthermore, we could also carve up the inputs and outputs along the 2nd\n", + "dimension (we are only splitting along the first right now). Let's write a\n", + "more general kernel that handles both of these features." ] }, { @@ -403,9 +500,11 @@ "id": "P3SqEKDe3Mar" }, "source": [ - "How would you implement something like `jnp.sum` using `pallas_call`? Specifically, we'd like to pipeline across the reduction dimension.\n", + "How would you implement something like `jnp.sum` using `pallas_call`?\n", + "Specifically, we'd like to pipeline across the reduction dimension.\n", "\n", - "Take the example of reducing a `(8, 512, 512)`-shaped array to a `(512, 512)`-shaped one." + "Take the example of reducing a `(8, 512, 512)`-shaped array to a\n", + "`(512, 512)`-shaped one." ] }, { @@ -444,7 +543,10 @@ "id": "5O3ByvuT3iyC" }, "source": [ - "To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration `i` load `x[i]` into VMEM. Then we could add `x[i]` to an output VMEM buffer. Let's implement this naively first." + "To do this using `pallas_call`, we could use a grid of size `(8,)` and in\n", + "each iteration `i` load `x[i]` into VMEM.\n", + "Then we could add `x[i]` to an output VMEM buffer. Let's implement this\n", + "naively first." ] }, { @@ -497,11 +599,29 @@ "id": "Kv9qJYJY4jbK" }, "source": [ - "Notice how we've set up the `BlockSpec`s: we're loading the entirety of the `(512, 512)` dimension into VMEM (no pipelining there) but selecting the `i`-th dimension of `x` each iteration in the `index_map`. We are using a `None` for that dimension in the block shape, which indicates that we are selecting a singleton dimension from `x` that we would like to squeeze away in the kernel. Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well.\n", + "Notice how we've set up the `BlockSpec`s: we're loading the entirety of\n", + "the `(512, 512)` dimension into VMEM (no pipelining there) but selecting\n", + "the `i`-th dimension of `x` each iteration in the `index_map`.\n", + "We are using a `None` for that dimension in the block shape, which indicates\n", + "that we are selecting a singleton dimension from `x` that we would like\n", + "to squeeze away in the kernel.\n", + "Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well.\n", "\n", - "`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that `o_ref` is unchanged over the course of the pipeline. This means that we can update its value each iteration by reading from and writing to it. Or can it? Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll be accumulating into garbage. This will result in the overall function outputting the incorrect value!\n", + "`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that\n", + "`o_ref` is unchanged over the course of the pipeline.\n", + "This means that we can update its value each iteration by reading from and\n", + "writing to it. Or can it?\n", + "Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll\n", + "be accumulating into garbage.\n", + "This will result in the overall function outputting the incorrect value!\n", "\n", - "Therefore, **whenever we do a reduction in a kernel, we need to make sure to initialize the `Ref` that is storing the reduced value**. We can accomplish this by conditionally writing a value to `out_ref` when we're on iteration 0. We can do this with the helper function `pl.when`, a convenience wrapper around `jax.lax.cond`, and `pl.program_id`, which queries which iteration in a grid axis we are in." + "Therefore, **whenever we do a reduction in a kernel, we need to make sure\n", + "to initialize the `Ref` that is storing the reduced value**.\n", + "We can accomplish this by conditionally writing a value to `out_ref`\n", + "when we're on iteration 0.\n", + "We can do this with the helper function `pl.when`, a convenience wrapper\n", + "around `jax.lax.cond`, and `pl.program_id`,\n", + "which queries which iteration in a grid axis we are in." ] }, { @@ -558,7 +678,16 @@ "source": [ "This `sum` function now outputs the correct values!\n", "\n", - "One last thing to note about reductions in Pallas are that **they must be done in the minormost (rightmost) dimensions of our grid** (our grid is 1-dimensional in the above example so we are reducing over its minormost dimension). This is because the pipeline that Pallas generates using the `BlockSpec`s, `grid` and kernel function *does not read outputs back from HBM*. Once you've written an output value back to HBM you cannot revisit it. Therefore, you cannot do a reduction across a grid dimension that has any revisiting and therefore all reductions need to happen in the rightmost dimensions." + "One last thing to note about reductions in Pallas are that **they must be\n", + "done in the minormost (rightmost) dimensions of our grid** (our grid is\n", + "1-dimensional in the above example so we are reducing over its minormost\n", + "dimension). This is because the pipeline that Pallas generates using\n", + "the `BlockSpec`s, `grid` and kernel function *does not read outputs back\n", + "from HBM*.\n", + "Once you've written an output value back to HBM you cannot revisit it.\n", + "Therefore, you cannot do a reduction across a grid dimension that has any\n", + "revisiting and therefore all reductions need to happen in the rightmost\n", + "dimensions." ] }, { @@ -576,13 +705,21 @@ "id": "0f4HAVzQ8n71" }, "source": [ - "Some TPU chips have two TensorCores but appear as one device to JAX users. This is called \"megacore\". The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs and compute units but *share HBM*.\n", + "Some TPU chips have two TensorCores but appear as one device to JAX users.\n", + "This is called \"megacore\".\n", + "The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs\n", + "and compute units but *share HBM*.\n", "\n", "![TPU Memory Space Cartoon (Megacore).png]()\n", "\n", - "Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have only two threads. How do we modify our kernels to utilize both TensorCores simultaneously?\n", + "Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have\n", + "only two threads.\n", + "How do we modify our kernels to utilize both TensorCores simultaneously?\n", "\n", - "The basic idea is that if we have embarrassingly parallel dimensions in our computation, we can split up those dimensions across the TensorCores. We can indicate which dimensions are parallelizable by providing an annotation to `pallas_call` called `dimension_semantics`." + "The basic idea is that if we have embarrassingly parallel dimensions in our\n", + "computation, we can split up those dimensions across the TensorCores.\n", + "We can indicate which dimensions are parallelizable by providing an\n", + "annotation to `pallas_call` called `dimension_semantics`." ] }, { @@ -632,9 +769,12 @@ "id": "xG51AiUC-8cl" }, "source": [ - "`dimension_semantics` should be a tuple of same length as `grid` where each entry is either `\"parallel\"` or `\"arbitrary\"`. `\"parallel\"` indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. `\"arbitrary\"` indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized.\n", + "`dimension_semantics` should be a tuple of same length as `grid` where each\n", + "entry is either `\"parallel\"` or `\"arbitrary\"`. `\"parallel\"` indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. `\"arbitrary\"` indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized.\n", "\n", - "By specifying `dimension_semantics`, we now execute the kernel simultaneously on each TensorCore. Pallas will handle splitting up the grid automatically.\n", + "By specifying `dimension_semantics`, we now execute the kernel\n", + "simultaneously on each TensorCore. Pallas will handle splitting up the grid\n", + "automatically.\n", "\n", "> Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available)." ] @@ -647,7 +787,11 @@ "source": [ "## Conclusion\n", "\n", - "In this guide we covered how to express TPU pipelines using `pallas_call`, `grid` and `BlockSpec`s. We covered how to express nested loops via a multi-dimensional grid and how to handle reductions by initialize our accumulators at the beginning of the reduction. We also learned how to handle Megacore by adding annotations to the kernel.\n", + "In this guide we covered how to express TPU pipelines using `pallas_call`,\n", + "`grid` and `BlockSpec`s. We covered how to express nested loops via a\n", + "multi-dimensional grid and how to handle reductions by initialize our\n", + "accumulators at the beginning of the reduction.\n", + "We also learned how to handle Megacore by adding annotations to the kernel.\n", "\n", "Exercises left to the reader:\n", "* Try implementing a `sum` kernel that pipelines the other dimensions as well\n", diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 6acae60cf9b7..77d029229ceb 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -15,9 +15,12 @@ kernelspec: # Pipelining and `BlockSpec`s + + +++ {"id": "gAJDZh1gBh-h"} -In this guide we'll cover how memory spaces in TPU work and how to write pipelines in Pallas that overlap memory I/O with compute. +In this guide we'll cover how memory spaces in TPU work and how to write +pipelines in Pallas that overlap memory I/O with compute. ```{code-cell} :id: ejAVO6ikUUuF @@ -34,17 +37,33 @@ import numpy as np ## TPU and its memory spaces -A TPU and its TensorCore consist of memory spaces (where arrays can reside), registers (which temporarily store scalar and array values) and compute units (that do computation with values in registers). Below is a diagram of a TPU in which `x` and `y` are arrays that live in high-bandwidth memory (HBM): +A TPU and its TensorCore consist of memory spaces (where arrays can reside), +registers (which temporarily store scalar and array values) and compute units +(that do computation with values in registers). +Below is a diagram of a TPU in which `x` and `y` are arrays that live in +high-bandwidth memory (HBM): ![TPU Memory Space Cartoon.png]() Let's talk about the components of this diagram in more detail: -* **Memory spaces**: A TPU has high-bandwidth memory (HBM) which is what we often think of as "device memory". There is also vector memory (VMEM), a cache meant for storing vector and array values, and scalar memory (SMEM), a cache designed to store scalar values. -* **Registers**: A TensorCore has two main types of registers: vector registers (VREGs) store array values, and scalar registers (SREGs) store scalar values. Values can be loaded into memory from their respective caches (VMEM for VREGs and SMEM for SREGs). -* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and matrix unit (MXU) that can do numerical computation. Compute units operate on values that live in SREGs and VREGs and output values into those registers as well. - -In order to do a vectorized computation on our values `x` and `y` that live in HBM, we need to: +* **Memory spaces**: A TPU has high-bandwidth memory (HBM) which is what we + often think of as "device memory". + There is also vector memory (VMEM), + a cache meant for storing vector and array values, and scalar memory (SMEM), + a cache designed to store scalar values. +* **Registers**: A TensorCore has two main types of registers: vector + registers (VREGs) store array values, and scalar registers (SREGs) store + scalar values. + Values can be loaded into memory from their respective caches (VMEM for + VREGs and SMEM for SREGs). +* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and + matrix unit (MXU) that can do numerical computation. + Compute units operate on values that live in SREGs and VREGs and output + values into those registers as well. + +In order to do a vectorized computation on our values `x` and `y` that live +in HBM, we need to: 1. Copy the values `x` and `y` into VMEM. 2. Load the values from VMEM into VREGs. @@ -88,33 +107,66 @@ add_matrices(x, y) We've written two functions: `add_matrices_kernel` and `add_matrices`. -`add_matrices_kernel` operates using `Ref`s that live in VMEM. Loading from a VMEM `Ref` produces a value that lives in VREGs. Values in VREGs behave like `jax.Array`s in that we can use `jnp` and `jax.lax` operations on them to produce new values that live in VREGs. When we produce the values we'd like to return, we store them in the output VMEM `Ref`. - -The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into `pallas_call`. `pallas_call` is responsible for copying `x` and `y` into VMEM and for allocating the VMEM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output VMEM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`. +`add_matrices_kernel` operates using `Ref`s that live in VMEM. +Loading from a VMEM `Ref` produces a value that lives in VREGs. +Values in VREGs behave like `jax.Array`s in that we can use `jnp` and +`jax.lax` operations on them to produce new values that live in VREGs. +When we produce the values we'd like to return, we store them in the output +VMEM `Ref`. + +The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. +Inside it, we pass `x` and `y` into `pallas_call`. +`pallas_call` is responsible for copying `x` and `y` into VMEM and for +allocating the VMEM buffers that the kernel operates on (including allocating +`z_vmem_ref`, the output VMEM buffer). +After the kernel function is finished running, `pallas_call` will also copy +the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`. +++ {"id": "5kWr-1tKpYro"} ## Constraints of using VMEM/SMEM -Pallas exposes access to lower level memory spaces like VMEM and SMEM but writing kernels utilizing them adds some considerations. +Pallas exposes access to lower level memory spaces like VMEM and SMEM but +writing kernels utilizing them adds some considerations. -1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB and SMEM ranges in the tens to hundreds of KiB. If our arrays are too big, we won't even be able to fit them into VMEM at all. For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't scale beyond moderately sized arrays. +1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB + and SMEM ranges in the tens to hundreds of KiB. + If our arrays are too big, we won't even be able to fit them into VMEM at all. + For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't + scale beyond moderately sized arrays. -2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least compared to most compute instructions. The `add_matrices` function above will likely spend more time copying between HBM and VMEM than actually performing the addition itself. +2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least + compared to most compute instructions. + The `add_matrices` function above will likely spend more time copying + between HBM and VMEM than actually performing the addition itself. -With these two constraints in mind, we'll have to rethink our strategy for getting performance out of our TPUs. +With these two constraints in mind, we'll have to rethink our strategy for +getting performance out of our TPUs. +++ {"id": "_NTqvlbetB3P"} ## Primer: Pipelining -Pipelining our computation offers a way of dealing with both the memory capacity and bandwidth constraints in one fell swoop. What do we mean by pipelining? +Pipelining our computation offers a way of dealing with both the memory +capacity and bandwidth constraints in one fell swoop. +What do we mean by pipelining? -The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our compute units. Naively this is difficult because in our program above we copy *all* of `x` and `y` before we start doing any compute with them, creating a dependence between the copy and the compute. +The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our +compute units. +Naively this is difficult because in our program above we copy *all* of `x` +and `y` before we start doing any compute with them, creating a dependence +between the copy and the compute. -However, if we can chunk up our computation into several subcomputations (e.g. when we add two matrices, we can express that as addition of "blocks" of the original matrices together), we can now overlap the copies of one of those subcomputations with the compute of the other. Let's walk through a simple example: +However, if we can chunk up our computation into several subcomputations +(e.g. when we add two matrices, we can express that as addition of "blocks" +of the original matrices together), we can now overlap the copies of one of +those subcomputations with the compute of the other. Let's walk through a +simple example: -Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for example, split along the leading axis, resulting in two `(256, 512)` arrays for each input. We can now execute the following pipelined computation. +Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for +example, split along the leading axis, resulting in two `(256, 512)` arrays +for each input. +We can now execute the following pipelined computation. 1. Copy `x1` and `y1` into VMEM. 1. Start copying `x2` and `y2` into VMEM @@ -130,9 +182,15 @@ Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for exampl 10. Start copying `z2` from VMEM back into HBM. 10. Wait until `z2` is copied into HBM. -Any time we are doing compute here, we are asynchronously copying something. This means that some of the time spent copying is not wasted. +Any time we are doing compute here, we are asynchronously copying something. +This means that some of the time spent copying is not wasted. -The two most important numbers for determining how efficient a pipelined computation are a) how many floating point operations (FLOPs) we need to execute and b) how many bytes we need to copy to execute that computation. The ratio of these two (FLOPs/memory usage) is called the *arithmetic intensity* of an operation and determines if our pipeline will be compute bound or memory bound. +The two most important numbers for determining how efficient a pipelined +computation are a) how many floating point operations (FLOPs) we need to +execute and b) how many bytes we need to copy to execute that computation. +The ratio of these two (FLOPs/memory usage) is called the +*arithmetic intensity* of an operation and determines if our pipeline will +be compute bound or memory bound. +++ {"id": "gutx7y8uvZKH"} @@ -140,15 +198,25 @@ The two most important numbers for determining how efficient a pipelined computa +++ {"id": "U-dPTjlBverB"} -How do we implement a pipeline like the one above in Pallas? It seems like a complex sequence of asynchronous data operations and executing kernels that would be a pain to implement manually. Fear not! Pallas offers an API for expressing pipelines without too much boilerplate, namely through `grid`s and `BlockSpec`s. +How do we implement a pipeline like the one above in Pallas? +It seems like a complex sequence of asynchronous data operations and +executing kernels that would be a pain to implement manually. +Fear not! Pallas offers an API for expressing pipelines without too much +boilerplate, namely through `grid`s and `BlockSpec`s. +++ {"id": "x-LQKu8HwED7"} ### `grid`, a.k.a. kernels in a loop -See how in the above pipelined example, we are executing the same logic multiple times: steps 3-5 and 8-10 both execute the same operations, only on different inputs. The generalized version of this is a loop in which the same kernel is executed multiple times. `pallas_call` provides an option to do exactly that. +See how in the above pipelined example, we are executing the same logic +multiple times: steps 3-5 and 8-10 both execute the same operations, +only on different inputs. +The generalized version of this is a loop in which the same kernel is +executed multiple times. +`pallas_call` provides an option to do exactly that. -The number of iterations in the loop is specified via the `grid` argument to `pallas_call`. Conceptually: +The number of iterations in the loop is specified via the `grid` argument +to `pallas_call`. Conceptually: ```python pl.pallas_call(some_kernel, grid=n)(...) ``` @@ -159,7 +227,8 @@ for i in range(n): some_kernel(...) # do VMEM -> HBM copies ``` -Grids can be generalized to be multi-dimensional, corresponding to nested loops. For example, +Grids can be generalized to be multi-dimensional, corresponding to nested +loops. For example, ```python pl.pallas_call(some_kernel, grid=(n, m))(...) @@ -172,7 +241,8 @@ for i in range(n): some_kernel(...) # do VMEM -> HBM copies ``` -This generalizes to any tuple of integers (a length `d` grid will correspond to `d` nested loops). +This generalizes to any tuple of integers (a length `d` grid will correspond +to `d` nested loops). +++ {"id": "hRLr5JeyyEwM"} @@ -180,11 +250,22 @@ This generalizes to any tuple of integers (a length `d` grid will correspond to +++ {"id": "miWgPkytyIIa"} -The next piece of information we need to provide Pallas in order to automatically pipeline our computation is information on how to chunk it up. Specifically, we need to provide a mapping between *the iteration of the loop* to *which block of our inputs and outputs to be operated on*. A `BlockSpec` is exactly these two pieces of information. +The next piece of information we need to provide Pallas in order to +automatically pipeline our computation is information on how to chunk it up. +Specifically, we need to provide a mapping between *the iteration of the loop* +to *which block of our inputs and outputs to be operated on*. +A `BlockSpec` is exactly these two pieces of information. - First we pick a `block_shape` for our inputs. In the pipelining example above, we had `(512, 512)`-shaped arrays and split them along the leading dimension into two `(256, 512)`-shaped arrays. In this pipeline, our `block_shape` would be `(256, 512)`. +First we pick a `block_shape` for our inputs. +In the pipelining example above, we had `(512, 512)`-shaped arrays and +split them along the leading dimension into two `(256, 512)`-shaped arrays. +In this pipeline, our `block_shape` would be `(256, 512)`. -We then provide an `index_map` function that maps the iteration space to the blocks. Specifically, in the aforementioned pipeline, on the 1st iteration we'd like to select `x1` and on the second iteration we'd like to use `x2`. This can be expressed with the following `index_map`: +We then provide an `index_map` function that maps the iteration space to the +blocks. +Specifically, in the aforementioned pipeline, on the 1st iteration we'd +like to select `x1` and on the second iteration we'd like to use `x2`. +This can be expressed with the following `index_map`: ```python def x_index_map(i): @@ -202,7 +283,9 @@ The `BlockSpec`s for `y` and `z` will be the same as the one for `x`. ### Putting it together -We provide these arguments to `pallas_call` via `grid`, `in_specs` and `out_specs` (`in_specs` corresponds to the tuple of positional arguments, and `out_specs` corresponds to the output). +We provide these arguments to `pallas_call` via `grid`, `in_specs` and +`out_specs` (`in_specs` corresponds to the tuple of positional arguments, +and `out_specs` corresponds to the output). ```{code-cell} :id: ehKAYAwIojfv @@ -222,9 +305,17 @@ add_matrices_pipelined(x, y) +++ {"id": "rkytgIZYzz4t"} -We've only added a little bit of code to our original function to add automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy lifting! +We've only added a little bit of code to our original function to add +automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy +lifting! -How does it work? Well, the `BlockSpec`s provide enough information to start *prefetching* blocks of our input from HBM into VMEM. For example, if we are starting iteration `i` of our `grid`, we can pass `i + 1` into the `index_map` functions to obtain the blocks needed for the next iteration. We can then start an asynchronous copy for those blocks. Similarly for outputs, we can wait for the outputs of the previous iteration to be copied before starting the copy for the current iteration's outputs. +How does it work? Well, the `BlockSpec`s provide enough information to start +*prefetching* blocks of our input from HBM into VMEM. +For example, if we are starting iteration `i` of our `grid`, we can pass +`i + 1` into the `index_map` functions to obtain the blocks needed for the +next iteration. We can then start an asynchronous copy for those blocks. +Similarly for outputs, we can wait for the outputs of the previous iteration +to be copied before starting the copy for the current iteration's outputs. +++ {"id": "7Xtz9oMs0ZRL"} @@ -232,9 +323,15 @@ How does it work? Well, the `BlockSpec`s provide enough information to start *pr +++ {"id": "esY4GcIB0bqQ"} -It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do). +It's common to parameterize the block shapes in our kernel. Block sizes are +perhaps the most important parameter to tune when optimizing the performance +of Pallas kernels! They give us control over the pipeline (for example, +picking smaller blocks adds more iterations to our pipelined loop where each +iteration has less work to do). -Furthermore, we could also carve up the inputs and outputs along the 2nd dimension (we are only splitting along the first right now). Let's write a more general kernel that handles both of these features. +Furthermore, we could also carve up the inputs and outputs along the 2nd +dimension (we are only splitting along the first right now). Let's write a +more general kernel that handles both of these features. ```{code-cell} :id: VartelFd0YfY @@ -271,9 +368,11 @@ np.testing.assert_array_equal( +++ {"id": "P3SqEKDe3Mar"} -How would you implement something like `jnp.sum` using `pallas_call`? Specifically, we'd like to pipeline across the reduction dimension. +How would you implement something like `jnp.sum` using `pallas_call`? +Specifically, we'd like to pipeline across the reduction dimension. -Take the example of reducing a `(8, 512, 512)`-shaped array to a `(512, 512)`-shaped one. +Take the example of reducing a `(8, 512, 512)`-shaped array to a +`(512, 512)`-shaped one. ```{code-cell} :id: JoT-ZKEk1R7l @@ -285,7 +384,10 @@ jnp.sum(x, axis=0) +++ {"id": "5O3ByvuT3iyC"} -To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration `i` load `x[i]` into VMEM. Then we could add `x[i]` to an output VMEM buffer. Let's implement this naively first. +To do this using `pallas_call`, we could use a grid of size `(8,)` and in +each iteration `i` load `x[i]` into VMEM. +Then we could add `x[i]` to an output VMEM buffer. Let's implement this +naively first. ```{code-cell} :id: hqvv_WRQ3bvP @@ -311,11 +413,29 @@ naive_sum(x) +++ {"id": "Kv9qJYJY4jbK"} -Notice how we've set up the `BlockSpec`s: we're loading the entirety of the `(512, 512)` dimension into VMEM (no pipelining there) but selecting the `i`-th dimension of `x` each iteration in the `index_map`. We are using a `None` for that dimension in the block shape, which indicates that we are selecting a singleton dimension from `x` that we would like to squeeze away in the kernel. Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well. - -`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that `o_ref` is unchanged over the course of the pipeline. This means that we can update its value each iteration by reading from and writing to it. Or can it? Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll be accumulating into garbage. This will result in the overall function outputting the incorrect value! - -Therefore, **whenever we do a reduction in a kernel, we need to make sure to initialize the `Ref` that is storing the reduced value**. We can accomplish this by conditionally writing a value to `out_ref` when we're on iteration 0. We can do this with the helper function `pl.when`, a convenience wrapper around `jax.lax.cond`, and `pl.program_id`, which queries which iteration in a grid axis we are in. +Notice how we've set up the `BlockSpec`s: we're loading the entirety of +the `(512, 512)` dimension into VMEM (no pipelining there) but selecting +the `i`-th dimension of `x` each iteration in the `index_map`. +We are using a `None` for that dimension in the block shape, which indicates +that we are selecting a singleton dimension from `x` that we would like +to squeeze away in the kernel. +Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well. + +`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that +`o_ref` is unchanged over the course of the pipeline. +This means that we can update its value each iteration by reading from and +writing to it. Or can it? +Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll +be accumulating into garbage. +This will result in the overall function outputting the incorrect value! + +Therefore, **whenever we do a reduction in a kernel, we need to make sure +to initialize the `Ref` that is storing the reduced value**. +We can accomplish this by conditionally writing a value to `out_ref` +when we're on iteration 0. +We can do this with the helper function `pl.when`, a convenience wrapper +around `jax.lax.cond`, and `pl.program_id`, +which queries which iteration in a grid axis we are in. ```{code-cell} :id: JXN2RthX5cSw @@ -345,7 +465,16 @@ sum(x) This `sum` function now outputs the correct values! -One last thing to note about reductions in Pallas are that **they must be done in the minormost (rightmost) dimensions of our grid** (our grid is 1-dimensional in the above example so we are reducing over its minormost dimension). This is because the pipeline that Pallas generates using the `BlockSpec`s, `grid` and kernel function *does not read outputs back from HBM*. Once you've written an output value back to HBM you cannot revisit it. Therefore, you cannot do a reduction across a grid dimension that has any revisiting and therefore all reductions need to happen in the rightmost dimensions. +One last thing to note about reductions in Pallas are that **they must be +done in the minormost (rightmost) dimensions of our grid** (our grid is +1-dimensional in the above example so we are reducing over its minormost +dimension). This is because the pipeline that Pallas generates using +the `BlockSpec`s, `grid` and kernel function *does not read outputs back +from HBM*. +Once you've written an output value back to HBM you cannot revisit it. +Therefore, you cannot do a reduction across a grid dimension that has any +revisiting and therefore all reductions need to happen in the rightmost +dimensions. +++ {"id": "KvPFez9N8cKJ"} @@ -353,13 +482,21 @@ One last thing to note about reductions in Pallas are that **they must be done i +++ {"id": "0f4HAVzQ8n71"} -Some TPU chips have two TensorCores but appear as one device to JAX users. This is called "megacore". The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs and compute units but *share HBM*. +Some TPU chips have two TensorCores but appear as one device to JAX users. +This is called "megacore". +The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs +and compute units but *share HBM*. ![TPU Memory Space Cartoon (Megacore).png]() -Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have only two threads. How do we modify our kernels to utilize both TensorCores simultaneously? +Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have +only two threads. +How do we modify our kernels to utilize both TensorCores simultaneously? -The basic idea is that if we have embarrassingly parallel dimensions in our computation, we can split up those dimensions across the TensorCores. We can indicate which dimensions are parallelizable by providing an annotation to `pallas_call` called `dimension_semantics`. +The basic idea is that if we have embarrassingly parallel dimensions in our +computation, we can split up those dimensions across the TensorCores. +We can indicate which dimensions are parallelizable by providing an +annotation to `pallas_call` called `dimension_semantics`. ```{code-cell} :id: nQNa8RaQ-TR1 @@ -382,9 +519,12 @@ add_matrices_pipelined_megacore(x, y) +++ {"id": "xG51AiUC-8cl"} -`dimension_semantics` should be a tuple of same length as `grid` where each entry is either `"parallel"` or `"arbitrary"`. `"parallel"` indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. `"arbitrary"` indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized. +`dimension_semantics` should be a tuple of same length as `grid` where each +entry is either `"parallel"` or `"arbitrary"`. `"parallel"` indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. `"arbitrary"` indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized. -By specifying `dimension_semantics`, we now execute the kernel simultaneously on each TensorCore. Pallas will handle splitting up the grid automatically. +By specifying `dimension_semantics`, we now execute the kernel +simultaneously on each TensorCore. Pallas will handle splitting up the grid +automatically. > Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available). @@ -392,7 +532,11 @@ By specifying `dimension_semantics`, we now execute the kernel simultaneously on ## Conclusion -In this guide we covered how to express TPU pipelines using `pallas_call`, `grid` and `BlockSpec`s. We covered how to express nested loops via a multi-dimensional grid and how to handle reductions by initialize our accumulators at the beginning of the reduction. We also learned how to handle Megacore by adding annotations to the kernel. +In this guide we covered how to express TPU pipelines using `pallas_call`, +`grid` and `BlockSpec`s. We covered how to express nested loops via a +multi-dimensional grid and how to handle reductions by initialize our +accumulators at the beginning of the reduction. +We also learned how to handle Megacore by adding annotations to the kernel. Exercises left to the reader: * Try implementing a `sum` kernel that pipelines the other dimensions as well diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index 7fd0e81a96fb..2f748825af1f 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -1,5 +1,7 @@ # Persistent Compilation Cache + + JAX has an optional disk cache for compiled programs. If enabled, JAX will store copies of compiled programs on disk, which can save recompilation time when running the same or similar tasks repeatedly. diff --git a/docs/profiling.md b/docs/profiling.md index fe92b1b0e934..6eceec8f54b8 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -1,5 +1,7 @@ # Profiling JAX programs + + ## Viewing program traces with Perfetto We can use the JAX profiler to generate traces of a JAX program that can be diff --git a/docs/pytrees.md b/docs/pytrees.md index 80860b1b8dde..a39c36db5de6 100644 --- a/docs/pytrees.md +++ b/docs/pytrees.md @@ -16,6 +16,8 @@ language_info: # Pytrees + + ## What is a pytree? In JAX, we use the term *pytree* to refer to a tree-like structure built out of diff --git a/docs/quickstart.md b/docs/quickstart.md index 5c3562b8b2ea..91ac5a63be20 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -14,6 +14,8 @@ kernelspec: # Quickstart + + **JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**. This document provides a quick overview of essential JAX features, so you can get started with JAX quickly: @@ -27,7 +29,7 @@ This document provides a quick overview of essential JAX features, so you can ge JAX can be installed for CPU on Linux, Windows, and macOS directly from the [Python Package Index](https://pypi.org/project/jax/): ``` -pip install "jax[cpu]" +pip install jax ``` or, for NVIDIA GPU: ``` diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 4a88ed5cc865..85bb5ce01974 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -15,6 +15,8 @@ kernelspec: (pseudorandom-numbers)= # Pseudorandom numbers + + In this section we focus on {mod}`jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution. PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next. diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index e6c16e2de7e1..8fa2107795fd 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -7,6 +7,8 @@ "(sharded-computation)=\n", "# Introduction to sharded computation\n", "\n", + "\n", + "\n", "This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.\n", "\n", "The tutorial covers three modes of parallel computation:\n", diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index 6a7dd36c2083..345ca7987b41 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -14,6 +14,8 @@ kernelspec: (sharded-computation)= # Introduction to sharded computation + + This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs. The tutorial covers three modes of parallel computation: diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index b802be0e0afc..5a8af2b74142 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -14,6 +14,8 @@ kernelspec: # Stateful Computations + + JAX transformations like {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, require the functions they wrap to be pure: that is, functions whose outputs depend *solely* on the inputs, and which have no side effects such as updating of global state. diff --git a/docs/type_promotion.rst b/docs/type_promotion.rst index c70371478f8e..103a8331df2b 100644 --- a/docs/type_promotion.rst +++ b/docs/type_promotion.rst @@ -226,7 +226,7 @@ context manager: >>> x = jnp.float32(1) >>> y = jnp.int32(1) >>> with jax.numpy_dtype_promotion('strict'): - ... z = x + y # doctest: +IGNORE_EXCEPTION_DETAIL + ... z = x + y # doctest: +SKIP ... Traceback (most recent call last): TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit diff --git a/docs/user_guides.rst b/docs/user_guides.rst index c57609c17fd9..f46d6b027471 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -32,6 +32,7 @@ or deployed codebases. :caption: Run Time aot + export/index errors transfer_guard diff --git a/docs/working-with-pytrees.md b/docs/working-with-pytrees.md index 6521a9b85e4f..2bd1cc08ecdf 100644 --- a/docs/working-with-pytrees.md +++ b/docs/working-with-pytrees.md @@ -22,6 +22,8 @@ kernelspec: (working-with-pytrees)= # Working with pytrees + + JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — in JAX these are called pytrees. This section will explain how to use them, provide useful code examples, and point out common "gotchas" and patterns. diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index b7abb136aca4..fccf0cc37048 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -21,9 +21,9 @@ cc_binary( srcs = ["main.cc"], tags = ["manual"], deps = [ + "//third_party/absl/status:statusor", "@xla//xla:literal", "@xla//xla:literal_util", - "@xla//xla:statusor", "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt/cpu:cpu_client", "@xla//xla/tools:hlo_module_loader", diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 42ceb6f51a9e..2a8f8d4debba 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -40,11 +40,11 @@ limitations under the License. #include #include +#include "third_party/absl/status/statusor.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/statusor.h" #include "xla/tools/hlo_module_loader.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" diff --git a/jax/BUILD b/jax/BUILD index ea3ddcc76d1d..2f7480e31b1a 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -19,7 +19,6 @@ load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "if_building_jaxlib", - "if_building_mosaic_gpu", "jax_extend_internal_users", "jax_extra_deps", "jax_internal_export_back_compat_test_util_visibility", @@ -72,19 +71,6 @@ config_setting( }, ) -# If this flag is true, jaxlib will be built with Mosaic GPU. VERY EXPERIMENTAL. -bool_flag( - name = "build_mosaic_gpu", - build_setting_default = False, -) - -config_setting( - name = "enable_mosaic_gpu", - flag_values = { - ":build_mosaic_gpu": "True", - }, -) - exports_files([ "LICENSE", "version.py", @@ -219,6 +205,7 @@ py_library_providing_imports_info( "_src/ad_checkpoint.py", "_src/api.py", "_src/array.py", + "_src/blocked_sampler.py", "_src/callback.py", "_src/checkify.py", "_src/custom_batching.py", @@ -228,7 +215,6 @@ py_library_providing_imports_info( "_src/dispatch.py", "_src/dlpack.py", "_src/earray.py", - "_src/ffi.py", "_src/flatten_util.py", "_src/interpreters/__init__.py", "_src/interpreters/ad.py", @@ -249,6 +235,7 @@ py_library_providing_imports_info( "_src/debugger/**/*.py", "_src/extend/**/*.py", "_src/image/**/*.py", + "_src/export/**/*.py", "_src/lax/**/*.py", "_src/nn/**/*.py", "_src/numpy/**/*.py", @@ -335,7 +322,7 @@ py_library_providing_imports_info( ":xla", ":xla_bridge", "//jax/_src/lib", - ] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + jax_extra_deps, + ] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps, ) pytype_strict_library( @@ -404,6 +391,7 @@ pytype_strict_library( ":compilation_cache_interface", ":config", ":gfile_cache", + ":lru_cache", ":monitoring", ":path", "//jax/_src/lib", @@ -429,6 +417,15 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "lru_cache", + srcs = ["_src/lru_cache.py"], + deps = [ + ":compilation_cache_interface", + ":config", + ] + py_deps("filelock"), +) + pytype_strict_library( name = "config", srcs = ["_src/config.py"], @@ -468,6 +465,7 @@ pytype_strict_library( deps = [ ":compute_on", ":config", + ":deprecations", ":dtypes", ":effects", ":pretty_printer", @@ -583,6 +581,7 @@ pytype_strict_library( ":partial_eval", ":path", ":pickle_util", + ":sharding", ":sharding_impls", ":source_info_util", ":state_types", @@ -615,6 +614,7 @@ pytype_strict_library( exclude = [ "experimental/pallas/gpu.py", "experimental/pallas/tpu.py", + "experimental/pallas/ops/gpu/**/*.py", "experimental/pallas/ops/tpu/**/*.py", ], ), @@ -635,9 +635,14 @@ pytype_strict_library( ":pallas_tpu_users", ], deps = [ - ":pallas", # buildcleaner: keep + ":pallas", # build_cleaner: keep ":tpu_custom_call", - "//jax/_src/pallas/mosaic", + "//jax/_src/pallas/mosaic:core", + "//jax/_src/pallas/mosaic:lowering", + "//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/mosaic:pipeline", + "//jax/_src/pallas/mosaic:primitives", + "//jax/_src/pallas/mosaic:random", ], ) @@ -675,8 +680,10 @@ pytype_strict_library( ], deps = [ ":pallas", - "//jax/_src/pallas/triton", - ] + if_building_mosaic_gpu(["//third_party/py/jax/_src/pallas/mosaic_gpu"]), + "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/triton:primitives", + ], ) # This target only supports sm_90 GPUs. @@ -692,10 +699,8 @@ py_library( ":jax", ":mlir", "//jax/_src/lib", - "//third_party/py/absl/flags", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", - "//jaxlib/mlir:execution_engine", "//jaxlib/mlir:func_dialect", "//jaxlib/mlir:gpu_dialect", "//jaxlib/mlir:ir", @@ -707,8 +712,7 @@ py_library( "//jaxlib/mlir:pass_manager", "//jaxlib/mlir:scf_dialect", "//jaxlib/mlir:vector_dialect", - "//third_party/py/numpy", - ], + ] + py_deps("absl/flags") + py_deps("numpy"), ) pytype_strict_library( @@ -779,6 +783,7 @@ pytype_strict_library( name = "sharding", srcs = ["_src/sharding.py"], deps = [ + ":op_shardings", ":util", ":xla_bridge", "//jax/_src/lib", @@ -806,6 +811,7 @@ pytype_strict_library( srcs = ["_src/sharding_impls.py"], deps = [ ":config", + ":core", ":mesh", ":op_shardings", ":partition_spec", @@ -957,7 +963,7 @@ pytype_strict_library( ":traceback_util", ":util", "//jax/_src/lib", - ] + py_deps("importlib_metadata"), + ], ) # Public JAX libraries below this point. diff --git a/jax/__init__.py b/jax/__init__.py index d7b4479e2d3f..df96b98cd4ba 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -124,7 +124,7 @@ from jax._src.api import value_and_grad as value_and_grad from jax._src.api import vjp as vjp from jax._src.api import vmap as vmap -from jax._src.api import xla_computation as xla_computation +from jax._src.api import xla_computation as _deprecated_xla_computation from jax._src.sharding_impls import NamedSharding as NamedSharding # Force import, allowing jax.interpreters.* to be used after import jax. @@ -158,7 +158,6 @@ from jax import dlpack as dlpack from jax import dtypes as dtypes from jax import errors as errors -from jax import ffi as ffi from jax import image as image from jax import lax as lax from jax import monitoring as monitoring @@ -181,6 +180,8 @@ from jax._src.deprecations import register as _register_deprecation _register_deprecation("jax-experimental-maps-module") +_register_deprecation('jax-scipy-beta-args') +_register_deprecation('tracer-hash') del _register_deprecation _deprecations = { @@ -225,11 +226,16 @@ "jax.clear_backends is deprecated.", _deprecated_clear_backends ), + "xla_computation": ( + "jax.xla_computation is deprecated. Please use the AOT APIs.", + _deprecated_xla_computation + ), } import typing as _typing if _typing.TYPE_CHECKING: from jax._src.api import clear_backends as clear_backends + from jax._src.api import xla_computation as xla_computation from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_leaves as tree_leaves diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index d7a86687cacb..4b52292f0e2d 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -14,11 +14,11 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools from functools import partial import logging -from typing import Any, Callable +from typing import Any import types import numpy as np diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 257fcc7c527c..90ae6c1413ec 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -13,8 +13,9 @@ # limitations under the License. from __future__ import annotations +from collections.abc import Callable import types -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar from jax._src import core from jax._src import traceback_util diff --git a/jax/_src/api.py b/jax/_src/api.py index ec08c5c1288e..4a42693c2e8f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -23,12 +23,12 @@ from __future__ import annotations import collections -from collections.abc import Generator, Hashable, Iterable, Sequence +from collections.abc import Callable, Generator, Hashable, Iterable, Sequence from functools import partial, lru_cache import inspect import math import typing -from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload, +from typing import (Any, Literal, NamedTuple, TypeVar, overload, cast) import weakref @@ -67,12 +67,12 @@ from jax._src.lib import xla_client as xc from jax._src.lib import pmap_lib from jax._src.sharding import Sharding -from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind, - XLACompatibleSharding) +from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind from jax._src.layout import Layout, AutoLayout from jax._src.traceback_util import api_boundary from jax._src import tree_util -from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps +from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps, + split_list) from jax._src import util from jax._src.interpreters import ad @@ -1807,60 +1807,55 @@ def cache_miss(*args, **kwargs): cpp_mapped_f = pmap_lib.pmap( fun, cache_miss, static_broadcasted_tuple, - pxla.shard_arg, pytree_registry=tree_util.default_registry) + lambda x, s: pxla.shard_args([s], [x])[0], + pytree_registry=tree_util.default_registry) _pmap_cache_clears.add(cpp_mapped_f) pmap_f = wraps(fun)(cpp_mapped_f) - pmap_f.lower = _pmap_lower( - fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices, - backend, axis_size, donate_tuple) - - return pmap_f - -_pmap_cache_clears = weakref.WeakSet() # type: ignore - + @api_boundary + def lower(*args, **kwargs): + return trace(*args, **kwargs).lower() -def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, - devices, backend, axis_size, donate_tuple): # noqa: F811 - """Make a ``lower`` method for pmapped functions.""" - # If the function we returned from ``pmap`` were a class instance, - # this might naturally be a method, with ``fun`` as a ``self`` and - # all the other arguments stored as attributes. @api_boundary - def lower(*args, **kwargs) -> stages.Lowered: - """Lower a parallel-mapped form of this function for the given arguments. - - A parallel-mapped and lowered function is staged out of Python and - translated to a compiler's input language, possibly in a - backend-dependent manner. It is ready for compilation but is not yet - compiled. It represents a function intended for SPMD execution on - multiple devices. - - Returns: - A ``Lowered`` instance representing the post-map lowering. - """ - lowering_parameters = kwargs.pop( - '_experimental_lowering_parameters', mlir.LoweringParameters()) + def trace(*args, **kwargs): p = _prepare_pmap( fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, devices, backend, axis_size, args, kwargs) abstract_args = list(map(shaped_abstractify, p.flat_args)) - computation = pxla.lower_parallel_callable( + closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr( p.flat_fun, backend, axis_name, axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, devices=p.devices, name=p.flat_fun.__name__, in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk, + avals=abstract_args) + lower_callable = partial( + pxla.lower_parallel_callable, p.flat_fun, axis_name, + axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, + devices=p.devices, + name=p.flat_fun.__name__, + in_axes=p.in_axes_flat, donated_invars=p.donated_invars, is_explicit_global_axis_size=p.is_explicit_global_axis_size, avals=abstract_args, - lowering_parameters=lowering_parameters) - return stages.Lowered.from_flat_info( - computation, p.in_tree, abstract_args, donate_tuple, p.out_tree()) + closed_jaxpr=closed_jaxpr, + backend=xc_backend, + replicas=replicas, + shards=shards, + pci=pci) + args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple) + return stages.Traced(closed_jaxpr, args_info, p.flat_fun.__name__, + p.out_tree(), lower_callable) + + pmap_f.lower = lower + pmap_f.trace = trace + + return pmap_f + +_pmap_cache_clears = weakref.WeakSet() # type: ignore - return lower def jvp( fun: Callable, primals, tangents, has_aux: bool = False @@ -1918,7 +1913,7 @@ def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False): raise TypeError("primal and tangent arguments to jax.jvp must have the same tree " f"structure; primals have tree structure {tree_def} whereas tangents have " f"tree structure {tree_def_2}.") - for p, t in safe_zip(ps_flat, ts_flat): + for p, t in zip(ps_flat, ts_flat): if core.primal_dtype_to_tangent_dtype(_dtype(p)) != _dtype(t): raise TypeError("primal and tangent arguments to jax.jvp do not match; " "dtypes must be equal, or in case of int/bool primal dtype " @@ -2064,8 +2059,7 @@ def fun(*tangents): return apply_flat_fun_nokwargs(fun, io_tree, py_args) -def _vjp_pullback_wrapper(name, cotangent_dtypes, cotangent_shapes, io_tree, - fun, *py_args_): +def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_): if len(py_args_) != 1: msg = (f"The function returned by `jax.vjp` applied to {name} was called " f"with {len(py_args_)} arguments, but functions returned by " @@ -2091,23 +2085,27 @@ def _vjp_pullback_wrapper(name, cotangent_dtypes, cotangent_shapes, io_tree, in_tree_expected, out_tree = io_tree args, in_tree = tree_flatten(py_args) if in_tree != in_tree_expected: - raise TypeError(f"Tree structure of cotangent input {in_tree}, does not match structure of " - f"primal output {in_tree_expected}.") - for arg, ct_dtype, ct_shape in safe_zip(args, cotangent_dtypes, cotangent_shapes): - expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(_dtype(arg)) - if expected_tangent_dtype != ct_dtype: - raise TypeError( - f"Type of cotangent input to vjp pullback function ({ct_dtype}) is not " - f"the expected tangent type ({expected_tangent_dtype}) of corresponding primal output " - f"with dtype {_dtype(arg)}.") - if np.shape(arg) != ct_shape: + raise ValueError(f"unexpected tree structure of argument to vjp function: " + f"got {in_tree}, but expected to match {in_tree_expected}") + for arg, aval in zip(args, out_primal_avals): + ct_aval = shaped_abstractify(arg) + ct_aval_expected = aval.at_least_vspace() + if (not core.typecompat(ct_aval, ct_aval_expected) and + not _temporary_dtype_exception(ct_aval, ct_aval_expected)): raise ValueError( - f"Shape of cotangent input to vjp pullback function {np.shape(arg)} " - "must be the same as the shape of corresponding primal input " - f"{ct_shape}.") + "unexpected JAX type (e.g. shape/dtype) for argument to vjp function: " + f"got {ct_aval.str_short()}, but expected {ct_aval_expected.str_short()} " + f"because the corresponding output of the function {name} had JAX type " + f"{aval.str_short()}") ans = fun(*args) return tree_unflatten(out_tree, ans) +# TODO(mattjj): see similar function in custom_derivatives.py +def _temporary_dtype_exception(a, a_) -> bool: + if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray): + return a.shape == a_.shape and a_.dtype == float0 + return False + @overload def vjp(fun: Callable[..., T], *primals: Any, @@ -2175,21 +2173,16 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): for arg in primals_flat: dispatch.check_arg(arg) if not has_aux: flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) - out_primal, out_vjp = ad.vjp(flat_fun, primals_flat) + out_primals, vjp = ad.vjp(flat_fun, primals_flat) out_tree = out_tree() else: flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree) - out_primal, out_vjp, aux = ad.vjp( - flat_fun, primals_flat, has_aux=True) + out_primals, vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True) out_tree, aux_tree = out_aux_trees() - out_primal_py = tree_unflatten(out_tree, out_primal) - ct_dtypes = [core.primal_dtype_to_tangent_dtype(_dtype(x)) for x in out_primal] - ct_shapes = [np.shape(x) for x in out_primal] - # Ensure that vjp_py is a PyTree so that we can pass it from the forward to the - # backward pass in a custom VJP. + out_primal_avals = map(shaped_abstractify, out_primals) + out_primal_py = tree_unflatten(out_tree, out_primals) vjp_py = Partial(partial(_vjp_pullback_wrapper, fun.__name__, - ct_dtypes, ct_shapes, (out_tree, in_tree)), - out_vjp) + out_primal_avals, (out_tree, in_tree)), vjp) if not has_aux: return out_primal_py, vjp_py else: @@ -2364,43 +2357,34 @@ def make_jaxpr(fun: Callable, g:f32[] = mul f c in (g,) } """ - check_callable(fun) - static_argnums = _ensure_index_tuple(static_argnums) - - def abstractify(args, kwargs): - flat_args, in_tree = tree_flatten((args, kwargs)) - if abstracted_axes is None: - return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args) - else: - axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs) - in_type = pe.infer_lambda_input_type(axes_specs, flat_args) - in_avals, keep_inputs = unzip2(in_type) - return in_avals, in_tree, keep_inputs + try: + hash(fun) + weakref.ref(fun) + except TypeError: + fun = partial(fun) @wraps(fun) @api_boundary def make_jaxpr_f(*args, **kwargs): - f = lu.wrap_init(fun) - if static_argnums: - dyn_argnums = [i for i in range(len(args)) if i not in static_argnums] - f, args = argnums_partial(f, dyn_argnums, args) - in_avals, in_tree, keep_inputs = abstractify(args, kwargs) - in_type = tuple(zip(in_avals, keep_inputs)) - f, out_tree = flatten_fun(f, in_tree) - f = lu.annotate(f, in_type) - debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr') with ExitStack() as stack: for axis_name, size in axis_env or []: stack.enter_context(core.extend_axis_env(axis_name, size, None)) - jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2( - f, debug_info=debug_info) - closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + traced = jit(fun, static_argnums=static_argnums, + abstracted_axes=abstracted_axes).trace(*args, **kwargs) + # `jit` converts tracers in consts to args but that breaks the semantics of + # `make_jaxpr`. Hence convert the tracers in args back to consts in jaxpr. + if traced._num_consts: + consts, _ = split_list(traced._args_flat, [traced._num_consts]) + jaxpr_ = pe.convert_invars_to_constvars(traced.jaxpr.jaxpr, + traced._num_consts) + jaxpr = core.ClosedJaxpr(jaxpr_, consts) + else: + jaxpr = traced.jaxpr if return_shape: - out_avals, _ = unzip2(out_type) - out_shapes_flat = [ - ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals] - return closed_jaxpr, tree_unflatten(out_tree(), out_shapes_flat) - return closed_jaxpr + out = [ShapeDtypeStruct(o.shape, o.dtype, getattr(o, 'named_shape', None)) + for o in jaxpr.out_avals] + return jaxpr, tree_unflatten(tree_structure(traced.out_info), out) + return jaxpr make_jaxpr_f.__module__ = "jax" if hasattr(fun, "__qualname__"): @@ -2422,14 +2406,20 @@ def _infer_src_sharding(src, x) -> Sharding | None: return None -# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that -# to check if shardings are compatible with the input. +# TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use +# that to check if shardings are compatible with the input. @lru_cache(maxsize=2048) def _check_sharding(aval, s): + if (s is not None and + not isinstance(s, (xc.Device, Sharding, Layout, TransferToMemoryKind))): + raise ValueError( + "`jax.device_put` only accepts `None`, `jax.sharding.Sharding`," + " `jax.Device`, `Layout` or a pytree of these values. Received" + f" invalid value: {s}") if isinstance(s, Sharding): if isinstance(aval, core.AbstractToken): aval = core.token_shaped_array - if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding): + if not isinstance(s, PmapSharding): pjit.pjit_check_aval_sharding( (s,), (aval,), None, "device_put args", allow_uneven_sharding=False) s.shard_shape(aval.shape) # should raise an Error if incompatible @@ -2462,24 +2452,25 @@ def device_put( blocking the calling Python thread until any transfers are completed. """ with config.explicit_device_put_scope(): - if ((device is None or - isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and - (src is None or - isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))): - def _map(y): - _check_sharding(shaped_abstractify(y), s=device) - return dispatch.device_put_p.bind( - y, device=device, src=_infer_src_sharding(src, y)) - return tree_map(_map, x) - x_flat, treedef = tree_flatten(x) - device_flat = flatten_axes("device_put device", treedef, device) - src_flat = flatten_axes("device_put source", treedef, src) - out_flat = [] - for xf, d, s in zip(x_flat, device_flat, src_flat): + if (device is None or + isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))): + device_flat = [device] * len(x_flat) + else: + device_flat = flatten_axes("device_put device", treedef, device) + + if (src is None or + isinstance(src, (xc.Device, Sharding, TransferToMemoryKind))): + src_flat = [_infer_src_sharding(src, xf) for xf in x_flat] + else: + src_flat = flatten_axes("device_put source", treedef, src) + src_flat = list(map(_infer_src_sharding, src_flat, x_flat)) + + for xf, d in zip(x_flat, device_flat): _check_sharding(shaped_abstractify(xf), d) - out_flat.append(dispatch.device_put_p.bind( - xf, device=d, src=_infer_src_sharding(s, xf))) + out_flat = dispatch.device_put_p.bind( + *x_flat, devices=device_flat, srcs=src_flat + ) return tree_unflatten(treedef, out_flat) @@ -2954,6 +2945,7 @@ def clear_backends(): xb.local_devices.cache_clear() xb.process_count.cache_clear() dispatch.xla_primitive_callable.cache_clear() + pjit._infer_params_cached.cache_clear() pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error pjit._cpp_pjit_cache.clear() @@ -2972,13 +2964,14 @@ def clear_caches(): This doesn't clear the persistent cache; to disable it (e.g. for benchmarks), set the jax_enable_compilation_cache config option to False. """ - # Clear all lu.cache and util.weakref_lru_cache instances (used for staging - # and Python-dispatch compiled executable caches). - lu.clear_all_caches() + # Clear all lu.cache, util.cache and util.weakref_lru_cache instances + # (used for staging and Python-dispatch compiled executable caches). + util.clear_all_caches() util.clear_all_weakref_lru_caches() # Clear all C++ compiled executable caches for pjit pjit._cpp_pjit_cache.clear() + pjit._infer_params_cached.cache_clear() xc._xla.PjitFunctionCache.clear_all() # Clear all C++ compiled executable caches for pmap diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 67a8ee8bdfac..16a29e699bbc 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -14,11 +14,11 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import inspect import operator from functools import partial, lru_cache -from typing import Any, Callable, Type +from typing import Any import numpy as np @@ -713,6 +713,6 @@ def __hash__(self): def __eq__(self, other): return self.val is other.val -def register_class_with_attrs(t: Type) -> None: +def register_class_with_attrs(t: type) -> None: _class_with_attrs.add(t) -_class_with_attrs: set[Type] = set() +_class_with_attrs: set[type] = set() diff --git a/jax/_src/array.py b/jax/_src/array.py index 2a265f58278b..6e3f0a76f512 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -15,12 +15,12 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence import enum import functools import math import operator as op -from typing import Any, Callable, TYPE_CHECKING, cast +from typing import Any, TYPE_CHECKING, cast from jax._src import abstract_arrays from jax._src import api @@ -42,10 +42,10 @@ from jax._src.lib import xla_extension as xe from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - PmapSharding, SingleDeviceSharding, XLACompatibleSharding, - device_replica_id_map, hashed_index, num_addressable_indices) # pyformat: disable + PmapSharding, SingleDeviceSharding, + device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType -from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method +from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache import numpy as np @@ -120,7 +120,7 @@ def _reconstruct_array(fun, args, arr_state, aval_state): return jnp_value -@functools.lru_cache(maxsize=4096) +@cache(max_size=4096, trace_context_in_key=False) def _cached_index_calc(s, shape): map_ = s.addressable_devices_indices_map(shape) seen_h_indices = set() @@ -133,7 +133,7 @@ def _cached_index_calc(s, shape): return l -@functools.lru_cache(maxsize=4096) +@cache(max_size=4096, trace_context_in_key=False) def _process_has_full_value_in_mcjax(s, shape): # Return False for single host as a fast path. if xla_bridge.process_count() == 1: @@ -218,9 +218,8 @@ def _check_and_rearrange(self): f"{self.aval.str_short()} with sharding {self.sharding}") # Rearrange arrays based on the device assignment. - if isinstance(self.sharding, XLACompatibleSharding): - addressable_da = self.sharding._addressable_device_assignment - self._arrays = [device_id_to_buffer[device.id] for device in addressable_da] + addressable_da = self.sharding._addressable_device_assignment + self._arrays = [device_id_to_buffer[device.id] for device in addressable_da] @property def shape(self) -> Shape: @@ -654,7 +653,7 @@ def make_array_from_callback( Returns: A ``jax.Array`` via data fetched from ``data_callback``. - Example: + Examples: >>> import math >>> from jax.sharding import Mesh @@ -735,7 +734,7 @@ def make_array_from_callback( def make_array_from_process_local_data( sharding: Sharding, local_data: np.ndarray, - global_shape: tuple[int, ...], + global_shape: Shape | None = None, ) -> ArrayImpl: # pyformat: disable """Creates distributed tensor using the data available in process. @@ -744,26 +743,38 @@ def make_array_from_process_local_data( assumes that the data is available in the process and takes care of the index wrangling. - Note, if the two hosts are replicas, host_local_data should be identical as - well. - Each dimension of the shape of host_local_data should either match - global_shape or the # indices the devices on this process need to - address. For example if dimension $i$ is fully sharded then this size would be - `per_device_shape[i] * jax.local_device_count()`. + The most common case is when the sharding is sharded across the batch + dimension and each host just loads its corresponding sub-batch. This function + supports more general cases as well, such as mixed multi-host and multi-axis + replication and sharding but you would need to compute the size and the + contents of process-local data correctly to satisfy the sharding constraints. + + In particular, if any two hosts are replicas, host_local_data should be + identical as well. - If the shape matches global shape, each device slice will just lookup - the slice in the local_data. In the latter case the global slice of each - device will be mapped into local slice of `local_data` array. For example, - if given process only addresses slices (8, 12) and (24, 28), then - these slices will be mapped into (0, 4) and (4, 8) of the `local_data`. + The global_shape is optional. If not provided it will be be inferred from + the local_data and sharding, under the assumption that + each host represents only their own data for uniform sharding. If sharding + is non-uniform, (see note below) an exception will be raised. - This function can be used to create tensors from dataset feeding pipelines. + Setting global_shape explicitly allows for finer grain control and works with + non-uniform shardings. Each dimension of global_shape must either match + host_local_data, or match the inferred global shape of the sharding (in which + case it is equivalent to setting it to None, but is more explicit). - The most common case is when the sharding is fully sharded across the batch - dimension and each host just loads its corresponding sub-batch. This function - supports more general case as well, such as multi-host replication - but you would need to compute the size and the contents of process-local data - correctly to satisfy the replication constraints. + For example if dimension `i` is fully sharded then this size would be + `per_device_shape[i] * jax.local_device_count()`. Each device will be mapped + into local slice of `local_data` array. For example, if given process + addresses slices (8, 12) and (24, 28), then these slices will be mapped + into (0, 4) and (4, 8) of the `local_data`. + + For each dimension where global_shapes matches local_shape, each device + will lookup the slice in the local_data. For example if + global_shape == local_data.shape, the local data is assumed to be the + actual target array that will be sharded into device. + + If global_shape is the same as local_data.shape, then the data must + be the same across all hosts. Examples: >>> from jax.sharding import PartitionSpec as P @@ -785,19 +796,71 @@ def make_array_from_process_local_data( >>> assert output_global_array.addressable_data(0).shape == per_device_shape >>> assert output_global_array.shape == global_shape + NB: While most shardings are uniform, It is possible to design am exotic + sharding mesh where each process's devices will be arranged in a non-grid + like pattern in some dimensions, or for indices to overlap non-trivially. + Such sharding is called "non-uniform" in those dimensions. In that case, + the global shape along those directions must match local shape as there is + no meaningful way to represent all needed + per-process data in non-overlapping fashion. For example for global_shape 4x4 + if sharding looks like this: + + 0123 + 2103 + 4675 + 4567 + + with 4 processes, containing devices (0,1), (2, 3), (4, 5), (6, 7) respectively. + Then the data for each host look like + + xx.. ..xx .... .... + .xx. x..x .... .... + .... .... x..x .xx. + .... .... xx.. ..xx + + the sharding is uniform on rows (each host requires either rows 1-2, or rows 3-4) + and non-uniform on columns (hosts require overlapping but not matching + set of columns). Thus local data must have the shape 2x4 or 4x4 + for all hosts, even though each host can potentially fit into 2x2 shape. + In this case user must provide global_shape explicitly and for + local_shape=(2, 4), potentially valid global shapes are (2, 4) and (4, 4). + + On the other hand for sharding: + + 0213 x.x. .x.x. .... .... + 0213 x.x. .x.x. .... .... + 4657 .... .... .x.x x.x. + 4657 .... .... .x.x x.x. + + for local_shape=(2, 2) this function can accept a choice of 2x2, 2x4, 4x2 + and 4x4 global shapes. Setting global_shape to None, is equivalent to + setting it to (4, 4) in this case. + Args: sharding: sharding of the global tensor. host_local_data: data on the host to be placed on local devices. Each dimension should either match global_shape, or match num_addressable_indices(dim). - global_shape: the target shape of the global tensor. In some cases this - parameter can be inferred from sharding and host_local_data, however it is - useful to catch common sharding errors. + global_shape: the target shape of the global tensor. If None, + will infer from host_local_data and sharding. Returns: - Tensor that will have sharding=sharding. + Tensor that will have sharding=sharding and of shape global_shape. """ # pyformat: enable + # TODO(sandler): consider supporting partially specified global_shape or + # making local_to_global_shape available in the api. + local_shape = local_data.shape + if global_shape is None: + global_shape = local_to_global_shape(sharding, local_shape) # type: ignore[assignment] + assert global_shape is not None + if None in global_shape: + raise ValueError( + "Unable to compute global_shape due to non-uniform sharding." + f" Specify global shape directly. Partially computed {global_shape=}." + ) + elif None in global_shape: + raise ValueError(f"{global_shape=} has Nones. This is not supported.") shard_shape = sharding.shard_shape(global_shape) full_dim = [] for i, (data_dim, global_dim) in enumerate( @@ -1005,7 +1068,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): if not candidates_list: # This array isn't sharded correctly. Reshard it via host roundtrip. # TODO(skye): more efficient reshard? - return pxla.shard_arg(x._value, sharding, canonicalize=False) + return pxla.shard_args([sharding], [x._value], canonicalize=False)[0] # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. for buf in candidates_list: @@ -1018,32 +1081,51 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): return pxla.batched_device_put(x.aval, sharding, bufs, devices) -@functools.lru_cache(maxsize=4096) +@cache(max_size=4096, trace_context_in_key=False) def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): src_indices = src_sharding.addressable_devices_indices_map(shape).values() dst_indices = dst_sharding.addressable_devices_indices_map(shape).values() return dst_indices, tuple(src_indices) == tuple(dst_indices) -def _array_shard_arg(x, sharding): - x._check_if_deleted() +def _array_shard_arg(xs, shardings): + results = [] + batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], [] + for i, (x, sharding) in enumerate(safe_zip(xs, shardings)): + x._check_if_deleted() - indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) - if not x.is_fully_addressable: - if same_indices: - return x - else: - raise NotImplementedError( - "Cannot reshard an input that is not fully addressable") - else: - devices = sharding._addressable_device_assignment - if same_indices: - return xc.copy_array_to_devices_with_sharding(x, list(devices), sharding) - # Resharding starts here: - if dispatch.is_single_device_sharding(x.sharding): - return shard_device_array(x, devices, indices, sharding) + indices, same_indices = _sharding_indices_and_eq( + x.sharding, x.shape, sharding) + if not x.is_fully_addressable: + if same_indices: + results.append(x) + else: + raise NotImplementedError( + "Cannot reshard an input that is not fully addressable") else: - return shard_sharded_device_array_slow_path(x, devices, indices, sharding) + devices = sharding._addressable_device_assignment + if same_indices: + # Add a placeholder result that will be filled in later. + results.append(None) + # Accumulate arguments to `batched_copy_array_to_devices_with_sharding`. + batch_xs.append(x) + batch_devs.append(list(devices)) + batch_shardings.append(sharding) + batch_indices.append(i) + # Resharding starts here: + elif dispatch.is_single_device_sharding(x.sharding): + results.append(shard_device_array(x, devices, indices, sharding)) + else: + results.append( + shard_sharded_device_array_slow_path(x, devices, indices, sharding)) + + copy_outs = xc.batched_copy_array_to_devices_with_sharding( + batch_xs, batch_devs, batch_shardings) + for i, copy_out in safe_zip(batch_indices, copy_outs): + assert results[i] is None + results[i] = copy_out + return results + pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg @@ -1076,8 +1158,8 @@ def _array_local_result_handler(aval, sharding, indices): # Token handlers -def _token_shard_arg(x, sharding): - return _array_shard_arg(x._buf, sharding) +def _token_shard_arg(xs, shardings): + return _array_shard_arg([x._buf for x in xs], shardings) pxla.shard_arg_handlers[core.Token] = _token_shard_arg diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 6f37c16e6715..fbdd4894843e 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Callable, Sequence +from typing import Any, Union import numpy as np from jax._src.sharding import Sharding @@ -113,24 +114,24 @@ class Array(abc.ABC): def __release_buffer__(self, view: memoryview) -> None: ... # np.ndarray methods: - def all(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, - keepdims=None, *, where: Optional[ArrayLike] = ...) -> Array: ... - def any(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, - keepdims=None, *, where: Optional[ArrayLike] = ...) -> Array: ... - def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ... - def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ... + def all(self, axis: int | Sequence[int] | None = None, out=None, + keepdims=None, *, where: ArrayLike | None = ...) -> Array: ... + def any(self, axis: int | Sequence[int] | None = None, out=None, + keepdims=None, *, where: ArrayLike | None = ...) -> Array: ... + def argmax(self, axis: int | None = None, out=None, keepdims=None) -> Array: ... + def argmin(self, axis: int | None = None, out=None, keepdims=None) -> Array: ... def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ... - def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ... + def argsort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ... def astype(self, dtype) -> Array: ... def choose(self, choices, out=None, mode='raise') -> Array: ... def clip(self, min=None, max=None, out=None) -> Array: ... - def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ... + def compress(self, condition, axis: int | None = None, out=None) -> Array: ... def conj(self) -> Array: ... def conjugate(self) -> Array: ... def copy(self) -> Array: ... - def cumprod(self, axis: Optional[Union[int, Sequence[int]]] = None, + def cumprod(self, axis: int | Sequence[int] | None = None, dtype=None, out=None) -> Array: ... - def cumsum(self, axis: Optional[Union[int, Sequence[int]]] = None, + def cumsum(self, axis: int | Sequence[int] | None = None, dtype=None, out=None) -> Array: ... def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ... def dot(self, b, *, precision=None) -> Array: ... @@ -138,35 +139,35 @@ class Array(abc.ABC): @property def imag(self) -> Array: ... def item(self, *args) -> Any: ... - def max(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, + def max(self, axis: int | Sequence[int] | None = None, out=None, keepdims=None, initial=None, where=None) -> Array: ... - def mean(self, axis: Optional[Union[int, Sequence[int]]] = None, dtype=None, + def mean(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, keepdims=False, *, where=None,) -> Array: ... - def min(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, + def min(self, axis: int | Sequence[int] | None = None, out=None, keepdims=None, initial=None, where=None) -> Array: ... @property def nbytes(self) -> int: ... def nonzero(self, *, size=None, fill_value=None) -> Array: ... - def prod(self, axis: Optional[Union[int, Sequence[int]]] = None, dtype=None, + def prod(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, keepdims=None, initial=None, where=None) -> Array: ... - def ptp(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, + def ptp(self, axis: int | Sequence[int] | None = None, out=None, keepdims=False,) -> Array: ... def ravel(self, order='C') -> Array: ... @property def real(self) -> Array: ... - def repeat(self, repeats, axis: Optional[int] = None, *, + def repeat(self, repeats, axis: int | None = None, *, total_repeat_length=None) -> Array: ... def reshape(self, *args, order='C') -> Array: ... def round(self, decimals=0, out=None) -> Array: ... def searchsorted(self, v, side='left', sorter=None) -> Array: ... - def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ... - def squeeze(self, axis: Optional[Union[int, Sequence[int]]] = None) -> Array: ... - def std(self, axis: Optional[Union[int, Sequence[int]]] = None, + def sort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ... + def squeeze(self, axis: int | Sequence[int] | None = None) -> Array: ... + def std(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ... - def sum(self, axis: Optional[Union[int, Sequence[int]]] = None, dtype=None, + def sum(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, keepdims=None, initial=None, where=None) -> Array: ... def swapaxes(self, axis1: int, axis2: int) -> Array: ... - def take(self, indices, axis: Optional[int] = None, out=None, + def take(self, indices, axis: int | None = None, out=None, mode=None) -> Array: ... def tobytes(self, order='C') -> bytes: ... def tolist(self) -> list[Any]: ... @@ -177,15 +178,15 @@ class Array(abc.ABC): def T(self) -> Array: ... @property def mT(self) -> Array: ... - def var(self, axis: Optional[Union[int, Sequence[int]]] = None, + def var(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ... def view(self, dtype=None, type=None) -> Array: ... # Even though we don't always support the NumPy array protocol, e.g., for # tracer types, for type checking purposes we must declare support so we # implement the NumPy ArrayLike protocol. - def __array__(self, dtype: Optional[np.dtype] = ..., - copy: Optional[bool] = ...) -> np.ndarray: ... + def __array__(self, dtype: np.dtype | None = ..., + copy: bool | None = ...) -> np.ndarray: ... def __dlpack__(self) -> Any: ... # JAX extensions @@ -237,23 +238,23 @@ class _IndexUpdateHelper: class _IndexUpdateRef: def get(self, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[str] = None, fill_value: Optional[StaticScalar] = None) -> Array: ... + mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... def set(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[str] = None, fill_value: Optional[StaticScalar] = None) -> Array: ... + mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... def add(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def mul(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def multiply(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def divide(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def power(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def min(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def max(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def apply(self, func: Callable[[ArrayLike], ArrayLike], indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py new file mode 100644 index 000000000000..16da61d75b3f --- /dev/null +++ b/jax/_src/blocked_sampler.py @@ -0,0 +1,165 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from typing import Any, Protocol +import jax +from jax._src import random +from jax._src.typing import Array, ArrayLike +from jax import numpy as jnp + +NdKeyList = Any +Shape = random.Shape + +class SampleFn(Protocol): + def __call__(self, key: random.KeyArrayLike, *args, shape: Shape, + **kwargs) -> Array: + ... + + +def _compute_scalar_index(iteration_index: Sequence[int], + total_size: Shape, + block_size: Shape, + block_index: Sequence[int]) -> int: + ndims = len(iteration_index) + dim_size = 1 + total_idx = 0 + for i in range(ndims-1, -1, -1): + dim_idx = block_index[i] + iteration_index[i] * block_size[i] + total_idx += dim_idx * dim_size + dim_size *= total_size[i] + return total_idx + + +def blocked_fold_in( + global_key: random.KeyArrayLike, + total_size: Shape, + block_size: Shape, + tile_size: Shape, + block_index: Sequence[ArrayLike], + ) -> NdKeyList: + """Computes a grid of keys for block-invariant sampling. + + Suppose we wished to construct a 16x512 array of random numbers, using + block sizes of 16x128 and 16x256. We could select an tile size of 8x128 + (which divides both 16x128 and 16x256) and divide the total array in tiles as: + --------------------------------- + | 8x128 | 8x128 | 8x128 | 8x128 | + --------------------------------- + | 8x128 | 8x128 | 8x128 | 8x128 | + --------------------------------- + + We generate a key for each tile as: + tile_key = fold_in(global_key, tile_idx) + + Where the tile_idx is the row-major raveled index of each element: + ----------------- + | 0 | 1 | 2 | 3 | + ----------------- + | 4 | 5 | 6 | 7 | + ----------------- + + We then compute and return the keys required to sample the tiles that make + up the current block (specified via `block_index`). + With a 16x256 block size, each block requires 4 (2x2) tile keys: + --------------- + | 0, 1 | 2, 3 | + | 4, 5 | 6, 7 | + --------------- + Therefore, we return a grid of 2x2 keys for each block (2 blocks total). + + With a 16x128 block size, each block requires 2 (2x1) tile keys: + ----------------- + | 0 | 1 | 2 | 3 | + | 4 | 5 | 6 | 7 | + ----------------- + Therefore, we return a grid of 2x1 keys for each block (4 blocks total). + + Args: + global_key: The global key shared between all blocks. + total_size: The shape of the array being generated. + block_size: The shape of an individual block. + tile_size: The shape of a `tile`, which is the smallest unit at + which samples are generated. This should be selected to be a divisor + of all block sizes one needs to be invariant to. + block_index: The index denoting which block to generate keys for. + + Returns: + An N-dimensional nested list of keys required to sample the tiles + corresponding to the block specified by `block_index`. + """ + size_in_blocks = tuple( + _shape // _element for _shape, _element in zip(block_size, tile_size)) + + def _keygen_loop(axis, prefix): + if axis == len(size_in_blocks): + subtile_key = jax.random.fold_in( + global_key, _compute_scalar_index( + block_index, total_size, size_in_blocks, prefix)) + return subtile_key + else: + keys = [] + for i in range(size_in_blocks[axis]): + keys.append(_keygen_loop(axis+1, prefix+(i,))) + return keys + return _keygen_loop(0, tuple()) + + +def sample_block( + sampler_fn: SampleFn, + keys: NdKeyList, + block_size: Shape, + tile_size: Shape, + *args, + **kwargs + ) -> jax.Array: + """Draws random samples for a single block. + + This function is intended to be used in conjunction with `blocked_fold_in`: + ``` + key_list = blocked_fold_in(global_key, total_size, block_size, tile_size, + block_index) + samples = sample_block(jax.random.uniform, key_list, block_size, tile_size) + ``` + + Args: + sampler_fn: A random sampling function, e.g. jax.random.uniform. + keys: A grid of keys generated by `blocked_fold_in`. + block_size: The shape of an individual block. + tile_size: The shape of a `tile`, which is the smallest unit at + which samples are generated. This should be selected to be a divisor + of all block sizes one needs to be invariant to. + args: varargs for sampler_fn. + kwargs: kwargs for sampler_fn. + + Returns: + An array of random samples drawn using sampler_fn. + """ + size_in_tiles = tuple( + _shape // _element for _shape, _element in zip(block_size, tile_size)) + def _nested_index(arr: jax.Array, idx: Sequence[int]) -> jax.Array: + if len(idx) == 1: + return arr[idx[0]] + return _nested_index(arr[idx[0]], idx[1:]) + + def _sample_loop(axis: int, prefix: tuple[int, ...]) -> jax.Array: + if axis == len(size_in_tiles): + return sampler_fn(_nested_index(keys, prefix), *args, + shape=tile_size, **kwargs) + else: + samples = [] + for i in range(size_in_tiles[axis]): + samples.append(_sample_loop(axis+1, prefix+(i,))) + return jnp.concatenate(samples, axis=axis) + return _sample_loop(0, tuple()) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 8e9ab0b62593..054804379043 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -14,11 +14,11 @@ """Module for JAX callbacks.""" from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools import logging -from typing import Any, Callable +from typing import Any import jax from jax._src import core diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 3727c8364a12..9bbaa1296c93 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools import itertools as it -from typing import Callable, TypeVar, Any, Union +from typing import TypeVar, Any, Union import numpy as np diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 71827d8f8a9f..73e61de68008 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -34,6 +34,18 @@ def maybe_import_libtpu(): return libtpu +def get_tpu_library_path() -> str | None: + path_from_env = os.getenv("TPU_LIBRARY_PATH") + if path_from_env is not None and os.path.isfile(path_from_env): + return path_from_env + + libtpu_module = maybe_import_libtpu() + if libtpu_module is not None: + return libtpu_module.get_library_path() + + return None + + def jax_force_tpu_init() -> bool: return 'JAX_FORCE_TPU_INIT' in os.environ @@ -57,9 +69,9 @@ def cloud_tpu_init() -> None: global running_in_cloud_tpu_vm # Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed. - libtpu_module = maybe_import_libtpu() + libtpu_path = get_tpu_library_path() num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0] - if (libtpu_module is None or num_tpu_chips == 0) and not jax_force_tpu_init(): + if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init(): return running_in_cloud_tpu_vm = True @@ -68,6 +80,7 @@ def cloud_tpu_init() -> None: os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu') os.environ['TPU_ML_PLATFORM'] = 'JAX' os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__ + os.environ['ENABLE_RUNTIME_UPTIME_TELEMETRY'] = '1' if hardware_utils.tpu_enhanced_barrier_supported(): os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true" diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index e60a29e274e9..9c276151741d 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -34,6 +34,7 @@ from jax._src.gfile_cache import GFileCache from jax._src.lib import xla_client from jax._src.lib.mlir import ir +from jax._src.lru_cache import LRUCache logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ # Mutex to protect _cache_initialized and _cache_used. _cache_initialized_mutex = threading.Lock() +_UNSUPPORTED_RUNTIMES: set[str] = set() def set_once_cache_used(f) -> None: """One-time setting of _cache_used. @@ -65,7 +67,20 @@ def set_once_cache_used(f) -> None: def get_file_cache(path: str) -> tuple[CacheInterface, str] | None: """Returns the file cache and the path to the cache.""" - return GFileCache(path), path + + def is_local_filesystem(path: str) -> bool: + return path.startswith("file://") or "://" not in path + + # `LRUCache` currently only supports local filesystem. Therefore, if `path` + # is not on a local filesystem, instead of using `LRUCache`, we + # fallback to the old `GFileCache`, which does not support LRU eviction. + # TODO(ayx): Add cloud storage support for `LRUCache`, so that all these code + # can be removed. + if not is_local_filesystem(path): + return GFileCache(path), path + + max_size = config.compilation_cache_max_size.value + return LRUCache(path, max_size=max_size), path def set_cache_dir(path) -> None: @@ -134,10 +149,13 @@ def _initialize_cache() -> None: logger.debug("Initialized persistent compilation cache at %s", path) -def _get_cache() -> CacheInterface | None: +def _get_cache(backend) -> CacheInterface | None: # TODO(b/289098047): consider making this an API and changing the callers of # get_executable_and_time() and put_executable_and_time() to call get_cache() # and passing the result to them. + if backend.runtime_type in _UNSUPPORTED_RUNTIMES: + logger.debug("_get_cache: Unsupported runtime: %s", backend.runtime_type) + return None if _cache is None: _initialize_cache() # initialization is done at most once; see above return _cache @@ -157,13 +175,25 @@ def decompress_executable(executable): else: return zlib.decompress(executable) + +def is_executable_in_cache(backend, cache_key: str) -> bool: + """Checks if the executable is in the cache.""" + cache = _get_cache(backend) + if cache is None: + return False + + # TODO(patrios): add check cache key method to cache interface. + executable_and_time = cache.get(cache_key) + return executable_and_time is not None + + def get_executable_and_time( cache_key: str, compile_options, backend ) -> tuple[xla_client.LoadedExecutable | None, int | None]: """Returns the cached executable and its compilation time if present, or None otherwise. """ - cache = _get_cache() + cache = _get_cache(backend) if cache is None: logger.debug("get_executable_and_time: cache is disabled/not initialized") return None, None @@ -189,7 +219,7 @@ def put_executable_and_time( """Adds the 'executable' and its compilation time to the cache, possibly evicting older entries. """ - cache = _get_cache() + cache = _get_cache(backend) if cache is None: logger.debug("put_executable_and_time: cache is disabled/not initialized") return diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 7abfb915a9bd..438f1f9e5183 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -37,13 +37,13 @@ import numpy as np -_DISABLE_MOST_OPTIMIZATIONS = config.DEFINE_bool( +_DISABLE_MOST_OPTIMIZATIONS = config.bool_flag( 'jax_disable_most_optimizations', config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False), 'Try not to do much optimization work. This can be useful if the cost of ' 'optimization is greater than that of running a less-optimized program.') -_COMPILER_DETAILED_LOGGING_MIN_OPS = config.DEFINE_integer( +_COMPILER_DETAILED_LOGGING_MIN_OPS = config.int_flag( "jax_compiler_detailed_logging_min_ops", config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10), help=( @@ -243,6 +243,7 @@ def compile_or_get_cached( devices: np.ndarray, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], + pgle_profiler: profiler.PGLEProfiler | None = None, ) -> xc.LoadedExecutable: sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value @@ -278,14 +279,55 @@ def compile_or_get_cached( return backend_compile(backend, computation, compile_options, host_callbacks) + is_multi_process = ( + len({device.process_index for device in devices.flatten()}) > 1) + min_device_process_id = ( + min(devices.flatten(), key=lambda device: device.id).process_index) + + # When PGLE is enabled there might be 3 types of situations: + # 1. PGLE profiled module (the one which was recompiled with FDO profile) is + # in the persistent cache. In this case the module should be returned from + # cache and PGLE should be disabled for this module. Is module is stored in + # the persistent cache under the "pgle_profiled_module_key" which calculated + # with replacing FDO profile with flag which identify that module were PGLE + # profiled. + # 2. PGLE profiled module is not in the persistent cache and the module is + # getting built with an FDO profile. In this case we need to share FDO profile + # with other processes and store the result under the + # "pgle_profiled_module_key" so later in case 1 we will be able to find the + # module. + # 3. PGLE profiled module is not in the persistent cache and the module is + # getting compiled to be PGLEd (FDO profile is empty). In this case we need to + # simply return the non-PGLE profiled module from the persistent cache. + if (config.enable_pgle.value + and config.pgle_profiling_runs.value > 0): + fdo_profile = compile_options.executable_build_options.fdo_profile + compile_options.executable_build_options.fdo_profile = b"pgle profiled" + + pgle_profiled_module_key = compilation_cache.get_cache_key( + computation, devices, compile_options, backend) + compile_options.executable_build_options.fdo_profile = fdo_profile + + if _is_executable_in_cache(backend, pgle_profiled_module_key): + # Load PGLE profiled module from the persistent cache. + cache_key = pgle_profiled_module_key + if pgle_profiler is not None: + pgle_profiler.disable() + elif fdo_profile is not None and len(fdo_profile) > 0: + # Store module under PGLE profiled module cache key. + cache_key = pgle_profiled_module_key + if is_multi_process and distributed.global_state.client is not None: + compile_options.executable_build_options.fdo_profile = _share_fdo_profiles( + computation, devices, compile_options, backend, + distributed.global_state.client, + min_device_process_id + ) + cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( module_name, cache_key, compile_options, backend) cache_retrieval_time = time.monotonic() - cache_retrieval_start - - is_multi_process = ( - len({device.process_index for device in devices.flatten()}) > 1) if retrieved_executable is not None: assert retrieved_compile_time is not None logger.debug("Persistent compilation cache hit for '%s'", module_name) @@ -315,7 +357,7 @@ def compile_or_get_cached( distributed.global_state.client, module_name, cache_key, - min(devices.flatten(), key=lambda device: device.id).process_index + min_device_process_id ) elif ( config.share_autotune_config_between_hosts.value @@ -330,7 +372,7 @@ def compile_or_get_cached( distributed.global_state.client, module_name, cache_key, - min(devices.flatten(), key=lambda device: device.id).process_index + min_device_process_id ) else: return _compile_and_write_cache( @@ -342,6 +384,58 @@ def compile_or_get_cached( cache_key, ) +# The process that has the lowest device ID should share FDO profile before +# compilation with other processes. +def _share_fdo_profiles( + computation: ir.Module, + devices: np.ndarray, + compile_options: xc.CompileOptions, + backend: xc.Client, + global_client: lib.xla_extension.DistributedRuntimeClient, + min_process_id +) -> bytes | None: + sym_name = computation.operation.attributes['sym_name'] + module_name = ir.StringAttr(sym_name).value + fdo_profile = compile_options.executable_build_options.fdo_profile + if fdo_profile is None or len(fdo_profile) == 0: + return fdo_profile + + compile_options.executable_build_options.fdo_profile = b"" + profile_key = ( + compilation_cache.get_cache_key( + computation, devices, compile_options, backend + ) + + "_fdo_sync" + ) + if profile_key in _share_fdo_profiles.modules_profiles: + return _share_fdo_profiles.modules_profiles[profile_key] + + share_timeout = config.share_binary_between_hosts_timeout_ms.value + if distributed.global_state.process_id == min_process_id: + logger.debug( + "Sharing FDO profile: %s. For module %s. Process %d.", + fdo_profile, + module_name, + min_process_id, + ) + global_client.key_value_set_bytes(profile_key, fdo_profile) + else: + logger.debug( + "Waiting for FDO profile: %s. For module %s. Should be set by process %d.", + fdo_profile, + module_name, + min_process_id, + ) + fdo_profile = global_client.blocking_key_value_get_bytes( + profile_key, share_timeout + ) + + _share_fdo_profiles.modules_profiles[profile_key] = fdo_profile + return fdo_profile + + +_share_fdo_profiles.modules_profiles = {} + # The process with the first_process_id should compile the module and write an # autotune config to the K-V storage. @@ -520,6 +614,20 @@ def _compile_and_write_cache( ) return executable +def _is_executable_in_cache(backend, cache_key) -> bool: + """Checks if executable is presented in cache on a given key + """ + try: + return compilation_cache.is_executable_in_cache(backend, cache_key) + except Exception as ex: + if config.raise_persistent_cache_errors.value: + raise + warnings.warn( + f"Error reading persistent compilation cache entry for " + f"'{cache_key}': {type(ex).__name__}: {ex}") + return False + + def _cache_read( module_name: str, cache_key: str, compile_options: xc.CompileOptions, backend: xc.Client diff --git a/jax/_src/config.py b/jax/_src/config.py index 4ee6b16abba1..80557b86e840 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Hashable, Iterator +from collections.abc import Callable, Hashable, Iterator, Sequence import contextlib import functools import itertools @@ -22,7 +22,7 @@ import os import sys import threading -from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast +from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast from jax._src import lib from jax._src.lib import jax_jit @@ -60,24 +60,24 @@ def int_env(varname: str, default: int) -> int: return int(os.getenv(varname, str(default))) -UPGRADE_BOOL_HELP = ( - " This will be enabled by default in future versions of JAX, at which " - "point all uses of the flag will be considered deprecated (following " - "the `API compatibility policy " - "`_).") +class ValueHolder(Protocol[_T]): + """A holder for a configuration value. -UPGRADE_BOOL_EXTRA_DESC = " (transient)" + There are two kinds of value holders: ``Flag``, which is assigned exactly + once and never modified after; and ``State``, which can be changed locally + within a thread via a context manager. + """ + + value: _T + + def _set(self, value: _T) -> None: ... class Config: _HAS_DYNAMIC_ATTRIBUTES = True def __init__(self): - # There are two kinds of value holders: FlagHolders, which hold global - # flags, and StateContextManagers, which hold state that can be changed - # locally within a thread. A value holder needs a `.value` property and a - # `._set()` method. - self._value_holders = {} + self._value_holders: dict[str, ValueHolder] = {} self.meta = {} self.use_absl = False self._contextmanager_flags = set() @@ -113,11 +113,13 @@ def add_option(self, name, holder, opt_type, meta_args, meta_kwargs): def config_with_absl(self): """Registers absl flags for the JAX configs. - E.g., for each JAX config defined using define_bool_state(), this method + E.g., for each JAX config defined using bool_state(), this method registers an absl boolean flag, with the same name. This is the recommended method to call if you use `app.run(main)` and you - need JAX flags. Example: + need JAX flags. + + Examples: ```python from absl import app @@ -217,7 +219,9 @@ def trace_context(): debug_key_reuse.value, jax_xla_profile_version.value, # Technically this affects jaxpr->stablehlo lowering, not tracing. - hlo_source_file_canonicalization_regex.value) + hlo_source_file_canonicalization_regex.value, + pgle_profiling_runs.value, + enable_pgle.value) config = Config() @@ -235,7 +239,8 @@ class _Unset: pass _thread_local_state = threading.local() -class _StateContextManager(Generic[_T]): +class State(Generic[_T]): + __slots__ = ( '_name', '_value', '_update_thread_local_hook', '_update_global_hook', '_validator', '_default_context_manager_value', '__doc__', '__name__', @@ -269,6 +274,8 @@ def __bool__(self) -> NoReturn: type(self).__name__)) def _set(self, value: _T) -> None: + if self._validator: + self._validator(value) self._value = value if self._update_global_hook: self._update_global_hook(value) @@ -316,7 +323,16 @@ def _add_hooks(self, update_global_hook, update_thread_local_hook): update_global_hook(self._value) -def define_bool_state( +UPGRADE_BOOL_HELP = ( + " This will be enabled by default in future versions of JAX, at which " + "point all uses of the flag will be considered deprecated (following " + "the `API compatibility policy " + "`_).") + +UPGRADE_BOOL_EXTRA_DESC = " (transient)" + + +def bool_state( name: str, default: bool, help: str, @@ -325,7 +341,7 @@ def define_bool_state( update_thread_local_hook: Callable[[bool | None], None] | None = None, upgrade: bool = False, extra_description: str = '', -) -> _StateContextManager[bool]: +) -> State[bool]: """Set up thread-local state and return a contextmanager for managing it. This function is a convenience wrapper. It defines a flag, environment @@ -356,9 +372,9 @@ def define_bool_state( Returns: A contextmanager to control the thread-local state value. - Example: + Examples: - enable_foo = config.define_bool_state( + ENABLE_FOO = config.bool_state( name='jax_enable_foo', default=False, help='Enable foo.') @@ -386,7 +402,7 @@ def define_bool_state( extra_description += UPGRADE_BOOL_EXTRA_DESC config._contextmanager_flags.add(name) - s = _StateContextManager[bool]( + s = State[bool]( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, extra_description=extra_description, default_context_manager_value=True) @@ -395,18 +411,18 @@ def define_bool_state( return s -def define_enum_state( +def enum_state( name: str, - enum_values: list[str], + enum_values: Sequence[str], default: str, help: str, *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, -) -> _StateContextManager[str]: +) -> State[str]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -435,7 +451,7 @@ def validator(new_val): raise ValueError(f"new enum value must be in {enum_values}, " f"got {new_val} of type {type(new_val)}.") - s = _StateContextManager[str]( + s = State[str]( name, default, help, @@ -452,18 +468,18 @@ def validator(new_val): return s -def define_optional_enum_state( +def optional_enum_state( name: str, - enum_values: list[str], + enum_values: Sequence[str], default: str | None, help: str, *, update_global_hook: Callable[[str | None], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, -) -> _StateContextManager[str | None]: +) -> State[str | None]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -493,7 +509,7 @@ def validate(new_val): raise ValueError(f"new enum value must be None or in {enum_values}, " f"got {new_val} of type {type(new_val)}.") - s = _StateContextManager['str | None']( + s = State['str | None']( name, default, help, update_global_hook, update_thread_local_hook, validate ) @@ -506,17 +522,17 @@ def validate(new_val): return s -def define_int_state( +def int_state( name: str, default: int, help: str, *, update_global_hook: Callable[[int], None] | None = None, update_thread_local_hook: Callable[[int | None], None] | None = None, -) -> _StateContextManager[int]: +) -> State[int]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -546,24 +562,24 @@ def validate(new_val): raise ValueError(f'new int config value must be None or of type int, ' f'got {new_val} of type {type(new_val)}') - s = _StateContextManager[int](name, default, help, update_global_hook, - update_thread_local_hook, validate) + s = State[int](name, default, help, update_global_hook, + update_thread_local_hook, validate) config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s -def define_float_state( +def float_state( name: str, default: float, help: str, *, update_global_hook: Callable[[float], None] | None = None, update_thread_local_hook: Callable[[float | None], None] | None = None, -) -> _StateContextManager[float]: +) -> State[float]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -594,24 +610,24 @@ def validate(new_val): f'new float config value must be None or of type float, ' f'got {new_val} of type {type(new_val)}') - s = _StateContextManager[float](name, default, help, update_global_hook, - update_thread_local_hook, validate) + s = State[float](name, default, help, update_global_hook, + update_thread_local_hook, validate) config.add_option(name, s, float, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s -def define_string_state( +def string_state( name: str, default: str, help: str, *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, -) -> _StateContextManager[str]: +) -> State[str]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -638,24 +654,24 @@ def validator(new_val): raise TypeError('new string config value must be of type str,' f' got {new_val} of type {type(new_val)}.') - return define_string_or_object_state( + return string_or_object_state( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, validator=validator) -def define_optional_string_state( +def optional_string_state( name: str, default: str | None, help: str, *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, -) -> _StateContextManager[str | None]: +) -> State[str | None]: """Set up thread-local state and return a contextmanager for managing it. - See docstring for ``define_bool_state``. + See docstring for ``bool_state``. Args: name: string, converted to lowercase to define the name of the config @@ -682,13 +698,13 @@ def validator(new_val): raise ValueError('new string config value must be None or of type str,' f' got {new_val} of type {type(new_val)}.') - return define_string_or_object_state( + return string_or_object_state( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, validator=validator) -def define_string_or_object_state( +def string_or_object_state( name: str, default: Any, help: str, @@ -696,11 +712,11 @@ def define_string_or_object_state( update_global_hook: Callable[[Any], None] | None = None, update_thread_local_hook: Callable[[Any], None] | None = None, validator: Callable[[Any], None] | None = None, -) -> _StateContextManager[Any]: +) -> State[Any]: """Set up thread-local state and return a contextmanager for managing it. - Similar to ``define_string_state``, except the context manager will accept - any object, not just a string. Any value passed via commandline flag or + Similar to ``string_state``, except the context manager will accept + any object, not just a string. Any value passed via command line flag or environment variable will be treated as a string. Args: @@ -726,7 +742,7 @@ def define_string_or_object_state( default = os.getenv(name.upper(), default) config._contextmanager_flags.add(name) - s = _StateContextManager[Any]( + s = State[Any]( name, default, help, update_global_hook, update_thread_local_hook, validator) setattr(Config, name, property(lambda _: s.value)) @@ -734,7 +750,8 @@ def define_string_or_object_state( return s -class FlagHolder(Generic[_T]): +class Flag(Generic[_T]): + __slots__ = ("_name", "value", "_update_hook") _name: str @@ -759,42 +776,37 @@ def _set(self, value: _T) -> None: self._update_hook(value) -def check_exists(name): - if name not in config._value_holders: - raise AttributeError(f"Unrecognized config option: {name}") - - -def DEFINE_bool(name, default, *args, **kwargs) -> FlagHolder[bool]: +def bool_flag(name, default, *args, **kwargs) -> Flag[bool]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, bool, args, kwargs) return holder -def DEFINE_integer(name, default, *args, **kwargs) -> FlagHolder[int]: +def int_flag(name, default, *args, **kwargs) -> Flag[int]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, int, args, kwargs) return holder -def DEFINE_float(name, default, *args, **kwargs) -> FlagHolder[float]: +def float_flag(name, default, *args, **kwargs) -> Flag[float]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, float, args, kwargs) return holder -def DEFINE_string(name, default, *args, **kwargs) -> FlagHolder[str]: +def string_flag(name, default, *args, **kwargs) -> Flag[str]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, str, args, kwargs) return holder -def DEFINE_enum(name, default, *args, **kwargs) -> FlagHolder[str]: +def enum_flag(name, default, *args, **kwargs) -> Flag[str]: update_hook = kwargs.pop("update_hook", None) - holder = FlagHolder(name, default, update_hook) + holder = Flag(name, default, update_hook) config.add_option(name, holder, 'enum', args, kwargs) return holder @@ -815,6 +827,8 @@ class _GlobalExtraJitContext(NamedTuple): threefry_gpu_kernel_lowering: bool = False softmax_custom_jvp: bool = False xla_profile_version: int = 0 + pgle_profiling_runs: int = 0 + enable_pgle: bool = False def _update_global_jit_state(**kw): @@ -850,6 +864,8 @@ class _ThreadLocalExtraJitContext(NamedTuple): threefry_gpu_kernel_lowering: bool | None = None softmax_custom_jvp: bool | None = None xla_profile_version: int | None = None + pgle_profiling_runs: int | None = None + enable_pgle: bool | None = None class _ThreadLocalStateCache(threading.local): @@ -857,7 +873,7 @@ class _ThreadLocalStateCache(threading.local): The extra_jit_context in jax_jit.thread_local_state() may get updated and thus incurring dispatch overhead for comparing this python object during jit calls. - We want to duduplicate the objects that have the same hash/equality to also + We want to deduplicate the objects that have the same hash/equality to also have the same object ID, since the equality check is much faster if the object IDs match. """ @@ -879,7 +895,7 @@ def update_thread_local_jit_state(**kw): # TODO(b/214340779): remove flag when XLA:CPU is improved. -jax2tf_associative_scan_reductions = define_bool_state( +jax2tf_associative_scan_reductions = bool_state( name='jax2tf_associative_scan_reductions', default=False, help=( @@ -894,7 +910,7 @@ def update_thread_local_jit_state(**kw): ) ) -jax2tf_default_native_serialization = define_bool_state( +jax2tf_default_native_serialization = bool_state( name='jax2tf_default_native_serialization', default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', True), help=( @@ -904,22 +920,30 @@ def update_thread_local_jit_state(**kw): ) ) -jax_serialization_version = define_int_state( +jax_serialization_version = int_state( name='jax_serialization_version', - # Note: bump the default serialization version at least one month after + default=int_env('JAX_SERIALIZATION_VERSION', 0), # We use 0 to detect default. + help=( + 'DEPRECATED: use jax_export_calling_convention_version.' + ) +) + +jax_export_calling_convention_version = int_state( + name='jax_export_calling_convention_version', + # Note: bump the default calling convention version at least one month after # we update XlaCallModule to support the new version, so that serialized # modules are forward compatible with deployed versions of XlaCallModule. # Version 9 of XlaCallModule is supported since October 27th, 2023. - default=int_env('JAX_SERIALIZATION_VERSION', 9), + default=int_env('JAX_EXPORT_CALLING_CONVENTION_VERSION', 9), help=( - 'The version number to use for native serialization. This must be ' + 'The calling convention version number to use for exporting. This must be ' 'within the range of versions supported by the tf.XlaCallModule ' 'used in your deployment environment. ' - 'See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.' + 'See https://jax.readthedocs.io/en/latest/export/shape_poly.html#calling-convention-versions.' ) ) -jax_platforms = define_optional_string_state( +jax_platforms = optional_string_state( name='jax_platforms', default=None, help=( @@ -935,18 +959,18 @@ def update_thread_local_jit_state(**kw): 'otherwise.' )) -jax_pjrt_client_create_options = define_optional_string_state( +jax_pjrt_client_create_options = optional_string_state( name='jax_pjrt_client_create_options', default=None, help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings ' 'provided to a device platform pjrt client as extra arguments.')) -enable_checks = define_bool_state( +enable_checks = bool_state( name='jax_enable_checks', default=False, help='Turn on invariant checking for JAX internals. Makes things slower.') -debug_key_reuse = define_bool_state( +debug_key_reuse = bool_state( name='jax_debug_key_reuse', default=False, help=('Turn on experimental key reuse checking. With this configuration enabled,' @@ -955,7 +979,7 @@ def update_thread_local_jit_state(**kw): ' an error. Currently enabling this leads to a small Python overhead on' ' every call to a JIT-compiled function with keys as inputs or outputs.')) -check_tracer_leaks = define_bool_state( +check_tracer_leaks = bool_state( name='jax_check_tracer_leaks', default=False, help=('Turn on checking for leaked tracers as soon as a trace completes. ' @@ -965,7 +989,7 @@ def update_thread_local_jit_state(**kw): 'to disable any debuggers while leak checking is enabled.')) checking_leaks = functools.partial(check_tracer_leaks, True) -debug_nans = define_bool_state( +debug_nans = bool_state( name='jax_debug_nans', default=False, help=('Add nan checks to every operation. When a nan is detected on the ' @@ -973,7 +997,7 @@ def update_thread_local_jit_state(**kw): 'version in an attempt to more precisely identify the operation ' 'which produced the nan.')) -debug_infs = define_bool_state( +debug_infs = bool_state( name='jax_debug_infs', default=False, help=('Add inf checks to every operation. When an inf is detected on the ' @@ -981,7 +1005,7 @@ def update_thread_local_jit_state(**kw): 'version in an attempt to more precisely identify the operation ' 'which produced the inf.')) -log_compiles = define_bool_state( +log_compiles = bool_state( name='jax_log_compiles', default=False, help=('Log a message each time `jit` or `pmap` compiles an XLA ' @@ -989,7 +1013,7 @@ def update_thread_local_jit_state(**kw): 'option is set, the log level is WARNING; otherwise the level is ' 'DEBUG.')) -explain_cache_misses = define_bool_state( +explain_cache_misses = bool_state( name='jax_explain_cache_misses', default=False, help=('Each time there is a miss on one of the main caches (e.g. the ' @@ -997,14 +1021,14 @@ def update_thread_local_jit_state(**kw): '`logging`. When this option is set, the log level is WARNING; ' 'otherwise the level is DEBUG.')) -log_checkpoint_residuals = define_bool_state( +log_checkpoint_residuals = bool_state( name='jax_log_checkpoint_residuals', default=False, help=('Log a message every time jax.checkpoint (aka jax.remat) is ' 'partially evaluated (e.g. for autodiff), printing what residuals ' 'are saved.')) -pmap_shmap_merge = define_bool_state( +pmap_shmap_merge = bool_state( name='jax_pmap_shmap_merge', default=False, upgrade=True, @@ -1016,7 +1040,7 @@ def _update_jax_memories_global(val): def _update_jax_memories_thread_local(val): lib.jax_jit.thread_local_state().enable_memories = val -enable_memories = define_bool_state( +enable_memories = bool_state( 'jax_enable_memories', default=False, upgrade=True, @@ -1025,7 +1049,7 @@ def _update_jax_memories_thread_local(val): help=("If True, will allow fetching memory kinds available on executable " "and annotate Shardings with it.")) -spmd_mode = define_enum_state( +spmd_mode = enum_state( name='jax_spmd_mode', enum_values=['allow_all', 'allow_jit'], default='allow_jit', @@ -1038,14 +1062,14 @@ def _update_jax_memories_thread_local(val): " execute on non-fully addressable `jax.Array`s.")) -distributed_debug = define_bool_state( +distributed_debug = bool_state( name='jax_distributed_debug', default=False, help=('Enable logging useful for debugging multi-process distributed ' 'computations. Logging is performed with `logging` at WARNING ' 'level.')) -random_seed_offset = define_int_state( +random_seed_offset = int_state( name='jax_random_seed_offset', default=0, help=('Offset to all random seeds (e.g. argument to jax.random.key()).'), @@ -1055,7 +1079,7 @@ def _update_jax_memories_thread_local(val): random_seed_offset=val) ) -legacy_prng_key = define_enum_state( +legacy_prng_key = enum_state( name='jax_legacy_prng_key', enum_values=['allow', 'warn', 'error'], default='allow', @@ -1063,21 +1087,21 @@ def _update_jax_memories_thread_local(val): 'jax.random APIs.') ) -enable_custom_prng = define_bool_state( +enable_custom_prng = bool_state( name='jax_enable_custom_prng', default=False, upgrade=True, help=('Enables an internal upgrade that allows one to define custom ' 'pseudo-random number generator implementations.')) -default_prng_impl = define_enum_state( +default_prng_impl = enum_state( name='jax_default_prng_impl', enum_values=['threefry2x32', 'rbg', 'unsafe_rbg'], default='threefry2x32', help=('Select the default PRNG implementation, used when one is not ' 'explicitly provided at seeding time.')) -threefry_partitionable = define_bool_state( +threefry_partitionable = bool_state( name='jax_threefry_partitionable', default=False, upgrade=True, @@ -1092,7 +1116,7 @@ def _update_jax_memories_thread_local(val): update_thread_local_hook=lambda val: update_thread_local_jit_state( threefry_partitionable=val)) -threefry_gpu_kernel_lowering = define_bool_state( +threefry_gpu_kernel_lowering = bool_state( name='jax_threefry_gpu_kernel_lowering', default=False, help=('On GPU, lower threefry PRNG operations to a kernel implementation. ' @@ -1104,7 +1128,7 @@ def _update_jax_memories_thread_local(val): threefry_gpu_kernel_lowering=val)) -softmax_custom_jvp = define_bool_state( +softmax_custom_jvp = bool_state( name='jax_softmax_custom_jvp', default=False, upgrade=True, @@ -1117,14 +1141,14 @@ def _update_jax_memories_thread_local(val): softmax_custom_jvp=val)) -enable_custom_vjp_by_custom_transpose = define_bool_state( +enable_custom_vjp_by_custom_transpose = bool_state( name='jax_enable_custom_vjp_by_custom_transpose', default=False, upgrade=True, help=('Enables an internal upgrade that implements `jax.custom_vjp` by ' 'reduction to `jax.custom_jvp` and `jax.custom_transpose`.')) -raise_persistent_cache_errors = define_bool_state( +raise_persistent_cache_errors = bool_state( name='jax_raise_persistent_cache_errors', default=False, help=('If true, exceptions raised when reading or writing to the ' @@ -1134,14 +1158,14 @@ def _update_jax_memories_thread_local(val): 'continue. Defaults to false so cache bugs or intermittent issues ' 'are non-fatal.')) -persistent_cache_min_compile_time_secs = define_float_state( +persistent_cache_min_compile_time_secs = float_state( name='jax_persistent_cache_min_compile_time_secs', default=1., help=('The minimum compile time of a computation to be written to the ' 'persistent compilation cache. This threshold can be raised to ' 'decrease the number of entries written to the cache.')) -persistent_cache_min_entry_size_bytes = define_int_state( +persistent_cache_min_entry_size_bytes = int_state( name='jax_persistent_cache_min_entry_size_bytes', default=0, help=('The minimum size (in bytes) of an entry that will be cached in the ' @@ -1152,7 +1176,7 @@ def _update_jax_memories_thread_local(val): ' filesystem being used for the cache. ' '* > 0: the actual minimum size desired; no overrides.')) -compilation_cache_include_metadata_in_key = define_bool_state( +compilation_cache_include_metadata_in_key = bool_state( name='jax_compilation_cache_include_metadata_in_key', default=False, help=( @@ -1164,7 +1188,7 @@ def _update_jax_memories_thread_local(val): ), ) -hlo_source_file_canonicalization_regex = define_optional_string_state( +hlo_source_file_canonicalization_regex = optional_string_state( name='jax_hlo_source_file_canonicalization_regex', default=None, help=('Used to canonicalize the source_path metadata of HLO instructions ' @@ -1174,7 +1198,7 @@ def _update_jax_memories_thread_local(val): 'persistent compilation cache, which includes HLO metadata in the ' 'cache key.')) -include_full_tracebacks_in_locations = define_bool_state( +include_full_tracebacks_in_locations = bool_state( name='jax_include_full_tracebacks_in_locations', default=True, help=( @@ -1182,7 +1206,7 @@ def _update_jax_memories_thread_local(val): ), ) -traceback_in_locations_limit = define_int_state( +traceback_in_locations_limit = int_state( name='jax_traceback_in_locations_limit', default=10, help=( @@ -1192,7 +1216,7 @@ def _update_jax_memories_thread_local(val): ), ) -share_autotune_config_between_hosts = define_bool_state( +share_autotune_config_between_hosts = bool_state( name='jax_share_autotune_config_between_hosts', default=False, help=( @@ -1206,7 +1230,7 @@ def _update_jax_memories_thread_local(val): ), ) -share_binary_between_hosts = define_bool_state( +share_binary_between_hosts = bool_state( name='jax_share_binary_between_hosts', default=False, help=( @@ -1215,13 +1239,49 @@ def _update_jax_memories_thread_local(val): ), ) -share_binary_between_hosts_timeout_ms = define_int_state( +share_binary_between_hosts_timeout_ms = int_state( name='jax_share_binary_between_hosts_timeout_ms', default=20 * 60 * 1000, help='Timeout for the compiled module share.', ) -enable_compilation_cache = define_bool_state( +enable_pgle = bool_state( + name='jax_enable_pgle', + default=False, + help=( + 'If set to True and the property jax_pgle_profiling_runs is set to ' + 'greater than 0, the modules will be recompiled after running specified ' + 'number times with collected data provided to the profile guided latency ' + 'estimator.' + ), + update_global_hook=lambda val: _update_global_jit_state(enable_pgle=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + enable_pgle=val), +) + +pgle_profiling_runs = int_state( + name='jax_pgle_profiling_runs', + default=3, + help=( + 'Amount of times module should be profiled before recompilation when ' + 'PGLE is used.' + ), + update_global_hook=lambda val: _update_global_jit_state( + pgle_profiling_runs=val + ), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + pgle_profiling_runs=val + ), +) + +pgle_aggregation_percentile = int_state( + name='jax_pgle_aggregation_percentile', + default=90, + help='Percentile used to aggregate performance data between devices when ' + 'PGLE is used.', +) + +enable_compilation_cache = bool_state( name='jax_enable_compilation_cache', default=True, help=('If set to False, the compilation cache will be disabled regardless ' @@ -1230,7 +1290,7 @@ def _update_jax_memories_thread_local(val): 'set_cache_dir().'), ) -compilation_cache_dir = define_optional_string_state( +compilation_cache_dir = optional_string_state( name='jax_compilation_cache_dir', default=None, help=('Path for the cache. ' @@ -1239,7 +1299,19 @@ def _update_jax_memories_thread_local(val): '2. The value of this flag set in the command line or by default.'), ) -default_dtype_bits = define_enum_state( +compilation_cache_max_size = int_state( + name='jax_compilation_cache_max_size', + default=-1, + help=('The maximum size (in bytes) allowed for the persistent compilation ' + 'cache. When set, the least recently accessed cache entry(s) ' + 'will be deleted once the total cache directory size ' + 'exceeds the specified limit. ' + 'Caching will be disabled if this value is set to 0. A ' + 'special value of -1 indicates no limit, allowing the cache ' + 'size to grow indefinitely.'), +) + +default_dtype_bits = enum_state( name='jax_default_dtype_bits', enum_values=['32', '64'], default='64', @@ -1247,7 +1319,7 @@ def _update_jax_memories_thread_local(val): 'This is a temporary flag that will be used during the process ' 'of deprecating the ``jax_enable_x64`` flag.')) -numpy_dtype_promotion = define_enum_state( +numpy_dtype_promotion = enum_state( name='jax_numpy_dtype_promotion', enum_values=['standard', 'strict'], default='standard', @@ -1266,7 +1338,7 @@ def _update_x64_global(val): def _update_x64_thread_local(val): lib.jax_jit.thread_local_state().enable_x64 = val -enable_x64 = define_bool_state( +enable_x64 = bool_state( name='jax_enable_x64', default=False, help='Enable 64-bit types to be used', @@ -1301,7 +1373,7 @@ def _validate_default_device(val): # TODO(skye): default_device only accepts devices for now. Make it work with # platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]). -default_device = define_string_or_object_state( +default_device = string_or_object_state( name='jax_default_device', default=None, help=( @@ -1321,7 +1393,7 @@ def _update_disable_jit_global(val): def _update_disable_jit_thread_local(val): lib.jax_jit.thread_local_state().disable_jit = val -disable_jit = define_bool_state( +disable_jit = bool_state( name='jax_disable_jit', default=False, help=('Disable JIT compilation and just call original Python.'), @@ -1329,7 +1401,7 @@ def _update_disable_jit_thread_local(val): update_thread_local_hook=_update_disable_jit_thread_local) -numpy_rank_promotion = define_enum_state( +numpy_rank_promotion = enum_state( name='jax_numpy_rank_promotion', enum_values=['allow', 'warn', 'raise'], default='allow', @@ -1340,9 +1412,9 @@ def _update_disable_jit_thread_local(val): update_thread_local_hook=lambda val: \ update_thread_local_jit_state(numpy_rank_promotion=val)) -default_matmul_precision = define_optional_enum_state( +default_matmul_precision = optional_enum_state( name='jax_default_matmul_precision', - enum_values=['bfloat16', 'tensorfloat32', 'float32'], + enum_values=['default', 'high', 'highest', 'bfloat16', 'tensorfloat32', 'float32'], default=None, help=('Control the default matmul and conv precision for 32bit inputs.\n\n' @@ -1365,7 +1437,7 @@ def _update_disable_jit_thread_local(val): update_thread_local_hook=lambda val: \ update_thread_local_jit_state(default_matmul_precision=val)) -traceback_filtering = define_enum_state( +traceback_filtering = enum_state( name = 'jax_traceback_filtering', enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames", "auto"], @@ -1386,14 +1458,14 @@ def _update_disable_jit_thread_local(val): # This flag is for internal use. # TODO(tianjianlu): Removes once we always enable cusparse lowering. # TODO(b/262050896): Set to true after bug is fixed -bcoo_cusparse_lowering = define_bool_state( +bcoo_cusparse_lowering = bool_state( name='jax_bcoo_cusparse_lowering', default=False, help=('Enables lowering BCOO ops to cuSparse.')) # TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging # if the intended backend can handle lowering the result -dynamic_shapes = define_bool_state( +dynamic_shapes = bool_state( name='jax_dynamic_shapes', default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')), help=('Enables experimental features for staging out computations with ' @@ -1405,26 +1477,26 @@ def _update_disable_jit_thread_local(val): # This flag is temporary during rollout of the remat barrier. # TODO(parkers): Remove if there are no complaints. -remat_opt_barrier = define_bool_state( +remat_opt_barrier = bool_state( name='jax_remat_opt_barrier', default=True, help=('Enables using optimization-barrier op for lowering remat.')) # TODO(sharadmv,mattjj): set default to True, then remove -eager_pmap = define_bool_state( +eager_pmap = bool_state( name='jax_eager_pmap', default=True, upgrade=True, help='Enable eager-mode pmap when jax_disable_jit is activated.') # TODO(mattjj): remove once we land mutable array plumbing, or face great shame -custom_vjp_disable_shape_check = define_bool_state( +custom_vjp_disable_shape_check = bool_state( name='jax_custom_vjp_disable_shape_check', default=False, upgrade=True, help='Disable the check from #19009 to enable some custom_vjp hacks.') -xla_runtime_errors = define_bool_state( +xla_runtime_errors = bool_state( name='jax_experimental_unsafe_xla_runtime_errors', default=False, help=('Enable XLA runtime errors for jax.experimental.checkify.checks ' @@ -1434,7 +1506,7 @@ def _update_disable_jit_thread_local(val): 'work under pmap/pjit.') ) -jax_xla_profile_version = define_int_state( +jax_xla_profile_version = int_state( name='jax_xla_profile_version', default=0, help=( @@ -1486,7 +1558,7 @@ def _update_transfer_guard(state, key, val): else: assert False, f'Invalid transfer guard level {val}' -transfer_guard_host_to_device = define_optional_enum_state( +transfer_guard_host_to_device = optional_enum_state( name='jax_transfer_guard_host_to_device', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' @@ -1501,7 +1573,7 @@ def _update_transfer_guard(state, key, val): update_thread_local_hook=lambda val: _update_transfer_guard( transfer_guard_lib.thread_local_state(), 'host_to_device', val)) -transfer_guard_device_to_device = define_optional_enum_state( +transfer_guard_device_to_device = optional_enum_state( name='jax_transfer_guard_device_to_device', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' @@ -1516,7 +1588,7 @@ def _update_transfer_guard(state, key, val): update_thread_local_hook=lambda val: _update_transfer_guard( transfer_guard_lib.thread_local_state(), 'device_to_device', val)) -transfer_guard_device_to_host = define_optional_enum_state( +transfer_guard_device_to_host = optional_enum_state( name='jax_transfer_guard_device_to_host', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' @@ -1537,7 +1609,7 @@ def _update_all_transfer_guard_global(val): 'jax_transfer_guard_device_to_host'): config.update(name, val) -_transfer_guard = define_optional_enum_state( +_transfer_guard = optional_enum_state( name='jax_transfer_guard', enum_values=[ 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit' @@ -1581,7 +1653,7 @@ def _update_debug_log_modules(module_names_str: str | None): logging_config.enable_debug_logging(module_name) # Don't define a context manager since this isn't threadsafe. -define_string_state( +string_state( name='jax_debug_log_modules', default='', help=('Comma-separated list of module names (e.g. "jax" or ' @@ -1589,7 +1661,7 @@ def _update_debug_log_modules(module_names_str: str | None): 'for.'), update_global_hook=_update_debug_log_modules) -pmap_no_rank_reduction = define_bool_state( +pmap_no_rank_reduction = bool_state( name='jax_pmap_no_rank_reduction', default=False, help=( diff --git a/jax/_src/core.py b/jax/_src/core.py index 806ffeefcdd6..0a9ad6981df5 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -13,10 +13,9 @@ # limitations under the License. from __future__ import annotations -import collections # noqa: F401 from collections import Counter, defaultdict, deque, namedtuple -from collections.abc import (Collection, Generator, Hashable, Iterable, - Iterator, Set, Sequence, MutableSet, +from collections.abc import (Callable, Collection, Generator, Hashable, + Iterable, Iterator, Set, Sequence, MutableSet, MutableMapping) from contextlib import contextmanager, ExitStack from dataclasses import dataclass @@ -29,13 +28,14 @@ import operator import threading import types -from typing import (Any, Callable, ClassVar, Generic, NamedTuple, TypeVar, +from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar, cast, overload, Union) import warnings from weakref import ref import numpy as np +from jax._src import deprecations from jax._src import dtypes from jax._src import config from jax._src import effects @@ -62,7 +62,7 @@ map, unsafe_map = safe_map, map -_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.DEFINE_integer( +_TRACER_ERROR_NUM_TRACEBACK_FRAMES = config.int_flag( 'jax_tracer_error_num_traceback_frames', config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5), help='Set the number of stack frames in JAX tracer error messages.' @@ -668,13 +668,25 @@ class Tracer(typing.Array, metaclass=StrictABCMeta): size = _aval_property('size') shape = _aval_property('shape') + def __hash__(self): + # TODO(jakevdp) finalize this deprecation and set __hash__ = None + # Warning added 2024-06-13 + if deprecations.is_accelerated('tracer-hash'): + raise TypeError(f"unhashable type: {type(self)}") + # Use FutureWarning rather than DeprecationWarning because hash is likely + # not called directly by the user, so we want to warn at all stacklevels. + warnings.warn( + f"unhashable type: {type(self)}. Attempting to hash a tracer will lead to an" + " error in a future JAX release.", category=FutureWarning) + return super().__hash__() + def __init__(self, trace: Trace): self._trace = trace def _error_repr(self): if self.aval is None: return f"traced array with aval {self.aval}" - return f"traced array with shape {raise_to_shaped(self.aval).str_short()}." + return f"traced array with shape {raise_to_shaped(self.aval).str_short()}" def __array__(self, *args, **kw): raise TracerArrayConversionError(self) @@ -1028,13 +1040,8 @@ def copy(self): def _update_thread_local_jit_state(dynamic): - # Copies the MainTrace instance, removing any .debug_info or .jaxpr_stack - # fields that should not be kept alive as part of a cache key. - # TODO(mattjj): split debug_info and jaxpr_stack out of MainTrace. - # TODO(mattjj): add a test that verifies that JIT-ted functions are not kept - # alive by the JIT cache, particularly for nested JIT-ted functions. - copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload) - config.update_thread_local_jit_state(dynamic_trace_state=copy) + state = (dynamic.level, dynamic.trace_type) + config.update_thread_local_jit_state(dynamic_trace_state=state) # The global state of the tracer is accessed by a thread-local object. @@ -1059,8 +1066,8 @@ def _initialize_jax_jit_thread_local_state(): tls = jax_jit.thread_local_state() if tls.extra_jit_context is None: dynamic = thread_local_state.trace_state.trace_stack.dynamic - copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload) - config.update_thread_local_jit_state(dynamic_trace_state=copy) + state = (dynamic.level, dynamic.trace_type) + config.update_thread_local_jit_state(dynamic_trace_state=state) jax_jit.set_thread_local_state_initialization_callback( @@ -1767,6 +1774,7 @@ def join(self, other): def str_short(self, short_dtypes=False): dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + dt_str = dt_str.replace('void', 'float0') shapestr = ','.join(map(str, self.shape)) if self.named_shape: named_shapestr = ','.join(f'{k}:{v}' for k, v in self.named_shape.items()) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 0e323d82c9d7..e5b1f0084d00 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -15,7 +15,6 @@ from enum import Enum from functools import partial, reduce import operator -from typing import Optional import json import jax @@ -309,23 +308,29 @@ def check_compute_capability(cc): def _dot_product_attention_fwd( query, key, value, bias, mask, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout, is_training): + dropout_rate, variadic_args, mask_type, layout, cudnn_version): + # check if flash attention is supported for this attention pattern + check_is_flash_attention( + query, key, layout, cudnn_version, bias is not None, False) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, mask, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, is_training=is_training) + mask_type=mask_type, layout=layout, is_training=False) output = outputs[0] return output def _dot_product_attention_fwd_rule( query, key, value, bias, mask, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout, is_training): + dropout_rate, variadic_args, mask_type, layout, cudnn_version): + # check if flash attention is supported for this attention pattern + check_is_flash_attention( + query, key, layout, cudnn_version, bias is not None, True) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, mask, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, is_training=is_training) + mask_type=mask_type, layout=layout, is_training=True) res = (query, key, value, bias, mask, q_seqlen, kv_seqlen, - outputs[1], outputs[0]) if is_training else None + outputs[1], outputs[0]) return outputs[0], res def _dot_product_attention_bwd_rule( @@ -907,11 +912,11 @@ def _dot_product_attention(query: Array, variadic_args: tuple[bool, ...], mask_type: bool, layout: int, - is_training: bool): + cudnn_version: int): output = _dot_product_attention_fwd( query, key, value, bias, mask, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, is_training=is_training) + mask_type=mask_type, layout=layout, cudnn_version=cudnn_version) return output # _dot_product_attention_fwd must have the same func signature as _dot_product_attention @@ -921,17 +926,16 @@ def _dot_product_attention(query: Array, def dot_product_attention(query: Array, key: Array, value: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, - q_seqlen: Optional[Array] = None, - kv_seqlen: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, + q_seqlen: Array | None = None, + kv_seqlen: Array | None = None, *, scale: float = 1.0, mask_type: MaskType = MaskType.NO_MASK, seed: int = 42, dropout_rate: float = 0., - qkv_layout: str = "BTNH", - is_training = False): + qkv_layout: str = "BTNH"): """Computes dot-product attention given query (Q), key (K), and value (V). This function serves as the core operation for applying attention @@ -963,7 +967,6 @@ def dot_product_attention(query: Array, dropout_rate: Dropout rate. qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH, BNSH. - is_training: choose to save activation or not. Returns: Output of the same shape as the query. @@ -978,14 +981,11 @@ def dot_product_attention(query: Array, bias = bias.reshape((1,) * (4 - len(bias.shape)) + bias.shape) # check if input shape and data type is compatiable check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout) - # check if flash attention is supported for this attention pattern - check_is_flash_attention( - query, key, layout, cudnn_version, bias is not None, is_training) if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None): raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask") has_bias = bias is not None has_mask = mask is not None - has_dbias = has_bias and is_training and \ + has_dbias = has_bias and \ should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] variadic_args = (has_bias, has_mask, has_dbias) if bias is None: @@ -998,6 +998,6 @@ def dot_product_attention(query: Array, kv_seqlen = jnp.zeros(0, dtype=query.dtype) output = _dot_product_attention( query, key, value, bias, mask, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout.value, is_training + dropout_rate, variadic_args, mask_type, layout.value, cudnn_version ) return output diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 22592b854d14..4d41849b75d3 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -14,9 +14,9 @@ from __future__ import annotations +from collections.abc import Callable import functools import operator -from typing import Callable from jax import lax from jax._src import api diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index da3f5fac2c29..46d9fab00455 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -14,11 +14,11 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses from functools import update_wrapper, reduce, partial import inspect -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Generic, TypeVar from jax._src import config from jax._src import core @@ -40,10 +40,10 @@ from jax._src.interpreters import xla from jax._src.interpreters.batching import not_mapped from jax._src.lax import lax -from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_map, - treedef_is_leaf, treedef_tuple, - register_pytree_node_class, tree_leaves, - tree_flatten_with_path, keystr) +from jax._src.tree_util import ( + tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, + register_pytree_node_class, tree_leaves, tree_flatten_with_path, keystr, + treedef_children) from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable, unzip2) @@ -180,7 +180,7 @@ def defjvp(self, Returns: None. - Example:: + Examples: @jax.custom_jvp def f(x, y): @@ -212,7 +212,7 @@ def defjvps(self, *jvps: Callable[..., ReturnValue] | None): Returns: None. - Example:: + Examples: @jax.custom_jvp def f(x, y): @@ -567,7 +567,7 @@ def defvjp(self, Returns: None. - Example:: + Examples: @jax.custom_vjp def f(x, y): @@ -729,6 +729,8 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args): py_res = tree_unflatten(res_tree, res) py_cts_out = tree_unflatten(out_tree, cts_out) py_cts_in = yield (py_res, py_cts_out), {} + if isinstance(py_cts_in, list) and len(py_cts_in) == len(treedef_children(in_tree)): + py_cts_in = tuple(py_cts_in) # For each None in py_cts_in, indicating an argument for which the rule # produces no cotangent, we replace it with a pytree with the structure of the # corresponding subtree of in_tree and with leaves of a non-pytree sentinel @@ -764,7 +766,8 @@ def append(x, d): if not core.typecompat(a.at_least_vspace(), a_ := ct.aval): msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype " "that does not match the corresponding input tangent shape/dtype: " - f"the SymbolicZero had shape/dtype {a_.str_short()} while the " + f"at output{keystr(kp)} the SymbolicZero had shape/dtype " + f"{a_.str_short()} while the " f"corresponding input had shape/dtype {a.str_short()}. " "Consider just returning a None here instead of a SymbolicZero " "object.") diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index da132e085086..a4de1b8cc46c 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -14,8 +14,9 @@ from __future__ import annotations +from collections.abc import Callable import functools -from typing import Any, Callable +from typing import Any from jax._src import ad_util from jax._src import api_util diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 265dcd0de843..7d8b3a914b6d 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -16,12 +16,12 @@ from __future__ import annotations import importlib.util -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import logging import string import sys -from typing import Any, Callable, Union +from typing import Any, Union import weakref import numpy as np diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 0bbae2d140ea..5513b169cca9 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -101,3 +101,12 @@ def is_accelerated(deprecation_id: str) -> bool: if deprecation_id not in _registered_deprecations: raise ValueError(f"{deprecation_id=!r} not registered.") return _registered_deprecations[deprecation_id].accelerated + + +def warn(deprecation_id: str, message: str, stacklevel: int) -> None: + """Warns about a deprecation, or errors if the deprecation is accelerated.""" + if is_accelerated(deprecation_id): + raise ValueError(message) + else: + warnings.warn(message, category=DeprecationWarning, + stacklevel=stacklevel + 1) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 220fe063bd3e..9ae1f7a6c2a3 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -16,12 +16,13 @@ from __future__ import annotations import atexit -from collections.abc import Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence import contextlib +import dataclasses from functools import partial import itertools import time -from typing import Any, Callable, NamedTuple +from typing import Any, NamedTuple import logging import threading @@ -48,8 +49,8 @@ from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding, - GSPMDSharding, TransferToMemoryKind) + SingleDeviceSharding, NamedSharding, + GSPMDSharding, TransferToMemoryKind, is_single_device_sharding) from jax._src.layout import Layout, DeviceLocalLayout @@ -163,12 +164,6 @@ def wait_for_tokens(): runtime_tokens.block_until_ready() -def is_single_device_sharding(sharding: Sharding) -> bool: - # Special case PmapSharding here because PmapSharding maps away an axis - # and needs to be handled separately.test_pjit_single_device_sharding_add - return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) - - @contextlib.contextmanager def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None): if _on_exit: @@ -226,7 +221,7 @@ class SourceInfo(NamedTuple): def jaxpr_shardings( jaxpr: core.Jaxpr, -) -> Iterator[tuple[XLACompatibleSharding, SourceInfo]]: +) -> Iterator[tuple[Sharding, SourceInfo]]: from jax._src import pjit from jax.experimental import shard_map @@ -246,10 +241,9 @@ def _names_to_pspec(names): yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info) for names in [*eqn.params['in_names'], *eqn.params['out_names']]) elif eqn.primitive is device_put_p: - s = eqn.params['device'] - if isinstance(s, XLACompatibleSharding) and s.memory_kind is not None: - source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - yield (s, source_info) + source_info = SourceInfo(eqn.source_info, eqn.primitive.name) + yield from ((s, source_info) for s in eqn.params['devices'] + if isinstance(s, Sharding) and s.memory_kind is not None) for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_shardings(subjaxpr) @@ -328,10 +322,6 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: raise FloatingPointError(f"invalid value (inf) encountered in {name}") -def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool): - result_handler = pxla.global_aval_to_result_handler(aval, s, committed) - return result_handler(pxla.shard_arg(x, s)) - def _override_get_device_assignment(sharding, *args, **kwargs): da = sharding._device_assignment return xb.get_device_backend(da[0]), da @@ -387,6 +377,25 @@ def _mcjax_reshard(x, target_sharding): pxla._get_and_check_device_assignment.fn = _orig_get_and_check_device_assignment +@dataclasses.dataclass(frozen=True) +class _DeferredShardArg: + """Deferred call to `pxla.shard_args`. + + Per-array impls return this object instead of a result array to indicate a + deferred `shard_args` call. `_batched_device_put_impl` then batches all + `_DeferredShardArg` objects into a single `shard_args` call. + """ + + x: Any + s: Sharding + aval: core.AbstractValue + committed: bool + + @property + def result_handler(self): + return pxla.global_aval_to_result_handler(self.aval, self.s, self.committed) + + def _device_put_sharding_impl(x, aval, device): from jax._src import array @@ -398,7 +407,7 @@ def _device_put_sharding_impl(x, aval, device): isinstance(x, array.ArrayImpl) and not x.is_fully_addressable): # This has to be XLACompatible because _mcjax_reshard will run a # XLA computation. - assert isinstance(s, XLACompatibleSharding) + assert isinstance(s, Sharding) return _mcjax_reshard(x, s) if not s.is_fully_addressable: # TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array. @@ -408,7 +417,7 @@ def _device_put_sharding_impl(x, aval, device): " trying to use device_put in multi-controller JAX which is not" " supported. Please use jax.make_array_from_single_device_arrays API" " or pass device or Sharding which represents addressable devices.") - return _put_x(x, s, aval, True) + return _DeferredShardArg(x, s, aval, True) # Only `Device` exists below. `Sharding` instance is handled above. if isinstance(x, array.ArrayImpl): @@ -424,12 +433,15 @@ def _device_put_sharding_impl(x, aval, device): sh = SingleDeviceSharding(pxla._get_default_device() if device is None else device) - return _put_x(x, sh, aval, device is not None) + return _DeferredShardArg(x, sh, aval, device is not None) + def _device_put_impl( x, - device: Device | Sharding | Layout | None = None, - src: Device | Sharding | Layout | None = None): + *, + device: Device | Sharding | Layout | None, + src: Device | Sharding | Layout | None, +): if (isinstance(device, TransferToMemoryKind) or isinstance(src, TransferToMemoryKind)): raise ValueError( @@ -463,43 +475,93 @@ def _device_put_impl( return _device_put_sharding_impl(x, aval, device) +def _batched_device_put_impl( + *xs, + devices: Sequence[Device | Sharding | Layout | None], + srcs: Sequence[Device | Sharding | Layout | None], +): + ys = [] + shard_arg_indices, shard_arg_xs, shard_arg_shardings = [], [], [] + for i, (x, device, src) in enumerate(zip(xs, devices, srcs)): + y = _device_put_impl(x, device=device, src=src) + if isinstance(y, _DeferredShardArg): + shard_arg_indices.append(i) + shard_arg_xs.append(y.x) + shard_arg_shardings.append(y.s) + ys.append(y) + + if shard_arg_xs: + # Batch shard_arg calls. Helps improve efficiency for backends that support + # efficient batch transfer. + shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs) + for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results): + assert isinstance(ys[i], _DeferredShardArg) + ys[i] = ys[i].result_handler(shard_arg_result) + + return ys + + device_put_p = core.Primitive('device_put') -device_put_p.def_impl(_device_put_impl) -device_put_p.def_abstract_eval(lambda x, device=None, src=None: x) - -def device_put_transpose_rule(ct, _, device, src): - return [device_put_p.bind(ct, device=src, src=device)] -ad.deflinear2(device_put_p, device_put_transpose_rule) -batching.defvectorized(device_put_p) - -def _tpu_gpu_device_put_lowering(ctx, x, *, device, src): - if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and - device.memory_kind is not None): - aval, = ctx.avals_in - out_aval, = ctx.avals_out - if isinstance(device, XLACompatibleSharding): - x = mlir.wrap_with_sharding_op( - ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) - x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval) - return [x] - return [x] +device_put_p.multiple_results = True +device_put_p.def_impl(_batched_device_put_impl) +device_put_p.def_abstract_eval(lambda *xs, devices, srcs: xs) + +def _device_put_transpose(cts, *_, devices, srcs): + results = [None] * len(cts) + dp_args = [] + for i, (ct, device, src) in enumerate(zip(cts, devices, srcs)): + if type(ct) is not ad.Zero: + dp_args.append((i, ct, device, src)) + if dp_args: + indices, args, devices, srcs = list(zip(*dp_args)) + ys = device_put_p.bind(*args, devices=srcs, srcs=devices) + for i, y in zip(indices, ys): + results[i] = y + return results +ad.primitive_jvps[device_put_p] = partial(ad.linear_jvp, device_put_p) +ad.primitive_transposes[device_put_p] = _device_put_transpose + +def _device_put_batcher(batched_args, batch_dims, **params): + mapped_batch_dims = [bd for bd in batch_dims if bd is not batching.not_mapped] + assert not mapped_batch_dims or all( + mapped_batch_dims[0] == bd for bd in mapped_batch_dims[1:] + ), batch_dims + return device_put_p.bind(*batched_args, **params), batch_dims +batching.primitive_batchers[device_put_p] = _device_put_batcher + +def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs): + def lower(x, device, src, aval, out_aval): + if (isinstance(device, (Sharding, TransferToMemoryKind)) and + device.memory_kind is not None): + if isinstance(device, Sharding): + x = mlir.wrap_with_sharding_op( + ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) + x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval) + return x + return x + return list(map(lower, xs, devices, srcs, ctx.avals_in, ctx.avals_out)) mlir.register_lowering( device_put_p, _tpu_gpu_device_put_lowering, platform='tpu') mlir.register_lowering( device_put_p, _tpu_gpu_device_put_lowering, platform='gpu') -def _common_device_put_lowering(ctx, x, *, device, src): - if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and - device.memory_kind is not None): - raise NotImplementedError( - "Passing memory_kind to device_put via Shardings is not supported on" - f" platforms {ctx.module_context.platforms}") - return [x] +def _common_device_put_lowering(ctx, *xs, devices, srcs): + for device in devices: + if (isinstance(device, (Sharding, TransferToMemoryKind)) and + device.memory_kind is not None): + raise NotImplementedError( + "Passing memory_kind to device_put via Shardings is not supported on" + f" platforms {ctx.module_context.platforms}") + return xs mlir.register_lowering(device_put_p, _common_device_put_lowering) -def _propagate_mem_kind_dp(xm, device=None, src=None): - if isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)): - return device.memory_kind - return None +def _propagate_mem_kind_dp(*xm, devices=None, srcs=None): + memory_kinds = [] + for device in devices: + if isinstance(device, (Sharding, TransferToMemoryKind)): + memory_kinds.append(device.memory_kind) + else: + memory_kinds.append(None) + return memory_kinds pxla.memory_kind_propagate_rule[device_put_p] = _propagate_mem_kind_dp diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index d71e206b9a3e..ea0961082d03 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -102,7 +102,9 @@ def initialize(self, if process_id == 0: if self.service is not None: raise RuntimeError('distributed.initialize should only be called once.') - logger.info('Starting JAX distributed service on %s', coordinator_address) + logger.info( + 'Starting JAX distributed service on %s', coordinator_bind_address + ) self.service = xla_extension.get_distributed_runtime_service( coordinator_bind_address, num_processes) @@ -208,7 +210,7 @@ def initialize(coordinator_address: str | None = None, RuntimeError: If :func:`~jax.distributed.initialize` is called more than once or if called after the backend is already initialized. - Example: + Examples: Suppose there are two GPU processes, and process 0 is the designated coordinator with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index f85e4833e13c..1d50c5be74b6 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -31,7 +31,7 @@ import numpy as np from jax._src import config -from jax._src.typing import DType, DTypeLike +from jax._src.typing import Array, DType, DTypeLike from jax._src.util import set_module, StrictABC from jax._src import traceback_util @@ -42,8 +42,8 @@ except: pass else: - if _ml_dtypes_version < (0, 4, 0): - raise ValueError("JAX requires ml_dtypes version 0.4.0 or newer; " + if _ml_dtypes_version < (0, 2, 0): + raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; " f"installed version is {ml_dtypes.__version__}.") export = set_module('jax.dtypes') @@ -500,7 +500,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis This DAG maps each type to its immediately higher type on the lattice. """ b1, = _bool_types - uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types + _uint4, u1, u2, u4, u8, _int4, i1, i2, i4, i8 = _int_types *f1_types, bf, f2, f4, f8 = _float_types c4, c8 = _complex_types i_, f_, c_ = _weak_types @@ -508,13 +508,18 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis out: dict[JAXType, list[JAXType]] out = { b1: [i_], - uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], - i_: [uint4, int4, u1, i1], - int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], + u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], + i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_], f_: [*f1_types, bf, f2, c_], **{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], c_: [c4], c4: [c8], c8: [], } + if _int4_dtype is not None: + out[i_].append(_int4_dtype) + out[_int4_dtype] = [] + if _uint4_dtype is not None: + out[i_].append(_uint4_dtype) + out[_uint4_dtype] = [] return out elif jax_numpy_dtype_promotion == 'strict': return { @@ -736,6 +741,12 @@ def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tupl return (dtype, weak_type) if return_weak_type_flag else dtype # type: ignore[return-value] def check_user_dtype_supported(dtype, fun_name=None): + if isinstance(dtype, Array): + # Deprecation warning added 2024 June 13. + warnings.warn("Passing an array as a dtype argument is deprecated; " + "instead of dtype=arr use dtype=arr.dtype.", + category=DeprecationWarning, stacklevel=3) + return # no further check needed, as array dtypes have already been validated. if issubdtype(dtype, extended): return # Avoid using `dtype in [...]` because of numpy dtype equality overloading. @@ -747,14 +758,14 @@ def check_user_dtype_supported(dtype, fun_name=None): msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" msg += f" in {fun_name}" if fun_name else "" raise TypeError(msg) - if dtype is not None and np_dtype != canonicalize_dtype(dtype): + if dtype is not None and np_dtype != canonicalize_dtype(np_dtype): msg = ("Explicitly requested dtype {} {} is not available, " "and will be truncated to dtype {}. To enable more dtypes, set the " "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell " "environment variable. " "See https://github.com/google/jax#current-gotchas for more.") fun_name = f"requested in {fun_name}" if fun_name else "" - truncated_dtype = canonicalize_dtype(dtype).name + truncated_dtype = canonicalize_dtype(np_dtype).name warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3) def safe_to_cast(input_dtype_or_value: Any, diff --git a/jax/_src/earray.py b/jax/_src/earray.py index 176abcc5a37a..f4b5e232bc33 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -20,6 +20,7 @@ from jax._src import basearray from jax._src import core from jax._src import tree_util +from jax._src import sharding_impls from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.util import safe_zip, safe_map @@ -80,7 +81,7 @@ def __len__(self): @property def sharding(self): phys_sharding = self._data.sharding - return self.aval.dtype._rules.logical_sharding(self.aval, phys_sharding) + return sharding_impls.logical_sharding(self.aval, phys_sharding) # TODO(mattjj): not implemented below here, need more methods from ArrayImpl @@ -97,10 +98,11 @@ def global_shards(self): # TODO(mattjj): _set_array_base_attributes -def _earray_shard_arg_handler(x, sharding): - arr = x._data - phys_sharding = x.aval.dtype._rules.physical_sharding(x.aval, sharding) - return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding) +def _earray_shard_arg_handler(xs, shardings): + arrs = [x._data for x in xs] + phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding) + for x, sharding in zip(xs, shardings)] + return pxla.shard_args(phys_shardings, arrs) pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval diff --git a/jax/ffi.py b/jax/_src/export/__init__.py similarity index 91% rename from jax/ffi.py rename to jax/_src/export/__init__.py index ddbac4fa309c..862a661e24b9 100644 --- a/jax/ffi.py +++ b/jax/_src/export/__init__.py @@ -11,5 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from jax._src.ffi import include_dir as include_dir diff --git a/jax/experimental/export/_export.py b/jax/_src/export/_export.py similarity index 63% rename from jax/experimental/export/_export.py rename to jax/_src/export/_export.py index 1689bdc87608..a228eaa8b285 100644 --- a/jax/experimental/export/_export.py +++ b/jax/_src/export/_export.py @@ -17,13 +17,13 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import copy import dataclasses import functools import itertools import re -from typing import Any, Callable, Union +from typing import Any, Union import warnings from absl import logging @@ -47,11 +47,12 @@ from jax._src import pjit from jax._src import sharding_impls from jax._src import source_info_util +from jax._src import stages from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb -from jax.experimental.export import _shape_poly +from jax._src.export import shape_poly map = util.safe_map zip = util.safe_zip @@ -59,33 +60,31 @@ DType = Any Shape = jax._src.core.Shape # The values of input and output sharding from the lowering. -LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue] +LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue] +HloSharding = xla_client.HloSharding -# None means unspecified sharding -Sharding = Union[xla_client.HloSharding, None] - -# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions -# for a description of the different versions. -minimum_supported_serialization_version = 9 -maximum_supported_serialization_version = 9 +# The minimum and maximum supported calling convention version. +# See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#calling-conventions-versions +minimum_supported_calling_convention_version = 9 +maximum_supported_calling_convention_version = 9 class DisabledSafetyCheck: - """A safety check should be skipped on (de)serialization. + """A safety check that should be skipped on (de)serialization. Most of these checks are performed on serialization, but some are deferred to deserialization. The list of disabled checks is attached to the serialization, - e.g., as a sequence of string attributes to `jax_export.Exported` or of + e.g., as a sequence of string attributes to `jax.export.Exported` or of `tf.XlaCallModuleOp`. - You can disable more deserialization safety checks by passing - `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`. + When using jax2tf, you can disable more deserialization safety checks + by passing `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`. """ _impl: str @classmethod def platform(cls) -> DisabledSafetyCheck: - """Allows the execution platform to differ from the serialization platform. + """Allows the compilation platform to differ from the export platform. Has effect only on deserialization. """ @@ -103,7 +102,7 @@ def custom_call(cls, target_name: str) -> DisabledSafetyCheck: @classmethod def shape_assertions(cls) -> DisabledSafetyCheck: - """A noop. DEPRECATED. + """DEPRECATED: A noop. Was used previously to allow invocations with shapes that do not meet the constraints. Has no effect anymore, shape assertions cannot be disabled. @@ -150,26 +149,37 @@ class Exported: out_tree: a PyTreeDef describing the result of the lowered JAX function. out_avals: the flat tuple of output abstract values. May contain dimension expressions in the shapes, with dimension variables among those in - `in_avals. - in_shardings: the flattened input shardings, as long as `in_avals`. - out_shardings: the flattened output shardings, as long as `out_avals`. + `in_avals`. + in_shardings_hlo: the flattened input shardings, a sequence as long + as `in_avals`. `None` means unspecified sharding. + Note that these do not include the mesh or the actual devices used in + the mesh. See `in_shardings_jax` for a way to turn these + into sharding specification that can be used with JAX APIs. + out_shardings_hlo: the flattened output shardings, a sequence as long + as `out_avals`. `None` means unspecified sharding. + Note that these do not include the mesh or the actual devices used in + the mesh. See `out_shardings_jax` for a way to turn these + into sharding specification that can be used with JAX APIs. nr_devices: the number of devices that the module has been lowered for. - lowering_platforms: a tuple containing at least one of 'tpu', 'cpu', - 'cuda', 'rocm'. See below for the calling convention for when - there are multiple lowering platforms. + platforms: a tuple containing the platforms for which the function should + be exported. The set of platforms in JAX is open-ended; users can + add platforms. JAX built-in platforms are: 'tpu', 'cpu', 'cuda', 'rocm'. + See https://jax.readthedocs.io/en/latest/export/export.html#cross-platform-and-multi-platform-export. ordered_effects: the ordered effects present in the serialized module. - This is present from serialization version 9. See below for the - calling convention in presence of ordered effects. + This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention + for the calling convention in presence of ordered effects. unordered_effects: the unordered effects present in the serialized module. This is present from serialization version 9. mlir_module_serialized: the serialized lowered VHLO module. - mlir_module_serialization_version: a version number for the serialized module. - See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions. + calling_convention_version: a version number for the calling + convention of the exported module. + See more versioning details at https://jax.readthedocs.io/en/latest/export.html#calling-convention-versions. module_kept_var_idx: the sorted indices of the arguments among `in_avals` that must be passed to the module. The other arguments have been dropped - because they are not used. Same length as `in_shardings`. - uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape - polymorphism. This may be because `in_avals` contains dimension + because they are not used. + uses_global_constants: whether the `mlir_module_serialized` uses shape + polymorphism or multi-platform export. + This may be because `in_avals` contains dimension variables, or due to inner calls of Exported modules that have dimension variables or platform index arguments. Such modules need shape refinement before XLA compilation. @@ -182,119 +192,26 @@ class Exported: for each primal output. It returns a tuple with the cotangents corresponding to the flattened primal inputs. - Calling convention for the exported module (for latest supported version): - - The `mlir_module` has a `main` function that takes an optional first - platform index argument if the module supports multiple platforms - (`len(lowering_platforms) > 1`), followed by the token arguments corresponding - to the ordered effects, followed by the kept array - arguments (corresponding to `module_kept_var_idx` and `in_avals`). - The platform index is a i32 or i64 scalar encoding the index of the current - compilation platform into the `lowering_platforms` sequence. - - Inner functions use a different calling convention: an optional - platform index argument, optional dimension variable arguments - (scalar tensors of type i32 or i64), - followed by optional token arguments (in presence of ordered effects), - followed by the regular array arguments. - The dimension arguments correspond to the dimension variables appearing in - the `args_avals`, in sorted order of their names. - - Consider the lowering of a function with one array argument of type "f32[w, - 2 * h]", where "w" and "h" are two dimension variables. - Assume that we use multi-platform lowering, and we have - one ordered effect. The `main` function will be as follows: - - func public main( - platform_index: i32 {jax.global_constant="_platform_index"}, - token_in: token, - arg: f32[?, ?]) { - arg_w = hlo.get_dimension_size(arg, 0) - dim1 = hlo.get_dimension_size(arg, 1) - arg_h = hlo.floordiv(dim1, 2) - call _check_shape_assertions(arg) # See below - token = new_token() - token_out, res = call _wrapped_jax_export_main(platform_index, - arg_h, - arg_w, - token_in, - arg) - return token_out, res - } - - The actual computation is in `_wrapped_jax_export_main`, taking also - the values of `h` and `w` dimension variables. - - The signature of the `_wrapped_jax_export_main` is: - - func private _wrapped_jax_export_main( - platform_index: i32 {jax.global_constant="_platform_index"}, - arg_h: i32 {jax.global_constant="h"}, - arg_w: i32 {jax.global_constant="w"}, - arg_token: stablehlo.token {jax.token=True}, - arg: f32[?, ?]) -> (stablehlo.token, ...) - - Prior to serialization version 9 the calling convention for effects is - different: the `main` function does not take or return a token. Instead - the function creates dummy tokens of type `i1[0]` and passes them to the - `_wrapped_jax_export_main`. The `_wrapped_jax_export_main` - takes dummy tokens of type `i1[0]` and will create internally real - tokens to pass to the inner functions. The inner functions use real - tokens (both before and after serialization version 9) - - Also starting with serialization version 9, function arguments that contain - the platform index or the dimension variable values have a - `jax.global_constant` string attribute whose value is the name of the - global constant, either `_platform_index` or a dimension variable name. - The global constant name may be empty if it is not known. - Some global constant computations use inner functions, e.g., for - `floor_divide`. The arguments of such functions have a `jax.global_constant` - attribute for all attributes, meaning that the result of the function is - also a global constant. - - Note that `main` contains a call to `_check_shape_assertions. - JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h` - have values >= 1. We must check these constraints when we invoke the - module. We use a special custom call `@shape_assertion` that takes - a boolean first operand, a string `error_message` attribute that may contain - format specifiers `{0}`, `{1}`, ..., and a variadic number of integer - scalar operands corresponding to the format specifiers. - - func private _check_shape_assertions(arg: f32[?, ?]) { - # Check that w is >= 1 - arg_w = hlo.get_dimension_size(arg, 0) - custom_call @shape_assertion(arg_w >= 1, arg_w, - error_message="Dimension variable 'w' must have integer value >= 1. Found {0}") - # Check that dim1 is even - dim1 = hlo.get_dimension_size(arg, 1) - custom_call @shape_assertion(dim1 % 2 == 0, dim1, - error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}") - # Check that h >= 1 - arg_h = hlo.floordiv(dim1, 2) - custom_call @shape_assertion(arg_h >= 1, arg_h, - error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}") - - If we `call_exported` with this module we perform these checks - statically (in `call_exported_abstract_eval`). + See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export.html#module-calling-convention). """ fun_name: str in_tree: tree_util.PyTreeDef - in_avals: tuple[core.AbstractValue, ...] + in_avals: tuple[core.ShapedArray, ...] out_tree: tree_util.PyTreeDef - out_avals: tuple[core.AbstractValue, ...] + out_avals: tuple[core.ShapedArray, ...] - in_shardings: tuple[Sharding, ...] - out_shardings: tuple[Sharding, ...] + in_shardings_hlo: tuple[HloSharding | None, ...] + out_shardings_hlo: tuple[HloSharding | None, ...] nr_devices: int - lowering_platforms: tuple[str, ...] + platforms: tuple[str, ...] ordered_effects: tuple[effects.Effect, ...] unordered_effects: tuple[effects.Effect, ...] disabled_safety_checks: Sequence[DisabledSafetyCheck] mlir_module_serialized: bytes - mlir_module_serialization_version: int + calling_convention_version: int module_kept_var_idx: tuple[int, ...] - uses_shape_polymorphism: bool + uses_global_constants: bool _get_vjp: Callable[[Exported], Exported] | None @@ -306,23 +223,146 @@ def __str__(self): # do not want the entire serialized module to end up in locations. return f"Exported(fun_name={self.fun_name}, ...)" + # For backwards compatibility + # TODO(necula): remove after September 2024. + @property + def in_shardings(self): + return self.in_shardings_hlo + @property + def out_shardings(self): + return self.out_shardings_hlo + + def in_shardings_jax( + self, + mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: + """Creates Shardings corresponding to self.in_shardings_hlo. + + The Exported object stores `in_shardings_hlo` as HloShardings, which are + independent of a mesh or set of devices. This method constructs + Sharding that can be used in JAX APIs such as `jax.jit` or + `jax.device_put`. + + Example usage: + >>> from jax import export + >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) + >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), + ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a"))) + ... )(np.arange(jax.device_count())) + >>> exp.in_shardings_hlo + ({devices=[8]<=[8]},) + + # Create a mesh for running the exported object + >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",)) + >>> + # Put the args and kwargs on the appropriate devices + >>> run_arg = jax.device_put(np.arange(jax.device_count()), + ... exp.in_shardings_jax(run_mesh)[0]) + >>> res = exp.call(run_arg) + >>> res.addressable_shards + [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]), + Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]), + Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]), + Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]), + Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]), + Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]), + Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]), + Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])] + """ + return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh) + for s in self.in_shardings_hlo) + + def out_shardings_jax( + self, + mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: + """Creates Shardings corresponding to self.out_shardings_hlo. + + See documentation for in_shardings_jax. + """ + return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh) + for s in self.out_shardings_hlo) + + # For backwards compatibility + # TODO(necula): remove after September 2024. + @property + def lowering_platforms(self): + """DEPRECATED.""" + warnings.warn("lowering_platform is deprecated. Use .platforms instead.", + DeprecationWarning, stacklevel=2) + return self.platforms + + # For backwards compatibility + # TODO(necula): remove after September 2024. + @property + def mlir_module_serialization_version(self): + """DEPRECATED.""" + warnings.warn("mlir_module_serialization_version is deprecated. Use .calling_convention_version instead.", + DeprecationWarning, stacklevel=2) + return self.calling_convention_version + + # For backwards compatibility + # TODO(necula): remove after September 2024. + @property + def uses_shape_polymorphism(self): + """DEPRECATED.""" + warnings.warn("uses_shape_polymorphism is deprecated. Use .uses_global_constants instead.", + DeprecationWarning, stacklevel=2) + return self.uses_global_constants + def has_vjp(self) -> bool: + """Returns if this Exported supports VJP.""" return self._get_vjp is not None def vjp(self) -> Exported: """Gets the exported VJP. Returns None if not available, which can happen if the Exported has been - loaded from an external format, without a VJP.""" + loaded from an external format without a VJP. + """ if self._get_vjp is None: raise ValueError("No VJP is available") return self._get_vjp(self) + def serialize(self, + vjp_order: int = 0) -> bytearray: + """Serializes an Exported. + + Args: + vjp_order: The maximum vjp order to include. E.g., the value 2 means that we + serialize the primal functions and two orders of the `vjp` function. This + should allow 2nd order reverse mode differentiation of the deserialized + function. i.e., `jax.grad(jax.grad(f)).` + """ + # Lazy load the serialization module, since flatbuffers is an optional + # dependency. + from jax._src.export.serialization import serialize + return serialize(self, vjp_order=vjp_order) -def default_lowering_platform() -> str: + def call(self, *args, **kwargs): + return call_exported(self)(*args, **kwargs) + + +def deserialize(blob: bytearray) -> Exported: + """Deserializes an Exported. + + Args: + blob: a bytearray obtained from `Exported.serialize`. + """ + # Lazy load the serialization module, since flatbuffers is an optional + # dependency. + from jax._src.export.serialization import deserialize + return deserialize(blob) + + +def default_export_platform() -> str: + """Retrieves the default export platform. + + One of: `tpu`, `cpu`, `cuda`, `rocm`. + """ # Canonicalize to turn 'gpu' into 'cuda' or 'rocm' return xb.canonicalize_platform(jax.default_backend()) +default_lowering_platform = default_export_platform + def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]: """Returns the shape and dtype of a jax.Array or a j""" if isinstance(a, jax.ShapeDtypeStruct): @@ -343,16 +383,24 @@ def args_specs( # This was needed in some older jax2tf implementations args = tree_util.tree_map(lambda a: jax.ShapeDtypeStruct(* get_shape_and_dtype(a)), args) - return _shape_poly.symbolic_args_specs(args, polymorphic_shapes) + return shape_poly.symbolic_args_specs(args, polymorphic_shapes) -def export(fun_jax: Callable, - *, - lowering_platforms: Sequence[str] | None = None, - disabled_checks: Sequence[DisabledSafetyCheck] = (), - ) -> Callable[..., Exported]: +# TODO(necula): remove this once we remove jax.experimental.export. +def export_back_compat( + fun_jax: Callable, + *, + lowering_platforms: Sequence[str] | None = None, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + _device_assignment_for_internal_jax2tf_use_only = None, + ) -> Callable[..., Exported]: """Exports native serialization for a JAX function. + Note: this function exists only for internal usage by jax2tf and for + backwards compatibility with jax.experimental.export. Use + `jax.export` instead. + See https://jax.readthedocs.io/en/latest/export.html + Args: fun_jax: the function to lower and serialize. lowering_platforms: @@ -360,8 +408,8 @@ def export(fun_jax: Callable, 'cuda', 'rocm'. If more than one platform is specified, then the lowered code takes an argument specifying the platform. If None, then use the default JAX backend. - The calling convention for multiple platforms is explained in the - `jax_export.Exported` docstring. + The calling convention for multiple platforms is explained + at https://jax.readthedocs.io/en/latest/export.html#module-calling-convention. disabled_checks: the safety checks to disable. See docstring of `DisabledSafetyCheck`. @@ -374,143 +422,254 @@ def export(fun_jax: Callable, def f_jax(*args, **kwargs): ... exported = jax_export.export(f_jax)(*args, **kwargs) """ - fun_name = getattr(fun_jax, "__name__", "unknown") - version = config.jax_serialization_version.value - if (version < minimum_supported_serialization_version or - version > maximum_supported_serialization_version): - raise ValueError( - f"The requested jax_serialization version {version} is outside the " - f"range of supported versions [{minimum_supported_serialization_version}" - f"..{maximum_supported_serialization_version}]") def do_export(*args_specs, **kwargs_specs) -> Exported: - if not hasattr(fun_jax, "lower"): + if hasattr(fun_jax, "trace"): + # If we have a pjit or pmap already we do not wrap with another, and we + # allow shardings. + wrapped_fun_jax = fun_jax + else: # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also # convert(f_jax), in which case a "jit" is implied. In that case we raise # an error if the lowered function contains non-replicated sharding annotations. wrapped_fun_jax = jax.jit(fun_jax) - else: - # If we have a pjit or pmap already we do not wrap with another, and we - # allow shardings. - wrapped_fun_jax = fun_jax # type: ignore if lowering_platforms is not None: actual_lowering_platforms = tuple(lowering_platforms) else: - actual_lowering_platforms = (default_lowering_platform(),) + actual_lowering_platforms = (default_export_platform(),) # TODO: move to `lower` - symbolic_scope: tuple[_shape_poly.SymbolicScope, tree_util.KeyPath] | None = None + symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore] for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: - # Static args may has no `shape` attribute. + # Static args may have no `shape` attribute. if not hasattr(aval, "shape"): continue for d in aval.shape: - if _shape_poly.is_symbolic_dim(d): + if shape_poly.is_symbolic_dim(d): if symbolic_scope is None: symbolic_scope = (d.scope, k_path) continue symbolic_scope[0]._check_same_scope( - d, when=f"when exporting {fun_name}", - self_descr=f"current (from {_shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", - other_descr=_shape_poly.args_kwargs_path_to_str(k_path)) - - lowered = wrapped_fun_jax.lower( - *args_specs, **kwargs_specs, - _experimental_lowering_parameters=mlir.LoweringParameters( - platforms=actual_lowering_platforms, - )) - - lowering = lowered._lowering - _check_lowering(lowering) - mlir_module = lowering.stablehlo() - - args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) - if "mut" in lowering.compile_args: - if lowering.compile_args["mut"]: raise NotImplementedError - if "kept_var_idx" in lowering.compile_args: - module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) - else: - # For pmap - module_kept_var_idx = tuple(range(len(args_avals_flat))) - shape_poly_state = lowering.compile_args["shape_poly_state"] - if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) - or lowering.compile_args.get("ordered_effects", [])): - mlir_module = _wrap_main_func( - mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree, - has_platform_index_argument=shape_poly_state.has_platform_index_argument, - module_kept_var_idx=module_kept_var_idx, - serialization_version=version) - - with mlir_module.context: - mlir_module_attrs = mlir_module.operation.attributes - mlir_module_attrs["jax.uses_shape_polymorphism"] = ( - mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars)) - - mlir_module_serialized = _module_to_bytecode(mlir_module) - - # Figure out the result types and shapes - if "global_out_avals" in lowering.compile_args: - # This is currently the case for pjit - out_avals_flat = lowering.compile_args["global_out_avals"] - elif "shards" in lowering.compile_args: # for PmapComputation - out_avals_flat = lowering.compile_args["shards"].out_sharded_avals - else: - out_avals_flat = lowered.compile_args["out_avals"] - - # Log and then check the module. - if logging.vlog_is_on(3): - logmsg = (f"version={version} " - f"lowering_platforms={actual_lowering_platforms} " - f"disabled_checks={disabled_checks}") - logging.info("Lowered JAX module: %s\n", logmsg) - if dumped_to := mlir.dump_module_to_file(mlir_module, "export"): - logging.info("Dumped the exported MLIR module to %s", dumped_to) - - _check_module(mlir_module, - disabled_checks=disabled_checks) - - ordered_effects = tuple(lowering.compile_args["ordered_effects"]) - unordered_effects = tuple(lowering.compile_args["unordered_effects"]) - - nr_devices = len(lowering.compile_args["device_assignment"]) - def export_sharding(s: LoweringSharding, - aval: core.ShapedArray) -> Sharding: - if sharding_impls.is_unspecified(s): - return None - return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] - - all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"], - module_kept_var_idx, - len(args_avals_flat)) - in_shardings = tuple( - export_sharding(s, aval) - for s, aval in zip(all_in_shardings, args_avals_flat)) - out_shardings = tuple( - export_sharding(s, aval) - for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat)) - return Exported( - fun_name=fun_name, - in_tree=lowered.in_tree, - out_tree=lowered.out_tree, - in_avals=tuple(args_avals_flat), - out_avals=tuple(out_avals_flat), - in_shardings=in_shardings, - out_shardings=out_shardings, - nr_devices=nr_devices, + d, when=f"when exporting {util.fun_name(wrapped_fun_jax)}", + self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", + other_descr=shape_poly.args_kwargs_path_to_str(k_path)) + + traced = wrapped_fun_jax.trace(*args_specs, **kwargs_specs) + lowered = traced.lower( lowering_platforms=actual_lowering_platforms, - ordered_effects=ordered_effects, - unordered_effects=unordered_effects, - disabled_safety_checks=tuple(disabled_checks), - mlir_module_serialized=mlir_module_serialized, - module_kept_var_idx=module_kept_var_idx, - uses_shape_polymorphism=shape_poly_state.uses_dim_vars, - mlir_module_serialization_version=version, - _get_vjp=lambda exported: _export_native_vjp(fun_jax, exported, - lowering.compile_args["device_assignment"])) + _private_parameters=mlir.LoweringParameters(for_export=True)) + return _export_lowered( + lowered, traced.jaxpr, traced.fun_name, + disabled_checks=disabled_checks, + _device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only) + return do_export +def export( + fun_jit: stages.Wrapped, + *, + platforms: Sequence[str] | None = None, + lowering_platforms: Sequence[str] | None = None, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + ) -> Callable[..., Exported]: + """Exports a JAX function for persistent serialization. + + Args: + fun_jit: the function to export. Should be the result of `jax.jit`. + platforms: + Optional sequence containing a subset of 'tpu', 'cpu', + 'cuda', 'rocm'. If more than one platform is specified, then + the exported code takes an argument specifying the platform. + If None, then use the default JAX backend. + The calling convention for multiple platforms is explained at + https://jax.readthedocs.io/en/latest/export.html#module-calling-convention. + lowering_platforms: DEPRECATED, use `platforms`. + disabled_checks: the safety checks to disable. See documentation for + of `jax.export.DisabledSafetyCheck`. + + Returns: a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`, + or values with `.shape` and `.dtype` attributes, and returns an + `Exported`. + + Usage: + + >>> from jax import export + >>> exported: export.Exported = export.export(jnp.sin)( + ... np.arange(4, dtype=np.float32)) + >>> + >>> # You can inspect the Exported object + >>> exported.in_avals + (ShapedArray(float32[4]),) + >>> blob: bytearray = exported.serialize() + >>> + >>> # The serialized bytes are safe to use in a separate process + >>> rehydrated: export.Exported = export.deserialize(blob) + >>> rehydrated.fun_name + 'sin' + >>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32)) + Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32) + """ + if not isinstance(fun_jit, stages.Wrapped): + raise ValueError( + f"Function to be exported must be the result of `jit` but is: {fun_jit}") + if platforms is not None and lowering_platforms is not None: + raise ValueError("Cannot use both `platforms` and `lowering_platforms`") + if platforms is None and lowering_platforms is not None: + platforms = lowering_platforms + if platforms is not None: + actual_lowering_platforms = tuple(platforms) + else: + actual_lowering_platforms = (default_export_platform(),) + + def do_export(*args_specs, **kwargs_specs) -> Exported: + # TODO: move to `lower` + symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore] + for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: + # Static args may have no `shape` attribute. + if not hasattr(aval, "shape"): + continue + for d in aval.shape: + if shape_poly.is_symbolic_dim(d): + if symbolic_scope is None: + symbolic_scope = (d.scope, k_path) + continue + symbolic_scope[0]._check_same_scope( + d, when=f"when exporting {util.fun_name(fun_jit)}", + self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", + other_descr=shape_poly.args_kwargs_path_to_str(k_path)) + + traced = fun_jit.trace(*args_specs, **kwargs_specs) + lowered = traced.lower( + lowering_platforms=actual_lowering_platforms, + _private_parameters=mlir.LoweringParameters(for_export=True)) + return _export_lowered( + lowered, traced.jaxpr, traced.fun_name, + disabled_checks=disabled_checks) return do_export +def _export_lowered( + lowered: stages.Lowered, + jaxpr: core.ClosedJaxpr, fun_name: str, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + _device_assignment_for_internal_jax2tf_use_only = None, + ) -> Exported: + version = config.jax_export_calling_convention_version.value + if (version < minimum_supported_calling_convention_version or + version > maximum_supported_calling_convention_version): + raise ValueError( + f"The requested export calling convention version {version} is outside the " + f"range of supported versions [{minimum_supported_calling_convention_version}" + f"..{maximum_supported_calling_convention_version}]") + + lowering = lowered._lowering + _check_lowering(lowering) + mlir_module = lowering.stablehlo() + + args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) + if "mut" in lowering.compile_args: + if lowering.compile_args["mut"]: raise NotImplementedError + if "kept_var_idx" in lowering.compile_args: + module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) + else: + # For pmap + module_kept_var_idx = tuple(range(len(args_avals_flat))) + shape_poly_state = lowering.compile_args["shape_poly_state"] + if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) + or lowering.compile_args.get("ordered_effects", [])): + mlir_module = _wrap_main_func( + mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree, + has_platform_index_argument=shape_poly_state.has_platform_index_argument, + module_kept_var_idx=module_kept_var_idx, + serialization_version=version) + + with mlir_module.context: + mlir_module_attrs = mlir_module.operation.attributes + mlir_module_attrs["jax.uses_shape_polymorphism"] = ( + mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars)) + + mlir_module_serialized = _module_to_bytecode(mlir_module) + + # Figure out the result types and shapes + if "global_out_avals" in lowering.compile_args: + # This is currently the case for pjit + out_avals_flat = lowering.compile_args["global_out_avals"] + elif "shards" in lowering.compile_args: # for PmapComputation + out_avals_flat = lowering.compile_args["shards"].out_sharded_avals + else: + out_avals_flat = lowered.compile_args["out_avals"] # type: ignore + + # Log and then check the module. + if logging.vlog_is_on(3): + logmsg = (f"version={version} " + f"lowering_platforms={lowering.compile_args['platforms']} " + f"disabled_checks={disabled_checks}") + logging.info("Lowered JAX module: %s\n", logmsg) + if dumped_to := mlir.dump_module_to_file(mlir_module, "export"): + logging.info("Dumped the exported MLIR module to %s", dumped_to) + + _check_module(mlir_module, + disabled_checks=disabled_checks) + + ordered_effects = tuple(lowering.compile_args["ordered_effects"]) + unordered_effects = tuple(lowering.compile_args["unordered_effects"]) + + nr_devices = len(lowering.compile_args["device_assignment"]) + def export_sharding(s: LoweringSharding, + aval: core.ShapedArray) -> HloSharding | None: + if sharding_impls.is_unspecified(s): + return None + return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] + + all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"], + module_kept_var_idx, + len(args_avals_flat)) + in_shardings = tuple( + export_sharding(s, aval) + for s, aval in zip(all_in_shardings, args_avals_flat)) + out_shardings = tuple( + export_sharding(s, aval) + for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat)) + + device_assignment = lowering.compile_args["device_assignment"] + if _device_assignment_for_internal_jax2tf_use_only is not None: + _device_assignment_for_internal_jax2tf_use_only[0] = device_assignment + def _get_exported_vjp(exp_primal: Exported) -> Exported: + # Turn the primal jaxpr into a function, in preparation for exporting + # the VJP. Note that jaxpr_as_fun produces a function with flat arguments + assert(jaxpr is not None) # None only when the lowered was created outside JAX + fun_jax = core.jaxpr_as_fun(jaxpr) + + fun_vjp_jax, vjp_in_avals = _get_vjp_fun(fun_jax, + in_tree=exp_primal.in_tree, + in_avals=exp_primal.in_avals, + in_shardings_hlo=exp_primal.in_shardings_hlo, + out_avals=exp_primal.out_avals, + out_shardings_hlo=exp_primal.out_shardings_hlo, + device_assignment=device_assignment, + apply_jit=True, + flat_primal_fun=True) + return export(fun_vjp_jax, # type: ignore[arg-type] + platforms=exp_primal.platforms, + disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals) + + return Exported( + fun_name=fun_name, + in_tree=lowered.in_tree, + out_tree=lowered.out_tree, + in_avals=tuple(args_avals_flat), + out_avals=tuple(out_avals_flat), + in_shardings_hlo=in_shardings, + out_shardings_hlo=out_shardings, + nr_devices=nr_devices, + platforms=lowering._platforms, # type: ignore + ordered_effects=ordered_effects, + unordered_effects=unordered_effects, + disabled_safety_checks=tuple(disabled_checks), + mlir_module_serialized=mlir_module_serialized, + module_kept_var_idx=module_kept_var_idx, + uses_global_constants=shape_poly_state.uses_dim_vars, + calling_convention_version=version, + _get_vjp=_get_exported_vjp) def _module_to_bytecode(module: ir.Module) -> bytes: mlir_str = mlir.module_to_bytecode(module) @@ -531,7 +690,7 @@ def _module_to_bytecode(module: ir.Module) -> bytes: # and still have the payloads produced by `serialize_portable_artifact` # compatible with potential consumers from the past. target_version = hlo.get_minimum_version() - module_serialized = xla_client._xla.mlir.serialize_portable_artifact( + module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore mlir_str, target_version) return module_serialized @@ -547,11 +706,10 @@ def _wrap_main_func( ) -> ir.Module: """Wraps the lowered module with a new "main" handling dimension arguments. - See calling convention documentation for `jax_export.Exported`. + See calling convention documentation https://jax.readthedocs.io/en/latest/export.html#module-calling-convention. Args: - module: the HLO module as obtained from lowering. See the calling convention - for inner functions in `jax_export.Exported`. + module: the HLO module as obtained from lowering. args_avals_flat: the avals for all the arguments of the lowered function, which correspond to the array arguments of the `module`. args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error @@ -564,7 +722,7 @@ def _wrap_main_func( Returns the wrapped module, without dimension and token arguments. """ - dim_vars = _shape_poly.all_dim_vars(args_avals_flat) + dim_vars = shape_poly.all_dim_vars(args_avals_flat) context = mlir.make_ir_context() with context, ir.Location.unknown(context): # Make a copy, do not mutate because it may be cached @@ -578,8 +736,8 @@ def _wrap_main_func( def is_token(typ, attrs): return (typ == mlir.token_type()[0]) - orig_input_types = orig_main.type.inputs - arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) + orig_input_types = orig_main.type.inputs # type: ignore + arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) # type: ignore # The order of args: platform_index_arg, dim args, token args, array args. nr_platform_index_args = 1 if has_platform_index_argument else 0 nr_dim_args = len(dim_vars) @@ -601,8 +759,8 @@ def is_token(typ, attrs): orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args]) # The order of results: tokens, array results - orig_output_types = orig_main.type.results - result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) + orig_output_types = orig_main.type.results # type: ignore + result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) # type: ignore token_result_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_output_types, result_attrs)) if is_token(typ, attrs)] @@ -652,7 +810,8 @@ def is_token(typ, attrs): keepalives=[], channel_iterator=itertools.count(1), host_callbacks=[], module=wrapped_module, context=context, lowering_parameters=mlir.LoweringParameters( - global_constant_computation=True + global_constant_computation=True, + for_export=True, )) ctx = mlir.LoweringRuleContext( module_context=module_context, @@ -661,12 +820,12 @@ def is_token(typ, attrs): tokens_in=mlir.TokenSet(), tokens_out=None) # We compute dim_values from the array arguments. new_main_op_array_args = new_main_op.arguments[-nr_array_args:] - if _shape_poly.all_dim_vars(args_avals_flat): + if shape_poly.all_dim_vars(args_avals_flat): # TODO(necula): handle module_kept_var_idx in presence of shape # polymorphism. For now we ensured upstream that we keep all variables. assert len(set(module_kept_var_idx)) == len(args_avals_flat) dim_values = mlir.lower_fun( - functools.partial(_shape_poly.compute_dim_vars_from_arg_shapes, + functools.partial(shape_poly.compute_dim_vars_from_arg_shapes, args_avals_flat, args_kwargs_tree=args_kwargs_tree), multiple_results=True)(ctx, *new_main_op_array_args) else: @@ -705,7 +864,7 @@ def is_token(typ, attrs): def _check_lowering(lowering) -> None: if not isinstance(lowering, pxla.MeshComputation): - raise NotImplementedError(f"serialization is supported only for pjit. {lowering}") + raise NotImplementedError(f"serialization is supported only for jit. {lowering}") if lowering.compile_args["host_callbacks"] or lowering.compile_args["keepalive"]: raise NotImplementedError("serialization of host_callbacks is not yet implemented") @@ -713,13 +872,14 @@ def _check_lowering(lowering) -> None: # safe to add it to the allowed_compile_args if it does not change the semantics # or the calling convention of the lowered module. allowed_compile_args = [ - "backend", "mesh", "global_in_avals", + "backend", "platforms", "mesh", "global_in_avals", "global_out_avals", "in_shardings", "out_shardings", "kept_var_idx", "mut", "spmd_lowering", "auto_spmd_lowering", "tuple_args", "ordered_effects", "unordered_effects", "keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment", "jaxpr_debug_info", "shape_poly_state", - "all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info"] + "all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info", + "pgle_profiler"] for compile_arg in lowering.compile_args.keys(): if compile_arg not in allowed_compile_args: raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") @@ -758,7 +918,7 @@ def _check_lowering(lowering) -> None: # Their backwards compatibility is tested by back_compat_test.py. _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", - "dynamic_ducc_fft", "cu_threefry2x32", + "cu_threefry2x32", "__gpu$xla.gpu.triton", # Pallas call on GPU # cholesky on CPU "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", @@ -794,7 +954,7 @@ def _check_lowering(lowering) -> None: # lu on TPU "LuDecomposition", # ApproxTopK on TPU - "ApproxTopK", + "ApproxTopK", "stablehlo.dynamic_approx_top_k", "tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True) "tpu_custom_call", # Pallas/TPU kernels # TODO(burmako): maintain backwards compatibility for these, until they @@ -872,7 +1032,7 @@ def walk_operations(op): msg = ("Cannot serialize code with custom calls whose targets have no " "compatibility guarantees. Examples are:\n" f"{disallowed_custom_call_ops_str}.\n" - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls") + "See https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls") raise ValueError(msg) return module_uses_non_replicated_sharding @@ -890,40 +1050,38 @@ def expand_in_shardings(in_shardings: Sequence[LoweringSharding], all_in_shardings[idx] = in_s return tuple(all_in_shardings) -# TODO(yashkatariya, necula): remove this function once we relax the checks -# in the jit front-end. -def canonical_shardings( - device_assignment: Sequence[jax.Device], - in_shardings: Sequence[Sharding], - out_shardings: Sequence[Sharding] - ) -> tuple[Sequence[sharding.XLACompatibleSharding | None], - Sequence[sharding.XLACompatibleSharding | None]]: - """Prepares canonical in_ and out_shardings for a pjit invocation. - - Turns the HloSharding into XLACompatibleSharding. - - Returns: a pair with the canonicalized input and output shardings. - """ - def canonicalize( - ss: Sequence[Sharding]) -> Sequence[sharding.XLACompatibleSharding | None]: - return tuple( - sharding.GSPMDSharding(device_assignment, s) if s is not None else None - for s in ss) - return (canonicalize(in_shardings), canonicalize(out_shardings)) +def _hlo_sharding_to_xla_compatible_sharding( + hlo_sharding: HloSharding | None, + mesh: sharding.Mesh) -> sharding.Sharding | None: + if hlo_sharding is None: + return None + return sharding_impls._gspmd_to_named_sharding_via_mesh( + _hlo_sharding_to_gspmd_sharding(hlo_sharding, tuple(mesh.devices.flat)), # type: ignore[arg-type] + mesh) + +def _hlo_sharding_to_gspmd_sharding( + hlo_sharding: HloSharding | None, + device_assignment: Sequence[jax.Device]) -> sharding.GSPMDSharding | None: + if hlo_sharding is None: + return None + return sharding.GSPMDSharding(device_assignment, hlo_sharding) def _get_vjp_fun(primal_fun: Callable, *, in_tree: tree_util.PyTreeDef, in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], - in_shardings: tuple[Sharding, ...], - out_shardings: tuple[Sharding, ...], + in_shardings_hlo: tuple[HloSharding | None, ...], + out_shardings_hlo: tuple[HloSharding | None, ...], device_assignment: Sequence[sharding_impls.Device] | None, - apply_jit: bool + apply_jit: bool, + flat_primal_fun: bool = False, ) -> tuple[Callable, Sequence[core.AbstractValue]]: # Since jax.vjp does not handle kwargs, it is easier to do all the work # here with flattened functions. # apply_jit=False is only used for backwards compatibility with the graph # graph serialization. When apply_jit=True, we must pass a device assignment. + # flat_primal_fun=False is used only from jax2tf, and it means that the + # `primal_fun` takes PyTree `*args` and `**kwargs`. def fun_vjp_jax(*args_and_out_cts_flat_jax): # Takes a flat list of primals and output cotangents def flattened_primal_fun_jax(*args_flat): @@ -934,7 +1092,8 @@ def flattened_primal_fun_jax(*args_flat): args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(in_avals)]) - _, pullback_jax = jax.vjp(flattened_primal_fun_jax, *args_flat_jax) + _, pullback_jax = jax.vjp(primal_fun if flat_primal_fun else flattened_primal_fun_jax, + *args_flat_jax) return pullback_jax(out_cts_flat_jax) vjp_in_avals = list( @@ -943,32 +1102,18 @@ def flattened_primal_fun_jax(*args_flat): if apply_jit: assert device_assignment is not None - vjp_in_shardings, vjp_out_shardings = canonical_shardings( - device_assignment, - tuple(itertools.chain(in_shardings, out_shardings)), - in_shardings) + vjp_in_shardings = tuple( + _hlo_sharding_to_gspmd_sharding(s, device_assignment) + for s in itertools.chain(in_shardings_hlo, out_shardings_hlo)) + vjp_out_shardings = tuple( + _hlo_sharding_to_gspmd_sharding(s, device_assignment) + for s in in_shardings_hlo) return pjit.pjit(fun_vjp_jax, in_shardings=vjp_in_shardings, out_shardings=vjp_out_shardings), vjp_in_avals else: return fun_vjp_jax, vjp_in_avals -def _export_native_vjp(primal_fun, - primal: Exported, - device_assignment: Sequence[sharding_impls.Device]) -> Exported: - # Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp - fun_vjp_jax, vjp_in_avals = _get_vjp_fun(primal_fun, - in_tree=primal.in_tree, - in_avals=primal.in_avals, - in_shardings=primal.in_shardings, - out_avals=primal.out_avals, - out_shardings=primal.out_shardings, - device_assignment=device_assignment, - apply_jit=True) - return export(fun_vjp_jax, - lowering_platforms=primal.lowering_platforms, - disabled_checks=primal.disabled_safety_checks)(*vjp_in_avals) - ### Calling the exported function def call(exported: Exported) -> Callable[..., jax.Array]: @@ -1016,7 +1161,7 @@ def f_imported(*args, **kwargs): f"as when the function '{exported.fun_name}' was exported, but they " "have the following structural differences:\n" + ("\n".join( - f" - {_shape_poly.args_kwargs_path_to_str(path)} is a {thing1} in the invocation and a " + f" - {shape_poly.args_kwargs_path_to_str(path)} is a {thing1} in the invocation and a " f"{thing2} when exported, so {explanation}.\n" for path, thing1, thing2, explanation in tree_util.equality_errors(in_args, exp_in_args)))) @@ -1037,12 +1182,14 @@ def _call_exported_abstract_eval( *in_avals: core.AbstractValue, exported: Exported ) -> tuple[tuple[core.AbstractValue, ...], set[effects.Effect]]: - exported_dim_vars = _shape_poly.all_dim_vars(exported.in_avals) + exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals) assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure # Check that the expected shapes match the actual ones for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)): + if not isinstance(actual_aval, core.ShapedArray): + raise ValueError(f"Expected ShapedArray but got: {actual_aval}") def pp_arg_dim(dim_idx: int | None) -> str: - return _shape_poly.pretty_print_dimension_descriptor(exported.in_tree, + return shape_poly.pretty_print_dimension_descriptor(exported.in_tree, arg_idx, dim_idx) if len(exp_aval.shape) != len(actual_aval.shape): raise ValueError( @@ -1064,11 +1211,11 @@ def pp_arg_dim(dim_idx: int | None) -> str: f"expected {exp_aval.shape} and called with {actual_aval.shape}") # Must express the exported_dim_vars in terms of the shapes in in_avals. - solution, shape_constraints, synth_dim_vars = _shape_poly.solve_dim_vars( + solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars( exported.in_avals, args_kwargs_tree=exported.in_tree) synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx] for (vname, arg_idx, dim_idx) in synth_dim_vars} - synthetic_eval = _shape_poly.CachingShapeEvaluator(**synthetic_env) + synthetic_eval = shape_poly.CachingShapeEvaluator(**synthetic_env) # We discharge all the constraints statically. This results in much simpler # composability (because we do not have to worry about the constraints of the # Exported called recursively; we only need to worry about entry-point @@ -1101,7 +1248,7 @@ def _call_exported_impl(*args, exported: Exported): def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, exported: Exported): - if exported.uses_shape_polymorphism: + if exported.uses_global_constants: ctx.module_context.shape_poly_state.uses_dim_vars = True submodule = ir.Module.parse(exported.mlir_module()) @@ -1119,14 +1266,14 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, # than the function was exported for. err_msg = "" if exported.nr_devices != 1: - err_msg = "the module was lowered for more than 1 device." + err_msg = "the function was exported for more than 1 device." elif (_check_module(submodule, disabled_checks=()) or any(s is not None and not s.is_replicated() - for s in exported.in_shardings + exported.out_shardings)): - err_msg = "the module contains non-replicated sharding annotations." + for s in exported.in_shardings_hlo + exported.out_shardings_hlo)): + err_msg = "the function contains non-replicated sharding annotations." if err_msg: - raise NotImplementedError( - f"Exported module {exported.fun_name} was lowered for " + raise ValueError( + f"Function {exported.fun_name} was exported for " f"{exported.nr_devices} devices and is called in a context with " f"{num_devices} devices. This is disallowed because: {err_msg}" ) @@ -1134,7 +1281,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, # Apply in_shardings args = tuple( wrap_with_sharding(ctx, x, x_aval, x_sharding) - for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings)) + for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings_hlo)) symtab = ir.SymbolTable(submodule.operation) # The called function may have been exported with polymorphic shapes and called # now with more refined shapes. We insert hlo.ConvertOp to ensure the module @@ -1160,18 +1307,18 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra callee_lowering_platform_index: list[int] = [] for platform in lowering_platforms: - if platform in exported.lowering_platforms: + if platform in exported.platforms: callee_lowering_platform_index.append( - exported.lowering_platforms.index(platform)) + exported.platforms.index(platform)) elif DisabledSafetyCheck.platform() in exported.disabled_safety_checks: callee_lowering_platform_index.append(0) else: raise ValueError( - f"The exported function '{exported.fun_name}' was lowered for " - f"platforms '{exported.lowering_platforms}' but it is used " + f"Function '{exported.fun_name}' was exported for " + f"platforms '{exported.platforms}' but it is used " f"on '{lowering_platforms}'.") - if len(exported.lowering_platforms) > 1: + if len(exported.platforms) > 1: # The exported module takes a platform index argument if len(lowering_platforms) > 1: current_platform_idx = ctx.dim_var_values[0] @@ -1223,7 +1370,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra # Apply out_shardings results = tuple( wrap_with_sharding(ctx, x, x_aval, x_sharding) - for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings) + for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings_hlo) ) return results @@ -1232,35 +1379,8 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra def wrap_with_sharding(ctx: mlir.LoweringRuleContext, x: ir.Value, x_aval: core.AbstractValue, - x_sharding: Sharding) -> ir.Value: + x_sharding: HloSharding | None) -> ir.Value: if x_sharding is None: return x return mlir.wrap_with_sharding_op( ctx, x, x_aval, x_sharding.to_proto()) - -# TODO(necula): Previously, we had `from jax.experimental.export import export` -# Now we want to simplify the usage, and export the public APIs directly -# from `jax.experimental.export` and now `jax.experimental.export.export` -# refers to the `export` function. Since there may still be users of the -# old API in other packages, we add the old public API as attributes of the -# exported function. We will clean this up after a deprecation period. -def wrap_with_deprecation_warning(f): - msg = (f"You are using function `{f.__name__}` from " - "`jax.experimental.export.export`. You should instead use it directly " - "from `jax.experimental.export`. Instead of " - "`from jax.experimental.export import export` you should use " - "`from jax.experimental import export`.") - def wrapped_f(*args, **kwargs): - warnings.warn(msg, DeprecationWarning, stacklevel=2) - return f(*args, **kwargs) - return wrapped_f - -export.export = wrap_with_deprecation_warning(export) -export.Exported = Exported -export.call_exported = wrap_with_deprecation_warning(call_exported) -export.DisabledSafetyCheck = DisabledSafetyCheck -export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform) -export.symbolic_shape = wrap_with_deprecation_warning(_shape_poly.symbolic_shape) -export.args_specs = wrap_with_deprecation_warning(args_specs) -export.minimum_supported_serialization_version = minimum_supported_serialization_version -export.maximum_supported_serialization_version = maximum_supported_serialization_version diff --git a/jax/experimental/export/serialization.fbs b/jax/_src/export/serialization.fbs similarity index 93% rename from jax/experimental/export/serialization.fbs rename to jax/_src/export/serialization.fbs index e7904954a111..758950adaa8e 100644 --- a/jax/experimental/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -20,7 +20,7 @@ // 3. Add back the licence comment at the start // -namespace jax.experimental.export.serialization; +namespace jax.export.serialization; enum PyTreeDefKind: byte { leaf = 0, @@ -38,7 +38,7 @@ table PyTreeDef { enum AbstractValueKind: byte { shapedArray = 0, - abstractToken = 1, + abstractToken = 1, // unused } enum DType: byte { @@ -119,16 +119,16 @@ table Exported { in_shardings: [Sharding]; out_shardings: [Sharding]; - lowering_platforms: [string]; + platforms: [string]; ordered_effects: [Effect]; unordered_effects: [Effect]; disabled_checks: [DisabledSafetyCheck]; mlir_module_serialized: [byte]; - mlir_module_serialization_version: uint16; + calling_convention_version: uint16; module_kept_var_idx: [uint16]; - uses_shape_polymorphism: bool; + uses_global_constants: bool; vjp: Exported; } diff --git a/jax/experimental/export/_serialization.py b/jax/_src/export/serialization.py similarity index 82% rename from jax/experimental/export/_serialization.py rename to jax/_src/export/serialization.py index c9dddbce485a..a47b095e4450 100644 --- a/jax/experimental/export/_serialization.py +++ b/jax/_src/export/serialization.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Serialization and deserialization of export.Exported +# Serialization and deserialization of _export.Exported -from typing import Callable, TypeVar -from collections.abc import Sequence +from __future__ import annotations + +from collections.abc import Callable, Sequence from functools import partial +from typing import TypeVar try: import flatbuffers @@ -29,10 +31,10 @@ from jax._src import dtypes from jax._src import effects from jax._src import tree_util +from jax._src.export import serialization_generated as ser_flatbuf +from jax._src.export import _export +from jax._src.export import shape_poly from jax._src.lib import xla_client -from jax.experimental.export import serialization_generated as ser_flatbuf -from jax.experimental.export import _export -from jax.experimental import export import numpy as np @@ -45,8 +47,8 @@ # Version 2, Dec 16th, 2023, adds the f0 dtype. _SERIALIZATION_VERSION = 2 -def serialize(exp: export.Exported, vjp_order: int = 0) -> bytearray: - """Serialize an Exported. +def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: + """Serializes an Exported. Args: exp: the Exported to serialize. @@ -61,14 +63,14 @@ def serialize(exp: export.Exported, vjp_order: int = 0) -> bytearray: return builder.Output() -def deserialize(ser: bytearray) -> export.Exported: - """Deserialize an Exported.""" +def deserialize(ser: bytearray) -> _export.Exported: + """Deserializes an Exported.""" exp = ser_flatbuf.Exported.GetRootAsExported(ser) return _deserialize_exported(exp) def _serialize_exported( - builder: flatbuffers.Builder, exp: export.Exported, vjp_order: int + builder: flatbuffers.Builder, exp: _export.Exported, vjp_order: int ) -> int: # Serialize bottom-up fun_name = builder.CreateString(exp.fun_name) @@ -77,10 +79,10 @@ def _serialize_exported( out_tree = _serialize_pytreedef(builder, exp.out_tree) out_avals = _serialize_array(builder, _serialize_aval, exp.out_avals) in_shardings = _serialize_array( - builder, _serialize_sharding, exp.in_shardings + builder, _serialize_sharding, exp.in_shardings_hlo ) out_shardings = _serialize_array( - builder, _serialize_sharding, exp.out_shardings + builder, _serialize_sharding, exp.out_shardings_hlo ) ordered_effects = _serialize_array( builder, _serialize_effect, exp.ordered_effects @@ -91,8 +93,8 @@ def _serialize_exported( disabled_safety_checks = _serialize_array( builder, _serialize_disabled_safety_check, exp.disabled_safety_checks ) - lowering_platforms = _serialize_array( - builder, lambda b, p: b.CreateString(p), exp.lowering_platforms + platforms = _serialize_array( + builder, lambda b, p: b.CreateString(p), exp.platforms ) mlir_module_serialized = builder.CreateByteVector(exp.mlir_module_serialized) module_kept_var_idx = builder.CreateNumpyVector( @@ -119,17 +121,17 @@ def _serialize_exported( ser_flatbuf.ExportedAddNrDevices(builder, exp.nr_devices) ser_flatbuf.ExportedAddInShardings(builder, in_shardings) ser_flatbuf.ExportedAddOutShardings(builder, out_shardings) - ser_flatbuf.ExportedAddLoweringPlatforms(builder, lowering_platforms) + ser_flatbuf.ExportedAddPlatforms(builder, platforms) ser_flatbuf.ExportedAddOrderedEffects(builder, ordered_effects) ser_flatbuf.ExportedAddUnorderedEffects(builder, unordered_effects) ser_flatbuf.ExportedAddDisabledChecks(builder, disabled_safety_checks) ser_flatbuf.ExportedAddMlirModuleSerialized(builder, mlir_module_serialized) - ser_flatbuf.ExportedAddMlirModuleSerializationVersion( - builder, exp.mlir_module_serialization_version + ser_flatbuf.ExportedAddCallingConventionVersion( + builder, exp.calling_convention_version ) ser_flatbuf.ExportedAddModuleKeptVarIdx(builder, module_kept_var_idx) - ser_flatbuf.ExportedAddUsesShapePolymorphism( - builder, exp.uses_shape_polymorphism + ser_flatbuf.ExportedAddUsesGlobalConstants( + builder, exp.uses_global_constants ) if vjp is not None: ser_flatbuf.ExportedAddVjp(builder, vjp) @@ -148,7 +150,7 @@ def _serialize_array( return builder.EndVector() -def _deserialize_exported(exp: ser_flatbuf.Exported) -> export.Exported: +def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: serialization_version = exp.SerializationVersion() if serialization_version != _SERIALIZATION_VERSION: raise NotImplementedError( @@ -159,7 +161,7 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> export.Exported: _, in_tree = tree_util.tree_flatten( _deserialize_pytreedef_to_pytree(exp.InTree()) ) - scope = export.SymbolicScope(()) # TODO: serialize the constraints + scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints deser_aval = partial(_deserialize_aval, scope=scope) in_avals = _deserialize_tuple( exp.InAvalsLength, exp.InAvals, deser_aval @@ -177,9 +179,9 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> export.Exported: out_shardings = _deserialize_tuple( exp.OutShardingsLength, exp.OutShardings, _deserialize_sharding ) - lowering_platforms = _deserialize_tuple( - exp.LoweringPlatformsLength, - exp.LoweringPlatforms, + platforms = _deserialize_tuple( + exp.PlatformsLength, + exp.Platforms, lambda v: v.decode("utf-8"), ) ordered_effects = _deserialize_tuple( @@ -195,31 +197,31 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> export.Exported: ) mlir_module_serialized = exp.MlirModuleSerializedAsNumpy().tobytes() - mlir_module_serialization_version = exp.MlirModuleSerializationVersion() + calling_convention_version = exp.CallingConventionVersion() module_kept_var_idx = tuple(exp.ModuleKeptVarIdxAsNumpy().tolist()) - uses_shape_polymorphism = exp.UsesShapePolymorphism() + uses_global_constants = exp.UsesGlobalConstants() _get_vjp = None if vjp := exp.Vjp(): _get_vjp = lambda _: _deserialize_exported(vjp) - return export.Exported( + return _export.Exported( fun_name=fun_name, in_tree=in_tree, in_avals=in_avals, out_tree=out_tree, out_avals=out_avals, nr_devices=nr_devices, - in_shardings=in_shardings, - out_shardings=out_shardings, - lowering_platforms=lowering_platforms, + in_shardings_hlo=in_shardings, + out_shardings_hlo=out_shardings, + platforms=platforms, ordered_effects=ordered_effects, unordered_effects=unordered_effects, disabled_safety_checks=disabled_safety_checks, mlir_module_serialized=mlir_module_serialized, - mlir_module_serialization_version=mlir_module_serialization_version, + calling_convention_version=calling_convention_version, module_kept_var_idx=module_kept_var_idx, - uses_shape_polymorphism=uses_shape_polymorphism, + uses_global_constants=uses_global_constants, _get_vjp=_get_vjp, ) @@ -329,32 +331,28 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): def _serialize_aval( - builder: flatbuffers.Builder, aval: core.AbstractValue + builder: flatbuffers.Builder, aval: core.ShapedArray ) -> int: - aval_type = type(aval) - if aval_type is core.ShapedArray: - aval_kind = ser_flatbuf.AbstractValueKind.shapedArray - shape_offsets = [builder.CreateString(str(d)) for d in aval.shape] - ser_flatbuf.AbstractValueStartShapeVector(builder, len(aval.shape)) - for d in reversed(shape_offsets): - builder.PrependUOffsetTRelative(d) - shape_vector_offset = builder.EndVector() - - ser_flatbuf.AbstractValueStart(builder) - ser_flatbuf.AbstractValueAddKind(builder, aval_kind) - ser_flatbuf.AbstractValueAddShape(builder, shape_vector_offset) - ser_flatbuf.AbstractValueAddDtype(builder, _dtype_to_dtype_kind[aval.dtype]) - return ser_flatbuf.AbstractValueEnd(builder) - else: - raise NotImplementedError(f"serializing AbstractValue: {aval}") + aval_kind = ser_flatbuf.AbstractValueKind.shapedArray + shape_offsets = [builder.CreateString(str(d)) for d in aval.shape] + ser_flatbuf.AbstractValueStartShapeVector(builder, len(aval.shape)) + for d in reversed(shape_offsets): + builder.PrependUOffsetTRelative(d) + shape_vector_offset = builder.EndVector() + + ser_flatbuf.AbstractValueStart(builder) + ser_flatbuf.AbstractValueAddKind(builder, aval_kind) + ser_flatbuf.AbstractValueAddShape(builder, shape_vector_offset) + ser_flatbuf.AbstractValueAddDtype(builder, _dtype_to_dtype_kind[aval.dtype]) + return ser_flatbuf.AbstractValueEnd(builder) def _deserialize_aval(aval: ser_flatbuf.AbstractValue, - scope) -> core.AbstractValue: + scope) -> core.ShapedArray: aval_kind = aval.Kind() if aval_kind == ser_flatbuf.AbstractValueKind.shapedArray: dtype = _dtype_kind_to_dtype[aval.Dtype()] - shape = export.symbolic_shape( + shape = shape_poly.symbolic_shape( ",".join( aval.Shape(i).decode("utf-8") for i in range(aval.ShapeLength()) ), @@ -366,7 +364,7 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue, def _serialize_sharding( - builder: flatbuffers.Builder, s: _export.Sharding + builder: flatbuffers.Builder, s: _export.HloSharding | None ) -> int: proto = None if s is None: @@ -383,7 +381,7 @@ def _serialize_sharding( return ser_flatbuf.ShardingEnd(builder) -def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.Sharding: +def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.HloSharding | None: kind = s.Kind() if kind == ser_flatbuf.ShardingKind.unspecified: return None @@ -443,16 +441,16 @@ def _deserialize_effect(eff: ser_flatbuf.Effect) -> core.Effect: def _serialize_disabled_safety_check( - builder: flatbuffers.Builder, check: export.DisabledSafetyCheck + builder: flatbuffers.Builder, check: _export.DisabledSafetyCheck ) -> int: custom_call_target_str = check.is_custom_call() custom_call_target = None if custom_call_target_str is not None: kind = ser_flatbuf.DisabledSafetyCheckKind.custom_call custom_call_target = builder.CreateString(custom_call_target_str) - elif check == export.DisabledSafetyCheck.platform(): + elif check == _export.DisabledSafetyCheck.platform(): kind = ser_flatbuf.DisabledSafetyCheckKind.platform - elif check == export.DisabledSafetyCheck.shape_assertions(): + elif check == _export.DisabledSafetyCheck.shape_assertions(): kind = ser_flatbuf.DisabledSafetyCheckKind.shape_assertions else: raise NotImplementedError(f"serializing DisabledSafetyCheck: {check}") @@ -468,14 +466,14 @@ def _serialize_disabled_safety_check( def _deserialize_disabled_safety_check( sc: ser_flatbuf.DisabledSafetyCheck, -) -> export.DisabledSafetyCheck: +) -> _export.DisabledSafetyCheck: kind = sc.Kind() if kind == ser_flatbuf.DisabledSafetyCheckKind.custom_call: - return export.DisabledSafetyCheck.custom_call( + return _export.DisabledSafetyCheck.custom_call( sc.CustomCallTarget().decode("utf-8") ) if kind == ser_flatbuf.DisabledSafetyCheckKind.platform: - return export.DisabledSafetyCheck.platform() + return _export.DisabledSafetyCheck.platform() if kind == ser_flatbuf.DisabledSafetyCheckKind.shape_assertions: - return export.DisabledSafetyCheck.shape_assertions() + return _export.DisabledSafetyCheck.shape_assertions() assert False, kind diff --git a/jax/experimental/export/serialization_generated.py b/jax/_src/export/serialization_generated.py similarity index 96% rename from jax/experimental/export/serialization_generated.py rename to jax/_src/export/serialization_generated.py index 941513667dae..a872d03a9fdd 100644 --- a/jax/experimental/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pytype: skip-file # automatically generated by the FlatBuffers compiler, do not modify # namespace: serialization @@ -20,7 +21,7 @@ from flatbuffers.compat import import_numpy np = import_numpy() -class PyTreeDefKind(object): +class PyTreeDefKind: leaf = 0 none = 1 tuple = 2 @@ -28,12 +29,12 @@ class PyTreeDefKind(object): dict = 4 -class AbstractValueKind(object): +class AbstractValueKind: shapedArray = 0 abstractToken = 1 -class DType(object): +class DType: bool = 0 i8 = 1 i16 = 2 @@ -59,18 +60,18 @@ class DType(object): f0 = 22 -class ShardingKind(object): +class ShardingKind: unspecified = 0 hlo_sharding = 1 -class DisabledSafetyCheckKind(object): +class DisabledSafetyCheckKind: platform = 0 custom_call = 1 shape_assertions = 2 -class PyTreeDef(object): +class PyTreeDef: __slots__ = ['_tab'] @classmethod @@ -162,7 +163,7 @@ def PyTreeDefEnd(builder): -class AbstractValue(object): +class AbstractValue: __slots__ = ['_tab'] @classmethod @@ -234,7 +235,7 @@ def AbstractValueEnd(builder): -class Sharding(object): +class Sharding: __slots__ = ['_tab'] @classmethod @@ -303,7 +304,7 @@ def ShardingEnd(builder): -class Effect(object): +class Effect: __slots__ = ['_tab'] @classmethod @@ -339,7 +340,7 @@ def EffectEnd(builder): -class DisabledSafetyCheck(object): +class DisabledSafetyCheck: __slots__ = ['_tab'] @classmethod @@ -385,7 +386,7 @@ def DisabledSafetyCheckEnd(builder): -class Exported(object): +class Exported: __slots__ = ['_tab'] @classmethod @@ -546,7 +547,7 @@ def OutShardingsIsNone(self): return o == 0 # Exported - def LoweringPlatforms(self, j): + def Platforms(self, j): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) if o != 0: a = self._tab.Vector(o) @@ -554,14 +555,14 @@ def LoweringPlatforms(self, j): return "" # Exported - def LoweringPlatformsLength(self): + def PlatformsLength(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) if o != 0: return self._tab.VectorLen(o) return 0 # Exported - def LoweringPlatformsIsNone(self): + def PlatformsIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22)) return o == 0 @@ -665,7 +666,7 @@ def MlirModuleSerializedIsNone(self): return o == 0 # Exported - def MlirModuleSerializationVersion(self): + def CallingConventionVersion(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(32)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint16Flags, o + self._tab.Pos) @@ -699,7 +700,7 @@ def ModuleKeptVarIdxIsNone(self): return o == 0 # Exported - def UsesShapePolymorphism(self): + def UsesGlobalConstants(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(36)) if o != 0: return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) @@ -757,10 +758,10 @@ def ExportedAddOutShardings(builder, outShardings): def ExportedStartOutShardingsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def ExportedAddLoweringPlatforms(builder, loweringPlatforms): - builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(loweringPlatforms), 0) +def ExportedAddPlatforms(builder, platforms): + builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(platforms), 0) -def ExportedStartLoweringPlatformsVector(builder, numElems): +def ExportedStartPlatformsVector(builder, numElems): return builder.StartVector(4, numElems, 4) def ExportedAddOrderedEffects(builder, orderedEffects): @@ -787,8 +788,8 @@ def ExportedAddMlirModuleSerialized(builder, mlirModuleSerialized): def ExportedStartMlirModuleSerializedVector(builder, numElems): return builder.StartVector(1, numElems, 1) -def ExportedAddMlirModuleSerializationVersion(builder, mlirModuleSerializationVersion): - builder.PrependUint16Slot(14, mlirModuleSerializationVersion, 0) +def ExportedAddCallingConventionVersion(builder, callingConventionVersion): + builder.PrependUint16Slot(14, callingConventionVersion, 0) def ExportedAddModuleKeptVarIdx(builder, moduleKeptVarIdx): builder.PrependUOffsetTRelativeSlot(15, flatbuffers.number_types.UOffsetTFlags.py_type(moduleKeptVarIdx), 0) @@ -796,8 +797,8 @@ def ExportedAddModuleKeptVarIdx(builder, moduleKeptVarIdx): def ExportedStartModuleKeptVarIdxVector(builder, numElems): return builder.StartVector(2, numElems, 2) -def ExportedAddUsesShapePolymorphism(builder, usesShapePolymorphism): - builder.PrependBoolSlot(16, usesShapePolymorphism, 0) +def ExportedAddUsesGlobalConstants(builder, usesGlobalConstants): + builder.PrependBoolSlot(16, usesGlobalConstants, 0) def ExportedAddVjp(builder, vjp): builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(vjp), 0) diff --git a/jax/experimental/export/_shape_poly.py b/jax/_src/export/shape_poly.py similarity index 93% rename from jax/experimental/export/_shape_poly.py rename to jax/_src/export/shape_poly.py index ada073bd11fe..d380bc5a2476 100644 --- a/jax/experimental/export/_shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -13,28 +13,13 @@ # limitations under the License. """Shape polymorphism support. -We introduce a set of dimension variables at the top-level of a `jit` function. -They are introduced implicitly by way of specifying for each dimension of each -argument a symbolic dimension expression in terms of some dimension variables. -All dimension variables are assumed to range over integers greater or equal to 1. - -Symbolic dimensions overload some integer operations, such as -add, multiply, divide, equality, etc. The JAX NumPy layer and the LAX layers have been -touched up to be sensitive to handling shapes that contain symbolic dimensions. -This enables many JAX programs to be traced with symbolic dimensions -in some dimensions. A priority has been to enable the batch -dimension in neural network examples to be polymorphic. - -This was built initially for jax2tf, but it is now -independent of TF. The best documentation at the moment is in the -jax2tf.convert docstring, and the -[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). +See documentation at https://jax.readthedocs.io/en/latest/export/shape_poly.html. """ from __future__ import annotations import enum -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Sequence import dataclasses from enum import Enum import functools @@ -43,7 +28,7 @@ import copy import operator as op import tokenize -from typing import Any, Callable, Union, overload +from typing import Any, Union, overload import warnings import numpy as np @@ -86,7 +71,7 @@ class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation): are non-constant, and the result of the operation cannot be represented as a boolean value for all values of the symbolic dimensions involved. -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported +Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported for more details. """ @@ -230,8 +215,8 @@ def evaluate(self, env: DimVarEnv): return env[self.var] except KeyError: err_msg = ( - f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the used function arguments.\n" - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") + f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n" + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") raise KeyError(err_msg) else: operand_values = [opnd._evaluate(env) for opnd in self.operands] @@ -591,15 +576,13 @@ def _get_vars(self) -> set[str]: @staticmethod def _linear_combination_sorted_pairs( e1: SortedTerms, i1: int, f1: int, - e2: SortedTerms, i2: int, f2: int) -> SortedTerms: - ... + e2: SortedTerms, i2: int, f2: int) -> SortedTerms: ... # type: ignore[bad-return-type,unused-ignore] @overload @staticmethod def _linear_combination_sorted_pairs( e1: SortedFactors, i1: int, f1: int, - e2: SortedFactors, i2: int, f2: int) -> SortedFactors: - ... + e2: SortedFactors, i2: int, f2: int) -> SortedFactors: ... # type: ignore[bad-return-type,unused-ignore] @staticmethod def _linear_combination_sorted_pairs( @@ -664,7 +647,7 @@ def _eq(self, other: _DimExpr) -> bool: # Here we really ought to raise InconclusiveDimensionOperation, but __eq__ # cannot raise exceptions, because it is used indirectly when hashing. # So, we say that the expressions are disequal, which is really unsound. - # See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported + # See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported return False return diff == 0 @@ -841,7 +824,7 @@ def __eq__(self, other: Any) -> bool: # Here we really ought to raise InconclusiveDimensionOperation, but __eq__ # cannot raise exceptions, because it is used indirectly when hashing. # So, we say that the expressions are disequal, which is really unsound. - # See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported + # See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported return False return diff == 0 @@ -875,7 +858,7 @@ def _divmod(self, divisor: DimSize) -> tuple[DimSize, int]: # invariant: self = dividend + divisor * quotient # quotient and dividend are changed in the loop; the leading term of # dividend decreases at each iteration. - while is_symbolic_dim(dividend) and not dividend._is_constant: + while is_symbolic_dim(dividend) and not dividend._is_constant: # type: ignore[attribute-error,unused-ignore] mon, count = dividend._leading_term if isinstance(divisor, _DimExpr): dterm, dcount = divisor._leading_term @@ -984,7 +967,7 @@ class SymbolicScope: Holds the constraints on symbolic expressions. - See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints) + See [the README](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for more details. Args: @@ -1048,9 +1031,9 @@ def _parse_and_process_explicit_constraint(self, c_str: str): if cmp_pos < 0: raise ValueError("Constraint parsing error: must contain one of '==' or '>=' or '<='") e1_str = c_str[:cmp_pos] - e1, = _Parser(e1_str, None, repr(e1_str), self).parse() + e1, = _Parser(e1_str, None, repr(e1_str), self).parse() # type: ignore[name-error,unused-ignore] e2_str = c_str[cmp_pos + 2:] - e2, = _Parser(e2_str, None, repr(e2_str), self).parse() + e2, = _Parser(e2_str, None, repr(e2_str), self).parse() # type: ignore[name-error,unused-ignore] if cmp == Comparator.GEQ and not is_geq: e1, e2 = e2, e1 @@ -1072,7 +1055,7 @@ def _parse_and_process_explicit_constraint(self, c_str: str): raise ValueError("Invalid equality constraint: {e1} == {e2}. " "The left-hand-side must be of the form `term * coefficient`.") - after = _ensure_poly(e2, "parse_constraint", e1.scope) + after = _ensure_poly(e2, "parse_constraint", e1.scope) # type: ignore[name-error,unused-ignore] if before in self._normalization_rules: raise NotImplementedError( f"Found multiple equality constraints with the same left-hand-side: {before}") @@ -1087,7 +1070,7 @@ def _check_same_scope(self, other: _DimExpr, f"Invalid mixing of symbolic scopes {when}.\n" f"Expected {self_descr}scope {self}\n" f"and found for '{other}' ({other_descr}) scope {other.scope}\n" - f"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.") + f"See https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints.") def _clear_caches(self): self._bounds_cache.clear() @@ -1196,6 +1179,8 @@ def _convertible_to_poly(p: DimSize) -> bool: return isinstance(p, _DimExpr) or _convertible_to_int(p) def is_symbolic_dim(p: DimSize) -> bool: + """Checks if a dimension is symbolic. + """ return isinstance(p, _DimExpr) def is_poly_dim(p: DimSize) -> bool: @@ -1313,7 +1298,7 @@ def shape_assertion(assert_what: jax.Array, def dim_as_value_impl(dim: DimSize): raise NotImplementedError( "Evaluation rule for 'dim_as_value' is not implemented. " - "It seems that you are using shape polymorphism outside jax2tf.") + "It seems that you are using shape polymorphism outside jax.export.") dim_as_value_p.def_impl(dim_as_value_impl) def _dim_as_value(dim: DimSize): @@ -1362,26 +1347,36 @@ def symbolic_shape(shape_spec: str | None, scope: SymbolicScope | None = None, like: Sequence[int | None] | None = None ) -> Sequence[DimSize]: - """Constructs a jax.ShapeDtypeStruct with polymorphic shapes. + """Constructs a symbolic shape from a string representation. + + See https://jax.readthedocs.io/en/latest/export/shape_poly.html for examples. Args: shape_spec: a symbolic shape specification. None stands for "...". + A shape specification is the string representation of a tuple (the + parentheses are optional) with comma-separated dimension expressions. + A dimension expression can be either: an integer constant, + a dimension variable (alphanumeric + starting with a letter), e1 + e2, e1 - e2, e1 * e2, floordiv(e1, e2), + mod(e1, e2), max(e1, e2), or min(e1, e2). + constraints: a sequence of constraints on symbolic dimension expressions, of + the form `e1 >= e2` or `e1 <= e2`, or `e1 == e2`. + See [the documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + for usage. + scope: optionally, you can specify that the parsed symbolic expressions + be created in the given scope. If this is missing, then a new + `SymbolicScope` is created with the given `constraints`. + You cannot specify both a `scope` and `constraints`. + See [the documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + for usage. like: when `shape_spec` contains placeholders ("_", "..."), use this shape to fill in the placeholders. The dimensions of `like` that are used for filling - must be known (not `None`). If a dimension in `like` is known and + must be not `None`. If a dimension in `like` is not `None` and the corresponding dimension in `shape_spec` is a constant then they must be equal. - scope: optionally, you can specify that the parsed symbolic expressions - be created in a given scope. You cannot specify `constraints` in this case. - constraints: a sequence of constraints on symbolic dimension expressions, of - the form `e1 >= e2` or `e1 <= e2`. This is used to create a new SymbolicScope - shared by all symbolic expressions created. - See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints) - for more details. - Returns: a jax.ShapeDTypeStruct with shapes that may contain symbolic - expressions involving dimension variables. + Returns: a tuple with integers or symbolic expressions involving dimension variables. """ shape_spec_repr = repr(shape_spec) if shape_spec is None: @@ -1400,43 +1395,51 @@ def symbolic_shape(shape_spec: str | None, def symbolic_args_specs( args, # pytree of arguments - polymorphic_shapes, # prefix pytree of strings - symbolic_scope: SymbolicScope | None = None, - symbolic_constraints: Sequence[str] = (), + shapes_specs, # prefix pytree of strings + constraints: Sequence[str] = (), + scope: SymbolicScope | None = None, + symbolic_constraints: Sequence[str] = (), # DEPRECATED on 6/14/24 + symbolic_scope: SymbolicScope | None = None, # DEPRECATED on 6/14/24 ): """Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`. - Note that this function does not ensure that the provided `args` shapes - are compatible with `polymorphic_shapes`. The `.shape` of the `args` are - used only to fill-in placeholders from `polymorphic_shapes`. - - See docstring of `symbolic_shape` and - [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) - for more details. + See the documentation of :func:`jax.export.symbolic_shape` and + the [shape polymorphism documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html) for details. Args: args: a pytree of arguments. These can be jax.Array, or jax.ShapeDTypeSpec. - This is used to learn the pytree structure of the arguments, their dtypes, - and to fill-in the actual shapes where the `polymorphic_shapes` contains + They are used to learn the pytree structure of the arguments, their dtypes, + and to fill-in the actual shapes where the `shapes_specs` contains placeholders. Note that only the shape dimensions for which - `polymorphic_shapes` is a placeholder are used from `args`. - The unused dimensions can be `None`, which jax2tf uses when the TF - shapes are not known. - polymorphic_shapes: should be `None` (all arguments have static shapes), - a single string (applies to all arguments), or a pytree matching a prefix + `shapes_specs` is a placeholder are used from `args`. + shapes_specs: should be `None` (all arguments have static shapes), + a single string (see `shape_spec` for :func:`jax.export.symbolic_shape`; + applies to all arguments), or a pytree matching a prefix of the `args`. See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). - symbolic_scope: optionally, you can specify that the parsed symbolic expressions - be created in a given scope. You cannot specify `symbolic_constraints` in this case. - symbolic_constraints: a sequence of constraints on symbolic dimension expressions, of - the form `e1 >= e2` or `e1 <= e2`. This is used to create a new SymbolicScope - shared by all symbolic expressions created. - See more details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. + constraints: as for :func:`jax.export.symbolic_shape`. + scope: as for :func:`jax.export.symbolic_shape`. + symbolic_constraints: DEPRECATED, use `constraints`. + symbolic_scope: DEPRECATED, use `scope`. Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes - replaced with symbolic dimensions as specified by `polymorphic_shapes`. + replaced with symbolic dimensions as specified by `shapes_specs`. """ + if symbolic_constraints: + warnings.warn("symbolic_constraints is deprecated, use constraints", + DeprecationWarning, stacklevel=2) + if constraints: + raise ValueError("Cannot use both symbolic_constraints and constraints") + constraints = symbolic_constraints + if symbolic_scope is not None: + warnings.warn("symbolic_scope is deprecated, use scope", + DeprecationWarning, stacklevel=2) + if scope is not None: + raise ValueError("Cannot use both symbolic_scope and scope") + scope = symbolic_scope + + polymorphic_shapes = shapes_specs args_flat, args_tree = tree_util.tree_flatten(args) shapes_and_dtypes = tuple(map(shape_and_dtype_jax_array, args_flat)) @@ -1456,15 +1459,15 @@ def symbolic_args_specs( e, *_ = tree_util.prefix_errors( polymorphic_shapes_, args, is_leaf=lambda x: x is None) - raise e("jax_export polymorphic_shapes") from None + raise e("export.symbolic_args_specs shapes_specs") from None # Now add in the polymorphic shapes - if symbolic_scope is None: - symbolic_scope = SymbolicScope(symbolic_constraints) - elif symbolic_constraints: - raise ValueError("Cannot have both `symbolic_scope` and `symbolic_constraints`") + if scope is None: + scope = SymbolicScope(constraints) + elif constraints: + raise ValueError("Cannot use both `scope` and `constraints`") args_specs_flat = ( - jax.ShapeDtypeStruct(symbolic_shape(spec, like=s, scope=symbolic_scope), t) + jax.ShapeDtypeStruct(symbolic_shape(spec, like=s, scope=scope), t) for s, t, spec in zip(shapes, dtypes, polymorphic_shapes_flat)) return args_tree.unflatten(args_specs_flat) @@ -1522,7 +1525,7 @@ def parse_err(self, tok: tokenize.TokenInfo | None, detail: str) -> Exception: def next_tok(self) -> tokenize.TokenInfo: while True: try: - t = next(self.tokstream) + t = next(self.tokstream) # type: ignore[attribute-error,unused-ignore] except StopIteration: raise self.parse_err(None, "unexpected end of string") if t.exact_type not in [tokenize.NEWLINE, tokenize.INDENT, tokenize.DEDENT]: @@ -1608,7 +1611,7 @@ def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: while True: t, tok = self.term(tok) t_sign = - t if next_t_negated else t - acc = acc + t_sign if acc is not None else t_sign # type:ignore [operator] + acc = acc + t_sign if acc is not None else t_sign # type: ignore[operator] if tok.exact_type in self.FOLLOW_EXPR: return acc, tok next_t_negated = (tok.exact_type == tokenize.MINUS) @@ -1629,7 +1632,7 @@ def term(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: acc = acc * f if acc is not None else f # type: ignore[operator] if tok.exact_type in self.FOLLOW_TERM: - return acc, tok + return acc, tok # type: ignore[bad-return-type,unused-ignore] tok = self.consume_token(tok, tokenize.STAR) def factor(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: @@ -1719,7 +1722,7 @@ def _dimension_size_lowering_rule(ctx, arg, *, dimension): mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule) -def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]: +def all_dim_vars(args_avals: Sequence[core.ShapedArray]) -> Sequence[str]: dim_vars: set[str] = set() for a in args_avals: for d in a.shape: @@ -1858,7 +1861,7 @@ def check_statically(self, eval: CachingShapeEvaluator) -> None: def shape_assertions(self, eval: CachingShapeEvaluator) -> None: """Computes the shape assertions for the set of constraints. - See jax_export._wrap_main_func docstring. + See jax_export.Exported docstring. """ # We want to report the errors in the same order as `check_statically`. # So, we process them in order, in case some fail statically, and we @@ -1913,7 +1916,7 @@ def pretty_print_dimension_descriptor( @util.cache() def solve_dim_vars( - args_avals: Sequence[core.AbstractValue], + args_avals: Sequence[core.ShapedArray], args_kwargs_tree: tree_util.PyTreeDef, ) -> tuple[DimVarEnv, ShapeConstraints, Sequence[tuple[str, int, int]]]: """Solves dimension variables in a called function's avals in terms of actual argument shapes. @@ -1978,7 +1981,7 @@ def solve_dim_vars( def compute_dim_vars_from_arg_shapes( - args_avals: Sequence[core.AbstractValue], + args_avals: Sequence[core.ShapedArray], *actual_args: jax.Array, args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]: """Computes values of dimension variables to unify args_avals with actual arguments. @@ -2019,7 +2022,7 @@ def _solve_dim_equations( " Using the following polymorphic shapes specifications: " + ",".join(f"{arg_name}.shape = {arg_spec}" for arg_name, arg_spec in polymorphic_shape_specs)) + "." - solution_err_msg_trailer_errors = ". Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + solution_err_msg_trailer_errors = ". Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." shape_constraints = ShapeConstraints() # accumulate shape constraints scope: SymbolicScope | None = None @@ -2073,9 +2076,9 @@ def process_one_eqn(eqn: _DimEquation) -> bool: solution_err_msg_trailer_errors])) if not isinstance(var_value, _DimExpr): - assert var_value.dtype == core.dim_value_dtype() + assert var_value.dtype == core.dim_value_dtype() # type: ignore[attribute-error,unused-ignore] shape_env[var] = var_value # type: ignore - solution_error_message_pieces.extend([ + solution_error_message_pieces.extend([ # type: ignore[container-type-mismatch,unused-ignore] f"'{var}' = ", var_value, f" from specification '{eqn.aval_dim_expr}' " f"for dimension {eqn.dim_name} (= ", @@ -2148,6 +2151,6 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv): " Unprocessed specifications: " + ", ".join(f"'{eqn.aval_dim_expr}' for dimension size {eqn.dim_name}" for eqn in eqns) + - ". Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + ". Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." ) raise ValueError(err_msg) diff --git a/jax/experimental/export/_shape_poly_decision.py b/jax/_src/export/shape_poly_decision.py similarity index 98% rename from jax/experimental/export/_shape_poly_decision.py rename to jax/_src/export/shape_poly_decision.py index 9e35d82972ee..e325722b0c26 100644 --- a/jax/experimental/export/_shape_poly_decision.py +++ b/jax/_src/export/shape_poly_decision.py @@ -23,8 +23,8 @@ import numpy as np -from jax.experimental.export import _shape_poly -from jax.experimental.export._shape_poly import ( +from jax._src.export import shape_poly +from jax._src.export.shape_poly import ( _DimExpr, _DimTerm, _DimFactor, SymbolicScope, DimSize, @@ -43,7 +43,7 @@ def bounds_decision(e: DimSize, decision = _DecisionByElimination.build(e.scope) return decision.bounds(e, prec, add_implicit_constraints=True) -_shape_poly._bounds_decision = bounds_decision +shape_poly._bounds_decision = bounds_decision class _DecisionByElimination: @@ -183,7 +183,7 @@ def add_to_state(self, lead_t_constraints.add((cmp, lead_t_k, e)) def combine_term_with_existing(self, t: _DimTerm, t_k: int, *, - scope: _shape_poly.SymbolicScope, + scope: shape_poly.SymbolicScope, only_smaller_than_t=True, ) -> Sequence[tuple[Comparator, _DimExpr, @@ -292,7 +292,7 @@ def _bounds_for_sorted_terms(self, prec: BoundsPrecision) -> tuple[float, float]: """The lower and upper bounds of e[i:]. - See comments about soundness and `cmp_with` in the `_shape_poly.bounds_decision`` method. + See comments about soundness and `cmp_with` in the `shape_poly.bounds_decision`` method. Returns (lower-bound, upper-bound) """ if i >= len(e): return (0, 0) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py new file mode 100644 index 000000000000..aec124549e1e --- /dev/null +++ b/jax/_src/extend/ffi.py @@ -0,0 +1,142 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import ctypes +from collections.abc import Iterable, Mapping, Sequence +from typing import Any + +import numpy as np + +from jax._src import dtypes +from jax._src.interpreters import mlir +from jax._src.lib import jaxlib +from jax._src.lib.mlir import ir +from jax._src.typing import DimSize + + +def pycapsule(funcptr): + """Wrap a ctypes function pointer in a PyCapsule. + + The primary use of this function, and the reason why it lives with in the + ``jax.extend.ffi`` submodule, is to wrap function calls from external + compiled libraries to be registered as XLA custom calls. + + Example usage:: + + import ctypes + import jax + from jax.lib import xla_client + + libfoo = ctypes.cdll.LoadLibrary('./foo.so') + xla_client.register_custom_call_target( + name="bar", + fn=jax.extend.ffi.pycapsule(libfoo.bar), + platform=PLATFORM, + api_version=API_VERSION + ) + + Args: + funcptr: A function pointer loaded from a dynamic library using ``ctypes``. + + Returns: + An opaque ``PyCapsule`` object wrapping ``funcptr``. + """ + destructor = ctypes.CFUNCTYPE(None, ctypes.py_object) + builder = ctypes.pythonapi.PyCapsule_New + builder.restype = ctypes.py_object + builder.argtypes = (ctypes.c_void_p, ctypes.c_char_p, destructor) + return builder(funcptr, None, destructor(0)) + + +def include_dir() -> str: + """Get the path to the directory containing header files bundled with jaxlib""" + jaxlib_dir = os.path.dirname(os.path.abspath(jaxlib.__file__)) + return os.path.join(jaxlib_dir, "include") + + +def ffi_lowering( + call_target_name: str, + *, + operand_layouts: Sequence[Sequence[DimSize]] | None = None, + result_layouts: Sequence[Sequence[DimSize]] | None = None, + backend_config: Mapping[str, ir.Attribute] | None = None, + **lowering_args: Any +) -> mlir.LoweringRule: + """Build a lowering rule for an foreign function interface (FFI) target. + + By default, this lowering rule can use the input and output abstract values to + compute the input and output types and shapes for the custom call, assuming + row-major layouts. + + If keyword arguments are passed to the lowering rule, these are treated as + attributes, and added to `backend_config`. + + Args: + call_target_name: The name of the custom call target. + operand_layouts: A sequence of layouts (dimension orders) for each operand. + By default, the operands are assumed to be row-major. + result_layouts: A sequence of layouts (dimension orders) for each result. + By default, the results are assumed to be row-major. + backend_config: Configuration data for the custom call. Any keyword + arguments passed to the lowering rule will added to this dictionary. + lowering_args: Any other arguments to :func:`mlir.custom_call` will also be + passed through if provided as extra arguments to this function. + """ + + def _lowering( + ctx: mlir.LoweringRuleContext, *operands: ir.Value, **params: Any + ) -> Sequence[ir.Value | Sequence[ir.Value]]: + kwargs = dict(lowering_args) + kwargs.setdefault("api_version", 4) + kwargs["backend_config"] = dict( + backend_config or {}, **{k: _ir_attribute(v) for k, v in params.items()}) + if "result_types" not in kwargs: + kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out] + if operand_layouts is None: + kwargs["operand_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_in) # pytype: disable=attribute-error + if result_layouts is None: + kwargs["result_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_out) + + return mlir.custom_call(call_target_name, operands=operands, **kwargs).results # type: ignore + + return _lowering + + +def _default_layouts(shapes: Iterable[Sequence[DimSize]]) -> list[list[DimSize]]: + return [list(reversed(range(len(shape)))) for shape in shapes] + + +def _ir_attribute(obj: Any) -> ir.Attribute: + # TODO(dfm): Similar functions exist in Pallas and Mosaic GPU. Perhaps these + # could be consolidated into mlir or similar. + if isinstance(obj, str): + return ir.StringAttr.get(obj) + elif isinstance(obj, bool): + return ir.BoolAttr.get(obj) + elif isinstance(obj, int): + return mlir.i64_attr(obj) + elif isinstance(obj, float): + return ir.FloatAttr.get_f64(obj) + elif hasattr(obj, "dtype"): + if not (dtypes.is_python_scalar(obj) or np.isscalar(obj)): + raise TypeError("Only scalar attributes are supported") + mlir_type = mlir.dtype_to_ir_type(obj.dtype) + if isinstance(mlir_type, ir.IntegerType): + return ir.IntegerAttr.get(mlir_type, obj) + elif isinstance(mlir_type, ir.FloatType): + return ir.FloatAttr.get(mlir_type, obj) + raise TypeError(f"Unsupported attribute type: {type(obj)}") diff --git a/jax/_src/extend/random.py b/jax/_src/extend/random.py index ffc7f63072b9..df927486dd2f 100644 --- a/jax/_src/extend/random.py +++ b/jax/_src/extend/random.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable -from collections.abc import Hashable +from collections.abc import Callable, Hashable from jax import Array diff --git a/jax/_src/gfile_cache.py b/jax/_src/gfile_cache.py index 301f61cc6bdb..989844b10ddb 100644 --- a/jax/_src/gfile_cache.py +++ b/jax/_src/gfile_cache.py @@ -17,6 +17,10 @@ from jax._src import path as pathlib from jax._src.compilation_cache_interface import CacheInterface + +# TODO (ayx): This class will be ultimately removed after `lru_cache.py` is +# finished. It exists because the current `lru_cache.py` does not support +# `gs://`. class GFileCache(CacheInterface): def __init__(self, path: str): diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index 1ef1db916f73..aa9910555130 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -14,10 +14,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import enum -from typing import Callable import numpy as np diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_ducc_fft.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_ducc_fft.py deleted file mode 100644 index 6418f374239c..000000000000 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_ducc_fft.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -from numpy import array, float32, complex64 - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14 = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['dynamic_ducc_fft'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(array([[ 0., 1., 2., 3.], - [ 4., 5., 6., 7.], - [ 8., 9., 10., 11.]], dtype=float32),), - expected_outputs=(array([[ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j], - [22.+0.j, -2.+2.j, -2.+0.j, -2.-2.j], - [38.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]], dtype=complex64),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<3x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<3x4xcomplex> {jax.result_info = ""}) { - %0 = call @fft(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xcomplex> loc(#loc3) - return %0 : tensor<3x4xcomplex> loc(#loc) - } loc(#loc) - func.func private @fft(%arg0: tensor<3x4xf32> loc(unknown)) -> tensor<3x4xcomplex> { - %0 = stablehlo.convert %arg0 : (tensor<3x4xf32>) -> tensor<3x4xcomplex> loc(#loc4) - %1 = stablehlo.constant dense<4> : tensor loc(#loc5) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.convert %1 : (tensor) -> tensor loc(#loc5) - %4 = stablehlo.reshape %3 : (tensor) -> tensor<1xi32> loc(#loc5) - %5 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %6 = stablehlo.reshape %5 : (tensor) -> tensor<1xi32> loc(#loc5) - %7 = stablehlo.concatenate %4, %6, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %8 = stablehlo.constant dense<4> : tensor loc(#loc5) - %9 = stablehlo.constant dense<1> : tensor loc(#loc5) - %10 = stablehlo.convert %8 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %13 = stablehlo.reshape %12 : (tensor) -> tensor<1xi32> loc(#loc5) - %14 = stablehlo.concatenate %11, %13, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %15 = stablehlo.constant dense<4> : tensor loc(#loc5) - %16 = stablehlo.convert %15 : (tensor) -> tensor loc(#loc5) - %17 = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc5) - %18 = stablehlo.reshape %17 : (tensor) -> tensor<1xf64> loc(#loc5) - %19 = stablehlo.constant dense<3> : tensor loc(#loc5) - %20 = stablehlo.constant dense<4> : tensor loc(#loc5) - %21 = stablehlo.convert %19 : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.reshape %21 : (tensor) -> tensor<1xi32> loc(#loc5) - %23 = stablehlo.convert %20 : (tensor) -> tensor loc(#loc5) - %24 = stablehlo.reshape %23 : (tensor) -> tensor<1xi32> loc(#loc5) - %25 = stablehlo.concatenate %22, %24, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %26 = stablehlo.constant dense<3> : tensor loc(#loc5) - %27 = stablehlo.constant dense<4> : tensor loc(#loc5) - %28 = stablehlo.convert %26 : (tensor) -> tensor loc(#loc5) - %29 = stablehlo.reshape %28 : (tensor) -> tensor<1xi32> loc(#loc5) - %30 = stablehlo.convert %27 : (tensor) -> tensor loc(#loc5) - %31 = stablehlo.reshape %30 : (tensor) -> tensor<1xi32> loc(#loc5) - %32 = stablehlo.concatenate %29, %31, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %33 = stablehlo.constant dense<[20, 0, 0, 0, 0, 0, 14, 0, 16, 0, 8, 0, 0, 0, 0, 0, 12, 0, 7, 0, 14, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]> : tensor<44xui8> loc(#loc5) - %34 = stablehlo.custom_call @dynamic_ducc_fft(%33, %0, %25, %7, %14, %18, %32) {api_version = 2 : i32, indices_of_shape_operands = dense<6> : tensor<1xi64>, operand_layouts = [dense<0> : tensor<1xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<44xui8>, tensor<3x4xcomplex>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1xf64>, tensor<2xi32>) -> tensor<3x4xcomplex> loc(#loc5) - return %34 : tensor<3x4xcomplex> loc(#loc3) - } loc(#loc3) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":437:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":428:0) -#loc3 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft keep_unused=False inline=False]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]"(#loc2)) -#loc5 = loc("jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01!\x05\x01\x03\x01\x03\x05\x03\x11\x07\t\x0b\r\x0f\x11\x13\x15\x03\xd3\x95+\x01U\x0f\x07\x13\x0b\x13\x0b\x0f\x0f\x0b\x0b\x0b\x0b\x0b\x17\x13\x13#\x0b\x0b\x0b33\x0b\x17\x0f\x0b\x0b\x0b\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x03A/\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b//\x0f//\xbf\x0b\x0b\x0b/'\x0f\x01\x03\x0f\x03)\x0f\x0f\x13\x17\x13\x17\x07\x07\x0f\x07\x07\x13\x07\x17\x0b\x13\x07\x13\x13\x13\x02r\x06\x1d5\x1b\x1f\x03\x03\x07}\x05\x17\x03\x037\x81\x05\x19\x1d-/\x11\x01\x05\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x17\x19\xb2\x06\x01\x03\x03\x07\x7f\x03\x03\x07\x85\x03\x07#\x0f%\x0f\x0b'\x05%\x05'\x05)\x03\x0b\x11c\x13W\x15o\x0bu\x17w\x03\x0b\x11[\x13W\x15[\x0b]\x17{\x05+\x17\x19\xd6\x06\x01\x1d3\x1b\x05-\x05/\x051\x03\x03\x07\x83\x03\x03\x07\x87\x03\x13?\x89AYC\x8bE_G\x8dI\x8fK\x91M_O\x93\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\x03S]\x05E\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1d\x1dG\x03\x03y\x1dI\x03\x01\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03e\r\x05gikm\x1dK\x1dM\x1dO\x1dQ\x03\x03q\r\x03sY\x1dS\x1dU\x1dW\r\x01\x1dY\x1f\x03\x11\x04\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x13\x0f\x01\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf0?\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19Y\x14\x00\x00\x00\x00\x00\x0e\x00\x10\x00\x08\x00\x00\x00\x00\x00\x0c\x00\x07\x00\x0e\x00\x00\x00\x00\x00\x00\x01\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x0b\x05\x1d[\x05\x01\x1f%\x11\x06\x00\x00\x00\x00\x00\x00\x00\x03\x0fUaUUUUU\x03\x03a\x01\x02\x02)\x01\x0f)\x01\x11)\x03\x05\x11)\x05\r\x11\x1f)\x03\t\x11)\x05\r\x11\x15\x1d\x1b)\x01\x17\t\x0b)\x03\xb1#\x13\x11\x03\r\x03\t\x03\x15)\x03\x05\x17!)\x03\x05\x0f)\x03\x05\x1b)\x03\t\x1b\x04\xaa\x04\x05\x01\x11\x03!\x07\x03\x01\t\x0b\x11\x03)\x05\x03\x05\x0b\x03\r\x03\x11\x07\rQ\x03\t\x03\x01\r\x04\x03\x03\x03\x0b\x11\r+\x05\x03I\x93\x03\r\x03\x05\x061\x03\t\x03\x01\x03\x03\x01\x05\x03\x03\x03\x03\x01\x1d\x03\x03\x05\x06\x01\x03\x05\x03\x05\x07\x06\x01\x03\x07\x03\t\x05\x06\x01\x03\x05\x03\x07\x07\x06\x01\x03\x07\x03\r\t\x07\x01\t\x03\x0b\x05\x0b\x0f\x03\x03\x01\x05\x03\x03\x03\x03\x01\x1d\x03\x03\x05\x06\x01\x03\x05\x03\x13\x07\x06\x01\x03\x07\x03\x17\x05\x06\x01\x03\x05\x03\x15\x07\x06\x01\x03\x07\x03\x1b\t\x07\x01\t\x03\x0b\x05\x19\x1d\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x13\x03!\x03\x03\x019\x03\x13\x07\x06\x01\x03!\x03%\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x05\x03)\x07\x06\x01\x03\x07\x03-\x05\x06\x01\x03\x05\x03+\x07\x06\x01\x03\x07\x031\t\x07\x01\t\x03\x0b\x05/3\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x05\x03\x03\x05\x06\x01\x03\x05\x037\x07\x06\x01\x03\x07\x03;\x05\x06\x01\x03\x05\x039\x07\x06\x01\x03\x07\x03?\t\x07\x01\t\x03\x0b\x05=A\x03\x03\x01;\x03\x19\x0f\x07\x01=\x03\t\x0fE\x035\x11\x1f'C\r\x04\r\x03G\x06\x03\x01\x05\x01\x00\xc6\x0e]#\x11\x0f\x0b!\x1b\x1d\x05\x1b\t\x03\x0f\x1f/!5!)#\x1f\x19\x15\x91\xaf\xbe\x02\x13%)\x83\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x11\x1f\x17\x17\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00convert_v1\x00reshape_v1\x00concatenate_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=fft keep_unused=False inline=False]\x00jit(func)/jit(main)/jit(fft)/convert_element_type[new_dtype=complex64 weak_type=False]\x00jit(func)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(4,)]\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00indices_of_shape_operands\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00fft\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00dynamic_ducc_fft\x00", - xla_call_module_version=6, -) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_approx_top_k.py b/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_approx_top_k.py new file mode 100644 index 000000000000..b676cc8011d3 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/stablehlo_dynamic_approx_top_k.py @@ -0,0 +1,84 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, float32, int32 + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_05_30 = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['shape_assertion', 'stablehlo.dynamic_approx_top_k'], + serialized_date=datetime.date(2024, 5, 30), + inputs=(array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.], + dtype=float32),), + expected_outputs=(array([23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13., 12., 11., + 10., 9., 8., 7., 6., 5., 4.], dtype=float32), array([23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, + 6, 5, 4], dtype=int32)), + mlir_module_text=r""" +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":718:13) +#loc3 = loc("a") +#loc9 = loc("jit(func)/jit(main)/approx_top_k[k=b reduction_dimension=-1 recall_target=0.95 is_max_k=True reduction_input_size_override=-1 aggregate_to_topk=True]"(#loc2)) +module @jit_func attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"} loc(unknown)) -> (tensor {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor loc(#loc4) + %1 = stablehlo.convert %0 : (tensor) -> tensor loc(#loc4) + %2 = stablehlo.convert %1 : tensor loc(#loc5) + %c = stablehlo.constant dense<-4> : tensor loc(#loc) + %3 = stablehlo.add %2, %c : tensor loc(#loc6) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc) + %4 = stablehlo.compare GE, %3, %c_0, SIGNED : (tensor, tensor) -> tensor loc(#loc7) + stablehlo.custom_call @shape_assertion(%4, %3, %1) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'b'. Using the following polymorphic shapes specifications: args[0].shape = (b + 4,). Obtained dimension variables: 'b' = {0} from specification 'b + 4' for dimension args[0].shape[0] (= {1}), . Please see https://github.com/googlexjax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.", has_side_effect = true} : (tensor, tensor, tensor) -> () loc(#loc8) + %5:2 = call @_wrapped_jax_export_main(%3, %arg0) : (tensor, tensor) -> (tensor, tensor) loc(#loc) + return %5#0, %5#1 : tensor, tensor loc(#loc) + } loc(#loc) + func.func @top_k_gt_f32_comparator(%arg0: tensor loc("jit(func)/jit(main)/approx_top_k[k=b reduction_dimension=-1 recall_target=0.95 is_max_k=True reduction_input_size_override=-1 aggregate_to_topk=True]"(#loc2)), %arg1: tensor loc("jit(func)/jit(main)/approx_top_k[k=b reduction_dimension=-1 recall_target=0.95 is_max_k=True reduction_input_size_override=-1 aggregate_to_topk=True]"(#loc2)), %arg2: tensor loc("jit(func)/jit(main)/approx_top_k[k=b reduction_dimension=-1 recall_target=0.95 is_max_k=True reduction_input_size_override=-1 aggregate_to_topk=True]"(#loc2)), %arg3: tensor loc("jit(func)/jit(main)/approx_top_k[k=b reduction_dimension=-1 recall_target=0.95 is_max_k=True reduction_input_size_override=-1 aggregate_to_topk=True]"(#loc2))) -> tensor { + %0 = stablehlo.compare GT, %arg0, %arg1 : (tensor, tensor) -> tensor loc(#loc9) + return %0 : tensor loc(#loc9) + } loc(#loc9) + func.func private @_wrapped_jax_export_main(%arg0: tensor {jax.global_constant = "b", mhlo.layout_mode = "default"} loc(unknown), %arg1: tensor {mhlo.layout_mode = "default"} loc("a")) -> (tensor {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.convert %arg0 : tensor loc(#loc10) + %c = stablehlo.constant dense<4> : tensor loc(#loc9) + %1 = stablehlo.add %0, %c : tensor loc(#loc11) + %2 = stablehlo.convert %1 : (tensor) -> tensor loc(#loc9) + %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> loc(#loc9) + %4 = stablehlo.dynamic_iota %3, dim = 0 : (tensor<1xi32>) -> tensor loc(#loc9) + %c_0 = stablehlo.constant dense<-1> : tensor loc(#loc9) + %cst = stablehlo.constant dense<0xFF800000> : tensor loc(#loc9) + %5 = stablehlo.convert %arg0 : (tensor) -> tensor loc(#loc9) + %6 = stablehlo.reshape %5 : (tensor) -> tensor<1xi32> loc(#loc9) + %7 = stablehlo.convert %arg0 : (tensor) -> tensor loc(#loc9) + %8 = stablehlo.reshape %7 : (tensor) -> tensor<1xi32> loc(#loc9) + %9 = stablehlo.convert %arg0 : (tensor) -> tensor loc(#loc9) + %10:2 = stablehlo.custom_call @stablehlo.dynamic_approx_top_k(%arg1, %4, %cst, %c_0, %9, %6, %8) {called_computations = [@top_k_gt_f32_comparator], indices_of_shape_operands = dense<[5, 6]> : tensor<2xi64>, mhlo.backend_config = {aggregate_to_topk = true, is_fallback = true, recall_target = 0.949999988 : f32, reduction_dim = 0 : i64, reduction_input_size_override = -1 : i64}} : (tensor, tensor, tensor, tensor, tensor, tensor<1xi32>, tensor<1xi32>) -> (tensor, tensor) loc(#loc9) + return %10#0, %10#1 : tensor, tensor loc(#loc) + } loc(#loc) +} loc(#loc) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":738:4) +#loc4 = loc("/dimension_size[dimension=0]"(#loc1)) +#loc5 = loc("/convert_element_type[new_dtype=int64 weak_type=False]"(#loc1)) +#loc6 = loc("/add"(#loc1)) +#loc7 = loc("/ge"(#loc1)) +#loc8 = loc("/shape_assertion[error_message=Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'b'. Using the following polymorphic shapes specifications: args[0].shape = (b + 4,). Obtained dimension variables: 'b' = {0} from specification 'b + 4' for dimension args[0].shape[0] (= {1}), . Please see https://github.com/googlexjax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.]"(#loc1)) +#loc10 = loc("jit(func)/jit(main)/convert_element_type[new_dtype=int64 weak_type=False]"(#loc2)) +#loc11 = loc("jit(func)/jit(main)/add"(#loc2)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\'\x05\x01\x03\x01\x03\x05\x03\x17\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x03B\x02\xeb#\x01\x85\x0f\x07\x0b\x17\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f+\x0b\x0f\x0b\x0b\x0b33\x0b3\x0f\x0b\x0f\x0b\x13\x0f\x0b\x13\x0b\x13\x13[\x0b\x0b\x1b\x13\x0b\x0b\x0f\x0b\x13\x0f\x0b\x13\x1b\x0f\x0bS\x0b\x0f\x0b\x13\x0b\x03g\x0b\x0b\x0b\x0b\x0f\x0b\x13\x13\x0b\x0b\x0b\x0f\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x0b\x0b/\x1f\x1f\x0b\x0b\x0f\x0bO3\x0b\x0b\x0b\x1f\x0b\x0b\x0f\x0b\x0b//\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x1f\x0f\x0f3\x0f3\x07\x07\x07\x0f\x13\x1b#\x07\x1f\x13\x02\x0e\x08\x1d?\x13\x1f\x05\x1d\x17\x17\x8a\x0b\t\x05\x1f\x05!\x05#\x05%\x05\'\x17\x17:\x0b\x1b\x11\x03\x05\x05)\x05+\x05-\x05/\x051\x053\x055\x057\x059\x05;\x05=\x1de\x07\x03\t135\x157\x15\t9\x05?\x11\x01\x01\x05A\x05C\x05E\x03\x0b\x0b\x9b\r\x9d\x0f\x93\t\xa7\x11\xa9\x03\x0b\x0b\x85\r\xab\x0f\x85\t\x97\x11\x8b\x05G\x03\x0b\x0b\xad\r\xb5\x0f\x93\t\x99\x11\xb7\x1dE\x03\x05I\x1dI\x13\x05K\x03\x03\x05\xb9\x1dO\x13\x05M\x03\x03S\x8d\x05O\x03\x03\x05\xbb\x03\x03\x05\xbd\x03\x15\x19\xbf\x1b\x8b\x1d\xc1\x1f\xc3!\xc5[\xc7]\xc9#\x85%\x85\'\x85\x05Q\x05S\x03\x05)\xd9+\xdb\x03\x03c\x8d\x05U\x05W\x1di\x07\x05Y\x03\x03\x05\xdd\x1do\x07\x05[\x03\x03\x05\xdf\x03\x05)\xe1+\xe3\x1dw\x07\x05]\x03\x13\x19\xe5\x1b\x8b\x1d\xe7\x1f\x85{\xe9!\x8f#\x85%\x85\'\x85\x05_\x1d\x7f\x07\x05a\x03\x03\x83\x99\x05c\x03\x01\x1de\x1dg\x1di\x13\x0f\x01\x05\x03\r\x03\x87\x89\x03\x05\x9f\xa3\x1dk\x1dm\x1do\x03\x03\x91#\x19\r\x05\x95\xa1\x87\x89\x1dq\r\x05\x95\xa5\x87\x89\x1ds\x1du\x1dw#\x1b\x03\x05\xaf\x91\r\x05\xb1\xb3\x87\x89\x1dy\x1d{#\x1f\x1d}\x1f\x05\x11\x04\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\xff\xff\xff\xff\x1f\x0b\t\x00\x00\x80\xff\x0b\x03\x1d\x7f\x03\x03\x97\x05\x01\x1f!!\x05\x00\x00\x00\x00\x00\x00\x00\x06\x00\x00\x00\x00\x00\x00\x00\r\x0b\xcb\x8f\xcd\x8f\xcf\xd1\xd3\x8d\xd5\xd7\x1d\x81\x1d\x83\x1d\x85\x11\x11\xd0\xcc\xcc\xdc\x0f\x1d\x87\x1d\x89\x13\x0f\x03\t\x01\x07\x07\x1f\x05\x11\xfc\xff\xff\xff\xff\xff\xff\xff\x1f\x05\x11\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x05\x0b\x05\x1d\x8b\x1d\x8d\x01\t\x01\x02\x02)\x01\x0f)\x01\x13)\x03\x00\xff\xff\xff\xff\xff\xff\xff\xff\x11)\x01\x11)\x03\x00\xff\xff\xff\xff\xff\xff\xff\xff\x13\x1d\t\x1b)\x01\x1d)\x03\x05\x13\x11\x03\t\x05\t\r\x11\t\x0b\x0b\x07\x07\x03\x15\x01\x11\x05\x05\t\x05\t\r)\x03\t\x0f\x04\xea\x03\x05\x01\x11\x03/\x07\x03\x01\r\x07\x11\x03;\x07\x03\x15+\x03\t\x03\x15\x07-a\x03\x07\x03\x01\x03\x06-\x03\x05\x03\x03\x03\x06g\x03\x05\x03\x05\x05\x03\x03k\x03\x05\r\x06m\x03\x05\x05\x07\t\x05\x03\x03q\x03\x05\x11\x07us\x03\x15\x05\x0b\r\x0f\x05}y\x07\x0f\x0b\x05\x17\x07\x03\x81\x05\t\r\x05\x0b\x01\x0b\x04\x03\x05\x11\x13\x07\x11\x01=\x07\x03\x0b\x0b\t\x0b\x01\x0b\x01\x07\x01\x07\x01\x11\x07\x01_\x03\x15\x05\x01\x03\x0b\x04\x01\x03\t\x07\x11\x03A\x07\x03#?\x05\x05\x03\tC\x03\x06G\x03\x05\x03\x01\x05\x03\x01K\x03\x05\r\x06M\x03\x05\x05\x05\x07\x03\x06\x01\x03\x07\x03\t\t\x06\x01\x03\x17\x03\x0b\x13\x07\x01Q\x03\r\x03\r\x05\x03\x01U\x03\x07\x05\x03\x01W\x03\x0b\x03\x06\x01\x03\x07\x03\x01\t\x06\x01\x03\x17\x03\x15\x03\x06\x01\x03\x07\x03\x01\t\x06\x01\x03\x17\x03\x19\x03\x06\x01\x03\x07\x03\x01\x0f\x07\x01Y\x05\t\r\x0f\x03\x0f\x13\x11\x1d\x17\x1b\x0b\x04\x03\x05\x1f!\x06\x03\x01\x05\x01\x00""\x8f\xb2\x06!=\x1d\x1d\x19%?\x11\x05)\x0f\x0b\t\t31!\x03\x11#\x0f2\x07\x1d\t\x0bo;\x15)5\x1f1\x95\x05Z\x02\x13%)9+\x1b\x1f/!!)#\x1f\x19i\x1f\x15\x1d\x15\x13\r\x11-!\x17\x1f\x0f\x15\x17\x11\x19\x17\x0f\x0b\x11builtin\x00vhlo\x00module\x00convert_v1\x00constant_v1\x00func_v1\x00reshape_v1\x00return_v1\x00add_v1\x00custom_call_v1\x00compare_v1\x00dynamic_iota_v1\x00get_dimension_size_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/approx_top_k[k=b reduction_dimension=-1 recall_target=0.95 is_max_k=True reduction_input_size_override=-1 aggregate_to_topk=True]\x00a\x00jit(func)/jit(main)/convert_element_type[new_dtype=int64 weak_type=False]\x00jit(func)/jit(main)/add\x00iota_dimension\x00indices_of_shape_operands\x00mhlo.backend_config\x00dimension\x00/dimension_size[dimension=0]\x00/convert_element_type[new_dtype=int64 weak_type=False]\x00/add\x00/ge\x00error_message\x00/shape_assertion[error_message=Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable \'b\'. Using the following polymorphic shapes specifications: args[0].shape = (b + 4,). Obtained dimension variables: \'b\' = {0} from specification \'b + 4\' for dimension args[0].shape[0] (= {1}), . Please see https://github.com/googlexjax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.]\x00callee\x00mhlo.layout_mode\x00default\x00\x00jax.result_info\x00top_k_gt_f32_comparator\x00_wrapped_jax_export_main\x00[0]\x00[1]\x00main\x00public\x00jax.global_constant\x00b\x00private\x00stablehlo.dynamic_approx_top_k\x00aggregate_to_topk\x00is_fallback\x00recall_target\x00reduction_dim\x00reduction_input_size_override\x00shape_assertion\x00Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable \'b\'. Using the following polymorphic shapes specifications: args[0].shape = (b + 4,). Obtained dimension variables: \'b\' = {0} from specification \'b + 4\' for dimension args[0].shape[0] (= {1}), . Please see https://github.com/googlexjax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 3cd4cad517dd..5a975e3c5a61 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -54,7 +54,7 @@ def func(...): ... file jax/_src/internal_test_util/export_back_compat_test_data/foo_call.py and paste the test data that you will see printed in the logs. -Name the literal `data_YYYYY_MM_DD` to include the date of serializaton +Name the literal `data_YYYYY_MM_DD` to include the date of serialization (for readability only). Then add to this file: from jax._src.internal_test_util.export_back_compat_test_data import foo_call @@ -70,13 +70,13 @@ def func(...): ... from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import dataclasses import datetime import os import re import sys -from typing import Any, Callable +from typing import Any from absl import logging @@ -86,7 +86,7 @@ def func(...): ... import jax from jax import tree_util -from jax.experimental import export +from jax import export from jax.experimental import pjit @@ -303,7 +303,7 @@ def serialize(self, module_str = str(exported.mlir_module()) serialized = exported.mlir_module_serialized - module_version = exported.mlir_module_serialization_version + module_version = exported.calling_convention_version nr_devices = exported.nr_devices return serialized, module_str, module_version, nr_devices @@ -330,19 +330,19 @@ def _get_vjp(_): in_avals=tuple(in_avals), out_tree=out_tree, out_avals=tuple(out_avals), - in_shardings=(None,) * len(in_avals), - out_shardings=(None,) * len(out_avals), - lowering_platforms=(data.platform,), + in_shardings_hlo=(None,) * len(in_avals), + out_shardings_hlo=(None,) * len(out_avals), + platforms=(data.platform,), ordered_effects=(), unordered_effects=(), disabled_safety_checks=(), mlir_module_serialized=data.mlir_module_serialized, - mlir_module_serialization_version=data.xla_call_module_version, + calling_convention_version=data.xla_call_module_version, nr_devices=data.nr_devices, module_kept_var_idx=tuple(range(len(in_avals))), - uses_shape_polymorphism=any(not core.is_constant_shape(a.shape) + uses_global_constants=any(not core.is_constant_shape(a.shape) for a in in_avals), _get_vjp=_get_vjp) # We use pjit in case there are shardings in the exported module. - return pjit.pjit(export.call_exported(exported))(*data.inputs) + return pjit.pjit(exported.call)(*data.inputs) diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index cd65b702f603..2b22944c17b8 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -38,11 +38,11 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import operator import os from functools import partial -from typing import Any, Callable, NamedTuple, Union +from typing import Any, NamedTuple, Union from absl import testing import numpy as np @@ -62,6 +62,9 @@ from jax._src.lib import xla_client from jax._src import random as jax_random +# mypy generates a lot of false positive due to re-assigned variables. +# mypy: disable-error-code="assignment, no-redef" + # The code in this file relies on the values of some flags that are defined by # jtu. Note that the following can not always be moved to a test file since # then the test file has to import jtu first (to define the flags) which is not @@ -172,9 +175,9 @@ def __init__(self, self.group_name = jtu.sanitize_test_name(group_name) self.name = jtu.sanitize_test_name(name) self.fullname = self.name if self.group_name is None else f"{self.group_name}_{self.name}" - self.fun = fun # type: ignore[assignment] + self.fun = fun self.arg_descriptors = arg_descriptors - self.rng_factory = rng_factory # type: ignore[assignment] + self.rng_factory = rng_factory self.jax_unimplemented = jax_unimplemented self.dtype = dtype self.params = params @@ -651,7 +654,7 @@ def _make_device_put_harness(name, define( "device_put", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{device=}", - lambda x: dispatch.device_put_p.bind(x, device=_device_fn(), src=None), + lambda x: dispatch.device_put_p.bind(x, devices=[_device_fn()], srcs=[None])[0], [RandArg(shape, dtype)], shape=shape, dtype=dtype, @@ -2060,18 +2063,17 @@ def _make_slice_harness(name, define( lax.slice_p, f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_{strides=}", - # type: ignore lax.slice, [ - RandArg(shape, dtype), # type: ignore - StaticArg(start_indices), # type: ignore - StaticArg(limit_indices), # type: ignore + RandArg(shape, dtype), + StaticArg(start_indices), + StaticArg(limit_indices), StaticArg(strides) - ], # type: ignore + ], dtype=dtype, - shape=shape, # type: ignore - start_indices=start_indices, # type: ignore - limit_indices=limit_indices) # type: ignore + shape=shape, + start_indices=start_indices, + limit_indices=limit_indices) # Test first all dtypes @@ -2161,17 +2163,16 @@ def _make_dynamic_slice_harness(name, define( lax.dynamic_slice_p, f"{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_{start_indices=}_{limit_indices=}_enablexla={enable_xla}", - # type: ignore lax.dynamic_slice, [ - RandArg(shape, dtype), # type: ignore + RandArg(shape, dtype), np.array(list(start_indices)), StaticArg(tuple(map(operator.sub, limit_indices, start_indices))) - ], # type: ignore + ], dtype=dtype, - shape=shape, # type: ignore - start_indices=start_indices, # type: ignore - limit_indices=limit_indices, # type: ignore + shape=shape, + start_indices=start_indices, + limit_indices=limit_indices, enable_xla=enable_xla) @@ -2218,19 +2219,19 @@ def _make_dynamic_update_slice_harness(name, define( lax.dynamic_update_slice_p, ( - f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore + f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}" f"_update={jtu.format_shape_dtype_string(update_shape, dtype)}" f"_{start_indices=}_{enable_xla=}"), lax.dynamic_update_slice, [ - RandArg(shape, dtype), # type: ignore - RandArg(update_shape, dtype), # type: ignore + RandArg(shape, dtype), + RandArg(update_shape, dtype), np.array(start_indices) - ], # type: ignore + ], dtype=dtype, - shape=shape, # type: ignore - start_indices=start_indices, # type: ignore - update_shape=update_shape, # type: ignore + shape=shape, + start_indices=start_indices, + update_shape=update_shape, enable_xla=enable_xla) @@ -2261,12 +2262,12 @@ def _make_squeeze_harness(name, dtype=np.float32): define( lax.squeeze_p, - f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}", # type: ignore + f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}_{dimensions=}", lax.squeeze, - [RandArg(shape, dtype), StaticArg(dimensions)], # type: ignore[has-type] + [RandArg(shape, dtype), StaticArg(dimensions)], dtype=dtype, arg_shape=shape, - dimensions=dimensions) # type: ignore[has-type] + dimensions=dimensions) # Test first all dtypes @@ -3312,6 +3313,7 @@ def _make_conv_harness(name, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation) +key_types: list[tuple[tuple[int, ...], jax.typing.DTypeLike]] key_types = [((4,), np.uint32)] if config.enable_x64.value: key_types.append(((2,), np.uint64)) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 0106a5b77181..a527acb8db90 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -14,12 +14,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import functools import itertools as it from functools import partial -from typing import Any, Callable +from typing import Any import jax from jax._src import config diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 87add6b74567..3a87fffa5116 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -14,10 +14,10 @@ from __future__ import annotations import collections -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import dataclasses from functools import partial -from typing import Any, Callable, Union +from typing import Any, Union import numpy as np diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index f5f0a2df42c9..7eb826c95a67 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -16,7 +16,7 @@ from __future__ import annotations import collections -from collections.abc import Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence import dataclasses import functools from functools import partial @@ -27,7 +27,7 @@ import re import types import typing -from typing import Any, Callable, NamedTuple, Protocol, Union, cast as type_cast +from typing import Any, NamedTuple, Protocol, Union, cast as type_cast import warnings import numpy as np @@ -47,6 +47,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout +from jax._src.sharding import Sharding as JSharding from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension from jax._src.lib.mlir import dialects @@ -54,7 +55,6 @@ from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir import register_jax_dialects -from jax._src.sharding_impls import XLACompatibleSharding from jax._src.state.types import AbstractRef map, unsafe_map = util.safe_map, map @@ -67,14 +67,14 @@ # mypy implicitly sets this variable to true when type checking. MYPY = False -_JAX_DUMP_IR_TO = config.DEFINE_string( +_JAX_DUMP_IR_TO = config.string_flag( 'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''), help="Path to which the IR that is emitted by JAX should be dumped as " "text files. If omitted, JAX will not dump IR. " "Supports the special value 'sponge' to pick the path from the " "environment variable TEST_UNDECLARED_OUTPUTS_DIR.") -_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.DEFINE_string( +_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.string_flag( 'jax_include_debug_info_in_dumps', os.getenv('JAX_INCLUDE_DEBUG_INFO_IN_DUMPS', "True"), help="Determine whether or not to keep debug symbols and location information " @@ -546,13 +546,6 @@ class LoweringParameters: # existing Jax rules. override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None - # The current lowering platforms, a non-empty tuple containing some of - # 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are - # doing multi-platform lowering, otherwise it can specify cross-platform - # lowering. The value None specifies the default lowering platform. - # This is used only in export and jax2tf. - platforms: tuple[str, ...] | None = None - # Signals that the entire computation being lowered operates on global # constants. This will result in adding jax.global_constant attributes # to the arguments of all functions that are created, e.g., floor_divide. @@ -560,6 +553,9 @@ class LoweringParameters: # or multi-platform lowering. global_constant_computation: bool = False + # Signals that we are lowering for exporting. + for_export: bool = False + @dataclasses.dataclass class TracebackCaches: @@ -582,6 +578,8 @@ class ModuleContext: ip: ir.InsertionPoint symbol_table: ir.SymbolTable backend_or_name: str | xb.XlaBackend | None + # The lowering platforms for the module. Can be more than one only when + # exporting. platforms: Sequence[str] axis_context: AxisContext keepalives: list[Any] @@ -616,8 +614,7 @@ def __init__( module: ir.Module | None = None, ip: ir.InsertionPoint | None = None, symbol_table: ir.SymbolTable | None = None, - cached_primitive_lowerings: None | (dict[Any, - func_dialect.FuncOp]) = None, + cached_primitive_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None, traceback_caches: None | TracebackCaches = None, shape_poly_state = None): @@ -686,6 +683,9 @@ class LoweringRuleContext: # module_context.shape_poly_state.dim_vars dim_var_values: Sequence[ir.Value] = () compute_type: str | None = None + # Override module_context.platforms if not None. Used during multi-platform + # lowering, when in a scope with a subset of the module_context.platforms. + platforms: Sequence[str] | None = None def set_tokens_out(self, tokens_out: TokenSet): assert self.tokens_out is None, 'Should only set `tokens_out` once.' @@ -735,7 +735,7 @@ def flatten_lowering_ir_args( _module_name_regex = re.compile(r"[^\w.-]") def sharded_aval(aval: core.AbstractValue, - sharding: XLACompatibleSharding | None) -> core.AbstractValue: + sharding: JSharding | None) -> core.AbstractValue: """Returns the new aval sharded based on sharding proto.""" if sharding is None: return aval @@ -809,16 +809,16 @@ class LoweringResult(NamedTuple): def _to_physical_op_sharding( - aval: core.AbstractValue, sharding: XLACompatibleSharding | None, + aval: core.AbstractValue, sharding: JSharding | None, ) -> xc.OpSharding | None: if sharding is None: return None - assert isinstance(sharding, sharding_impls.XLACompatibleSharding) + assert isinstance(sharding, JSharding) if isinstance(aval, AbstractRef): return _to_physical_op_sharding(aval.inner_aval, sharding) assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) if dtypes.issubdtype(aval.dtype, dtypes.extended): - sharding = aval.dtype._rules.physical_sharding(aval, sharding) + sharding = sharding_impls.physical_sharding(aval, sharding) aval = core.physical_aval(aval) return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore @@ -831,10 +831,10 @@ def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout) -> str | None: return layout._to_xla_layout() -def _get_mem_kind(s: XLACompatibleSharding | None) -> str | None: +def _get_mem_kind(s: JSharding | None) -> str | None: if s is None: return None - assert isinstance(s, sharding_impls.XLACompatibleSharding) + assert isinstance(s, JSharding) return s.memory_kind @@ -849,8 +849,8 @@ def lower_jaxpr_to_module( name_stack: source_info_util.NameStack, donated_args: Sequence[bool], replicated_args: Sequence[bool] | None = None, - arg_shardings: Sequence[XLACompatibleSharding | None] | None = None, - result_shardings: Sequence[XLACompatibleSharding | None] | None = None, + arg_shardings: Sequence[JSharding | None] | None = None, + result_shardings: Sequence[JSharding | None] | None = None, in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, arg_names: Sequence[str | None] | None = None, @@ -886,7 +886,8 @@ def lower_jaxpr_to_module( platforms_with_donation = [p for p in platforms if p in _platforms_with_donation] if platforms_with_donation: - if len(platforms_with_donation) != len(platforms): + if len(platforms_with_donation) != len(platforms) and ( + xla_donated_args or any(donated_args)): raise NotImplementedError( "In multi-platform lowering either all or no lowering platforms " f"should support donation. Lowering for {platforms} of which " @@ -940,8 +941,7 @@ def lower_jaxpr_to_module( channel_iterator=channel_iter, host_callbacks=host_callbacks, lowering_parameters=lowering_parameters, - shape_poly_state=ShapePolyLoweringState( - dim_vars, lowering_parameters.platforms)) + shape_poly_state=ShapePolyLoweringState(dim_vars, platforms)) with ctx.context, ir.Location.unknown(ctx.context): # Remove module name characters that XLA would alter. This ensures that # XLA computation preserves the module name. @@ -970,7 +970,7 @@ def lower_jaxpr_to_module( try: if not ctx.module.operation.verify(): raise ValueError( - "Cannot lower jaxpr with verifier errors." + + "Cannot lower jaxpr with verifier errors. " + dump_module_message(ctx.module, "verification")) except ir.MLIRError as e: msg_lines = ["Cannot lower jaxpr with verifier errors:"] @@ -981,7 +981,7 @@ def emit_diagnostic_info(d): emit_diagnostic_info(n) for d in e.error_diagnostics: emit_diagnostic_info(d) - raise ValueError("\n".join(msg_lines) + + raise ValueError("\n".join(msg_lines) + "\n" + dump_module_message(ctx.module, "verification")) from e return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks, @@ -1090,8 +1090,8 @@ def lower_jaxpr_to_fun( *, public: bool = False, replicated_args: Sequence[bool] | None = None, - arg_shardings: Sequence[XLACompatibleSharding | None] | None = None, - result_shardings: Sequence[XLACompatibleSharding | None] | None = None, + arg_shardings: Sequence[JSharding | None] | None = None, + result_shardings: Sequence[JSharding | None] | None = None, use_sharding_annotations: bool = True, input_output_aliases: Sequence[int | None] | None = None, xla_donated_args: Sequence[bool] | None = None, @@ -1232,10 +1232,6 @@ def lower_jaxpr_to_fun( if pom is not None and mk is None: res.append([pom] * len(types)) else: - if pom is not None and mk is not None and pom != mk: - raise AssertionError( - f"propagated out memory kind ({pom}) does not match the memory" - f" kind specified in out_shardings of jit ({mk})") res.append([mk] * len(types)) # type: ignore # To add the custom call on the output to signal a transfer, only do it # if memory kind comes from out_shardings on `jit` and result_memory_kinds @@ -1380,7 +1376,7 @@ def lower_jaxpr_to_fun( if ir_arg_shardings is not None and name == "main": flat_args = [ - a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # pytype: disable=attribute-error + replicate_trailing_dims(entry_lowering_ctx, o, a) if (a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error for o, s, a in zip(flat_args, ir_arg_shardings, input_avals) @@ -1421,7 +1417,7 @@ def lower_jaxpr_to_fun( if ir_result_shardings is not None and name == "main": flat_outputs = [ - a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # pytype: disable=attribute-error + replicate_trailing_dims(entry_lowering_ctx, o, a) if (a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error for o, s, a in zip(flat_outputs, ir_result_shardings, output_avals) @@ -1445,6 +1441,19 @@ def wrap_with_memory_kind( return op.result +def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: + # Set the sharding of extended dtypes to be UNCONSTRAINED + # (i.e. XLA will choose) on aval.shape. + # For the trailing dims i.e. the dimension of key_shape on the base_array, + # the sharding is set to be REPLICATED always. + # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), + # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). + # The below custom call achieves the sharding like above example. + return wrap_with_sharding_op( + ctx, val, aval, xc.HloSharding.replicate().to_proto(), + unspecified_dims=set(range(aval.ndim))) + + def _emit_lowering_rule_as_fun(lowering_rule, ctx: LoweringRuleContext) -> func_dialect.FuncOp: """Emits the contents of a lowering rule as a private function.""" @@ -1650,7 +1659,7 @@ def lower_per_platform(ctx: LoweringRuleContext, rule_args: the args of the lowering rules. rule_kwargs: the kwargs of the lowering rules. """ - platforms: Sequence[str] = ctx.module_context.platforms + platforms: Sequence[str] = ctx.platforms or ctx.module_context.platforms # Special case the common case (single-platform lowering) if len(platforms) == 1: rule = platform_rules.get(platforms[0], default_rule) @@ -1711,7 +1720,10 @@ def lower_per_platform(ctx: LoweringRuleContext, index=rule_idx_op, num_branches=len(kept_rules)) for i, rule in enumerate(kept_rules): - inner_ctx = ctx.replace() + platforms_for_this_rule = [p + for p, rule_idx in platform_to_kept_rules_idx.items() + if rule_idx == i] + inner_ctx = ctx.replace(platforms=platforms_for_this_rule) branch = case_op.regions[i].blocks.append() with ir.InsertionPoint(branch): output = rule(inner_ctx, *rule_args, **rule_kwargs) @@ -1752,7 +1764,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable: The returned function does not use `avals_out`, so callers may pass any value as `avals_out`.""" - def f_lowered(ctx, *args, **params): + def f_lowered(ctx: LoweringRuleContext, *args, **params): f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) wrapped_fun = lu.wrap_init(f, params) @@ -1762,11 +1774,12 @@ def f_lowered(ctx, *args, **params): # case, we need to form a jaxpr with leading binders for those axis size # arguments (by computing an InputType and using trace_to_jaxpr_dynamic2), # and we need to call jaxpr_subcomp with these arguments made explicit. + assert ctx.axis_size_env is not None args = (*ctx.axis_size_env.values(), *args) idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)} i32_aval = core.ShapedArray((), np.dtype('int32')) implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env) - explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) + explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore if type(a) is core.DShapedArray else a, True) for a in ctx.avals_in] wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args)) @@ -1775,8 +1788,12 @@ def f_lowered(ctx, *args, **params): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out? + if ctx.platforms is not None: + sub_context = ctx.module_context.replace(platforms=ctx.platforms) + else: + sub_context = ctx.module_context out, tokens = jaxpr_subcomp( - ctx.module_context, jaxpr, ctx.name_stack, ctx.tokens_in, + sub_context, jaxpr, ctx.name_stack, ctx.tokens_in, _ir_consts(consts), *map(wrap_singleton_ir_values, args), dim_var_values=ctx.dim_var_values) ctx.set_tokens_out(tokens) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 5eff845556ee..497c9ea129a8 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -14,13 +14,13 @@ from __future__ import annotations from collections import namedtuple -from collections.abc import Sequence, Hashable +from collections.abc import Callable, Sequence, Hashable from contextlib import contextmanager, AbstractContextManager from functools import partial import inspect import itertools as it import operator as op -from typing import Any, Callable, NamedTuple, Union +from typing import Any, NamedTuple, Union from weakref import ref import numpy as np @@ -45,7 +45,8 @@ InputType, OutputType, get_referent, JaxprEqnContext) from jax._src.state.types import AbstractRef from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten, - KeyPath, generate_key_paths, keystr) + tree_flatten, tree_structure, KeyPath, generate_key_paths, + keystr) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, as_hashable_function, weakref_lru_cache, subs_list) @@ -1259,7 +1260,7 @@ def has_effects(effects) -> bool: outvars_copy = list[Atom](eqn.outvars) offload_eqn = core.JaxprEqn( outvars_copy, resvars, device_put_p, - dict(device=TransferToMemoryKind(policy.dst), src=None), + dict(devices=[TransferToMemoryKind(policy.dst)], srcs=[None]), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) known_eqns.append(offload_eqn) @@ -1268,7 +1269,7 @@ def has_effects(effects) -> bool: residuals.update(resvars) reload_eqn = core.JaxprEqn( resvars, eqn.outvars, device_put_p, # type: ignore - dict(device=TransferToMemoryKind(policy.src), src=None), + dict(devices=[TransferToMemoryKind(policy.src)], srcs=[None]), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) staged_eqns.append(reload_eqn) @@ -1738,7 +1739,10 @@ def get_referent(self): frame = self._trace.frame val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) return self if val is None else get_referent(val) -api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval") + +def _dynamic_jaxpr_tracer_shaped_abstractify(x): + return core.raise_to_shaped(x.aval) +api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: sentinel = object() @@ -1800,12 +1804,15 @@ def __init__(self): def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) - def to_jaxpr(self, out_tracers: Sequence[Tracer] - ) -> tuple[Jaxpr, list[Any], list[tuple[Any, str]]]: + def to_jaxpr(self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer] + ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) invars = self.attrs_vars + self.invars - state_outvars = [self.tracer_to_var[id(t)] for t in get_states(self.attrs_tracked)] + state_ans, end_trees = unzip2( + tree_flatten(t) for t in get_states(self.attrs_tracked)) + state_outvars = [self.tracer_to_var[id(trace.full_raise(x))] + for xs in state_ans for x in xs] explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] outvars = state_outvars + explicit_outvars constvars, constvals = unzip2(self.constvar_to_val.items()) @@ -1813,8 +1820,9 @@ def to_jaxpr(self, out_tracers: Sequence[Tracer] jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects) jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore + init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] set_states(self.attrs_tracked, self.attrs_inits) - return jaxpr, list(constvals), self.attrs_tracked + return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) def to_jaxpr2(self, out_tracers): # It's not necessary, but we keep the tracer-to-var mapping injective: @@ -1849,10 +1857,11 @@ def find_progenitors(self, tracer): produced = set(eqn.outvars) & active_vars if produced: active_vars.difference_update(produced) - active_vars.update(eqn.invars) + active_vars.update({v for v in eqn.invars if type(v) is Var}) invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars] constvars = active_vars & set(self.constvar_to_val) - const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars] + const_eqns = [eqn for eqn in self.eqns + if {v for v in eqn.invars if type(v) is Var} & constvars] return invar_positions, const_eqns def _const_folding_and_forwarding( @@ -1868,8 +1877,9 @@ def apply_var_sub(a: Atom) -> Atom: # if any inputs are constants and we have a constant-folding rule, apply it has_input_effect = any(isinstance(eff, effects.JaxprInputEffect) for eff in eqn.effects) - if (eqn.primitive in const_fold_rules and any(v in consts for v in eqn.invars) - and not has_input_effect): + if (eqn.primitive in const_fold_rules and + any(v in consts for v in eqn.invars if isinstance(v, Var)) and + not has_input_effect): consts_in = [consts.get(v) if isinstance(v, Var) else None for v in eqn.invars] consts_out, new_eqn = const_fold_rules[eqn.primitive](consts_in, eqn) @@ -1920,7 +1930,8 @@ def _inline_literals( has_input_effect) if type(c) in core.literalable_types and not np.shape(c) and not e} def lit(a: Atom) -> Literal | None: - return lits.get(a) if isinstance(a, Var) else None + return (a if isinstance(a, Literal) else lits.get(a) if isinstance(a, Var) + else None) newname: Callable[[AbstractValue], Var] = core.gensym() newvars: dict[Var, Var] = {} newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval)) @@ -1932,8 +1943,9 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: return [d for d in aval.shape if isinstance(d, Var)] return [] - used = {v for eqn in jaxpr.eqns for invar in eqn.invars - for v in it.chain([invar], vars_in_shape(invar.aval))} + used = {v for eqn in jaxpr.eqns for atom in eqn.invars + for v in it.chain([atom], vars_in_shape(atom.aval)) + if isinstance(atom, Var)} used |= {v for outvar in jaxpr.outvars for v in it.chain([outvar], vars_in_shape(outvar.aval))} new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)] @@ -1942,7 +1954,7 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: new_invars = [var(v) for v in jaxpr.invars] new_eqns = [] for eqn in jaxpr.eqns: - invars = [lit(v) or var(v) for v in eqn.invars] + invars = [lit(x) or var(x) for x in eqn.invars] outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) new_outvars = [lit(v) or var(v) for v in jaxpr.outvars] @@ -2335,7 +2347,8 @@ def trace_to_jaxpr_dynamic( debug_info: DebugInfo | None = None, *, keep_inputs: list[bool] | None = None, -) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[Any, str]]]: +) -> tuple[Jaxpr, list[AbstractValue], list[Any], + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: with core.new_main(DynamicJaxprTrace, dynamic=True) as main: main.jaxpr_stack = () # type: ignore jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic( @@ -2351,7 +2364,8 @@ def trace_to_subjaxpr_dynamic( *, keep_inputs: Sequence[bool] | None = None, debug_info: DebugInfo | None = None, -) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[Any, str]]]: +) -> tuple[Jaxpr, list[AbstractValue], list[Any], + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs frame = JaxprStackFrame() @@ -2362,7 +2376,7 @@ def trace_to_subjaxpr_dynamic( in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] ans = fun.call_wrapped(*in_tracers_) out_tracers = map(trace.full_raise, ans) - jaxpr, consts, attrs_tracked = frame.to_jaxpr(out_tracers) + jaxpr, consts, attrs_tracked = frame.to_jaxpr(trace, out_tracers) del fun, main, trace, frame, in_tracers, out_tracers, ans config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked @@ -2784,38 +2798,32 @@ def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer): else: return tracer -def inline_jaxpr_into_trace(trace: DynamicJaxprTrace, jaxpr: Jaxpr, consts, - *args) -> list[Any]: +def inline_jaxpr_into_trace( + trace: DynamicJaxprTrace, jaxpr: Jaxpr, consts: Sequence[Any], + *arg_tracers: DynamicJaxprTracer) -> list[Any]: # This function is conceptually the same thing as just calling eval_jaxpr, - # but doesn't redo abstract evaluation: we know the shapes from the jaxpr. - def read(v: Atom) -> Any: - return v.val if isinstance(v, Literal) else env[v] - - def write(v: Var, val: Any) -> None: - if config.enable_checks.value and not config.dynamic_shapes.value: - assert core.typecheck(v.aval, val), (v.aval, val) - env[v] = val + const_tracers = map(trace.new_const, consts) + constvars = map(trace.getvar, const_tracers) + argvars = map(trace.getvar, arg_tracers) + env: dict[Var, Var] = dict(zip([*jaxpr.constvars, *jaxpr.invars], + [*constvars, *argvars])) - env: dict[Var, Any] = {} - map(write, jaxpr.constvars, consts) - map(write, jaxpr.invars, args) - lu = core.last_used(jaxpr) - source_info = source_info_util.current() + src = source_info_util.current() for eqn in jaxpr.eqns: - ins = map(read, eqn.invars) - out_tracers = [DynamicJaxprTracer(trace, a.aval, source_info) - for a in eqn.outvars] - invars = [trace.getvar(trace.full_raise(x)) for x in ins] - outvars = map(trace.makevar, out_tracers) - if eqn.source_info.name_stack: - eqn_source_info = source_info.replace( - name_stack=source_info.name_stack + eqn.source_info.name_stack) - else: - eqn_source_info = source_info - - new_eqn = core.new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params, - eqn.effects, eqn_source_info) - trace.frame.add_eqn(new_eqn) - map(write, eqn.outvars, out_tracers) - core.clean_up_dead_vars(eqn, env, lu) - return map(read, jaxpr.outvars) + invars = [x if isinstance(x, Literal) else env[x] for x in eqn.invars] + outvars = [Var('', v.aval) for v in eqn.outvars] + src_ = (src if not eqn.source_info.name_stack else + src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) + trace.frame.add_eqn(core.new_jaxpr_eqn(invars, outvars, eqn.primitive, + eqn.params, eqn.effects, src_)) + map(env.setdefault, eqn.outvars, outvars) + + tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars], + [*consts, *arg_tracers])) + def new_tracer(atom): + tracer = DynamicJaxprTracer(trace, atom.aval, src) + trace.frame.tracers.append(tracer) + trace.frame.tracer_to_var[id(tracer)] = env[atom] + return tracer + return [x.val if isinstance(x, Literal) else tracer_env[x] if x in tracer_env + else new_tracer(x) for x in jaxpr.outvars] diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 3fb6e41ebbf9..69d7c619b0a6 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -17,16 +17,16 @@ import enum from contextlib import contextmanager +import collections from collections import namedtuple -from collections.abc import Sequence, Iterable +from collections.abc import Callable, Sequence, Iterable, Iterator import dataclasses from functools import partial, lru_cache, cached_property import itertools as it import logging import math import threading -from typing import Any, Callable, NamedTuple, TypeVar, Union, cast -from collections.abc import Iterator +from typing import Any, NamedTuple, TypeVar, Union, cast import warnings import numpy as np @@ -65,6 +65,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec +from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, @@ -106,19 +107,40 @@ class WeakRefList(list): def identity(x): return x -def shard_arg(arg, sharding, canonicalize=True): - if canonicalize: - arg = xla.canonicalize_dtype(arg) - return shard_arg_handlers[type(arg)](arg, sharding) - - @profiler.annotate_function -def shard_args( - shardings: Sequence[sharding_impls.XLACompatibleSharding], args, -) -> Sequence[jax.Array]: - return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)] - -shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {} +def shard_args(shardings: Sequence[JSharding], args, canonicalize=True) -> Sequence[xc.ArrayImpl]: + # Fast path for one argument. + if len(args) == 1: + arg = args[0] + if canonicalize: + arg = xla.canonicalize_dtype(arg) + return shard_arg_handlers[type(arg)]([arg], shardings) + + # type(arg) -> (indices, args, shardings) + batches = collections.defaultdict(lambda: ([], [], [])) # type: ignore + for i, (arg, sharding) in enumerate(safe_zip(args, shardings)): + if canonicalize: + arg = xla.canonicalize_dtype(arg) + batch = batches[type(arg)] + batch[0].append(i) + batch[1].append(arg) + batch[2].append(sharding) + + # Call `shard_arg_handlers` per batch and build a flat list of arrays returned + # from each call in the same order as `args`. Since `batches` is grouped by + # types, we cannot simply flatten the results and we have to use the original + # indices to put each array back to its original position. + results: list[jax.Array | None] = [None] * len(args) + for t, (indices, a, s) in batches.items(): + outs = shard_arg_handlers[t](a, s) + for i, out in safe_zip(indices, outs): + results[i] = out + + assert all(result is not None for result in results) + return results + + +shard_arg_handlers: dict[Any, Callable[[Sequence[Any], Sequence[Any]], Sequence[Any]]] = {} @lru_cache(maxsize=1024) @@ -126,35 +148,38 @@ def _get_replicated_slices(num_addressable_devices: int): return ((slice(None),),) * num_addressable_devices -def _masked_array_error(x, sharding): +def _masked_array_error(xs, shardings): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " "Use arr.filled() to convert the value to a standard numpy array.") shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error -def _shard_array(x, sharding): - devices = sharding._addressable_device_assignment - if x.dtype == dtypes.float0: - x = np.zeros(x.shape, dtype=np.dtype(bool)) - aval = api_util.shaped_abstractify(x) - if sharding.is_fully_replicated: - shards = [x] * len(devices) - else: - indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) - shards = [x[i] for i in indices] - return batched_device_put(aval, sharding, shards, devices) +def _shard_array(xs, shardings): + results = [] + for x, sharding in safe_zip(xs, shardings): + devices = sharding._addressable_device_assignment + if x.dtype == dtypes.float0: + x = np.zeros(x.shape, dtype=np.dtype(bool)) + aval = api_util.shaped_abstractify(x) + if sharding.is_fully_replicated: + shards = [x] * len(devices) + else: + indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) + shards = [x[i] for i in indices] + results.append(batched_device_put(aval, sharding, shards, devices)) + return results for _t in array_types: shard_arg_handlers[_t] = _shard_array -def _shard_darray(x, sharding): - return shard_arg(x._data, sharding) +def _shard_darray(xs, shardings): + return shard_args(shardings, [x._data for x in xs]) shard_arg_handlers[core.DArray] = _shard_darray -def _shard_mutable_array(x, sharding): - return shard_arg(x._buf, sharding) +def _shard_mutable_array(xs, shardings): + return shard_args(shardings, [x._buf for x in xs]) shard_arg_handlers[core.MutableArray] = _shard_mutable_array def batched_device_put(aval: core.ShapedArray, - sharding: jax.sharding.Sharding, xs: Sequence[Any], + sharding: JSharding, xs: Sequence[Any], devices: Sequence[jax.Device], committed: bool = True): from jax._src import array @@ -190,7 +215,7 @@ def _shard_abstract_array(size, axis: int, x): def local_aval_to_result_handler( aval: core.AbstractValue, - sharding: sharding_impls.XLACompatibleSharding, + sharding: JSharding, indices: tuple[Index, ...] | None, ) -> Callable[[list[xc.ArrayImpl]], Any]: """Returns a function for handling the raw buffers of a single output aval. @@ -556,11 +581,18 @@ def parallel_callable(fun: lu.WrappedFun, donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, *avals): + closed_jaxpr, xc_backend, replicas, shards, pci = get_pmap_jaxpr( + fun, backend_name, axis_name, + axis_size=axis_size, global_axis_size=global_axis_size, + devices=devices, name=fun.__name__, in_axes=in_axes, + out_axes_thunk=out_axes_thunk, avals=avals) pmap_computation = lower_parallel_callable( - fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, - in_axes, out_axes_thunk, donated_invars, + fun, axis_name, axis_size, global_axis_size, devices, name, + in_axes, donated_invars, is_explicit_global_axis_size, avals, - lowering_parameters=mlir.LoweringParameters()) + lowering_platforms=None, lowering_parameters=mlir.LoweringParameters(), + closed_jaxpr=closed_jaxpr, backend=xc_backend, replicas=replicas, + shards=shards, pci=pci) pmap_executable = pmap_computation.compile() return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint]) @@ -664,8 +696,7 @@ def stage_parallel_callable( return jaxpr, consts, replicas, shards -@profiler.annotate_function -def lower_parallel_callable( +def get_pmap_jaxpr( fun: lu.WrappedFun, backend_name: str | None, axis_name: core.AxisName, @@ -675,11 +706,41 @@ def lower_parallel_callable( name: str, in_axes: Iterable[int | None], out_axes_thunk: Callable[[], Sequence[int | None]], + avals: Sequence[core.AbstractValue]): + if devices is not None and backend_name is None: + backend = xb.get_device_backend(devices[0]) + else: + backend = xb.get_backend(backend_name) + + pci = ParallelCallableInfo( + name, backend, axis_name, axis_size, global_axis_size, devices, + in_axes, out_axes_thunk, avals) + jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) + jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) + closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + return closed_jaxpr, backend, replicas, shards, pci + + +@profiler.annotate_function +def lower_parallel_callable( + fun: lu.WrappedFun, + axis_name: core.AxisName, + axis_size: int, + global_axis_size: int, + devices: Sequence[xc.Device] | None, + name: str, + in_axes: Iterable[int | None], donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, avals: Sequence[core.AbstractValue], *, - lowering_parameters: mlir.LoweringParameters) -> PmapComputation: + lowering_platforms: tuple[str, ...] | None, + lowering_parameters: mlir.LoweringParameters, + closed_jaxpr: core.ClosedJaxpr, + backend: xc.Client, + replicas: ReplicaInfo, + shards: ShardInfo, + pci: ParallelCallableInfo) -> PmapComputation: # Determine global_axis_size for use in AxisEnv. # TODO(mattjj,skyewm): revive this check (inner_pmap always False now) # if xb.process_count() > 1 and global_axis_size is None and inner_pmap: @@ -690,10 +751,7 @@ def lower_parallel_callable( f"Specified axis_size {global_axis_size} doesn't match received " f"axis_size {axis_size}.") - if devices is not None and backend_name is None: - backend = xb.get_device_backend(devices[0]) - else: - backend = xb.get_backend(backend_name) + jaxpr = closed_jaxpr.jaxpr no_nested_sharding = False must_run_on_all_devices = False @@ -710,10 +768,6 @@ def lower_parallel_callable( # devices). Nested sharding is ok in this case. must_run_on_all_devices = True - pci = ParallelCallableInfo( - name, backend, axis_name, axis_size, global_axis_size, devices, - in_axes, out_axes_thunk, avals) - jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) if logger.isEnabledFor(logging.DEBUG): logger.debug("sharded_avals: %s", shards.sharded_avals) logger.debug("global_sharded_avals: %s", shards.global_sharded_avals) @@ -755,12 +809,11 @@ def lower_parallel_callable( axis_env = sharding_impls.AxisEnv( replicas.num_global_replicas, (axis_name,), (global_axis_size,)) name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap')) - jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) - closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) replicated_args = [axis is None for axis in in_axes] tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals), backend.platform) module_name = f"pmap_{fun.__name__}" + platforms = lowering_platforms or (backend.platform,) with maybe_extend_axis_env(axis_name, global_axis_size, None): ordered_effects = list( effects.ordered_effects.filter_in(closed_jaxpr.effects)) @@ -776,7 +829,7 @@ def lower_parallel_callable( closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - platforms=lowering_parameters.platforms or (backend.platform,), + platforms=platforms, axis_context=sharding_impls.ReplicaAxisContext(axis_env), name_stack=name_stack, donated_args=donated_invars, @@ -787,7 +840,9 @@ def lower_parallel_callable( result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, num_replicas=replicas.num_global_replicas, lowering_parameters=lowering_parameters) - return PmapComputation(lowering_result.module, pci=pci, replicas=replicas, + return PmapComputation(lowering_result.module, + platforms=platforms, + pci=pci, replicas=replicas, shards=shards, tuple_args=tuple_args, unordered_effects=unordered_effects, ordered_effects=ordered_effects, @@ -860,9 +915,9 @@ class UnloadedPmapExecutable: compiled: Any backend: xb.XlaBackend local_input_avals: Sequence[core.AbstractValue] - input_shardings: Sequence[sharding_impls.XLACompatibleSharding] + input_shardings: Sequence[JSharding] local_output_avals: Sequence[ShapedArray] - output_shardings: Sequence[sharding_impls.XLACompatibleSharding] + output_shardings: Sequence[JSharding] unordered_effects: list[core.Effect] ordered_effects: list[core.Effect] keepalive: Sequence[Any] @@ -907,10 +962,13 @@ def from_hlo(hlo: ir.Module, host_callbacks: list[Any], keepalive: Any, jaxpr_debug_info: core.JaxprDebugInfo, + platforms: Sequence[str], shape_poly_state: mlir.ShapePolyLoweringState | None = None, compiler_options=None): + del platforms if shape_poly_state is not None and shape_poly_state.uses_dim_vars: hlo = mlir.refine_polymorphic_shapes(hlo) + devices = pci.devices if devices is None: if shards.num_global_shards > xb.device_count(pci.backend): @@ -1089,7 +1147,7 @@ def __call__(self, out_bufs): def local_avals_to_results_handler( unmapped_local_out_avals: Sequence[ShapedArray], - local_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> ResultsHandler: + local_shardings: Sequence[JSharding]) -> ResultsHandler: out_indices = [tuple(s.devices_indices_map(aval.shape).values()) for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)] handlers = [ @@ -1101,7 +1159,7 @@ def local_avals_to_results_handler( def global_avals_to_results_handler( global_out_avals: Sequence[ShapedArray], - shardings: Sequence[sharding_impls.XLACompatibleSharding], + shardings: Sequence[JSharding], committed: bool) -> ResultsHandler: handlers = [ global_aval_to_result_handler(global_aval, s, committed) @@ -1115,14 +1173,15 @@ class ExecuteReplicated: __slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler', 'has_unordered_effects', 'ordered_effects', 'keepalive', 'has_host_callbacks', '_local_devices', 'kept_var_idx', - 'mut', '__weakref__'] + 'mut', 'pgle_profiler', '__weakref__'] def __init__(self, xla_executable, name, backend, in_handler: InputsHandler, out_handler: ResultsHandler, unordered_effects: list[core.Effect], ordered_effects: list[core.Effect], keepalive: Any, has_host_callbacks: bool, kept_var_idx: set[int], - mut: MutationData | None): + mut: MutationData | None, + pgle_profiler: profiler.PGLEProfiler | None = None): self.xla_executable = xla_executable self.name = name self.backend = backend @@ -1135,6 +1194,7 @@ def __init__(self, xla_executable, name, backend, in_handler: InputsHandler, self.has_host_callbacks = has_host_callbacks self.kept_var_idx = kept_var_idx self.mut = mut + self.pgle_profiler = pgle_profiler def _add_tokens_to_inputs(self, input_bufs): if self.ordered_effects: @@ -1175,25 +1235,33 @@ def __call__(self, *args): if self.mut: args = [*args, *self.mut.in_mut] input_bufs = self.in_handler(args) - if (self.ordered_effects or self.has_unordered_effects - or self.has_host_callbacks): - input_bufs = self._add_tokens_to_inputs(input_bufs) - results = self.xla_executable.execute_sharded( - input_bufs, with_tokens=True - ) - result_token_bufs = results.disassemble_prefix_into_single_device_arrays( - len(self.ordered_effects)) - sharded_runtime_token = results.consume_token() - self._handle_token_bufs(result_token_bufs, sharded_runtime_token) - else: - results = self.xla_executable.execute_sharded(input_bufs) - if dispatch.needs_check_special(): - out_arrays = results.disassemble_into_single_device_arrays() - for arrays in out_arrays: - dispatch.check_special(self.name, arrays) - out = self.out_handler(out_arrays) - else: - out = results.consume_with_handlers(self.out_handler.handlers) + with profiler.PGLEProfiler.trace(self.pgle_profiler): + if (self.ordered_effects or self.has_unordered_effects + or self.has_host_callbacks): + input_bufs = self._add_tokens_to_inputs(input_bufs) + results = self.xla_executable.execute_sharded( + input_bufs, with_tokens=True + ) + + result_token_bufs = results.disassemble_prefix_into_single_device_arrays( + len(self.ordered_effects)) + sharded_runtime_token = results.consume_token() + self._handle_token_bufs(result_token_bufs, sharded_runtime_token) + else: + results = self.xla_executable.execute_sharded(input_bufs) + + if dispatch.needs_check_special(): + out_arrays = results.disassemble_into_single_device_arrays() + for arrays in out_arrays: + dispatch.check_special(self.name, arrays) + out = self.out_handler(out_arrays) + else: + out = results.consume_with_handlers(self.out_handler.handlers) + + if (self.pgle_profiler is not None and self.pgle_profiler.is_running() + and len(out) > 0): + out[0].block_until_ready() + if self.mut is None: return out else: @@ -1528,11 +1596,11 @@ def manual_proto( tad_shape.append(math.prod([named_mesh_shape[a] for a in replicated_axes])) tad_shape.append(math.prod([named_mesh_shape[a] for a in manual_axes])) - raw_mesh = np.arange(math.prod(mesh_shape)).reshape(mesh_shape) proto = xc.OpSharding() proto.type = xc.OpSharding.Type.OTHER proto.tile_assignment_dimensions = tad_shape - proto.tile_assignment_devices = list(raw_mesh.transpose(tad_perm).reshape(tad_shape).flat) + proto.iota_reshape_dims = mesh_shape + proto.iota_transpose_perm = tad_perm proto.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL] return proto @@ -1600,8 +1668,7 @@ class TileManual: def check_if_any_auto( - shardings: Iterable[(sharding_impls.XLACompatibleSharding | - AUTO | UnspecifiedValue)]) -> bool: + shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool: for s in shardings: if is_auto(s): return True @@ -1666,7 +1733,7 @@ class DeviceAssignmentMismatchError(Exception): ShardingInfo = tuple[ - Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO], + Union[JSharding, UnspecifiedValue, AUTO], MismatchType, Union[Any, None], # Any is dispatch.SourceInfo to avoid circular imports ] @@ -1723,7 +1790,7 @@ def _get_and_check_device_assignment( final_device_assignment = first_sharding_info[0] return xb.get_device_backend(final_device_assignment[0]), final_device_assignment -MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue] +MaybeSharding = Union[JSharding, UnspecifiedValue] def prune_unused_inputs( @@ -1889,6 +1956,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, donated_invars, name_stack, all_default_mem_kind, inout_aliases: None | tuple[None | int, ...], propagated_out_mem_kinds: tuple[None | str, ...], + platforms: tuple[str, ...], lowering_parameters: mlir.LoweringParameters): jaxpr = closed_jaxpr.jaxpr in_shardings = semantic_in_shardings._gspmd_shardings @@ -1911,8 +1979,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, nreps = dispatch.jaxpr_replicas(jaxpr) _raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr) - in_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None - out_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None + in_mlir_shardings: list[JSharding | None] | None + out_mlir_shardings: list[JSharding | None] | None axis_ctx: mlir.AxisContext if nreps == 1: @@ -1941,7 +2009,6 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, "The following ordered effects are not supported for " f"more than 1 device: {unsupported_effects}") ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) - with dispatch.log_elapsed_time( "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec", fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): @@ -1950,8 +2017,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - # Optionally, override the lowering platform - platforms=lowering_parameters.platforms or (backend.platform,), + platforms=platforms, axis_context=axis_ctx, name_stack=name_stack, donated_args=donated_invars, @@ -1985,9 +2051,9 @@ def _create_da_object( # pytype: disable=invalid-annotation def jaxpr_transfer_mem_kinds( jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]: for eqn in jaxpr.eqns: - if (eqn.primitive is dispatch.device_put_p and - isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)): - yield eqn.params['device'] + if eqn.primitive is dispatch.device_put_p: + yield from (d for d in eqn.params['devices'] + if isinstance(d, sharding_impls.TransferToMemoryKind)) for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_transfer_mem_kinds(subjaxpr) @@ -2007,8 +2073,8 @@ def are_all_shardings_default_mem_kind(da_object, shardings): memory_kind_propagate_rule: dict[Any, Any] = {} @weakref_lru_cache -def get_out_memory_kinds_via_propagation(closed_jaxpr: core.ClosedJaxpr - ) -> tuple[None | str]: +def get_out_memory_kinds_via_propagation(closed_jaxpr: core.ClosedJaxpr, + in_shardings=None) -> tuple[None | str]: env = {} # type: ignore jaxpr = closed_jaxpr.jaxpr @@ -2023,7 +2089,12 @@ def write(var, val): def _default_rule(prim, num_outvars, *_, **__): return [None] * num_outvars if prim.multiple_results else None - safe_map(write, jaxpr.invars, [None] * len(jaxpr.invars)) + if in_shardings is None: + invar_mem_kind = [None] * len(jaxpr.invars) + else: + invar_mem_kind = [None if is_unspecified_or_auto(s) else s.memory_kind + for s in in_shardings] + safe_map(write, jaxpr.invars, invar_mem_kind) safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars)) for eqn in jaxpr.eqns: @@ -2047,8 +2118,7 @@ class AllArgsInfo(NamedTuple): @lru_cache(maxsize=2048) -def to_gspmd_sharding(s: sharding_impls.XLACompatibleSharding, - ndim: int) -> GSPMDSharding: +def to_gspmd_sharding(s: JSharding, ndim: int) -> GSPMDSharding: if isinstance(s, GSPMDSharding): return s return GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim), @@ -2096,8 +2166,10 @@ def lower_sharding_computation( *, keep_unused: bool, inline: bool, - devices_from_context: Sequence[xc.Device] | None = None, - lowering_parameters: mlir.LoweringParameters + devices_from_context: Sequence[xc.Device] | None, + lowering_platforms: tuple[str, ...] | None, + lowering_parameters: mlir.LoweringParameters, + pgle_profiler: profiler.PGLEProfiler | None, ) -> MeshComputation: """Lowers a computation to XLA. It can take arbitrary shardings as input. @@ -2141,6 +2213,7 @@ def lower_sharding_computation( for js, source_info in util.stable_unique(jaxpr_sharding))), devices_from_context) + platforms = lowering_platforms or (backend.platform,) # TODO(yashkatariya): Enable this when offload APIs are stable. # transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) @@ -2160,7 +2233,8 @@ def lower_sharding_computation( # TODO(yashkatariya): Remove this when XLA can propagate memory kinds or when # JAX puts memory kinds in the types of jaxpr. if not all_default_mem_kind: - propagated_out_mem_kinds = get_out_memory_kinds_via_propagation(closed_jaxpr) + propagated_out_mem_kinds = get_out_memory_kinds_via_propagation( + closed_jaxpr, in_shardings) else: propagated_out_mem_kinds = (None,) * len(global_out_avals) @@ -2179,7 +2253,8 @@ def lower_sharding_computation( semantic_out_shardings, in_layouts, out_layouts, len(da_object), tuple(da_object) if prim_requires_devices else None, donated_invars, name_stack, all_default_mem_kind, inout_aliases, - propagated_out_mem_kinds, lowering_parameters=lowering_parameters) + propagated_out_mem_kinds, platforms, + lowering_parameters=lowering_parameters) # backend and device_assignment is passed through to MeshExecutable because # if keep_unused=False and all in_shardings are pruned, then there is no way @@ -2190,6 +2265,7 @@ def lower_sharding_computation( str(name_stack), module, donated_invars, + platforms, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_shardings=in_shardings, @@ -2211,16 +2287,17 @@ def lower_sharding_computation( pmap_nreps=nreps, shape_poly_state=shape_poly_state, all_default_mem_kind=all_default_mem_kind, - all_args_info=all_args_info) + all_args_info=all_args_info, + pgle_profiler=pgle_profiler) def _to_logical_sharding( aval: core.AbstractValue, sharding: MaybeSharding | AUTO -) -> sharding_impls.XLACompatibleSharding | None: +) -> JSharding | None: if is_unspecified(sharding) or is_auto(sharding): return None elif isinstance(aval, (ShapedArray, DShapedArray, AbstractRef)): - assert isinstance(sharding, sharding_impls.XLACompatibleSharding) + assert isinstance(sharding, JSharding) return sharding elif isinstance(aval, core.AbstractToken): return None @@ -2241,9 +2318,11 @@ def lower_mesh_computation( spmd_lowering: bool, global_in_avals: Sequence[core.ShapedArray], tiling_method: TilingMethod | None, + lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters) -> MeshComputation: assert not mesh.empty backend = xb.get_device_backend(mesh.devices.flat[0]) + platforms = lowering_platforms or (backend.platform,) name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) global_axis_sizes = mesh.shape @@ -2312,8 +2391,8 @@ def lower_mesh_computation( # 2. Build up the HLO tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform) - in_partitions: list[sharding_impls.XLACompatibleSharding | None] | None - out_partitions: list[sharding_impls.XLACompatibleSharding | None] | None + in_partitions: list[JSharding | None] | None + out_partitions: list[JSharding | None] | None axis_ctx: mlir.AxisContext if spmd_lowering: in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings) @@ -2352,7 +2431,7 @@ def lower_mesh_computation( closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - platforms=lowering_parameters.platforms or (backend.platform,), + platforms=platforms, axis_context=axis_ctx, name_stack=name_stack, donated_args=donated_invars, @@ -2369,6 +2448,7 @@ def lower_mesh_computation( str(name_stack), lowering_result.module, donated_invars, + platforms, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_shardings=in_shardings, @@ -2387,17 +2467,20 @@ def lower_mesh_computation( in_layouts=(None,) * len(global_in_avals), out_layouts=(None,) * len(global_out_avals), shape_poly_state=lowering_result.shape_poly_state, - all_args_info=None) + all_args_info=None, + pgle_profiler=None) class MeshComputation(stages.XlaLowering): _hlo: ir.Module _executable: MeshExecutable | None def __init__(self, name: str, hlo: ir.Module, - donated_invars: Sequence[bool], **compile_args): + donated_invars: Sequence[bool], platforms: Sequence[str], + **compile_args): self._name = name self._hlo = hlo self._donated_invars = donated_invars + self._platforms = platforms self.compile_args = compile_args self._executable = None @@ -2523,7 +2606,7 @@ def _get_mesh_pspec_shardings_from_executable( _orig_out_sharding_handlers = {} -_ShardingT = TypeVar("_ShardingT", bound=sharding_impls.XLACompatibleSharding) +_ShardingT = TypeVar("_ShardingT", bound=JSharding) def _register_out_sharding_handler( @@ -2532,20 +2615,10 @@ def _register_out_sharding_handler( ) -> None: _orig_out_sharding_handlers[sharding_cls] = handler - -def _gspmd_to_named_sharding_via_mesh( - out_s: sharding_impls.GSPMDSharding, - mesh: Mesh) -> sharding_impls.NamedSharding: - parsed_pspec = sharding_impls.parse_flatten_op_sharding( - out_s._hlo_sharding, mesh)[0] - return create_mesh_pspec_sharding( - mesh, parsed_pspec.get_partition_spec(), parsed_pspec, - out_s.memory_kind) - def _gspmd_to_named_sharding( out_s: sharding_impls.GSPMDSharding, orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding: - return _gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) + return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) _register_out_sharding_handler( sharding_impls.NamedSharding, _gspmd_to_named_sharding) @@ -2662,38 +2735,23 @@ def get_logical_mesh_ids(mesh_shape): return np.arange(math.prod(mesh_shape)).reshape(mesh_shape) -@weakref_lru_cache -def _cached_compilation(computation, name, mesh, spmd_lowering, - tuple_args, auto_spmd_lowering, allow_prop_to_inputs, - allow_prop_to_outputs, host_callbacks, backend, - da, pmap_nreps, compiler_options_keys, - compiler_options_values): - # TODO(phawkins): One would normally just write: - # dev = np.array(device_assignment) - # The formulation below is substantially faster if there are many devices. - # If we were to optimize __getattr__ on xc.Device we might not need this - # workaround. - dev = np.vectorize(lambda i: da[i], otypes=[object])( - np.arange(len(da)) - ) +def create_compile_options( + computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, + allow_prop_to_inputs, allow_prop_to_outputs, backend, + np_dev, pmap_nreps, compiler_options): if pmap_nreps > 1: num_replicas, num_partitions = pmap_nreps, 1 elif spmd_lowering: - num_replicas, num_partitions = 1, dev.size + num_replicas, num_partitions = 1, np_dev.size else: - num_replicas, num_partitions = dev.size, 1 + num_replicas, num_partitions = np_dev.size, 1 if pmap_nreps > 1: # In `jit` device_assignment is set to None when num_replicas > 1. Do # the same thing here too. xla_device_assignment = None else: - xla_device_assignment = dev.reshape((num_replicas, num_partitions)) - - if compiler_options_keys is None: - compiler_options = None - else: - compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values)) + xla_device_assignment = np_dev.reshape((num_replicas, num_partitions)) fdo_profile = (None if compiler_options is None else compiler_options.pop("fdo_profile", None)) @@ -2720,12 +2778,36 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, compile_options.parameter_is_tupled_arguments = tuple_args opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs) opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs) + return compile_options + + +@weakref_lru_cache +def _cached_compilation(computation, name, mesh, spmd_lowering, + tuple_args, auto_spmd_lowering, allow_prop_to_inputs, + allow_prop_to_outputs, host_callbacks, backend, + da, pmap_nreps, compiler_options_keys, + compiler_options_values, + pgle_profiler): + # One would normally just write: dev = np.array(device_assignment) + # The formulation below is substantially faster if there are many devices. + dev = np.vectorize(lambda i: da[i], otypes=[object])(np.arange(len(da))) + + if compiler_options_keys is None: + compiler_options = None + else: + compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values)) + + compile_options = create_compile_options( + computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, + allow_prop_to_inputs, allow_prop_to_outputs, backend, + dev, pmap_nreps, compiler_options) with dispatch.log_elapsed_time( "Finished XLA compilation of {fun_name} in {elapsed_time} sec", fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT): xla_executable = compiler.compile_or_get_cached( - backend, computation, dev, compile_options, host_callbacks) + backend, computation, dev, compile_options, host_callbacks, + pgle_profiler) return xla_executable @@ -2753,7 +2835,7 @@ def _maybe_get_and_check_in_shardings( if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): - xla_s = aval.dtype._rules.logical_sharding(aval, xla_s) + xla_s = sharding_impls.logical_sharding(aval, xla_s) new_in_shardings.append(xla_s) else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) @@ -2789,7 +2871,7 @@ def _maybe_get_and_check_out_shardings( if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): - xla_s = aval.dtype._rules.logical_sharding(aval, xla_s) + xla_s = sharding_impls.logical_sharding(aval, xla_s) new_out_shardings.append(xla_s) else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) @@ -2819,9 +2901,9 @@ class UnloadedMeshExecutable: device_assignment: xc.DeviceList | Sequence[xc.Device] backend: xb.XlaBackend input_avals: Sequence[ShapedArray] - input_shardings: Sequence[sharding_impls.XLACompatibleSharding] + input_shardings: Sequence[JSharding] output_avals: Sequence[ShapedArray] - output_shardings: Sequence[sharding_impls.XLACompatibleSharding] + output_shardings: Sequence[JSharding] committed: bool name: str unordered_effects: list[core.Effect] @@ -2834,6 +2916,7 @@ class UnloadedMeshExecutable: in_layouts: Sequence[DeviceLocalLayout | None] out_layouts: Sequence[DeviceLocalLayout | None] all_args_info: AllArgsInfo | None + pgle_profiler: profiler.PGLEProfiler | None def build_unsafe_call(self): handle_args = InputsHandler(self.input_shardings) @@ -2843,7 +2926,8 @@ def build_unsafe_call(self): unsafe_call = ExecuteReplicated( self.xla_executable, self.name, self.backend, handle_args, handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive, - bool(self.host_callbacks), self.kept_var_idx, self.mut) + bool(self.host_callbacks), self.kept_var_idx, self.mut, + self.pgle_profiler) return unsafe_call def load(self) -> MeshExecutable: @@ -2859,9 +2943,8 @@ def from_hlo(name: str, hlo: ir.Module, global_in_avals: Sequence[ShapedArray], global_out_avals: Sequence[ShapedArray], - in_shardings: Sequence[sharding_impls.XLACompatibleSharding | AUTO], - out_shardings: Sequence[(sharding_impls.XLACompatibleSharding | AUTO | - UnspecifiedValue)], + in_shardings: Sequence[JSharding | AUTO], + out_shardings: Sequence[(JSharding | AUTO | UnspecifiedValue)], spmd_lowering: bool, tuple_args: bool, auto_spmd_lowering: bool, @@ -2881,6 +2964,7 @@ def from_hlo(name: str, all_default_mem_kind: bool = True, all_args_info: AllArgsInfo | None = None, compiler_options=None, + pgle_profiler: profiler.PGLEProfiler | None = None ) -> MeshExecutable: if shape_poly_state is not None and shape_poly_state.uses_dim_vars: hlo = mlir.refine_polymorphic_shapes(hlo) @@ -2910,7 +2994,7 @@ def from_hlo(name: str, hlo, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps, - compiler_options_keys, compiler_options_values) + compiler_options_keys, compiler_options_values, pgle_profiler) if auto_spmd_lowering: assert mesh is not None @@ -2960,14 +3044,15 @@ def from_hlo(name: str, auto_spmd_lowering=auto_spmd_lowering, in_layouts=in_layouts, out_layouts=out_layouts, - all_args_info=all_args_info).load() + all_args_info=all_args_info, + pgle_profiler=pgle_profiler).load() class MeshExecutableFastpathData(NamedTuple): xla_executable: xc.LoadedExecutable out_pytree_def: Any - in_shardings: Sequence[sharding_impls.XLACompatibleSharding] - out_shardings: Sequence[sharding_impls.XLACompatibleSharding] + in_shardings: Sequence[JSharding] + out_shardings: Sequence[JSharding] out_avals: Sequence[ShapedArray] out_committed: Sequence[bool] kept_var_bitvec: Iterable[bool] @@ -3043,10 +3128,10 @@ def call(self, *args): self._kept_var_idx) return self.unsafe_call(*args) # pylint: disable=not-callable - def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]: + def input_shardings(self) -> Sequence[JSharding]: return self._in_shardings - def output_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]: + def output_shardings(self) -> Sequence[JSharding]: return self._out_shardings def input_layouts(self): @@ -3076,7 +3161,7 @@ def aot_cache_miss(*args, **kwargs): kept_var_bitvec = [i in self._kept_var_idx for i in range(len(args_flat))] in_shardings = [ - a.dtype._rules.physical_sharding(a, s) + sharding_impls.physical_sharding(a, s) if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) else s for s, a in zip(self._in_shardings, self.in_avals) @@ -3088,11 +3173,11 @@ def aot_cache_miss(*args, **kwargs): self.unsafe_call.in_handler.input_indices) else: fastpath_data = None - return outs, fastpath_data + return outs, fastpath_data, False # Do not remove cache entry return xc._xla.pjit( self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, shard_arg) + tree_util.dispatch_registry, lambda x, s: shard_args([s], [x])[0]) def check_arg_avals_for_call(ref_avals, arg_avals, @@ -3139,14 +3224,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): return in_shardings, out_shardings, committed, tuple(local_devices) -@util.cache() -def create_mesh_pspec_sharding( - mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None, - memory_kind: str | None = None) -> sharding_impls.NamedSharding: - if pspec is None: - pspec, parsed_pspec = PartitionSpec(), None - return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec, - memory_kind=memory_kind) +create_mesh_pspec_sharding = sharding_impls.create_mesh_pspec_sharding def check_device_backend_on_shardings(shardings) -> bool: @@ -3160,7 +3238,7 @@ def check_device_backend_on_shardings(shardings) -> bool: def check_array_xla_sharding_layout_match( args_after_dce, - in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding], + in_xla_shardings: Sequence[JSharding], in_xla_layouts: Sequence[DeviceLocalLayout], jaxpr_debug_info: core.JaxprDebugInfo | None, kept_var_idx: set[int]) -> None: diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 314fdccfb975..2db877d3f970 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -17,12 +17,12 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools from functools import partial import itertools as it -from typing import Any, Callable, Protocol, Union +from typing import Any, Protocol, Union import numpy as np diff --git a/jax/_src/jaxpr_util.py b/jax/_src/jaxpr_util.py index 1358980d04ac..3f3f677b069d 100644 --- a/jax/_src/jaxpr_util.py +++ b/jax/_src/jaxpr_util.py @@ -17,11 +17,12 @@ from __future__ import annotations from collections import Counter, defaultdict +from collections.abc import Callable import gzip import itertools import json import types -from typing import Any, Callable, Union +from typing import Any, Union from jax._src import core from jax._src import util diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 683c4754fcbf..f950cfeada92 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -229,11 +229,19 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, if not dtypes.issubdtype(operand.dtype, np.floating): raise ValueError('operand must be a floating type') reduction_input_size = dims[reduction_dimension] - dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize( - reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, - reduction_input_size_override)[0] - return (operand.update( - shape=dims, dtype=operand.dtype, weak_type=operand.weak_type), + if aggregate_to_topk: + dims[reduction_dimension] = k + elif core.is_constant_shape((reduction_input_size, k)): + dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize( + reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, + reduction_input_size_override)[0] + else: + raise NotImplementedError( + "approx_top_k with aggregate_to_topk=False not yet implemented when " + f"either the `k` ({k}) or the " + f" reduction dimension size ({reduction_input_size}) are symbolic") + return (operand.update(shape=dims, dtype=operand.dtype, + weak_type=operand.weak_type), operand.update(shape=dims, dtype=np.dtype(np.int32))) @@ -254,30 +262,6 @@ def _comparator_builder(op_type, is_max_k): def _get_init_val_literal(op_type, is_max_k): return np.array(-np.inf if is_max_k else np.inf, dtype=op_type) -def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k, - reduction_dimension, recall_target, is_max_k, - reduction_input_size_override, - aggregate_to_topk): - c = ctx.builder - op_shape = c.get_shape(operand) - if not op_shape.is_array(): - raise ValueError(f'operand must be an array, but was {op_shape}') - op_dims = op_shape.dimensions() - op_type = op_shape.element_type() - if reduction_dimension < 0: - reduction_dimension = len(op_dims) + reduction_dimension - comparator = _comparator_builder(op_type, is_max_k) - init_val_literal = _get_init_val_literal(op_type, is_max_k) - iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims), - reduction_dimension) - init_val = xc.ops.Constant(c, init_val_literal) - init_arg = xc.ops.Constant(c, np.int32(-1)) - out = xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k, - reduction_dimension, comparator, recall_target, - aggregate_to_topk, reduction_input_size_override) - return xla.xla_destructure(c, out) - - def _comparator_builder_mlir(ctx, op_type, is_max_k): scalar = ir.RankedTensorType.get([], op_type) index = ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)) @@ -326,7 +310,6 @@ def _approx_top_k_lowering(ctx, operand, *, k, init_val = mlir.ir_constant(init_val_array.reshape(())) backend_config = { - "top_k" : mlir.i64_attr(k), "reduction_dim" : mlir.i64_attr(reduction_dimension), "recall_target" : mlir.ir.FloatAttr.get(recall_type, recall_target), "aggregate_to_topk" : mlir.ir.BoolAttr.get(aggregate_to_topk), @@ -342,13 +325,24 @@ def _approx_top_k_lowering(ctx, operand, *, k, mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape)) for aval_out in ctx.avals_out] - out = mlir.custom_call( - "ApproxTopK", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=[operand, iota, init_val, init_arg], - called_computations=[comparator.name.value], - backend_config=backend_config, - result_shapes=result_shapes) + if core.is_constant_dim(k): + backend_config["top_k"] = mlir.i64_attr(k) + out = mlir.custom_call( + "ApproxTopK", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=[operand, iota, init_val, init_arg], + called_computations=[comparator.name.value], + backend_config=backend_config, + result_shapes=result_shapes) + else: + k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,)) + out = mlir.custom_call( + "stablehlo.dynamic_approx_top_k", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=[operand, iota, init_val, init_arg, k_value], + called_computations=[comparator.name.value], + backend_config=backend_config, + result_shapes=result_shapes) return out.results diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index f75175b79bb6..b613193876b6 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -15,10 +15,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import os from functools import partial -from typing import Any, Callable +from typing import Any from jax._src import core from jax._src import linear_util as lu diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index ba957417db20..ffe5d086a98d 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -15,14 +15,15 @@ from __future__ import annotations import collections -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools from functools import partial import inspect import itertools import operator -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar +import jax from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util from jax._src import config @@ -247,6 +248,7 @@ def cond(pred, true_fun, false_fun, *operands): if any(isinstance(op_aval, AbstractRef) for op_aval in ops_avals): raise ValueError("Cannot pass `Ref`s into `cond`.") true_jaxpr, false_jaxpr = jaxprs + out_tree, false_out_tree = out_trees if any(isinstance(out_aval, AbstractRef) for out_aval in true_jaxpr.out_avals + false_jaxpr.out_avals): @@ -255,6 +257,14 @@ def cond(pred, true_fun, false_fun, *operands): _check_tree_and_avals("true_fun and false_fun output", out_tree, true_jaxpr.out_avals, false_out_tree, false_jaxpr.out_avals) + # prune passhtrough outputs + true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr) + false_fwds = pe._jaxpr_forwarding(false_jaxpr.jaxpr) + in_fwd = [i if i == j else None for i, j in zip(true_fwds, false_fwds)] + keep = [f is None for f in in_fwd] + true_jaxpr = pe.prune_closed_jaxpr_outputs(true_jaxpr, keep) + false_jaxpr = pe.prune_closed_jaxpr_outputs(false_jaxpr, keep) + joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: @@ -273,6 +283,18 @@ def cond(pred, true_fun, false_fun, *operands): out = cond_p.bind( index, *consts, *ops, branches=(false_jaxpr, true_jaxpr), linear=tuple(linear)) + num_consts = len(consts) + out_ = iter(out) + + def _cast_to_array(x): + _copy = isinstance(x, np.bool_) + return jax.numpy.asarray(x, copy=_copy) + + out = [ + next(out_) if fwd is None else _cast_to_array(ops[fwd - num_consts]) + for fwd in in_fwd + ] + assert next(out_, None) is None return tree_unflatten(out_tree, out) @api_boundary @@ -953,7 +975,19 @@ def other_platforms_code(*args): ... # Use a switch, to get the proper transformation rules for free. Since # platform index has no dependence on the input data, it won't be vectorized # under vmap. - return switch(platform_index, branches, *args) + # If the switch and the platform_index_p above are in the same compilation + # unit then constant-folding will remove the unnecessary branches. However, + # if we run in eager mode the switch below cannot be constant-folded and + # the compilation may fail if some of the branches contain custom calls not + # recognized on the compilation platform. Detect eager mode and keep only the + # needed branch. + try: + platform_index_concrete = core.concrete_or_error(operator.index, platform_index) + except core.ConcretizationTypeError: + return switch(platform_index, branches, *args) + else: + assert 0 <= platform_index_concrete < len(branches) + return branches[platform_index_concrete](*args) # A primitive to compute the index of a platform into a list of platforms. # Args: diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index c20e22a2385b..936656b0e7df 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -15,10 +15,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import operator -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Generic, TypeVar import jax.numpy as jnp from jax import lax diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 04d02cc5f4bf..0c704b84475b 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -14,12 +14,12 @@ """Module for the loop primitives.""" from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import inspect import itertools import operator -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar import weakref import jax @@ -50,7 +50,6 @@ _abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr, _initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros, _typecheck_param) -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.ufuncs import logaddexp @@ -314,12 +313,20 @@ def _create_jaxpr(init): def _set_states(attrs_tracked, vals): from jax.experimental.attrs import jax_setattr - for ((obj, attr), val) in zip(attrs_tracked, vals): + valss = split_list_checked(vals, [td.num_leaves for _, td, _ in attrs_tracked]) + for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): + val = tree_unflatten(treedef, leaves) jax_setattr(obj, attr, val) def _get_states(attrs_tracked): from jax.experimental.attrs import jax_getattr - return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked] + vals = [] + for treedef, _, (obj, attr) in attrs_tracked: + tree = jax_getattr(obj, attr) + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + return vals def _check_scan_carry_type(body_fun, in_carry, out_carry_tree, out_avals): try: @@ -671,12 +678,10 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, def _maybe_put(x): if isinstance(x, np.ndarray): - return dispatch._put_x( - x, - jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]), - shaped_abstractify(x), - False, - ) + aval = shaped_abstractify(x) + s = jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]) + result_handler = pxla.global_aval_to_result_handler(aval, s, False) + return result_handler(pxla.shard_args([s], [x])) else: return x @@ -2364,25 +2369,12 @@ def register_lowering(fn, platform=None): mlir.cache_lowering(mlir.lower_fun(fn, multiple_results=False)), platform=platform) - if xla_extension_version >= 266: - # In XLA, there's a rewriter for an O(N^2) reduce-window implementation. - register_lowering( - partial(cumred_reduce_window_impl, reduce_window_fn) - ) - else: - # Older XLA versions only have this rewrite for TPU. - register_lowering( - partial(cumred_reduce_window_impl, reduce_window_fn), 'tpu' - ) - # Default for platforms not treated specially below. - register_lowering(partial(associative_scan, reduce_fn)) - - # On GPU, we choose between window reduction and associative scan - # based on the input size. - for platform in ['cuda', 'rocm']: - register_lowering( - partial(cumred_gpu_impl, reduce_window_fn, reduce_fn), platform - ) + # For jax-metal, until reduce_window legalization is better supported. + register_lowering(partial(associative_scan, reduce_fn), 'METAL') + # In XLA, there's a rewriter for an O(N^2) reduce-window implementation. + register_lowering( + partial(cumred_reduce_window_impl, reduce_window_fn) + ) return reducer_p diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 1f621d685a15..52d8cac9a3dc 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -15,14 +15,14 @@ from __future__ import annotations import builtins -from collections.abc import Sequence +from collections.abc import Callable, Sequence import enum import functools from functools import partial import itertools import math import operator -from typing import Any, Callable, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING +from typing import Any, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING import warnings import numpy as np @@ -67,7 +67,7 @@ from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo from jax._src.sharding_impls import PmapSharding -from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike, Shape +from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis, split_list, NumpyComplexWarning) @@ -85,9 +85,9 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip -def _clip_int_to_valid_range(val: int, dtype) -> int: +def _clip_int_to_valid_range(val: DimSize, dtype) -> int: info = np.iinfo(dtype) - return builtins.max(info.min, builtins.min(int(val), info.max)) + return core.max_dim(info.min, core.min_dim(val, info.max)) def _validate_shapes(shapes: Sequence[Shape]): def _check_static_shape(shape: Shape): @@ -648,22 +648,26 @@ def value(self) -> int: else: class Precision(enum.Enum): - """Precision enum for lax functions + """Precision enum for lax matrix multiply related functions. - The `precision` argument to JAX functions generally controls the tradeoff - between speed and accuracy for array computations on accelerator backends, - (i.e. TPU and GPU). Members are: + The device-dependent `precision` argument to JAX functions generally + controls the tradeoff between speed and accuracy for array computations on + accelerator backends, (i.e. TPU and GPU). Has no impact on CPU backends. + This only has an effect on float32 computations, and does not affect the + input/output datatypes. Members are: DEFAULT: - Fastest mode, but least accurate. Performs computations in bfloat16. - Aliases: ``'default'``, ``'fastest'``, ``'bfloat16'``. + Fastest mode, but least accurate. On TPU: performs float32 computations in + bfloat16. On GPU: uses tensorfloat32 if available (e.g. on A100 and H100 + GPUs), otherwise standard float32 (e.g. on V100 GPUs). Aliases: + ``'default'``, ``'fastest'``. HIGH: - Slower but more accurate. Performs float32 computations in 3 bfloat16 - passes, or using tensorfloat32 where available. Aliases: ``'high'``, - ``'bfloat16_3x'``, ``'tensorfloat32'``. + Slower but more accurate. On TPU: performs float32 computations in 3 + bfloat16 passes. On GPU: uses tensorfloat32 where available, otherwise + float32. Aliases: ``'high'``.. HIGHEST: - Slowest but most accurate. Performs computations in float32 or float64 - as applicable. Aliases: ``'highest'``, ``'float32'``. + Slowest but most accurate. On TPU: performs float32 computations in 6 + bfloat16. Aliases: ``'highest'``. On GPU: uses float32. """ DEFAULT = 0 @@ -801,7 +805,7 @@ def ragged_dot( group_sizes: (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group. precision: Optional. Consistent with precision argument for :func:`jax.lax.dot`. preferred_element_type: Optional. Consistent with precision argument for :func:`jax.lax.dot`. - group_offset: Optional. (1,) shaped array that ndicates the group in group_sizes to start computing from. If not specified, defaults to [0]. + group_offset: Optional. (1,) shaped array that indicates the group in group_sizes to start computing from. If not specified, defaults to [0]. Results: (m, n) shaped array with preferred_element_type element type. @@ -1232,12 +1236,24 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]: k: integer specifying the number of top entries. Returns: - values: array containing the top k values along the last axis. - indices: array containing the indices corresponding to values. + A tuple ``(values, indices)`` where + + - ``values`` is an array containing the top k values along the last axis. + - ``indices`` is an array containing the indices corresponding to values. See also: - - :func:`jax.lax.approx_max_k` - - :func:`jax.lax.approx_min_k` + - :func:`jax.lax.approx_max_k` + - :func:`jax.lax.approx_min_k` + + Examples: + Find the largest three values, and their indices, within an array: + + >>> x = jnp.array([9., 3., 6., 4., 10.]) + >>> values, indices = jax.lax.top_k(x, 3) + >>> values + Array([10., 9., 6.], dtype=float32) + >>> indices + Array([4, 0, 2], dtype=int32) """ if core.is_constant_dim(k): k = int(k) @@ -1319,7 +1335,7 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int) -> Array: return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape), dimension=dimension) -def _eye(dtype: DTypeLike, shape: Shape, offset: int) -> Array: +def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array: """Like numpy.eye, create a 2D array with ones on a diagonal.""" offset = _clip_int_to_valid_range(offset, np.int32) dtype = dtypes.canonicalize_dtype(dtype) @@ -1339,11 +1355,12 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array: new_dtype=dtype, weak_type=False) return broadcast_in_dim(result, shape, axes) -def _tri(dtype: DTypeLike, shape: Shape, offset: int) -> Array: +def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array: """Like numpy.tri, create a 2D array with ones below a diagonal.""" offset = _clip_int_to_valid_range(offset, np.int32) dtype = dtypes.canonicalize_dtype(dtype) - bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)), + bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0), + asarray(core.dimension_as_value(offset)).astype(np.int32)), broadcasted_iota(np.int32, shape, 1)) return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False) @@ -1431,7 +1448,7 @@ def full_like(x: ArrayLike | DuckTypedArray, If not specified, the output will have the same sharding as the input, with a few exceptions/limitations in particular: 1. Sharding is not available during tracing, thus this will rely on jit. - 2. If x is weakly typed or uncomitted, will use default sharding. + 2. If x is weakly typed or uncommitted, will use default sharding. 3. Shape is not None and is different from x.shape, default will be used. Returns: @@ -2946,16 +2963,6 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): core.ShapedArray(rhs_aval.shape, aval_out.dtype)) lhs_dtype = rhs_dtype = aval_out.dtype - # TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul - if platform == "cpu": - if lhs_dtype == np.float16: - lhs = mlir.convert_hlo(ctx, lhs, lhs_aval, - core.ShapedArray(lhs_aval.shape, np.float32)) - - if rhs_dtype == np.float16: - rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, - core.ShapedArray(rhs_aval.shape, np.float32)) - dot_dnums = hlo.DotDimensionNumbers.get( lhs_batching_dimensions=list(lhs_batch), @@ -2983,10 +2990,10 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S m, k = lhs.shape group_count, rk, n = rhs.shape if k != rk: - raise TypeError("ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {} and {}.".format(k, rk)) + raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.") num_groups = group_sizes.shape[0] if group_count != num_groups: - raise TypeError("ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {} and {}.".format(group_count, num_groups)) + raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.") return (m, n) # DotDimensionNumbers used in the dot_general call for ragged_dot(). @@ -4789,7 +4796,9 @@ def _copy_impl(prim, *args, **kwargs): ad.deflinear(copy_p, lambda t: [copy_p.bind(t)]) pe.def_trivial_padding(copy_p) batching.defvectorized(copy_p) - +def _propagate_mem_kind_copy(in_mem_kind): + return in_mem_kind +pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy def rng_bit_generator(key, shape, dtype=np.uint32, algorithm=RandomAlgorithm.RNG_DEFAULT): @@ -5099,7 +5108,7 @@ def remaining(original, *removed_lists): def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precision] | None: - """Turns an API precision specification, into a pair of enumeration values. + """Turns an API precision specification into a pair of enumeration values. The API can take the precision as a string, or int, and either as a single value to apply to both operands, or as a sequence of two values. @@ -5206,14 +5215,6 @@ def handler(bufs): return core.DArray(aval, phys_handler(bufs)) return handler - @staticmethod - def logical_sharding(aval, phys_sharding): - return phys_sharding - - @staticmethod - def physical_sharding(aval, sharding): - return sharding - @staticmethod def convert_from(bint_dtype, other_dtype) -> bool: return other_dtype in (np.dtype('int32'), np.dtype('int64')) @@ -5222,12 +5223,5 @@ def convert_from(bint_dtype, other_dtype) -> bool: def convert_to(other_dtype, bint_dtype) -> bool: return other_dtype in (np.dtype('int32'), np.dtype('int64')) - @staticmethod - def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: - return val - - @staticmethod - def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval): - pass core.bint._rules = BIntRules diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 3740555bfacf..d31bba99171c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -14,10 +14,11 @@ from __future__ import annotations +from collections.abc import Callable import functools from functools import partial import math -from typing import Any, Callable, Literal, TypeVar, overload +from typing import Any, Literal, TypeVar, overload import numpy as np @@ -44,7 +45,6 @@ from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack -from jax._src.lib import version as jaxlib_version from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo @@ -510,10 +510,6 @@ def _cholesky_update_cuda_lowering_rule(ctx, r_matrix, w_vector): raise NotImplementedError( "Can only lower fast cholesky_update on CUDA." ) - if jaxlib_version < (0, 4, 29): - raise NotImplementedError( - f"The jaxlib version {jaxlib_version} is too old." - "Please update to at least 0.4.29.") return gpu_linalg.cuda_cholesky_update( r_matrix, w_vector, r_matrix_aval.dtype) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index a5a17b8b1322..47386cb4a5f0 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -983,10 +983,8 @@ def source_to_front(group): return [group[source]] + list(group[:source]) + list(group[source + 1:]) replica_groups = [source_to_front(group) for group in replica_groups] channel = ctx.module_context.new_channel() - channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE) return hlo.CollectiveBroadcastOp( - x, replica_groups=_replica_groups_hlo(replica_groups), - channel_handle=channel_handle).results + x, replica_groups=_replica_groups_hlo(replica_groups)).results pbroadcast_p = core.AxisPrimitive('pbroadcast') pbroadcast_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) @@ -1542,10 +1540,10 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, When ``False`` (the default value), the size of dimension in ``scatter_dimension`` must match the size of axis ``axis_name`` (or the group size if ``axis_index_groups`` is given). After scattering the - all-reduce result along ``scatter_dimension``, the output is sequeezed by + all-reduce result along ``scatter_dimension``, the output is squeezed by removing ``scatter_dimension``, so the result has lower rank than the input. When ``True``, the size of dimension in ``scatter_dimension`` must - be dividible by the size of axis ``axis_name`` (or the group size if + be divisible by the size of axis ``axis_name`` (or the group size if ``axis_index_groups`` is given), and the ``scatter_dimension`` axis is preserved (so the result has the same rank as the input). diff --git a/jax/_src/lax/qdwh.py b/jax/_src/lax/qdwh.py index ef2001f5b6b9..bac3ea957955 100644 --- a/jax/_src/lax/qdwh.py +++ b/jax/_src/lax/qdwh.py @@ -111,7 +111,8 @@ def _use_cholesky(u, m, n, params): return e * u + a_minus_e * z -def _qdwh(x, m, n, is_hermitian, max_iterations, eps): + +def _qdwh(x, m, n, max_iterations, eps): """QR-based dynamically weighted Halley iteration for polar decomposition.""" # Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of @@ -172,8 +173,6 @@ def iteration(k, state, update_fn, coefs, test_convergence): u_prev = u u = update_fn(u, m, n, params) - if is_hermitian: - u = (u + u.T.conj()) / 2.0 is_not_converged = True if test_convergence: @@ -243,11 +242,12 @@ def qdwh( """QR-based dynamically weighted Halley iteration for polar decomposition. Args: - x: A full-rank matrix, with shape `M x N`. The matrix may be - padded up to that size from a smaller true shape (``dynamic_shape``). - is_hermitian: True if `x` is Hermitian. Default to `False`. - eps: The final result will satisfy - ``|x_k - x_k-1| < |x_k| * (4*eps)**(1/3)`` where `x_k` is the iterate. + x: A full-rank matrix, with shape `M x N`. The matrix may be padded up to + that size from a smaller true shape (``dynamic_shape``). + is_hermitian: True if `x` is Hermitian. Default to `False`. This parameter + is currently unused, but exists for backward compatibility. + eps: The final result will satisfy ``|x_k - x_k-1| < |x_k| * + (4*eps)**(1/3)`` where `x_k` is the iterate. max_iterations: Iterations will terminate after this many steps even if the above is unsatisfied. dynamic_shape: the unpadded shape as an ``(m, n)`` tuple; optional. @@ -258,6 +258,7 @@ def qdwh( and `is_converged`, whose value is `True` when the convergence is achieved within the maximum number of iterations. """ + # TODO: Possibly take advantage of Hermitian inputs to speed up the QDWH step. is_hermitian = core.concrete_or_error( bool, is_hermitian, 'The `is_hermitian` argument must be statically ' 'specified to use `qdwh` within JAX transformations.') @@ -279,7 +280,6 @@ def qdwh( m, n = M, N with jax.default_matmul_precision('float32'): - u, h, num_iters, is_converged = _qdwh(x, m, n, is_hermitian, max_iterations, - eps) + u, h, num_iters, is_converged = _qdwh(x, m, n, max_iterations, eps) return u, h, num_iters, is_converged diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index e58d5a7c7909..b2bd30b3d364 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -14,12 +14,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import enum import operator from functools import partial import math -from typing import Callable, NamedTuple +from typing import NamedTuple import weakref import numpy as np @@ -1821,10 +1821,13 @@ def _gather_lower(ctx, operand, indices, *, assert mode in (GatherScatterMode.PROMISE_IN_BOUNDS, GatherScatterMode.CLIP), mode dnums = hlo.GatherDimensionNumbers.get( - collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims), - index_vector_dim=len(ctx.avals_in[1].shape) - 1, - offset_dims=list(dimension_numbers.offset_dims), - start_index_map=list(dimension_numbers.start_index_map)) + collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims), + operand_batching_dims=[], + start_indices_batching_dims=[], + index_vector_dim=len(ctx.avals_in[1].shape) - 1, + offset_dims=list(dimension_numbers.offset_dims), + start_index_map=list(dimension_numbers.start_index_map), + ) if not core.is_constant_shape(slice_sizes): slice_sizes = mlir.eval_dynamic_shape_as_tensor(ctx, slice_sizes) # TODO(burmako): Fix overly conservative type inference of DynamicGatherOp. @@ -2475,10 +2478,13 @@ def _scatter_lower(ctx, operand, indices, updates, *, dnums = dimension_numbers scatter_dnums = hlo.ScatterDimensionNumbers.get( - update_window_dims=list(dnums.update_window_dims), - inserted_window_dims=list(dnums.inserted_window_dims), - scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), - index_vector_dim=len(ctx.avals_in[1].shape) - 1) + update_window_dims=list(dnums.update_window_dims), + inserted_window_dims=list(dnums.inserted_window_dims), + input_batching_dims=[], + scatter_indices_batching_dims=[], + scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), + index_vector_dim=len(ctx.avals_in[1].shape) - 1, + ) result = mlir.aval_to_ir_types(aval_out) operand = [operand] updates = [updates] @@ -2532,10 +2538,13 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates, aval_out, = ctx.avals_out dnums = dimension_numbers scatter_dnums = hlo.ScatterDimensionNumbers.get( - update_window_dims=list(dnums.update_window_dims), - inserted_window_dims=list(dnums.inserted_window_dims), - scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), - index_vector_dim=len(ctx.avals_in[1].shape) - 1) + update_window_dims=list(dnums.update_window_dims), + inserted_window_dims=list(dnums.inserted_window_dims), + input_batching_dims=[], + scatter_indices_batching_dims=[], + scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), + index_vector_dim=len(ctx.avals_in[1].shape) - 1, + ) real_dtype = _real_dtype(aval_out.dtype) operand_type_part = mlir.aval_to_ir_types( core.ShapedArray(aval_out.shape, real_dtype)) diff --git a/jax/_src/lax/svd.py b/jax/_src/lax/svd.py index 607632ad0382..77ff4297e137 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/lax/svd.py @@ -15,7 +15,7 @@ """A JIT-compatible library for QDWH-based singular value decomposition. QDWH is short for QR-based dynamically weighted Halley iteration. The Halley -iteration implemented through QR decmopositions is numerically stable and does +iteration implemented through QR decompositions is numerically stable and does not require solving a linear system involving the iteration matrix or computing its inversion. This is desirable for multicore and heterogeneous computing systems. @@ -59,7 +59,7 @@ def _svd_tall_and_square_input( Args: a: A matrix of shape `m x n` with `m >= n`. hermitian: True if `a` is Hermitian. - compute_uv: Whether to compute also `u` and `v` in addition to `s`. + compute_uv: Whether to also compute `u` and `v` in addition to `s`. max_iterations: The predefined maximum number of iterations of QDWH. Returns: @@ -126,11 +126,11 @@ def svd( full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`, respectively. If False, the shapes are `m x k` and `k x n`, respectively, where `k = min(m, n)`. - compute_uv: Whether to compute also `u` and `v` in addition to `s`. + compute_uv: Whether to also compute `u` and `v` in addition to `s`. hermitian: True if `a` is Hermitian. max_iterations: The predefined maximum number of iterations of QDWH. subset_by_index: Optional 2-tuple [start, end] indicating the range of - indices of singular componenets to compute. For example, if + indices of singular components to compute. For example, if ``subset_by_index`` = [0,2], then ``svd`` computes the two largest singular values (and their singular vectors if `compute_uv` is true. diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 8a3fbf2c37bb..096fce7deb3a 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -14,9 +14,8 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Callable import warnings from jax import tree_util diff --git a/jax/_src/lazy_loader.py b/jax/_src/lazy_loader.py index 6041c77c65c0..cf6e68e49c81 100644 --- a/jax/_src/lazy_loader.py +++ b/jax/_src/lazy_loader.py @@ -14,9 +14,9 @@ """A LazyLoader class.""" -from collections.abc import Sequence +from collections.abc import Callable, Sequence import importlib -from typing import Any, Callable +from typing import Any def attach(package_name: str, submodules: Sequence[str]) -> tuple[ diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index d3ed72f4775e..c0d88759dcc0 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -15,7 +15,6 @@ load( "//jaxlib:jax.bzl", "if_building_jaxlib", - "if_building_mosaic_gpu", "jax_visibility", "py_library_providing_imports_info", "pytype_strict_library", @@ -60,5 +59,5 @@ py_library_providing_imports_info( "//jaxlib/mlir:stablehlo_dialect", "//jaxlib/mlir:vector_dialect", # xla_client - ]) + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]), + ]), ) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 54b4a9ea9b85..b2bcc53a53f8 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -87,8 +87,6 @@ def _parse_version(v: str) -> tuple[int, ...]: import jaxlib.xla_client as xla_client import jaxlib.lapack as lapack -import jaxlib.ducc_fft as ducc_fft - xla_extension = xla_client._xla pytree = xla_client._xla.pytree jax_jit = xla_client._xla.jax_jit @@ -102,7 +100,10 @@ def _xla_gc_callback(*args): try: import jaxlib.cuda._versions as cuda_versions # pytype: disable=import-error except ImportError: - cuda_versions = None + try: + import jax_cuda12_plugin._versions as cuda_versions # pytype: disable=import-error + except ImportError: + cuda_versions = None import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error @@ -132,11 +133,6 @@ def _cuda_path() -> str | None: # both of the things XLA looks for in the cuda path, namely bin/ptxas and # nvvm/libdevice/libdevice.10.bc path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc" - if path.is_dir(): - return str(path) - # Failing that, we use the copy of libdevice.10.bc we include with jaxlib and - # hope that the user has ptxas in their PATH. - path = _jaxlib_path / "cuda" if path.is_dir(): return str(path) return None diff --git a/jax/_src/lib/mosaic_gpu.py b/jax/_src/lib/mosaic_gpu.py index 69d2a4af4779..494112093029 100644 --- a/jax/_src/lib/mosaic_gpu.py +++ b/jax/_src/lib/mosaic_gpu.py @@ -15,9 +15,9 @@ # ruff: noqa try: - from jaxlib.mlir._mlir_libs import _mosaic_gpu_ext # pytype: disable=import-error + try: + from jaxlib.mosaic.gpu import _mosaic_gpu_ext # pytype: disable=import-error + except ImportError: + from jax_cuda12_plugin import _mosaic_gpu_ext # pytype: disable=import-error except ImportError as e: - raise ModuleNotFoundError( - "Cannot import the Mosaic GPU bindings. You may need to build jaxlib from" - " source." - ) from e + raise ModuleNotFoundError("Failed to import the Mosaic GPU bindings") from e diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 8f431d160153..bc4cc242f055 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -63,15 +63,16 @@ def trans1(static_arg, *dynamic_args, **kwargs): """ from __future__ import annotations +from collections.abc import Callable from functools import partial -from typing import Any, Callable, NamedTuple +from typing import Any, NamedTuple import weakref from jax._src import config from jax._src import core from jax._src import traceback_util from jax._src.tree_util import tree_map -from jax._src.util import curry +from jax._src.util import curry, cache_clearing_funs traceback_util.register_exclusion(__file__) @@ -246,8 +247,9 @@ def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: return fun.wrap(gen, gen_static_args, None) @curry -def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args, - use_eq_store=False) -> tuple[WrappedFun, Any]: +def transformation_with_aux( + gen, fun: WrappedFun, *gen_static_args, use_eq_store: bool = False +) -> tuple[WrappedFun, Callable[[], Any]]: """Adds one more transformation with auxiliary output to a WrappedFun.""" out_store = Store() if not use_eq_store else EqualStore() out_thunk = lambda: out_store.val @@ -359,17 +361,9 @@ def _evict_function(f): memoized_fun.cache_clear = fun_caches.clear # type: ignore memoized_fun.evict_function = _evict_function # type: ignore - cache_clearing_funs.add(memoized_fun.cache_clear) - return memoized_fun -cache_clearing_funs = weakref.WeakSet() # type: ignore - -def clear_all_caches(): - global cache_clearing_funs - for clear in cache_clearing_funs: - clear() @partial(partial, tree_map) def _copy_main_traces(x): diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py new file mode 100644 index 000000000000..3b1f9df07210 --- /dev/null +++ b/jax/_src/lru_cache.py @@ -0,0 +1,184 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import heapq +import logging +import pathlib +import warnings + +from jax._src.compilation_cache_interface import CacheInterface + + +try: + import filelock +except ImportError: + filelock = None + + +logger = logging.getLogger(__name__) + + +class LRUCache(CacheInterface): + """Bounded cache with least-recently-used (LRU) eviction policy. + + This implementation includes cache reading, writing and eviction + based on the LRU policy. + + Notably, when ``max_size`` is set to -1, the cache eviction + is disabled, and the LRU cache functions as a normal cache + without any size limitations. + """ + + def __init__(self, path: str, *, max_size: int, lock_timeout_secs: float | None = 10): + """Args: + + path: The path to the cache directory. + max_size: The maximum size of the cache in bytes. Caching will be + disabled if this value is set to ``0``. A special value of ``-1`` + indicates no limit, allowing the cache size to grow indefinitely. + lock_timeout_secs: (optional) The timeout for acquiring a file lock. + """ + # TODO(ayx): add support for cloud other filesystems such as GCS + if not self._is_local_filesystem(path): + raise NotImplementedError("LRUCache only supports local filesystem at this time.") + + self.path = pathlib.Path(path) + self.path.mkdir(parents=True, exist_ok=True) + + # TODO(ayx): having a `self._path` is required by the base class + # `CacheInterface`, but the base class can be removed after `LRUCache` + # and the original `GFileCache` are unified + self._path = self.path + + self.eviction_enabled = max_size != -1 # no eviction if `max_size` is set to -1 + + if self.eviction_enabled: + if filelock is None: + raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`") + + self.max_size = max_size + self.lock_timeout_secs = lock_timeout_secs + + self.lock_path = self.path / ".lockfile" + self.lock = filelock.FileLock(self.lock_path) + + def get(self, key: str) -> bytes | None: + """Retrieves the cached value for the given key. + + Args: + key: The key for which the cache value is retrieved. + + Returns: + The cached data as bytes if available; ``None`` otherwise. + """ + if not key: + raise ValueError("key cannot be empty") + + file = self.path / key + + if self.eviction_enabled: + self.lock.acquire(timeout=self.lock_timeout_secs) + + try: + if not file.exists(): + logger.debug(f"Cache miss for key: {key!r}") + return None + + logger.debug(f"Cache hit for key: {key!r}") + file.touch() # update mtime + return file.read_bytes() + + finally: + if self.eviction_enabled: + self.lock.release() + + def put(self, key: str, val: bytes) -> None: + """Adds a new entry to the cache. + + If a cache item with the same key already exists, no action + will be taken, even if the value is different. + + Args: + key: The key under which the data will be stored. + val: The data to be stored. + """ + if not key: + raise ValueError("key cannot be empty") + + # prevent adding entries that exceed the maximum size limit of the cache + if self.eviction_enabled and len(val) > self.max_size: + msg = (f"Cache value for key {key!r} of size {len(val)} bytes exceeds " + f"the maximum cache size of {self.max_size} bytes") + warnings.warn(msg) + return + + file = self.path / key + + if self.eviction_enabled: + self.lock.acquire(timeout=self.lock_timeout_secs) + + try: + if file.exists(): + return + + self._evict_if_needed(additional_size=len(val)) + file.write_bytes(val) + + finally: + if self.eviction_enabled: + self.lock.release() + + def _evict_if_needed(self, *, additional_size: int = 0) -> None: + """Evicts the least recently used items from the cache if necessary + to ensure the cache does not exceed its maximum size. + + Args: + additional_size: The size of the new entry being added to the cache. + This is included to account for the new entry when checking if + eviction is needed. + """ + if not self.eviction_enabled: + return + + # a priority queue, each element is a tuple `(file_mtime, file, file_size)` + h: list[tuple[int, pathlib.Path, int]] = [] + dir_size = 0 + for file in self.path.iterdir(): + if file.is_file() and file != self.lock_path: + file_size = file.stat().st_size + file_mtime = file.stat().st_mtime_ns + + dir_size += file_size + heapq.heappush(h, (file_mtime, file, file_size)) + + target_size = self.max_size - additional_size + # evict files until the directory size is less than or equal + # to `target_size` + while dir_size > target_size: + file_mtime, file, file_size = heapq.heappop(h) + msg = (f"Evicting cache file {file.name}: file size {file_size} bytes, " + f"target cache size {target_size} bytes") + logger.debug(msg) + file.unlink() + dir_size -= file_size + + # See comments in `jax.src.compilation_cache.get_file_cache()` for details. + # TODO(ayx): This function has a duplicate in that place, and there is + # redundancy here. However, this code is temporary, and once the issue + # is fixed, this code can be removed. + @staticmethod + def _is_local_filesystem(path: str) -> bool: + return path.startswith("file://") or "://" not in path diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 595d86b58ef3..20fc54d8fe37 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -15,12 +15,12 @@ from __future__ import annotations from collections import OrderedDict, abc -from collections.abc import Iterable, Sequence, Mapping +from collections.abc import Callable, Iterable, Sequence, Mapping import contextlib from functools import wraps, partial, partialmethod, lru_cache import itertools as it import math -from typing import Callable, Any, NamedTuple, Union, cast as type_cast +from typing import Any, NamedTuple, Union, cast as type_cast import numpy as np @@ -62,7 +62,7 @@ from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3, as_hashable_function, distributed_debug_log, tuple_insert, moveaxis, split_list, wrap_name, - merge_lists, partition_list) + merge_lists, partition_list, fun_name) source_info_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__) @@ -116,7 +116,7 @@ class SerialLoop: jointly over chunks of multiple axes (with the usual requirement that they do not coincide in a named shape of any value in the program). - Example:: + Examples: # Processes `x` in a vectorized way, but in 20 micro-batches. xmap(f, in_axes=['i'], out_axes=[i], axis_resources={'i': SerialLoop(20)})(x) @@ -161,7 +161,7 @@ def serial_loop(name: ResourceAxisName, length: int): name: Name of the loop in the resource environment. length: Number of iterations. - Example:: + Examples: >>> x = jnp.linspace(0, jnp.pi, 4) ... @@ -577,7 +577,7 @@ def infer_params(*args): in_axes_flat, args_flat) params = dict( - name=getattr(fun, '__name__', ''), + name=fun_name(fun), in_axes=tuple(in_axes_flat), out_axes_thunk=out_axes_thunk, donated_invars=donated_invars, @@ -627,8 +627,7 @@ def lower(*args, **kwargs): in_tree = treedef_tuple([in_tree, tree_flatten({})[1]]) in_avals = in_tree.unflatten(avals_flat) return stages.Lowered.from_flat_info( - computation, in_tree, in_avals, donate_argnums, out_tree(), - no_kwargs=True) + computation, in_tree, in_avals, donate_argnums, out_tree()) fun_mapped.lower = lower return type_cast(stages.Wrapped, fun_mapped) @@ -637,11 +636,12 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_ global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk): in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args] - xmap_callable = make_xmap_callable( + computation = make_xmap_callable( fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, - mlir.LoweringParameters(), *in_avals).compile().unsafe_call + mlir.LoweringParameters(), *in_avals) + xmap_callable = computation.compile().unsafe_call distributed_debug_log(("Running xmapped function", name), ("python function", fun.f), ("mesh", resource_env.physical_mesh), @@ -707,7 +707,7 @@ def make_xmap_callable(fun: lu.WrappedFun, f, 'xmap', name, mesh, in_shardings, out_shardings, donated_invars, use_spmd_lowering, in_avals, - tiling_method=tiling_method, + tiling_method=tiling_method, lowering_platforms=None, lowering_parameters=lowering_parameters) else: jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals) @@ -716,7 +716,8 @@ def make_xmap_callable(fun: lu.WrappedFun, (UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals), (None,) * len(in_avals), (None,) * len(out_avals), donated_invars, keep_unused=True, inline=False, - devices_from_context=None, lowering_parameters=lowering_parameters) + devices_from_context=None, lowering_platforms=None, + lowering_parameters=lowering_parameters, pgle_profiler=None) class EvaluationPlan(NamedTuple): @@ -1849,7 +1850,7 @@ def update(v): return update -SPMD_LOWERING = config.define_bool_state( +SPMD_LOWERING = config.bool_state( name="experimental_xmap_spmd_lowering", default=False, help=("When set, multi-device xmap computations will be compiled through " @@ -1857,7 +1858,7 @@ def update(v): "Not supported on CPU!"), update_global_hook=_clear_compilation_cache, update_thread_local_hook=_thread_local_flag_unsupported) -SPMD_LOWERING_MANUAL = config.define_bool_state( +SPMD_LOWERING_MANUAL = config.bool_state( name="experimental_xmap_spmd_lowering_manual", default=False, help=("When set, multi-device xmap computations will be compiled using " @@ -1866,7 +1867,7 @@ def update(v): "Requires experimental_xmap_spmd_lowering!"), update_global_hook=_ensure_spmd_and(_clear_compilation_cache), update_thread_local_hook=_thread_local_flag_unsupported) -_ENSURE_FIXED_SHARDING = config.define_bool_state( +_ENSURE_FIXED_SHARDING = config.bool_state( name="experimental_xmap_ensure_fixed_sharding", default=False, help=("When set and `experimental_xmap_spmd_lowering` is enabled, the lowering will " diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 7339bd6cf7c1..32138678561f 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -91,7 +91,7 @@ def __repr__(self): return f"ResourceEnv(mesh=Mesh({mesh_repr}), {self.loops!r})" -@functools.lru_cache(maxsize=128) +@util.cache(max_size=128, trace_context_in_key=False) def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: if global_mesh.empty: return global_mesh @@ -144,7 +144,7 @@ class Mesh(contextlib.ContextDecorator): dimensions of the ``devices`` argument. Its length should match the rank of ``devices``. - Example: + Examples: >>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 830643643162..822fb548ed90 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -69,7 +69,7 @@ def relu(x: ArrayLike) -> Array: Returns: An array. - Example: + Examples: >>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32) @@ -512,6 +512,10 @@ def log_softmax(x: ArrayLike, Returns: An array. + Note: + If any input values are ``+inf``, the result will be all ``NaN``: this reflects the + fact that ``inf / inf`` is not well-defined in the context of floating-point math. + See also: :func:`softmax` """ @@ -557,6 +561,10 @@ def softmax(x: ArrayLike, Returns: An array. + Note: + If any input values are ``+inf``, the result will be all ``NaN``: this reflects the + fact that ``inf / inf`` is not well-defined in the context of floating-point math. + See also: :func:`log_softmax` """ diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index d7353c396ae6..cf245f7927be 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -372,7 +372,7 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2, Returns: An initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_uniform() @@ -410,7 +410,7 @@ def glorot_normal(in_axis: int | Sequence[int] = -2, Returns: An initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_normal() @@ -448,7 +448,7 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2, Returns: An initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_uniform() @@ -484,7 +484,7 @@ def lecun_normal(in_axis: int | Sequence[int] = -2, Returns: An initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_normal() @@ -520,7 +520,7 @@ def he_uniform(in_axis: int | Sequence[int] = -2, Returns: An initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_uniform() @@ -558,7 +558,7 @@ def he_normal(in_axis: int | Sequence[int] = -2, Returns: An initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_normal() @@ -591,7 +591,7 @@ def orthogonal(scale: RealNumeric = 1.0, Returns: An orthogonal initializer. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.orthogonal() @@ -634,7 +634,7 @@ def delta_orthogonal( A `delta orthogonal initializer`_. The shape passed to the initializer must be 3D, 4D, or 5D. - Example: + Examples: >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.delta_orthogonal() diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 90b582c2c68b..ca6cec379878 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -27,13 +27,13 @@ import builtins import collections -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import importlib import math import operator import types -from typing import (cast, overload, Any, Callable, Literal, NamedTuple, +from typing import (cast, overload, Any, Literal, NamedTuple, Protocol, TypeVar, Union) from textwrap import dedent as _dedent import warnings @@ -127,8 +127,37 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: printoptions = np.printoptions set_printoptions = np.set_printoptions -@util.implements(np.iscomplexobj) def iscomplexobj(x: Any) -> bool: + """Check if the input is a complex number or an array containing complex elements. + + JAX implementation of :func:`numpy.iscomplexobj`. + + The function evaluates based on input type rather than value. + Inputs with zero imaginary parts are still considered complex. + + Args: + x: input object to check. + + Returns: + True if ``x`` is a complex number or an array containing at least one complex element, + False otherwise. + + See Also: + - :func:`jax.numpy.isrealobj` + - :func:`jax.numpy.iscomplex` + + Examples: + >>> jnp.iscomplexobj(True) + False + >>> jnp.iscomplexobj(0) + False + >>> jnp.iscomplexobj(jnp.array([1, 2])) + False + >>> jnp.iscomplexobj(1+2j) + True + >>> jnp.iscomplexobj(jnp.array([0, 1+2j])) + True + """ if x is None: return False try: @@ -347,6 +376,8 @@ def result_type(*args: Any) -> DType: @jit def trunc(x: ArrayLike) -> Array: util.check_arraylike('trunc', x) + if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + return lax_internal.asarray(x) return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x)) @@ -397,23 +428,161 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, return result[0, 0, out_order] -@util.implements(np.convolve, lax_description=_PRECISION_DOC, - extra_params=_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: + r"""Convolution of two one dimensional arrays. + + JAX implementation of :func:`numpy.convolve`. + + Convolution of one dimensional arrays is defined as: + + .. math:: + + c_k = \sum_j a_{k - j} v_j + + Args: + a: left-hand input to the convolution. Must have ``a.ndim == 1``. + v: right-hand input to the convolution. Must have ``v.ndim == 1``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: (default) output the full convolution of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``a``. + * ``"valid"``: return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + precision: Specify the precision of the computation. Refer to + :class:`jax.lax.Precision` for a description of available values. + + preferred_element_type: A datatype, indicating to accumulate results to and + return a result with that datatype. Default is ``None``, which means the + default accumulation type for the input types. + + Returns: + Array containing the convolved result. + + See Also: + - :func:`jax.scipy.signal.convolve`: ND convolution + - :func:`jax.numpy.correlate`: 1D correlation + + Examples: + A few 1D convolution examples: + + >>> x = jnp.array([1, 2, 3, 2, 1]) + >>> y = jnp.array([4, 1, 2]) + + ``jax.numpy.convolve``, by default, returns full convolution using implicit + zero-padding at the edges: + + >>> jnp.convolve(x, y) + Array([ 4., 9., 16., 15., 12., 5., 2.], dtype=float32) + + Specifying ``mode = 'same'`` returns a centered convolution the same size + as the first input: + + >>> jnp.convolve(x, y, mode='same') + Array([ 9., 16., 15., 12., 5.], dtype=float32) + + Specifying ``mode = 'valid'`` returns only the portion where the two arrays + fully overlap: + + >>> jnp.convolve(x, y, mode='valid') + Array([16., 15., 12.], dtype=float32) + + For complex-valued inputs: + + >>> x1 = jnp.array([3+1j, 2, 4-3j]) + >>> y1 = jnp.array([1, 2-3j, 4+5j]) + >>> jnp.convolve(x1, y1) + Array([ 3. +1.j, 11. -7.j, 15.+10.j, 7. -8.j, 31. +8.j], dtype=complex64) + """ util.check_arraylike("convolve", a, v) return _conv(asarray(a), asarray(v), mode=mode, op='convolve', precision=precision, preferred_element_type=preferred_element_type) -@util.implements(np.correlate, lax_description=_PRECISION_DOC, - extra_params=_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: + r"""Correlation of two one dimensional arrays. + + JAX implementation of :func:`numpy.correlate`. + + Correlation of one dimensional arrays is defined as: + + .. math:: + + c_k = \sum_j a_{k + j} \overline{v_j} + + where :math:`\overline{v_j}` is the complex conjugate of :math:`v_j`. + + Args: + a: left-hand input to the correlation. Must have ``a.ndim == 1``. + v: right-hand input to the correlation. Must have ``v.ndim == 1``. + mode: controls the size of the output. Available operations are: + + * ``"full"``: output the full correlation of the inputs. + * ``"same"``: return a centered portion of the ``"full"`` output which + is the same size as ``a``. + * ``"valid"``: (default) return the portion of the ``"full"`` output which do not + depend on padding at the array edges. + + precision: Specify the precision of the computation. Refer to + :class:`jax.lax.Precision` for a description of available values. + + preferred_element_type: A datatype, indicating to accumulate results to and + return a result with that datatype. Default is ``None``, which means the + default accumulation type for the input types. + + Returns: + Array containing the cross-correlation result. + + See Also: + - :func:`jax.scipy.signal.correlate`: ND correlation + - :func:`jax.numpy.convolve`: 1D convolution + + Examples: + >>> x = jnp.array([1, 2, 3, 2, 1]) + >>> y = jnp.array([4, 5, 6]) + + Since default ``mode = 'valid'``, ``jax.numpy.correlate`` returns only the + portion of correlation where the two arrays fully overlap: + + >>> jnp.correlate(x, y) + Array([32., 35., 28.], dtype=float32) + + Specifying ``mode = 'full'`` returns full correlation using implicit + zero-padding at the edges. + + >>> jnp.correlate(x, y, mode='full') + Array([ 6., 17., 32., 35., 28., 13., 4.], dtype=float32) + + Specifying ``mode = 'same'`` returns a centered correlation the same size + as the first input: + + >>> jnp.correlate(x, y, mode='same') + Array([17., 32., 35., 28., 13.], dtype=float32) + + If both the inputs arrays are real-valued and symmetric then the result will + also be symmetric and will be equal to the result of ``jax.numpy.convolve``. + + >>> x1 = jnp.array([1, 2, 3, 2, 1]) + >>> y1 = jnp.array([4, 5, 4]) + >>> jnp.correlate(x1, y1, mode='full') + Array([ 4., 13., 26., 31., 26., 13., 4.], dtype=float32) + >>> jnp.convolve(x1, y1, mode='full') + Array([ 4., 13., 26., 31., 26., 13., 4.], dtype=float32) + + For complex-valued inputs: + + >>> x2 = jnp.array([3+1j, 2, 2-3j]) + >>> y2 = jnp.array([4, 2-5j, 1]) + >>> jnp.correlate(x2, y2, mode='full') + Array([ 3. +1.j, 3.+17.j, 18.+11.j, 27. +4.j, 8.-12.j], dtype=complex64) + """ util.check_arraylike("correlate", a, v) return _conv(asarray(a), asarray(v), mode=mode, op='correlate', precision=precision, preferred_element_type=preferred_element_type) @@ -555,7 +724,7 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: """Return a transposed version of an N-dimensional array. - JAX implementation of :func:`jax.numpy.transpose`, implemented in terms of + JAX implementation of :func:`numpy.transpose`, implemented in terms of :func:`jax.lax.transpose`. Args: @@ -573,7 +742,7 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: - :func:`jax.numpy.matrix_transpose`: transpose the last two axes of an array. This is suitable for working with batched 2D matrices. - :func:`jax.numpy.swapaxes`: swap any two axes in an array. - - :func:`jax.numpy.moveaxis`: move an axis to another postion in the array. + - :func:`jax.numpy.moveaxis`: move an axis to another position in the array. Note: Unlike :func:`numpy.transpose`, :func:`jax.numpy.transpose` will return a copy rather @@ -639,7 +808,7 @@ def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: def matrix_transpose(x: ArrayLike, /) -> Array: """Transpose the last two dimensions of an array. - JAX implementation of :func:`jax.numpy.matrix_transpose`, implemented in terms of + JAX implementation of :func:`numpy.matrix_transpose`, implemented in terms of :func:`jax.lax.transpose`. Args: @@ -716,8 +885,62 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: return flip(transpose(m, perm), ax2) -@util.implements(np.flip, lax_description=_ARRAY_VIEW_DOC) def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: + """Reverse the order of elements of an array along the given axis. + + JAX implementation of :func:`numpy.flip`. + + Args: + m: Array. + axis: integer or sequence of integers. Specifies along which axis or axes + should the array elements be reversed. Default is ``None``, which flips + along all axes. + + Returns: + An array with the elements in reverse order along ``axis``. + + See Also: + - :func:`jax.numpy.fliplr`: reverse the order along axis 1 (left/right) + - :func:`jax.numpy.flipud`: reverse the order along axis 0 (up/down) + + Examples: + >>> x1 = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.flip(x1) + Array([[4, 3], + [2, 1]], dtype=int32) + + If ``axis`` is specified with an integer, then ``jax.numpy.flip`` reverses + the array along that particular axis only. + + >>> jnp.flip(x1, axis=1) + Array([[2, 1], + [4, 3]], dtype=int32) + + >>> x2 = jnp.arange(1, 9).reshape(2, 2, 2) + >>> x2 + Array([[[1, 2], + [3, 4]], + + [[5, 6], + [7, 8]]], dtype=int32) + >>> jnp.flip(x2) + Array([[[8, 7], + [6, 5]], + + [[4, 3], + [2, 1]]], dtype=int32) + + When ``axis`` is specified with a sequence of integers, then + ``jax.numpy.flip`` reverses the array along the specified axes. + + >>> jnp.flip(x2, axis=[1, 2]) + Array([[[4, 3], + [2, 1]], + + [[8, 7], + [6, 5]]], dtype=int32) + """ util.check_arraylike("flip", m) return _flip(asarray(m), reductions._ensure_optional_axes(axis)) @@ -729,32 +952,143 @@ def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis]) -@util.implements(np.fliplr, lax_description=_ARRAY_VIEW_DOC) def fliplr(m: ArrayLike) -> Array: + """Reverse the order of elements of an array along axis 1. + + JAX implementation of :func:`numpy.fliplr`. + + Args: + m: Array with at least two dimensions. + + Returns: + An array with the elements in reverse order along axis 1. + + See Also: + - :func:`jax.numpy.flip`: reverse the order along the given axis + - :func:`jax.numpy.flipud`: reverse the order along axis 0 + + Examples: + >>> x = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.fliplr(x) + Array([[2, 1], + [4, 3]], dtype=int32) + """ util.check_arraylike("fliplr", m) return _flip(asarray(m), 1) -@util.implements(np.flipud, lax_description=_ARRAY_VIEW_DOC) def flipud(m: ArrayLike) -> Array: + """Reverse the order of elements of an array along axis 0. + + JAX implementation of :func:`numpy.flipud`. + + Args: + m: Array with at least one dimension. + + Returns: + An array with the elements in reverse order along axis 0. + + See Also: + - :func:`jax.numpy.flip`: reverse the order along the given axis + - :func:`jax.numpy.fliplr`: reverse the order along axis 1 + + Examples: + >>> x = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.flipud(x) + Array([[3, 4], + [1, 2]], dtype=int32) + """ util.check_arraylike("flipud", m) return _flip(asarray(m), 0) -@util.implements(np.iscomplex) @jit def iscomplex(x: ArrayLike) -> Array: + """Return boolean array showing where the input is complex. + + JAX implementation of :func:`numpy.iscomplex`. + + Args: + x: Input array to check. + + Returns: + A new array containing boolean values indicating complex elements. + + See Also: + - :func:`jax.numpy.iscomplexobj` + - :func:`jax.numpy.isrealobj` + + Examples: + >>> jnp.iscomplex(jnp.array([True, 0, 1, 2j, 1+2j])) + Array([False, False, False, True, True], dtype=bool) + """ i = ufuncs.imag(x) return lax.ne(i, _lax_const(i, 0)) -@util.implements(np.isreal) @jit def isreal(x: ArrayLike) -> Array: + """Return boolean array showing where the input is real. + + JAX implementation of :func:`numpy.isreal`. + + Args: + x: input array to check. + + Returns: + A new array containing boolean values indicating real elements. + + See Also: + - :func:`jax.numpy.iscomplex` + - :func:`jax.numpy.isrealobj` + + Examples: + >>> jnp.isreal(jnp.array([False, 0j, 1, 2.1, 1+2j])) + Array([ True, True, True, True, False], dtype=bool) + """ i = ufuncs.imag(x) return lax.eq(i, _lax_const(i, 0)) -@util.implements(np.angle) + @partial(jit, static_argnames=['deg']) def angle(z: ArrayLike, deg: bool = False) -> Array: + """Return the angle of a complex valued number or array. + + JAX implementation of :func:`numpy.angle`. + + Args: + z: A complex number or an array of complex numbers. + deg: Boolean. If ``True``, returns the result in degrees else returns + in radians. Default is ``False``. + + Returns: + An array of counterclockwise angle of each element of ``z``, with the same + shape as ``z`` of dtype float. + + Examples: + + If ``z`` is a number + + >>> z1 = 2+3j + >>> jnp.angle(z1) + Array(0.98279375, dtype=float32, weak_type=True) + + If ``z`` is an array + + >>> z2 = jnp.array([[1+3j, 2-5j], + ... [4-3j, 3+2j]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.angle(z2)) + [[ 1.25 -1.19] + [-0.64 0.59]] + + If ``deg=True``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.angle(z2, deg=True)) + [[ 71.57 -68.2 ] + [-36.87 33.69]] + """ re = ufuncs.real(z) im = ufuncs.imag(z) dtype = _dtype(re) @@ -888,8 +1222,37 @@ def gradient_along_axis(a, h, axis): return a_grad[0] if len(axis_tuple) == 1 else a_grad -@util.implements(np.isrealobj) def isrealobj(x: Any) -> bool: + """Check if the input is not a complex number or an array containing complex elements. + + JAX implementation of :func:`numpy.isrealobj`. + + The function evaluates based on input type rather than value. + Inputs with zero imaginary parts are still considered complex. + + Args: + x: input object to check. + + Returns: + False if ``x`` is a complex number or an array containing at least one complex element, + True otherwise. + + See Also: + - :func:`jax.numpy.iscomplexobj` + - :func:`jax.numpy.isreal` + + Examples: + >>> jnp.isrealobj(0) + True + >>> jnp.isrealobj(1.2) + True + >>> jnp.isrealobj(jnp.array([1, 2])) + True + >>> jnp.isrealobj(1+2j) + False + >>> jnp.isrealobj(jnp.array([0, 1+2j])) + False + """ return not iscomplexobj(x) @@ -1031,7 +1394,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: >>> jnp.ravel(x, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32) - For convenience, the same functionality is availabel via the :meth:`jax.Array.ravel` + For convenience, the same functionality is available via the :meth:`jax.Array.ravel` method: >>> x.ravel() @@ -1068,7 +1431,7 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], See also: :func:`jax.numpy.unravel_index`: inverse of this function. - Example: + Examples: Define a 2-dimensional array and a sequence of indices of even values: >>> x = jnp.array([[2., 3., 4.], @@ -1255,7 +1618,7 @@ def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: >>> _.shape (3,) - Eqivalent while specifying the axes explicitly: + Equivalent while specifying the axes explicitly: >>> jnp.squeeze(x, axis=(1, 2)) Array([0, 1, 2], dtype=int32) @@ -1593,8 +1956,6 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, return jitted_interp(x, xp, fp, left, right, period) -_DEPRECATED_WHERE_ARG = object() - @overload def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, /, *, size: int | None = None, @@ -1736,7 +2097,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, ``bincount`` to be used with :func:`jax.jit` and other JAX transformations. Returns: - An array of counts or summed weights reflecting the number of occurrances of values + An array of counts or summed weights reflecting the number of occurrences of values in ``x``. See Also: @@ -3028,11 +3389,40 @@ def copy(a: ArrayLike, order: str | None = None) -> Array: return array(a, copy=True, order=order) -@util.implements(np.zeros_like) def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of zeros with the same shape and dtype as an array. + + JAX implementation of :func:`numpy.zeros_like`. + + Args: + a: Array-like object with ``shape`` and ``dtype`` attributes. + shape: optionally override the shape of the created array. + dtype: optionally override the dtype of the created array. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.zeros` + - :func:`jax.numpy.empty_like` + - :func:`jax.numpy.ones_like` + - :func:`jax.numpy.full_like` + + Examples: + >>> x = jnp.arange(4) + >>> jnp.zeros_like(x) + Array([0, 0, 0, 0], dtype=int32) + >>> jnp.zeros_like(x, dtype=bool) + Array([False, False, False, False], dtype=bool) + >>> jnp.zeros_like(x, shape=(2, 3)) + Array([[0, 0, 0], + [0, 0, 0]], dtype=int32) + """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing util.check_arraylike("zeros_like", a) dtypes.check_user_dtype_supported(dtype, "zeros_like") @@ -3041,11 +3431,40 @@ def zeros_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device)) -@util.implements(np.ones_like) def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array of ones with the same shape and dtype as an array. + + JAX implementation of :func:`numpy.ones_like`. + + Args: + a: Array-like object with ``shape`` and ``dtype`` attributes. + shape: optionally override the shape of the created array. + dtype: optionally override the dtype of the created array. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.empty` + - :func:`jax.numpy.zeros_like` + - :func:`jax.numpy.ones_like` + - :func:`jax.numpy.full_like` + + Examples: + >>> x = jnp.arange(4) + >>> jnp.ones_like(x) + Array([1, 1, 1, 1], dtype=int32) + >>> jnp.ones_like(x, dtype=bool) + Array([ True, True, True, True], dtype=bool) + >>> jnp.ones_like(x, shape=(2, 3)) + Array([[1, 1, 1], + [1, 1, 1]], dtype=int32) + """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing util.check_arraylike("ones_like", a) dtypes.check_user_dtype_supported(dtype, "ones_like") @@ -3054,13 +3473,42 @@ def ones_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device)) -@util.implements(np.empty_like, lax_description="""\ -Because XLA cannot create uninitialized arrays, the JAX version will -return an array initialized with zeros.""") def empty_like(prototype: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an empty array with the same shape and dtype as an array. + + JAX implementation of :func:`numpy.empty_like`. Because XLA cannot create + an un-initialized array, :func:`jax.numpy.empty` will always return an + array full of zeros. + + Args: + a: Array-like object with ``shape`` and ``dtype`` attributes. + shape: optionally override the shape of the created array. + dtype: optionally override the dtype of the created array. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.empty` + - :func:`jax.numpy.zeros_like` + - :func:`jax.numpy.ones_like` + - :func:`jax.numpy.full_like` + + Examples: + >>> x = jnp.arange(4) + >>> jnp.empty_like(x) + Array([0, 0, 0, 0], dtype=int32) + >>> jnp.empty_like(x, dtype=bool) + Array([False, False, False, False], dtype=bool) + >>> jnp.empty_like(x, shape=(2, 3)) + Array([[0, 0, 0], + [0, 0, 0]], dtype=int32) + """ if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing util.check_arraylike("empty_like", prototype) dtypes.check_user_dtype_supported(dtype, "empty_like") @@ -3074,10 +3522,43 @@ def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | No return device -@util.implements(np.full) def full(shape: Any, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of a specified value. + + JAX implementation of :func:`numpy.full`. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + fill_value: scalar or array with which to fill the created array. + dtype: optional dtype for the created array; defaults to the dtype of the + fill value. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.full_like` + - :func:`jax.numpy.empty` + - :func:`jax.numpy.zeros` + - :func:`jax.numpy.ones` + + Examples: + >>> jnp.full(4, 2, dtype=float) + Array([2., 2., 2., 2.], dtype=float32) + >>> jnp.full((2, 3), 0, dtype=bool) + Array([[False, False, False], + [False, False, False]], dtype=bool) + + `fill_value` may also be an array that is broadcast to the specified shape: + + >>> jnp.full((2, 3), fill_value=jnp.arange(3)) + Array([[0, 1, 2], + [0, 1, 2]], dtype=int32) + """ dtypes.check_user_dtype_supported(dtype, "full") util.check_arraylike("full", fill_value) @@ -3089,11 +3570,46 @@ def full(shape: Any, fill_value: ArrayLike, broadcast_to(asarray(fill_value, dtype=dtype), shape), device) -@util.implements(np.full_like) def full_like(a: ArrayLike | DuckTypedArray, fill_value: ArrayLike, dtype: DTypeLike | None = None, shape: Any = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of a specified value with the same shape and dtype as an array. + + JAX implementation of :func:`numpy.full_like`. + + Args: + a: Array-like object with ``shape`` and ``dtype`` attributes. + fill_value: scalar or array with which to fill the created array. + shape: optionally override the shape of the created array. + dtype: optionally override the dtype of the created array. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.full` + - :func:`jax.numpy.empty_like` + - :func:`jax.numpy.zeros_like` + - :func:`jax.numpy.ones_like` + + Examples: + >>> x = jnp.arange(4.0) + >>> jnp.full_like(x, 2) + Array([2., 2., 2., 2.], dtype=float32) + >>> jnp.full_like(x, 0, shape=(2, 3)) + Array([[0., 0., 0.], + [0., 0., 0.]], dtype=float32) + + `fill_value` may also be an array that is broadcast to the specified shape: + + >>> x = jnp.arange(6).reshape(2, 3) + >>> jnp.full_like(x, fill_value=jnp.array([[1], [2]])) + Array([[1, 1, 1], + [2, 2, 2]], dtype=int32) + """ if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing util.check_arraylike("full_like", 0, fill_value) else: @@ -3110,9 +3626,34 @@ def full_like(a: ArrayLike | DuckTypedArray, broadcast_to(asarray(fill_value, dtype=dtype), shape), device) -@util.implements(np.zeros) def zeros(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of zeros. + + JAX implementation of :func:`numpy.zeros`. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + dtype: optional dtype for the created array; defaults to floating point. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.zeros_like` + - :func:`jax.numpy.empty` + - :func:`jax.numpy.ones` + - :func:`jax.numpy.full` + + Examples: + >>> jnp.zeros(4) + Array([0., 0., 0., 0.], dtype=float32) + >>> jnp.zeros((2, 3), dtype=bool) + Array([[False, False, False], + [False, False, False]], dtype=bool) + """ if isinstance(shape, types.GeneratorType): raise TypeError("expected sequence object with len >= 0 or a single integer") if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m) @@ -3120,9 +3661,35 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *, shape = canonicalize_shape(shape) return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) -@util.implements(np.ones) + def ones(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an array full of ones. + + JAX implementation of :func:`numpy.ones`. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + dtype: optional dtype for the created array; defaults to floating point. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.ones_like` + - :func:`jax.numpy.empty` + - :func:`jax.numpy.zeros` + - :func:`jax.numpy.full` + + Examples: + >>> jnp.ones(4) + Array([1., 1., 1., 1.], dtype=float32) + >>> jnp.ones((2, 3), dtype=bool) + Array([[ True, True, True], + [ True, True, True]], dtype=bool) + """ if isinstance(shape, types.GeneratorType): raise TypeError("expected sequence object with len >= 0 or a single integer") if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m) @@ -3130,11 +3697,37 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *, dtypes.check_user_dtype_supported(dtype, "ones") return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) -@util.implements(np.empty, lax_description="""\ -Because XLA cannot create uninitialized arrays, the JAX version will -return an array initialized with zeros.""") + def empty(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: + """Create an empty array. + + JAX implementation of :func:`numpy.empty`. Because XLA cannot create an + un-initialized array, :func:`jax.numpy.empty` will always return an array + full of zeros. + + Args: + shape: int or sequence of ints specifying the shape of the created array. + dtype: optional dtype for the created array; defaults to floating point. + device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Array of the specified shape and dtype, on the specified device if specified. + + See also: + - :func:`jax.numpy.empty_like` + - :func:`jax.numpy.zeros` + - :func:`jax.numpy.ones` + - :func:`jax.numpy.full` + + Examples: + >>> jnp.empty(4) + Array([0., 0., 0., 0.], dtype=float32) + >>> jnp.empty((2, 3), dtype=bool) + Array([[False, False, False], + [False, False, False]], dtype=bool) + """ if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m) dtypes.check_user_dtype_supported(dtype, "empty") return zeros(shape, dtype, device=device) @@ -3242,8 +3835,66 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) -@util.implements(np.eye) def eye(N: DimSize, M: DimSize | None = None, + k: int | ArrayLike = 0, + dtype: DTypeLike | None = None, + *, device: xc.Device | Sharding | None = None) -> Array: + """Create a square or rectangular identity matrix + + JAX implementation of :func:`numpy.eye`. + + Args: + N: integer specifying the first dimension of the array. + M: optional integer specifying the second dimension of the array; + defaults to the same value as ``N``. + k: optional integer specifying the offset of the diagonal. Use positive + values for upper diagonals, and negative values for lower diagonals. + Default is zero. + dtype: optional dtype; defaults to floating point. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + Identity array of shape ``(N, M)``, or ``(N, N)`` if ``M`` is not specified. + + See also: + :func:`jax.numpy.identity`: Simpler API for generating square identity matrices. + + Examples: + A simple 3x3 identity matrix: + + >>> jnp.eye(3) + Array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]], dtype=float32) + + Integer identity matrices with offset diagonals: + + >>> jnp.eye(3, k=1, dtype=int) + Array([[0, 1, 0], + [0, 0, 1], + [0, 0, 0]], dtype=int32) + >>> jnp.eye(3, k=-1, dtype=int) + Array([[0, 0, 0], + [1, 0, 0], + [0, 1, 0]], dtype=int32) + + Non-square identity matrix: + + >>> jnp.eye(3, 5, k=1) + Array([[0., 1., 0., 0., 0.], + [0., 0., 1., 0., 0.], + [0., 0., 0., 1., 0.]], dtype=float32) + """ + # TODO(vfdev-5): optimize putting the array directly on the device specified + # instead of putting it on default device and then on the specific device + output = _eye(N, M=M, k=k, dtype=dtype) + if device is not None: + return jax.device_put(output, device=device) + return output + + +def _eye(N: DimSize, M: DimSize | None = None, k: int | ArrayLike = 0, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "eye") @@ -3262,13 +3913,40 @@ def eye(N: DimSize, M: DimSize | None = None, return (i + offset == j).astype(dtype) -@util.implements(np.identity) def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: + """Create a square identity matrix + + JAX implementation of :func:`numpy.identity`. + + Args: + n: integer specifying the size of each array dimension. + dtype: optional dtype; defaults to floating point. + + Returns: + Identity array of shape ``(n, n)``. + + See also: + :func:`jax.numpy.eye`: non-square and/or offset identity matrices. + + Examples: + A simple 3x3 identity matrix: + + >>> jnp.identity(3) + Array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]], dtype=float32) + + A 2x2 integer identity matrix: + + >>> jnp.identity(2, dtype=int) + Array([[1, 0], + [0, 1]], dtype=int32) + """ dtypes.check_user_dtype_supported(dtype, "identity") return eye(n, dtype=dtype) -@util.implements(np.arange,lax_description= """ +@util.implements(np.arange, lax_description= """ .. note:: Using ``arange`` with the ``step`` argument can lead to precision errors, @@ -3277,8 +3955,25 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: To avoid precision errors, consider using an expression like ``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision and then convert it to the desired lower precision. -""") +""", extra_params=""" +device : :py:class:`Device`, :py:class:`Sharding`, optional + The (optional) :py:class:`Device`, :py:class:`Sharding`, + representing the device(s) to which created array should be + transferred. If given, then the result is committed to the device(s). +""" +) def arange(start: DimSize, stop: DimSize | None = None, + step: DimSize | None = None, dtype: DTypeLike | None = None, + *, device: xc.Device | Sharding | None = None) -> Array: + # TODO(vfdev-5): optimize putting the array directly on the device specified + # instead of putting it on default device and then on the specific device + output = _arange(start, stop=stop, step=step, dtype=dtype) + if device is not None: + return jax.device_put(output, device=device) + return output + + +def _arange(start: DimSize, stop: DimSize | None = None, step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "arange") if not config.dynamic_shapes.value: @@ -3537,7 +4232,7 @@ def ix_(*args: ArrayLike) -> tuple[Array, ...]: - :obj:`jax.numpy.mgrid` - :func:`jax.numpy.meshgrid` - Example: + Examples: >>> rows = jnp.array([0, 2]) >>> cols = jnp.array([1, 3]) >>> open_mesh = jnp.ix_(rows, cols) @@ -3955,12 +4650,54 @@ def trim_zeros_tol(filt, tol, trim='fb'): end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt[start:len(filt) - end] - -@util.implements(np.append) @partial(jit, static_argnames=('axis',)) def append( arr: ArrayLike, values: ArrayLike, axis: int | None = None ) -> Array: + """Return a new array with values appended to the end of the original array. + + JAX implementation of :func:`numpy.append`. + + Args: + arr: original array. + values: values to be appended to the array. The ``values`` must have + the same number of dimensions as ``arr``, and all dimensions must + match except in the specified axis. + axis: axis along which to append values. If None (default), both ``arr`` + and ``values`` will be flattened before appending. + + Returns: + A new array with values appended to ``arr``. + + See also: + - :func:`jax.numpy.insert` + - :func:`jax.numpy.delete` + + Examples: + >>> a = jnp.array([1, 2, 3]) + >>> b = jnp.array([4, 5, 6]) + >>> jnp.append(a, b) + Array([1, 2, 3, 4, 5, 6], dtype=int32) + + Appending along a specific axis: + + >>> a = jnp.array([[1, 2], + ... [3, 4]]) + >>> b = jnp.array([[5, 6]]) + >>> jnp.append(a, b, axis=0) + Array([[1, 2], + [3, 4], + [5, 6]], dtype=int32) + + Appending along a trailing axis: + + >>> a = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> b = jnp.array([[7], [8]]) + >>> jnp.append(a, b, axis=1) + Array([[1, 2, 3, 7], + [4, 5, 6, 8]], dtype=int32) + """ if axis is None: return concatenate([ravel(arr), ravel(values)], 0) else: @@ -4696,7 +5433,7 @@ def einsum( >>> jnp.vecdot(x, y) Array(16, dtype=int32) - Here are some alternative ``einsum`` calling conventions to comput the same + Here are some alternative ``einsum`` calling conventions to compute the same result: >>> jnp.einsum('i,i->', x, y) # explicit form @@ -4916,7 +5653,7 @@ def einsum_path( ) -> tuple[list[tuple[int, ...]], Any]: """Evaluates the optimal contraction path without evaluating the einsum. - JAX implementation of :func:`jax.numpy.einsum_path`. This function calls into + JAX implementation of :func:`numpy.einsum_path`. This function calls into the opt_einsum_ package, and makes use of its optimization routines. Args: @@ -4931,7 +5668,7 @@ def einsum_path( A tuple containing the path that may be passed to :func:`~jax.numpy.einsum`, and a printable object representing this optimal path. - Example: + Examples: >>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> x = jax.random.randint(key1, minval=-5, maxval=5, shape=(2, 3)) >>> y = jax.random.randint(key2, minval=-5, maxval=5, shape=(3, 100)) @@ -5517,17 +6254,57 @@ def argsort( return lax.rev(indices, dimensions=[dimension]) if descending else indices -@util.implements(np.partition, lax_description=""" -The JAX version requires the ``kth`` argument to be a static integer rather than -a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If -you're only accessing the top or bottom k values of the output, it may be more -efficient to call :func:`jax.lax.top_k` directly. - -The JAX version differs from the NumPy version in the treatment of NaN entries; -NaNs which have the negative bit set are sorted to the beginning of the array. -""") @partial(jit, static_argnames=['kth', 'axis']) def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: + """Returns a partially-sorted copy of an array. + + JAX implementation of :func:`numpy.partition`. The JAX version differs from + NumPy in the treatment of NaN entries: NaNs which have the negative bit set + are sorted to the beginning of the array. + + Args: + a: array to be partitioned. + kth: static integer index about which to partition the array. + axis: static integer axis along which to partition the array; default is -1. + + Returns: + A copy of ``a`` partitioned at the ``kth`` value along ``axis``. The entries + before ``kth`` are values smaller than ``take(a, kth, axis)``, and entries + after ``kth`` are indices of values larger than ``take(a, kth, axis)`` + + Note: + The JAX version requires the ``kth`` argument to be a static integer rather than + a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If + you're only accessing the top or bottom k values of the output, it may be more + efficient to call :func:`jax.lax.top_k` directly. + + See Also: + - :func:`jax.numpy.sort`: full sort + - :func:`jax.numpy.argpartition`: indirect partial sort + - :func:`jax.lax.top_k`: directly find the top k entries + - :func:`jax.lax.approx_max_k`: compute the approximate top k entries + - :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries + + Examples: + >>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) + >>> kth = 4 + >>> x_partitioned = jnp.partition(x, kth) + >>> x_partitioned + Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32) + + The result is a partially-sorted copy of the input. All values before ``kth`` + are of smaller than the pivot value, and all values after ``kth`` are larger + than the pivot value: + + >>> smallest_values = x_partitioned[:kth] + >>> pivot_value = x_partitioned[kth] + >>> largest_values = x_partitioned[kth + 1:] + >>> print(smallest_values, pivot_value, largest_values) + [1 2 3 3] 4 [9 8 7 6 5] + + Notice that among ``smallest_values`` and ``largest_values``, the returned + order is arbitrary and implementation-dependent. + """ # TODO(jakevdp): handle NaN values like numpy. util.check_arraylike("partition", a) arr = asarray(a) @@ -5543,17 +6320,58 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: return swapaxes(out, -1, axis) -@util.implements(np.argpartition, lax_description=""" -The JAX version requires the ``kth`` argument to be a static integer rather than -a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If -you're only accessing the top or bottom k values of the output, it may be more -efficient to call :func:`jax.lax.top_k` directly. - -The JAX version differs from the NumPy version in the treatment of NaN entries; -NaNs which have the negative bit set are sorted to the beginning of the array. -""") @partial(jit, static_argnames=['kth', 'axis']) def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: + """Returns indices that partially sort an array. + + JAX implementation of :func:`numpy.argpartition`. The JAX version differs from + NumPy in the treatment of NaN entries: NaNs which have the negative bit set are + sorted to the beginning of the array. + + Args: + a: array to be partitioned. + kth: static integer index about which to partition the array. + axis: static integer axis along which to partition the array; default is -1. + + Returns: + Indices which partition ``a`` at the ``kth`` value along ``axis``. The entries + before ``kth`` are indices of values smaller than ``take(a, kth, axis)``, and + entries after ``kth`` are indices of values larger than ``take(a, kth, axis)`` + + Note: + The JAX version requires the ``kth`` argument to be a static integer rather than + a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If + you're only accessing the top or bottom k values of the output, it may be more + efficient to call :func:`jax.lax.top_k` directly. + + See Also: + - :func:`jax.numpy.partition`: direct partial sort + - :func:`jax.numpy.argsort`: full indirect sort + - :func:`jax.lax.top_k`: directly find the top k entries + - :func:`jax.lax.approx_max_k`: compute the approximate top k entries + - :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries + + Examples: + >>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) + >>> kth = 4 + >>> idx = jnp.argpartition(x, kth) + >>> idx + Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32) + + The result is a sequence of indices that partially sort the input. All indices + before ``kth`` are of values smaller than the pivot value, and all indices + after ``kth`` are of values larger than the pivot value: + + >>> x_partitioned = x[idx] + >>> smallest_values = x_partitioned[:kth] + >>> pivot_value = x_partitioned[kth] + >>> largest_values = x_partitioned[kth + 1:] + >>> print(smallest_values, pivot_value, largest_values) + [1 2 3 3] 4 [6 8 9 7 5] + + Notice that among ``smallest_values`` and ``largest_values``, the returned + order is arbitrary and implementation-dependent. + """ # TODO(jakevdp): handle NaN values like numpy. util.check_arraylike("partition", a) arr = asarray(a) @@ -5788,7 +6606,7 @@ def take( - :attr:`jax.numpy.ndarray.at`: take values via indexing syntax. - :func:`jax.numpy.take_along_axis`: take values along an axis - Example: + Examples: >>> x = jnp.array([[1., 2., 3.], ... [4., 5., 6.]]) >>> indices = jnp.array([2, 0]) @@ -5921,7 +6739,7 @@ def take_along_axis( a: array from which to take values. indices: array of integer indices. If ``axis`` is ``None``, must be one-dimensional. If ``axis`` is not None, must have ``a.ndim == indices.ndim``, and ``a`` must be - broadcast-compaible with ``indices`` along dimensions other than ``axis``. + broadcast-compatible with ``indices`` along dimensions other than ``axis``. axis: the axis along which to take values. If not specified, the array will be flattened before indexing is applied. mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default @@ -6820,7 +7638,7 @@ def extract(condition: ArrayLike, arr: ArrayLike, Notes: This function does not require strict shape agreement between ``condition`` and ``arr``. - If ``condition.size > ``arr.size``, then ``condition`` will be truncated, and if + If ``condition.size > arr.size``, then ``condition`` will be truncated, and if ``arr.size > condition.size``, then ``arr`` will be truncated. See also: @@ -6848,7 +7666,7 @@ def extract(condition: ArrayLike, arr: ArrayLike, Notice that unlike with boolean indexing, ``extract`` does not require strict agreement between the sizes of the array and condition, and will effectively - truncate both to the minimium size: + truncate both to the minimum size: >>> short_mask = jnp.array([False, True]) >>> jnp.extract(short_mask, x) @@ -7078,19 +7896,71 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt return comparisons.sum(dtype=dtype, axis=0) -@util.implements(np.searchsorted, skip_params=['sorter'], - extra_params=_dedent(""" - method : str - One of 'scan' (default), 'scan_unrolled', 'sort' or 'compare_all'. Controls the method used by the - implementation: 'scan' tends to be more performant on CPU (particularly when ``a`` is - very large), 'scan_unrolled' is more performant on GPU at the expense of additional compile time, - 'sort' is often more performant on accelerator backends like GPU and TPU - (particularly when ``v`` is very large), and 'compare_all' can be most performant - when ``a`` is very small.""")) -@partial(jit, static_argnames=('side', 'sorter', 'method')) +@partial(jit, static_argnames=('side', 'method')) def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', - sorter: None = None, *, method: str = 'scan') -> Array: - util.check_arraylike("searchsorted", a, v) + sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: + """Perform a binary search within a sorted array. + + JAX implementation of :func:`numpy.searchsorted`. + + This will return the indices within a sorted array ``a`` where values in ``v`` + can be inserted to maintain its sort order. + + Args: + a: one-dimensional array, assumed to be in sorted order unless ``sorter`` is specified. + v: N-dimensional array of query values + side: ``'left'`` (default) or ``'right'``; specifies whether insertion indices will be + to the left or the right in case of ties. + sorter: optional array of indices specifying the sort order of ``a``. If specified, + then the algorithm assumes that ``a[sorter]`` is in sorted order. + method: one of ``'scan'`` (default), ``'scan_unrolled'``, ``'sort'`` or ``'compare_all'``. + See *Note* below. + + Returns: + Array of insertion indices of shape ``v.shape``. + + Note: + The ``method`` argument controls the algorithm used to compute the insertion indices. + + - ``'scan'`` (the default) tends to be more performant on CPU, particularly when ``a`` is + very large. + - ``'scan_unrolled'`` is more performant on GPU at the expense of additional compile time. + - ``'sort'`` is often more performant on accelerator backends like GPU and TPU, particularly + when ``v`` is very large. + - ``'compare_all'`` tends to be the most performant when ``a`` is very small. + + Examples: + Searching for a single value: + + >>> a = jnp.array([1, 2, 2, 3, 4, 5, 5]) + >>> jnp.searchsorted(a, 2) + Array(1, dtype=int32) + >>> jnp.searchsorted(a, 2, side='right') + Array(3, dtype=int32) + + Searching for a batch of values: + + >>> vals = jnp.array([0, 3, 8, 1.5, 2]) + >>> jnp.searchsorted(a, vals) + Array([0, 3, 7, 1, 1], dtype=int32) + + Optionally, the ``sorter`` argument can be used to find insertion indices into + an array sorted via :func:`jax.numpy.argsort`: + + >>> a = jnp.array([4, 3, 5, 1, 2]) + >>> sorter = jnp.argsort(a) + >>> jnp.searchsorted(a, vals, sorter=sorter) + Array([0, 2, 5, 1, 1], dtype=int32) + + The result is equivalent to passing the sorted array: + + >>> jnp.searchsorted(jnp.sort(a), vals) + Array([0, 2, 5, 1, 1], dtype=int32) + """ + if sorter is None: + util.check_arraylike("searchsorted", a, v) + else: + util.check_arraylike("searchsorted", a, v, sorter) if side not in ['left', 'right']: raise ValueError(f"{side!r} is an invalid value for keyword 'side'. " "Expected one of ['left', 'right'].") @@ -7098,11 +7968,11 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', raise ValueError( f"{method!r} is an invalid value for keyword 'method'. " "Expected one of ['sort', 'scan', 'scan_unrolled', 'compare_all'].") - if sorter is not None: - raise NotImplementedError("sorter is not implemented") if ndim(a) != 1: raise ValueError("a should be 1-dimensional") a, v = util.promote_dtypes(a, v) + if sorter is not None: + a = a[sorter] dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64 if len(a) == 0: return zeros_like(v, dtype=dtype) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index e15788b0cc12..63aca76e6098 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -103,7 +103,7 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: - :func:`jax.scipy.linalg.cholesky`: SciPy-style Cholesky API - :func:`jax.lax.linalg.cholesky`: XLA-style Cholesky API - Example: + Examples: A small real Hermitian positive-definite matrix: >>> x = jnp.array([[2., 1.], @@ -250,7 +250,7 @@ def svd( - :func:`jax.scipy.linalg.svd`: SciPy-style SVD API - :func:`jax.lax.linalg.svd`: XLA-style SVD API - Example: + Examples: Consider the SVD of a small real-valued array: >>> x = jnp.array([[1., 2., 3.], @@ -496,7 +496,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]: @partial(jit, static_argnames=('method',)) def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: """ - Computes the sign and (natural) logarithm of the determinant of an array. + Compute the sign and (natural) logarithm of the determinant of an array. JAX implementation of :func:`numpy.linalg.slotdet`. @@ -662,7 +662,7 @@ def _det_3x3(a: Array) -> Array: @jit def det(a: ArrayLike) -> Array: """ - Computes the determinant of an array. + Compute the determinant of an array. JAX implementation of :func:`numpy.linalg.det`. @@ -706,7 +706,7 @@ def _det_jvp(primals, tangents): def eig(a: ArrayLike) -> tuple[Array, Array]: """ - Computes the eigenvalues and eigenvectors of a square array. + Compute the eigenvalues and eigenvectors of a square array. JAX implementation of :func:`numpy.linalg.eig`. @@ -750,7 +750,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: @jit def eigvals(a: ArrayLike) -> Array: """ - Computes the eigenvalues of a general matrix. + Compute the eigenvalues of a general matrix. JAX implementation of :func:`numpy.linalg.eigvals`. @@ -788,14 +788,14 @@ def eigvals(a: ArrayLike) -> Array: def eigh(a: ArrayLike, UPLO: str | None = None, symmetrize_input: bool = True) -> EighResult: """ - Computes the eigenvalues and eigenvectors of a Hermitian matrix. + Compute the eigenvalues and eigenvectors of a Hermitian matrix. JAX implementation of :func:`numpy.linalg.eigh`. Args: a: array of shape ``(..., M, M)``, containing the Hermitian (if complex) or symmetric (if real) matrix. - UPLO: specifies whether the calculation isdone with the lower triangular + UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). symmetrize_input: if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. @@ -842,7 +842,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None, @partial(jit, static_argnames=('UPLO',)) def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: """ - Computes the eigenvalues of a Hermitian matrix. + Compute the eigenvalues of a Hermitian matrix. JAX implementation of :func:`numpy.linalg.eigvalsh`. @@ -1007,7 +1007,7 @@ def inv(a: ArrayLike) -> Array: - :func:`jax.scipy.linalg.inv`: SciPy-style API for matrix inverse - :func:`jax.numpy.linalg.solve`: direct linear solver - Example: + Examples: Compute the inverse of a 3x3 matrix >>> a = jnp.array([[1., 2., 3.], @@ -1249,7 +1249,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: See also: - :func:`jax.scipy.linalg.qr`: SciPy-style QR decomposition API - - :func:`jax.lax.linalg.qr`: XLA-style QR decompositon API + - :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API Examples: Compute the QR decomposition of a matrix: @@ -1316,7 +1316,7 @@ def solve(a: ArrayLike, b: ArrayLike) -> Array: - :func:`jax.scipy.linalg.solve`: SciPy-style API for solving linear systems. - :func:`jax.lax.custom_linear_solve`: matrix-free linear solver. - Example: + Examples: A simple 3x3 linear system: >>> A = jnp.array([[1., 2., 3.], @@ -1422,7 +1422,7 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *, - ``rank`` is the rank of the matrix ``a``. - ``s`` is the singular values of the matrix ``a``. - Example: + Examples: >>> a = jnp.array([[1, 2], ... [3, 4]]) >>> b = jnp.array([5, 6]) @@ -1438,12 +1438,12 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *, def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): - r"""Compute the corss-product of two 3D vectors + r"""Compute the cross-product of two 3D vectors JAX implementation of :func:`numpy.linalg.cross` Args: - x1: N-dimesional array, with ``x1.shape[axis] == 3`` + x1: N-dimensional array, with ``x1.shape[axis] == 3`` x2: N-dimensional array, with ``x2.shape[axis] == 3``, and other axes broadcast-compatible with ``x1``. axis: axis along which to take the cross product (default: -1). @@ -1454,7 +1454,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): See Also: :func:`jax.numpy.cross`: more flexible cross-product API. - Example: + Examples: Showing that :math:`\hat{x} \times \hat{y} = \hat{z}`: @@ -1497,7 +1497,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: See also: :func:`jax.numpy.outer`: similar function in the main :mod:`jax.numpy` module. - Example: + Examples: >>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> jnp.linalg.outer(x1, x2) @@ -1599,7 +1599,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False, ord: int | str = 2) -> Array: - """Computes the vector norm of a vector or batch of vectors. + """Compute the vector norm of a vector or batch of vectors. JAX implementation of :func:`numpy.linalg.vector_norm`. @@ -1846,7 +1846,7 @@ def svdvals(x: ArrayLike, /) -> Array: See also: :func:`jax.numpy.linalg.svd`: compute singular values and singular vectors - Example: + Examples: >>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.svdvals(x) @@ -1916,7 +1916,7 @@ def tensorinv(a: ArrayLike, ind: int = 2) -> Array: - :func:`jax.numpy.linalg.tensordot` - :func:`jax.numpy.linalg.tensorsolve` - Example: + Examples: >>> key = jax.random.key(1337) >>> x = jax.random.normal(key, shape=(2, 2, 4)) >>> xinv = jnp.linalg.tensorinv(x, 2) @@ -2136,3 +2136,47 @@ def cond(x: ArrayLike, p=None): r = norm(x, ord=p, axis=(-2, -1)) * norm(inv(x), ord=p, axis=(-2, -1)) # Convert NaNs to infs where original array has no NaNs. return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r) + + +def trace(x: ArrayLike, /, *, + offset: int = 0, dtype: DTypeLike | None = None) -> Array: + """Compute the trace of a matrix. + + JAX implementation of :func:`numpy.linalg.trace`. + + Args: + x: array of shape ``(..., M, N)`` and whose innermost two + dimensions form MxN matrices for which to take the trace. + offset: positive or negative offset from the main diagonal + (default: 0). + dtype: data type of the returned array (default: ``None``). If ``None``, + then output dtype will match the dtype of ``x``, promoted to default + precision in the case of integer types. + + Returns: + array of batched traces with shape ``x.shape[:-2]`` + + See also: + - :func:`jax.numpy.trace`: similar API in the ``jax.numpy`` namespace. + + Examples: + Trace of a single matrix: + + >>> x = jnp.array([[1, 2, 3, 4], + ... [5, 6, 7, 8], + ... [9, 10, 11, 12]]) + >>> jnp.linalg.trace(x) + Array(18, dtype=int32) + >>> jnp.linalg.trace(x, offset=1) + Array(21, dtype=int32) + >>> jnp.linalg.trace(x, offset=-1, dtype="float32") + Array(15., dtype=float32) + + Batched traces: + + >>> x = jnp.arange(24).reshape(2, 3, 4) + >>> jnp.linalg.trace(x) + Array([15, 51], dtype=int32) + """ + check_arraylike('jnp.linalg.trace', x) + return jnp.trace(x, offset=offset, axis1=-2, axis2=-1, dtype=dtype) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 9e82284f7cc4..45595c4387a2 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -57,35 +57,49 @@ def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array: return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan)) -@implements(np.roots, lax_description="""\ -Unlike the numpy version of this function, the JAX version returns the roots in -a complex array regardless of the values of the roots. Additionally, the jax -version of this function adds the ``strip_zeros`` function which must be set to -False for the function to be compatible with JIT and other JAX transformations. -With ``strip_zeros=False``, if your coefficients have leading zeros, the -roots will be padded with NaN values: - ->>> coeffs = jnp.array([0, 1, 2]) - -# The default behavior matches numpy and strips leading zeros: ->>> jnp.roots(coeffs) -Array([-2.+0.j], dtype=complex64) - -# With strip_zeros=False, extra roots are set to NaN: ->>> jnp.roots(coeffs, strip_zeros=False) -Array([-2. +0.j, nan+nanj], dtype=complex64) -""", -extra_params=""" -strip_zeros : bool, default=True - If set to True, then leading zeros in the coefficients will be stripped, similar - to :func:`numpy.roots`. If set to False, leading zeros will not be stripped, and - undefined roots will be represented by NaN values in the function output. - ``strip_zeros`` must be set to ``False`` for the function to be compatible with - :func:`jax.jit` and other JAX transformations. -""") def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: + r"""Returns the roots of a polynomial given the coefficients ``p``. + + JAX implementations of :func:`numpy.roots`. + + Args: + p: Array of polynomial coefficients having rank-1. + strip_zeros : bool, default=True. If True, then leading zeros in the + coefficients will be stripped, similar to :func:`numpy.roots`. If set to + False, leading zeros will not be stripped, and undefined roots will be + represented by NaN values in the function output. ``strip_zeros`` must be + set to ``False`` for the function to be compatible with :func:`jax.jit` and + other JAX transformations. + + Returns: + An array containing the roots of the polynomial. + + Note: + Unlike ``np.roots`` of this function, the ``jnp.roots`` returns the roots + in a complex array regardless of the values of the roots. + + See Also: + - :func:`jax.numpy.poly`: Finds the polynomial coefficients of the given + sequence of roots. + - :func:`jax.numpy.polyfit`: Least squares polynomial fit to data. + - :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values. + + Examples: + >>> coeffs = jnp.array([0, 1, 2]) + + The default behavior matches numpy and strips leading zeros: + + >>> jnp.roots(coeffs) + Array([-2.+0.j], dtype=complex64) + + With ``strip_zeros=False``, extra roots are set to NaN: + + >>> jnp.roots(coeffs, strip_zeros=False) + Array([-2. +0.j, nan+nanj], dtype=complex64) + """ check_arraylike("roots", p) p_arr = atleast_1d(promote_dtypes_inexact(p)[0]) + del p if p_arr.ndim != 1: raise ValueError("Input must be a rank-1 array.") if p_arr.size < 2: @@ -102,51 +116,149 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: return _roots_with_zeros(p_arr, num_leading_zeros) -_POLYFIT_DOC = """\ -Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix -Also, it works best on rcond <= 10e-3 values. -""" -@implements(np.polyfit, lax_description=_POLYFIT_DOC) @partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov')) -def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, - full: bool = False, w: Array | None = None, cov: bool = False +def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, + full: bool = False, w: ArrayLike | None = None, cov: bool = False ) -> Array | tuple[Array, ...]: - check_arraylike("polyfit", x, y) + r"""Least squares polynomial fit to data. + + Jax implementation of :func:`numpy.polyfit`. + + Given a set of data points ``(x, y)`` and degree of polynomial ``deg``, the + function finds a polynomial equation of the form: + + .. math:: + + y = p(x) = p[0] x^{deg} + p[1] x^{deg - 1} + ... + p[deg] + + Args: + x: Array of data points of shape ``(M,)``. + y: Array of data points of shape ``(M,)`` or ``(M, K)``. + deg: Degree of the polynomials. It must be specified statically. + rcond: Relative condition number of the fit. Default value is ``len(x) * eps``. + It must be specified statically. + full: Switch that controls the return value. Default is ``False`` which + restricts the return value to the array of polynomail coefficients ``p``. + If ``True``, the function returns a tuple ``(p, resids, rank, s, rcond)``. + It must be specified statically. + w: Array of weights of shape ``(M,)``. If None, all data points are considered + to have equal weight. If not None, the weight :math:`w_i` is applied to the + unsquared residual of :math:`y_i - \widehat{y}_i` at :math:`x_i`, where + :math:`\widehat{y}_i` is the fitted value of :math:`y_i`. Default is None. + cov: Boolean or string. If ``True``, returns the covariance matrix scaled + by ``resids/(M-deg-1)`` along with ploynomial coefficients. If + ``cov='unscaled'``, returns the unscaaled version of covariance matrix. + Default is ``False``. ``cov`` is ignored if ``full=True``. It must be + specified statically. + + Returns: + - An array polynomial coefficients ``p`` if ``full=False`` and ``cov=False``. + + - A tuple of arrays ``(p, resids, rank, s, rcond)`` if ``full=True``. Where + + - ``p`` is an array of shape ``(M,)`` or ``(M, K)`` containing the polynomial + coefficients. + - ``resids`` is the sum of squared residual of shape () or (K,). + - ``rank`` is the rank of the matrix ``x``. + - ``s`` is the singular values of the matrix ``x``. + - ``rcond`` as the array. + - A tuple of arrays ``(p, C)`` if ``full=False`` and ``cov=True``. Where + + - ``p`` is an array of shape ``(M,)`` or ``(M, K)`` containing the polynomial + coefficients. + - ``C`` is the covariance matrix of polynomial coefficients of shape + ``(deg + 1, deg + 1)`` or ``(deg + 1, deg + 1, 1)``. + + Note: + Unlike :func:`numpy.polyfit` implementation of polyfit, :func:`jax.numpy.polyfit` + will not warn on rank reduction, which indicates an ill conditioned matrix. + + See Also: + - :func:`jax.numpy.poly`: Finds the polynomial coefficients of the given + sequence of roots. + - :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values. + - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given + coefficients. + + Examples: + >>> x = jnp.array([3., 6., 9., 4.]) + >>> y = jnp.array([[0, 1, 2], + ... [2, 5, 7], + ... [8, 4, 9], + ... [1, 6, 3]]) + >>> p = jnp.polyfit(x, y, 2) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(p) + [[ 0.2 -0.35 -0.14] + [-1.17 4.47 2.96] + [ 1.95 -8.21 -5.93]] + + If ``full=True``, returns a tuple of arrays as follows: + + >>> p, resids, rank, s, rcond = jnp.polyfit(x, y, 2, full=True) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print("Polynomial Coefficients:", "\n", p, "\n", + ... "Residuals:", resids, "\n", + ... "Rank:", rank, "\n", + ... "s:", s, "\n", + ... "rcond:", rcond) + Polynomial Coefficients: + [[ 0.2 -0.35 -0.14] + [-1.17 4.47 2.96] + [ 1.95 -8.21 -5.93]] + Residuals: [0.37 5.94 0.61] + Rank: 3 + s: [1.67 0.47 0.04] + rcond: 4.7683716e-07 + + If ``cov=True`` and ``full=False``, returns a tuple of arrays having + polynomial coefficients and covariance matrix. + + >>> p, C = jnp.polyfit(x, y, 2, cov=True) + >>> p.shape, C.shape + ((3, 3), (3, 3, 1)) + """ + if w is None: + check_arraylike("polyfit", x, y) + else: + check_arraylike("polyfit", x, y, w) deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 # check arguments + x_arr, y_arr = asarray(x), asarray(y) + del x, y if deg < 0: raise ValueError("expected deg >= 0") - if x.ndim != 1: + if x_arr.ndim != 1: raise TypeError("expected 1D vector for x") - if x.size == 0: + if x_arr.size == 0: raise TypeError("expected non-empty vector for x") - if y.ndim < 1 or y.ndim > 2: + if y_arr.ndim < 1 or y_arr.ndim > 2: raise TypeError("expected 1D or 2D array for y") - if x.shape[0] != y.shape[0]: + if x_arr.shape[0] != y_arr.shape[0]: raise TypeError("expected x and y to have same length") # set rcond if rcond is None: - rcond = len(x) * finfo(x.dtype).eps + rcond = len(x_arr) * finfo(x_arr.dtype).eps rcond = core.concrete_or_error(float, rcond, "rcond must be float") # set up least squares equation for powers of x - lhs = vander(x, order) - rhs = y + lhs = vander(x_arr, order) + rhs = y_arr # apply weighting if w is not None: - check_arraylike("polyfit", w) w, = promote_dtypes_inexact(w) - if w.ndim != 1: + w_arr = asarray(w) + if w_arr.ndim != 1: raise TypeError("expected a 1-d array for weights") - if w.shape[0] != y.shape[0]: + if w_arr.shape[0] != y_arr.shape[0]: raise TypeError("expected w and y to have the same length") - lhs *= w[:, np.newaxis] + lhs *= w_arr[:, np.newaxis] if rhs.ndim == 2: - rhs *= w[:, np.newaxis] + rhs *= w_arr[:, np.newaxis] else: - rhs *= w + rhs *= w_arr # scale lhs to improve condition number and solve scale = sqrt((lhs*lhs).sum(axis=0)) @@ -162,12 +274,12 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, if cov == "unscaled": fac = 1 else: - if len(x) <= order: + if len(x_arr) <= order: raise ValueError("the number of data points must exceed order " "to scale the covariance matrix") - fac = resids / (len(x) - order) + fac = resids / (len(x_arr) - order) fac = fac[0] #making np.array() of shape (1,) to int - if y.ndim == 1: + if y_arr.ndim == 1: return c, Vbase * fac else: return c, Vbase[:, :, np.newaxis] * fac @@ -175,80 +287,214 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, return c -_POLY_DOC = """\ -This differs from np.poly when an integer array is given. -np.poly returns a result with dtype float64 in this case. -jax returns a result with an inexact type, but not necessarily -float64. +@jit +def poly(seq_of_zeros: ArrayLike) -> Array: + r"""Returns the coefficients of a polynomial for the given sequence of roots. -This also differs from np.poly when the input array strictly -contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j]. -np.poly returns an array with a real dtype in such cases. -jax returns an array with a complex dtype in such cases. -""" + JAX implementation of :func:`numpy.poly`. -@implements(np.poly, lax_description=_POLY_DOC) -@jit -def poly(seq_of_zeros: Array) -> Array: + Args: + seq_of_zeros: A scalar or an array of roots of the polynomial of shape ``(M,)`` + or ``(M, M)``. + + Returns: + An array containing the coefficients of the polynomial. The dtype of the + output is always promoted to inexact. + + Note: + + :func:`jax.numpy.poly` differs from :func:`numpy.poly`: + + - When the input is a scalar, ``np.poly`` raises a ``TypeError``, whereas + ``jnp.poly`` treats scalars the same as length-1 arrays. + - For complex-valued or square-shaped inputs, ``jnp.poly`` always returns + complex coefficients, whereas ``np.poly`` may return real or complex + depending on their values. + + See also: + - :func:`jax.numpy.polyfit`: Least squares polynomial fit. + - :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values. + - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given + coefficients. + + Example: + + Scalar inputs: + + >>> jnp.poly(1) + Array([ 1., -1.], dtype=float32) + + Input array with integer values: + + >>> x = jnp.array([1, 2, 3]) + >>> jnp.poly(x) + Array([ 1., -6., 11., -6.], dtype=float32) + + Input array with complex conjugates: + + >>> x = jnp.array([2, 1+2j, 1-2j]) + >>> jnp.poly(x) + Array([ 1.+0.j, -4.+0.j, 9.+0.j, -10.+0.j], dtype=complex64) + + Input array as square matrix with real valued inputs: + + >>> x = jnp.array([[2, 1, 5], + ... [3, 4, 7], + ... [1, 3, 5]]) + >>> jnp.round(jnp.poly(x)) + Array([ 1.+0.j, -11.-0.j, 9.+0.j, -15.+0.j], dtype=complex64) + """ check_arraylike('poly', seq_of_zeros) seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros) - seq_of_zeros = atleast_1d(seq_of_zeros) + seq_of_zeros_arr = atleast_1d(seq_of_zeros) + del seq_of_zeros - sh = seq_of_zeros.shape + sh = seq_of_zeros_arr.shape if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0: # import at runtime to avoid circular import from jax._src.numpy import linalg - seq_of_zeros = linalg.eigvals(seq_of_zeros) + seq_of_zeros_arr = linalg.eigvals(seq_of_zeros_arr) - if seq_of_zeros.ndim != 1: + if seq_of_zeros_arr.ndim != 1: raise ValueError("input must be 1d or non-empty square 2d array.") - dt = seq_of_zeros.dtype - if len(seq_of_zeros) == 0: + dt = seq_of_zeros_arr.dtype + if len(seq_of_zeros_arr) == 0: return ones((), dtype=dt) a = ones((1,), dtype=dt) - for k in range(len(seq_of_zeros)): - a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full') + for k in range(len(seq_of_zeros_arr)): + a = convolve(a, array([1, -seq_of_zeros_arr[k]], dtype=dt), mode='full') return a -@implements(np.polyval, lax_description="""\ -The ``unroll`` parameter is JAX specific. It does not effect correctness but can -have a major impact on performance for evaluating high-order polynomials. The -parameter controls the number of unrolled steps with ``lax.scan`` inside the -``polyval`` implementation. Consider setting ``unroll=128`` (or even higher) to -improve runtime performance on accelerators, at the cost of increased -compilation time. -""") @partial(jit, static_argnames=['unroll']) -def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array: +def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: + r"""Evaluates the polynomial at specific values. + + JAX implementations of :func:`numpy.polyval`. + + For the 1D-polynomial coefficients ``p`` of length ``M``, the function returns + the value: + + .. math:: + + p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1} + + Args: + p: An array of polynomial coefficients of shape ``(M,)``. + x: A number or an array of numbers. + unroll: A number used to control the number of unrolled steps with + ``lax.scan``. It must be specified statically. + + Returns: + An array of same shape as ``x``. + + Note: + + The ``unroll`` parameter is JAX specific. It does not affect correctness but + can have a major impact on performance for evaluating high-order polynomials. + The parameter controls the number of unrolled steps with ``lax.scan`` inside + the ``jnp.polyval`` implementation. Consider setting ``unroll=128`` (or even + higher) to improve runtime performance on accelerators, at the cost of + increased compilation time. + + See also: + - :func:`jax.numpy.polyfit`: Least squares polynomial fit. + - :func:`jax.numpy.poly`: Finds the coefficients of a polynomial with given + roots. + - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given + coefficients. + + Example: + >>> p = jnp.array([2, 5, 1]) + >>> jnp.polyval(p, 3) + Array(34., dtype=float32) + + If ``x`` is a 2D array, ``polyval`` returns 2D-array with same shape as + that of ``x``: + + >>> x = jnp.array([[2, 1, 5], + ... [3, 4, 7], + ... [1, 3, 5]]) + >>> jnp.polyval(p, x) + Array([[ 19., 8., 76.], + [ 34., 53., 134.], + [ 8., 34., 76.]], dtype=float32) + """ check_arraylike("polyval", p, x) - p, x = promote_dtypes_inexact(p, x) - shape = lax.broadcast_shapes(p.shape[1:], x.shape) - y = lax.full_like(x, 0, shape=shape, dtype=x.dtype) - y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll) + p_arr, x_arr = promote_dtypes_inexact(p, x) + del p, x + shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape) + y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype) + y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) return y @implements(np.polyadd) @jit -def polyadd(a1: Array, a2: Array) -> Array: +def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: check_arraylike("polyadd", a1, a2) - a1, a2 = promote_dtypes(a1, a2) - if a2.shape[0] <= a1.shape[0]: - return a1.at[-a2.shape[0]:].add(a2) + a1_arr, a2_arr = promote_dtypes(a1, a2) + del a1, a2 + if a2_arr.shape[0] <= a1_arr.shape[0]: + return a1_arr.at[-a2_arr.shape[0]:].add(a2_arr) else: - return a2.at[-a1.shape[0]:].add(a1) + return a2_arr.at[-a1_arr.shape[0]:].add(a1_arr) -@implements(np.polyint) @partial(jit, static_argnames=('m',)) -def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: +def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array: + r"""Returns the coefficients of the integration of specified order of a polynomial. + + JAX implementation of :func:`numpy.polyint`. + + Args: + p: An array of polynomial coefficients. + m: Order of integration. Default is 1. It must be specified statically. + k: Scalar or array of ``m`` integration constant (s). + + Returns: + An array of coefficients of integrated polynomial. + + See also: + - :func:`jax.numpy.polyder`: Computes the coefficients of the derivative of + a polynomial. + - :func:`jax.numpy.polyval`: Evaluates a polynomial at specific values. + + Examples: + + The first order integration of the polynomial :math:`12 x^2 + 12 x + 6` is + :math:`4 x^3 + 6 x^2 + 6 x`. + + >>> p = jnp.array([12, 12, 6]) + >>> jnp.polyint(p) + Array([4., 6., 6., 0.], dtype=float32) + + Since the constant ``k`` is not provided, the result included ``0`` at the end. + If the constant ``k`` is provided: + + >>> jnp.polyint(p, k=4) + Array([4., 6., 6., 4.], dtype=float32) + + and the second order integration is :math:`x^4 + 2 x^3 + 3 x`: + + >>> jnp.polyint(p, m=2) + Array([1., 2., 3., 0., 0.], dtype=float32) + + When ``m>=2``, the constants ``k`` should be provided as an array having + ``m`` elements. The second order integration of the polynomial + :math:`12 x^2 + 12 x + 6` with the constants ``k=[4, 5]`` is + :math:`x^4 + 2 x^3 + 3 x^2 + 4 x + 5`: + + >>> jnp.polyint(p, m=2, k=jnp.array([4, 5])) + Array([1., 2., 3., 4., 5.], dtype=float32) + """ m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k check_arraylike("polyint", p, k) - p, k_arr = promote_dtypes_inexact(p, k) + p_arr, k_arr = promote_dtypes_inexact(p, k) + del p, k if m < 0: raise ValueError("Order of integral must be positive (see polyder)") k_arr = atleast_1d(k_arr) @@ -257,27 +503,62 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: if k_arr.shape != (m,): raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.") if m == 0: - return p + return p_arr else: - grid = (arange(len(p) + m, dtype=p.dtype)[np.newaxis] - - arange(m, dtype=p.dtype)[:, np.newaxis]) + grid = (arange(len(p_arr) + m, dtype=p_arr.dtype)[np.newaxis] + - arange(m, dtype=p_arr.dtype)[:, np.newaxis]) coeff = maximum(1, grid).prod(0)[::-1] - return true_divide(concatenate((p, k_arr)), coeff) + return true_divide(concatenate((p_arr, k_arr)), coeff) -@implements(np.polyder) @partial(jit, static_argnames=('m',)) -def polyder(p: Array, m: int = 1) -> Array: +def polyder(p: ArrayLike, m: int = 1) -> Array: + r"""Returns the coefficients of the derivative of specified order of a polynomial. + + JAX implementation of :func:`numpy.polyder`. + + Args: + p: Array of polynomials coefficients. + m: Order of differentiation (positive integer). Default is 1. It must be + specified statically. + + Returns: + An array of polynomial coefficients representing the derivative. + + Note: + :func:`jax.numpy.polyder` differs from :func:`numpy.polyder` when an integer + array is given. NumPy returns the result with dtype ``int`` whereas JAX + returns the result with dtype ``float``. + + See also: + - :func:`jax.numpy.polyint`: Computes the integral of polynomial. + - :func:`jax.numpy.polyval`: Evaluates a polynomial at specific values. + + Examples: + + The first order derivative of the polynomial :math:`2 x^3 - 5 x^2 + 3 x - 1` + is :math:`6 x^2 - 10 x +3`: + + >>> p = jnp.array([2, -5, 3, -1]) + >>> jnp.polyder(p) + Array([ 6., -10., 3.], dtype=float32) + + and its second order derivative is :math:`12 x - 10`: + + >>> jnp.polyder(p, m=2) + Array([ 12., -10.], dtype=float32) + """ check_arraylike("polyder", p) m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder") - p, = promote_dtypes_inexact(p) + p_arr, = promote_dtypes_inexact(p) + del p if m < 0: raise ValueError("Order of derivative must be positive") if m == 0: - return p - coeff = (arange(m, len(p), dtype=p.dtype)[np.newaxis] - - arange(m, dtype=p.dtype)[:, np.newaxis]).prod(0) - return p[:-m] * coeff[::-1] + return p_arr + coeff = (arange(m, len(p_arr), dtype=p_arr.dtype)[np.newaxis] + - arange(m, dtype=p_arr.dtype)[:, np.newaxis]).prod(0) + return p_arr[:-m] * coeff[::-1] _LEADING_ZEROS_DOC = """\ @@ -292,6 +573,7 @@ def polyder(p: Array, m: int = 1) -> Array: def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array: check_arraylike("polymul", a1, a2) a1_arr, a2_arr = promote_dtypes_inexact(a1, a2) + del a1, a2 if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1): a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f') if len(a1_arr) == 0: @@ -304,6 +586,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]: check_arraylike("polydiv", u, v) u_arr, v_arr = promote_dtypes_inexact(u, v) + del u, v m = len(u_arr) - 1 n = len(v_arr) - 1 scale = 1. / v_arr[0] @@ -319,7 +602,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> @implements(np.polysub) @jit -def polysub(a1: Array, a2: Array) -> Array: +def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: check_arraylike("polysub", a1, a2) a1, a2 = promote_dtypes(a1, a2) return polyadd(a1, -a2) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 23432b45b102..bcf6885dc468 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -15,11 +15,11 @@ from __future__ import annotations import builtins -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import math import operator -from typing import overload, Any, Callable, Literal, Protocol, Union +from typing import overload, Any, Literal, Protocol, Union import warnings import numpy as np @@ -204,15 +204,6 @@ def force(x): force, x, "The axis argument must be known statically.") -# TODO(jakevdp) change promote_integers default to False -_PROMOTE_INTEGERS_DOC = """ -promote_integers : bool, default=True - If True, then integer inputs will be promoted to the widest available integer - dtype, following numpy's behavior. If False, the result will have the same dtype - as the input. ``promote_integers`` is ignored if ``dtype`` is specified. -""" - - @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, @@ -224,10 +215,76 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, initial=initial, where_=where, parallel_reduce=lax.psum, promote_integers=promote_integers) -@implements(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) + def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: + r"""Sum of the elements of the array over a given axis. + + JAX implementation of :func:`numpy.sum`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which the sum to be computed. + If None, the sum is computed along all the axes. + dtype: The type of the output array. Default=None. + out: Unused by JAX + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + initial: int or array, Default=None. Initial value for the sum. + where: int or array, default=None. The elements to be used in the sum. Array + should be broadcast compatible to the input. + promote_integers : bool, default=True. If True, then integer inputs will be + promoted to the widest available integer dtype, following numpy's behavior. + If False, the result will have the same dtype as the input. + ``promote_integers`` is ignored if ``dtype`` is specified. + + Returns: + An array of the sum along the given axis. + + See also: + - :func:`jax.numpy.prod`: Compute the product of array elements over a given + axis. + - :func:`jax.numpy.max`: Compute the maximum of array elements over given axis. + - :func:`jax.numpy.min`: Compute the minimum of array elements over given axis. + + Examples: + + By default, the sum is computed along all the axes. + + >>> x = jnp.array([[1, 3, 4, 2], + ... [5, 2, 6, 3], + ... [8, 1, 3, 9]]) + >>> jnp.sum(x) + Array(47, dtype=int32) + + If ``axis=1``, the sum is computed along axis 1. + + >>> jnp.sum(x, axis=1) + Array([10, 16, 21], dtype=int32) + + If ``keepdims=True``, ``ndim`` of the output is equal to that of the input. + + >>> jnp.sum(x, axis=1, keepdims=True) + Array([[10], + [16], + [21]], dtype=int32) + + To include only specific elements in the sum, you can use a``where``. + + >>> where=jnp.array([[0, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.sum(x, axis=1, keepdims=True, where=where) + Array([[ 4], + [ 9], + [12]], dtype=int32) + >>> where=jnp.array([[False], + ... [False], + ... [False]]) + >>> jnp.sum(x, axis=0, keepdims=True, where=where) + Array([[0, 0, 0, 0]], dtype=int32) + """ return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers) @@ -243,11 +300,77 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, promote_integers=promote_integers) -@implements(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) + def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: + r"""Return product of the array elements over a given axis. + + JAX implementation of :func:`numpy.prod`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which the product to be computed. + If None, the product is computed along all the axes. + dtype: The type of the output array. Default=None. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + initial: int or array, Default=None. Initial value for the product. + where: int or array, default=None. The elements to be used in the product. + Array should be broadcast compatible to the input. + promote_integers : bool, default=True. If True, then integer inputs will be + promoted to the widest available integer dtype, following numpy's behavior. + If False, the result will have the same dtype as the input. + ``promote_integers`` is ignored if ``dtype`` is specified. + out: Unused by JAX. + + Returns: + An array of the product along the given axis. + + See also: + - :func:`jax.numpy.sum`: Compute the sum of array elements over a given axis. + - :func:`jax.numpy.max`: Compute the maximum of array elements over given axis. + - :func:`jax.numpy.min`: Compute the minimum of array elements over given axis. + + Examples: + By default, ``jnp.prod`` computes along all the axes. + + >>> x = jnp.array([[1, 3, 4, 2], + ... [5, 2, 1, 3], + ... [2, 1, 3, 1]]) + >>> jnp.prod(x) + Array(4320, dtype=int32) + + If ``axis=1``, product is computed along axis 1. + + >>> jnp.prod(x, axis=1) + Array([24, 30, 6], dtype=int32) + + If ``keepdims=True``, ``ndim`` of the output is equal to that of the input. + + >>> jnp.prod(x, axis=1, keepdims=True) + Array([[24], + [30], + [ 6]], dtype=int32) + + To include only specific elements in the sum, you can use a``where``. + + >>> where=jnp.array([[1, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.prod(x, axis=1, keepdims=True, where=where) + Array([[4], + [3], + [6]], dtype=int32) + >>> where = jnp.array([[False], + ... [False], + ... [False]]) + >>> jnp.prod(x, axis=1, keepdims=True, where=where) + Array([[1], + [1], + [1]], dtype=int32) + """ return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers) @@ -261,10 +384,77 @@ def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax) -@implements(np.max, skip_params=['out']) + def max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + r"""Return the maximum of the array elements along a given axis. + + JAX implementation of :func:`numpy.max`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which the maximum to be computed. + If None, the maximum is computed along all the axes. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + initial: int or array, default=None. Initial value for the maximum. + where: int or array of boolean dtype, default=None. The elements to be used + in the maximum. Array should be broadcast compatible to the input. + ``initial`` must be specified when ``where`` is used. + out: Unused by JAX. + + Returns: + An array of maximum values along the given axis. + + See also: + - :func:`jax.numpy.min`: Compute the minimum of array elements along a given + axis. + - :func:`jax.numpy.sum`: Compute the sum of array elements along a given axis. + - :func:`jax.numpy.prod`: Compute the product of array elements along a given + axis. + + Examples: + + By default, ``jnp.max`` computes the maximum of elements along all the axes. + + >>> x = jnp.array([[9, 3, 4, 5], + ... [5, 2, 7, 4], + ... [8, 1, 3, 6]]) + >>> jnp.max(x) + Array(9, dtype=int32) + + If ``axis=1``, the maximum will be computed along axis 1. + + >>> jnp.max(x, axis=1) + Array([9, 7, 8], dtype=int32) + + If ``keepdims=True``, ``ndim`` of the output will be same of that of the input. + + >>> jnp.max(x, axis=1, keepdims=True) + Array([[9], + [7], + [8]], dtype=int32) + + To include only specific elements in computing the maximum, you can use + ``where``. It can either have same dimension as input + + >>> where=jnp.array([[0, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.max(x, axis=1, keepdims=True, initial=0, where=where) + Array([[4], + [7], + [8]], dtype=int32) + + or must be broadcast compatible with input. + + >>> where = jnp.array([[False], + ... [False], + ... [False]]) + >>> jnp.max(x, axis=0, keepdims=True, initial=0, where=where) + Array([[0, 0, 0, 0]], dtype=int32) + """ return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) @@ -276,10 +466,76 @@ def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin) -@implements(np.min, skip_params=['out']) + def min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + r"""Return the minimum of array elements along a given axis. + + JAX implementation of :func:`numpy.min`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which the minimum to be computed. + If None, the minimum is computed along all the axes. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + initial: int or array, Default=None. Initial value for the minimum. + where: int or array, default=None. The elements to be used in the minimum. + Array should be broadcast compatible to the input. ``initial`` must be + specified when ``where`` is used. + out: Unused by JAX. + + Returns: + An array of minimum values along the given axis. + + See also: + - :func:`jax.numpy.max`: Compute the maximum of array elements along a given + axis. + - :func:`jax.numpy.sum`: Compute the sum of array elements along a given axis. + - :func:`jax.numpy.prod`: Compute the product of array elements along a given + axis. + + Examples: + By default, the minimum is computed along all the axes. + + >>> x = jnp.array([[2, 5, 1, 6], + ... [3, -7, -2, 4], + ... [8, -4, 1, -3]]) + >>> jnp.min(x) + Array(-7, dtype=int32) + + If ``axis=1``, the minimum is computed along axis 1. + + >>> jnp.min(x, axis=1) + Array([ 1, -7, -4], dtype=int32) + + If ``keepdims=True``, ``ndim`` of the output will be same of that of the input. + + >>> jnp.min(x, axis=1, keepdims=True) + Array([[ 1], + [-7], + [-4]], dtype=int32) + + To include only specific elements in computing the minimum, you can use + ``where``. ``where`` can either have same dimension as input. + + >>> where=jnp.array([[1, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.min(x, axis=1, keepdims=True, initial=0, where=where) + Array([[ 0], + [-2], + [-4]], dtype=int32) + + or must be broadcast compatible with input. + + >>> where = jnp.array([[False], + ... [False], + ... [False]]) + >>> jnp.min(x, axis=0, keepdims=True, initial=0, where=where) + Array([[0, 0, 0, 0]], dtype=int32) + """ return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) @@ -289,9 +545,53 @@ def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) -@implements(np.all, skip_params=['out']) + def all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: + r"""Test whether all array elements along a given axis evaluate to True. + + JAX implementation of :func:`numpy.all`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which to be tested. If None, + tests along all the axes. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + where: int or array of boolean dtype, default=None. The elements to be used + in the test. Array should be broadcast compatible to the input. + out: Unused by JAX. + + Returns: + An array of boolean values. + + Examples: + By default, ``jnp.all`` tests for True values along all the axes. + + >>> x = jnp.array([[True, True, True, False], + ... [True, False, True, False], + ... [True, True, False, False]]) + >>> jnp.all(x) + Array(False, dtype=bool) + + If ``axis=0``, tests for True values along axis 0. + + >>> jnp.all(x, axis=0) + Array([ True, False, False, False], dtype=bool) + + If ``keepdims=True``, ``ndim`` of the output will be same of that of the input. + + >>> jnp.all(x, axis=0, keepdims=True) + Array([[ True, False, False, False]], dtype=bool) + + To include specific elements in testing for True values, you can use a``where``. + + >>> where=jnp.array([[1, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.all(x, axis=0, keepdims=True, where=where) + Array([[ True, True, False, False]], dtype=bool) + """ return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) @@ -301,14 +601,69 @@ def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) -@implements(np.any, skip_params=['out']) + def any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: + r"""Test whether any of the array elements along a given axis evaluate to True. + + JAX implementation of :func:`numpy.any`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which to be tested. If None, + tests along all the axes. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + where: int or array of boolean dtype, default=None. The elements to be used + in the test. Array should be broadcast compatible to the input. + out: Unused by JAX. + + Returns: + An array of boolean values. + + Examples: + By default, ``jnp.any`` tests along all the axes. + + >>> x = jnp.array([[True, True, True, False], + ... [True, False, True, False], + ... [True, True, False, False]]) + >>> jnp.any(x) + Array(True, dtype=bool) + + If ``axis=0``, tests along axis 0. + + >>> jnp.any(x, axis=0) + Array([ True, True, True, False], dtype=bool) + + If ``keepdims=True``, ``ndim`` of the output will be same of that of the input. + + >>> jnp.any(x, axis=0, keepdims=True) + Array([[ True, True, True, False]], dtype=bool) + + To include specific elements in testing for True values, you can use a``where``. + + >>> where=jnp.array([[1, 0, 1, 0], + ... [0, 1, 0, 1], + ... [1, 0, 1, 0]], dtype=bool) + >>> jnp.any(x, axis=0, keepdims=True, where=where) + Array([[ True, False, True, False]], dtype=bool) + """ return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) -amin = min -amax = max +def amin(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: + """Alias of :func:`jax.numpy.min`.""" + return min(a, axis=axis, out=out, keepdims=keepdims, + initial=initial, where=where) + +def amax(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: + """Alias of :func:`jax.numpy.max`.""" + return max(a, axis=axis, out=out, keepdims=keepdims, + initial=initial, where=where) def _axis_size(a: ArrayLike, axis: int | Sequence[int]): if not isinstance(axis, (tuple, list)): diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index a04b6a85cbe2..6ac7ce804d8f 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -516,7 +516,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal *, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None): """Return the unique values from an array. - JAX implementation of :func:`jax.numpy.unique`. + JAX implementation of :func:`numpy.unique`. Because the size of the output of ``unique`` is data-dependent, the function semantics are not typically compatible with :func:`~jax.jit` and other JAX @@ -528,7 +528,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal return_index: if True, also return the indices in ``ar`` where each value occurs return_inverse: if True, also return the indices that can be used to reconstruct ``ar`` from the unique values. - return_counts: if True, also return the number of occurances of each unique value. + return_counts: if True, also return the number of occurrences of each unique value. axis: if specified, compute unique values along the specified axis. If None (default), then flatten ``ar`` before computing the unique values. equal_nan: if True, consider NaN values equivalent when determining uniqueness. @@ -546,8 +546,8 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal specified, shape is ``(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])``. - ``unique_index``: *(returned only if return_index is True)* An array of shape ``(n_unique,)``. Contains - the indices of the first occurance of each unique value in ``ar``. For 1D inputs, - ``ar[unique_index]`` is equivlent to ``unique_values``. + the indices of the first occurrence of each unique value in ``ar``. For 1D inputs, + ``ar[unique_index]`` is equivalent to ``unique_values``. - ``unique_inverse``: *(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis`` is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified. @@ -555,7 +555,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal ``unique_values[unique_inverse]`` is equivalent to ``ar``. - ``unique_counts``: *(returned only if return_counts is True)* An array of shape ``(n_unique,)``. - Contains the number of occurances of each unique value in ``ar``. + Contains the number of occurrences of each unique value in ``ar``. See also: - :func:`jax.numpy.unique_counts`: shortcut to ``unique(arr, return_counts=True)``. @@ -619,7 +619,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal **Returning indices** If you set ``return_index=True``, then ``unique`` returns the indices of the - first occurance of each unique value: + first occurrence of each unique value: >>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, indices = jnp.unique(x, return_index=True) @@ -660,7 +660,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal **Returning counts** - If you set ``return_counts=True``, then ``unique`` returns the number of occurances + If you set ``return_counts=True``, then ``unique`` returns the number of occurrences within the input for every unique value: >>> x = jnp.array([3, 4, 1, 3, 1]) @@ -671,7 +671,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal [2 2 1] For multi-dimensional arrays, this also returns a 1D array of counts - indicating number of occurances along the specified axis: + indicating number of occurrences along the specified axis: >>> values, counts = jnp.unique(M, axis=0, return_counts=True) >>> print(values) @@ -748,13 +748,13 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None, - ``values``: an array of shape ``(n_unique,)`` containing the unique values from ``x``. - ``indices``: - An array of shape ``(n_unique,)``. Contains the indices of the first occurance of - each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivlent to ``values``. + An array of shape ``(n_unique,)``. Contains the indices of the first occurrence of + each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivalent to ``values``. - ``inverse_indices``: An array of shape ``x.shape``. Contains the indices within ``values`` of each value in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``. - ``counts``: - An array of shape ``(n_unique,)``. Contains the number of occurances of each unique + An array of shape ``(n_unique,)``. Contains the number of occurrences of each unique value in ``x``. See also: @@ -770,7 +770,7 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None, >>> result = jnp.unique_all(x) The result is a :class:`~typing.NamedTuple` with four named attributes. - The ``values`` attribue contains the unique values from the array: + The ``values`` attribute contains the unique values from the array: >>> result.values Array([1, 3, 4], dtype=int32) @@ -829,7 +829,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None, - ``values``: an array of shape ``(n_unique,)`` containing the unique values from ``x``. - ``counts``: - An array of shape ``(n_unique,)``. Contains the number of occurances of each unique + An array of shape ``(n_unique,)``. Contains the number of occurrences of each unique value in ``x``. See also: @@ -846,7 +846,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None, >>> result = jnp.unique_counts(x) The result is a :class:`~typing.NamedTuple` with two named attributes. - The ``values`` attribue contains the unique values from the array: + The ``values`` attribute contains the unique values from the array: >>> result.values Array([1, 3, 4], dtype=int32) @@ -906,7 +906,7 @@ def unique_inverse(x: ArrayLike, /, *, size: int | None = None, >>> result = jnp.unique_inverse(x) The result is a :class:`~typing.NamedTuple` with two named attributes. - The ``values`` attribue contains the unique values from the array: + The ``values`` attribute contains the unique values from the array: >>> result.values Array([1, 3, 4], dtype=int32) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 29f5278bc096..2e114193af13 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -16,10 +16,11 @@ from __future__ import annotations +from collections.abc import Callable from functools import partial import math import operator -from typing import Any, Callable +from typing import Any import jax from jax._src.typing import Array, ArrayLike, DTypeLike @@ -354,7 +355,7 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, """Create a JAX ufunc from an arbitrary JAX-compatible scalar function. Args: - func : a callable that takes `nin` scalar arguments and return `nout` outputs. + func : a callable that takes `nin` scalar arguments and returns `nout` outputs. nin: integer specifying the number of scalar inputs nout: integer specifying the number of scalar outputs identity: (optional) a scalar specifying the identity of the operation, if any. diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 6633321e9646..1a75e413379a 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -18,10 +18,10 @@ from __future__ import annotations +from collections.abc import Callable from functools import partial import operator -from textwrap import dedent -from typing import Any, Callable, overload + import warnings import numpy as np @@ -45,189 +45,404 @@ 64: np.int64, } -UnOp = Callable[[ArrayLike], Array] -BinOp = Callable[[ArrayLike, ArrayLike], Array] - - def _constant_like(x, const): return np.array(const, dtype=dtypes.dtype(x)) - def _replace_inf(x: ArrayLike) -> Array: return lax.select(isposinf(real(x)), lax._zeros(x), x) +def _to_bool(x: Array) -> Array: + return x if x.dtype == bool else lax.ne(x, _lax_const(x, 0)) -def _one_to_one_unop( - numpy_fn: Callable[..., Any], lax_fn: UnOp, - promote_to_inexact: bool = False, lax_doc: bool = False) -> UnOp: - if promote_to_inexact: - fn = lambda x, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x)) - else: - fn = lambda x, /: lax_fn(*promote_args(numpy_fn.__name__, x)) - fn.__name__ = numpy_fn.__name__ - fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}" - fn = jit(fn, inline=True) - if lax_doc: - doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr] - return implements(numpy_fn, lax_description=doc, module='numpy')(fn) - else: - return implements(numpy_fn, module='numpy')(fn) +@implements(np.fabs, module='numpy') +@partial(jit, inline=True) +def fabs(x: ArrayLike, /) -> Array: + return lax.abs(*promote_args_inexact('fabs', x)) +@implements(getattr(np, 'bitwise_invert', np.invert), module='numpy') +@partial(jit, inline=True) +def bitwise_invert(x: ArrayLike, /) -> Array: + return lax.bitwise_not(*promote_args('bitwise_invert', x)) -def _one_to_one_binop( - numpy_fn: Callable[..., Any], lax_fn: BinOp, - promote_to_inexact: bool = False, lax_doc: bool = False, - promote_to_numeric: bool = False) -> BinOp: - if promote_to_inexact: - fn = lambda x1, x2, /: lax_fn(*promote_args_inexact(numpy_fn.__name__, x1, x2)) - elif promote_to_numeric: - fn = lambda x1, x2, /: lax_fn(*promote_args_numeric(numpy_fn.__name__, x1, x2)) - else: - fn = lambda x1, x2, /: lax_fn(*promote_args(numpy_fn.__name__, x1, x2)) - fn.__name__ = numpy_fn.__name__ - fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}" - fn = jit(fn, inline=True) - if lax_doc: - doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr] - return implements(numpy_fn, lax_description=doc, module='numpy')(fn) - else: - return implements(numpy_fn, module='numpy')(fn) - - -def _maybe_bool_binop( - numpy_fn: Callable[..., Any], lax_fn: BinOp, bool_lax_fn: BinOp, - lax_doc: bool = False) -> BinOp: - def fn(x1, x2, /): - x1, x2 = promote_args(numpy_fn.__name__, x1, x2) - return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2) - fn.__name__ = numpy_fn.__name__ - fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}" - fn = jit(fn, inline=True) - if lax_doc: - doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr] - return implements(numpy_fn, lax_description=doc, module='numpy')(fn) - else: - return implements(numpy_fn, module='numpy')(fn) - - -def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp: - def fn(x1, x2, /): - x1, x2 = promote_args(numpy_fn.__name__, x1, x2) - # Comparison on complex types are defined as a lexicographic ordering on - # the (real, imag) pair. - if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): - rx = lax.real(x1) - ry = lax.real(x2) - return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)), - lax_fn(rx, ry)) - return lax_fn(x1, x2) - fn.__name__ = numpy_fn.__name__ - fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}" - fn = jit(fn, inline=True) - return implements(numpy_fn, module='numpy')(fn) - -@overload -def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp) -> UnOp: ... -@overload -def _logical_op(np_op: Callable[..., Any], bitwise_op: BinOp) -> BinOp: ... -@overload -def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp | BinOp) -> UnOp | BinOp: ... - -def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp | BinOp) -> UnOp | BinOp: - @implements(np_op, update_doc=False, module='numpy') - @partial(jit, inline=True) - def op(*args): - zero = lambda x: lax.full_like(x, shape=(), fill_value=0) - args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(x, zero(x)) - for x in args) - return bitwise_op(*promote_args(np_op.__name__, *args)) - return op +@implements(np.bitwise_not, module='numpy') +@partial(jit, inline=True) +def bitwise_not(x: ArrayLike, /) -> Array: + return lax.bitwise_not(*promote_args('bitwise_not', x)) + +@implements(np.invert, module='numpy') +@partial(jit, inline=True) +def invert(x: ArrayLike, /) -> Array: + return lax.bitwise_not(*promote_args('invert', x)) + +@implements(np.negative, module='numpy') +@partial(jit, inline=True) +def negative(x: ArrayLike, /) -> Array: + return lax.neg(*promote_args('negative', x)) + +@implements(np.positive, module='numpy') +@partial(jit, inline=True) +def positive(x: ArrayLike, /) -> Array: + return lax.asarray(*promote_args('positive', x)) + +@implements(np.sign, module='numpy') +@partial(jit, inline=True) +def sign(x: ArrayLike, /) -> Array: + return lax.sign(*promote_args('sign', x)) + +@implements(np.floor, module='numpy') +@partial(jit, inline=True) +def floor(x: ArrayLike, /) -> Array: + check_arraylike('floor', x) + if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + return lax.asarray(x) + return lax.floor(*promote_args_inexact('floor', x)) + +@implements(np.ceil, module='numpy') +@partial(jit, inline=True) +def ceil(x: ArrayLike, /) -> Array: + check_arraylike('ceil', x) + if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + return lax.asarray(x) + return lax.ceil(*promote_args_inexact('ceil', x)) + +@implements(np.exp, module='numpy') +@partial(jit, inline=True) +def exp(x: ArrayLike, /) -> Array: + return lax.exp(*promote_args_inexact('exp', x)) + +@implements(np.log, module='numpy') +@partial(jit, inline=True) +def log(x: ArrayLike, /) -> Array: + return lax.log(*promote_args_inexact('log', x)) + +@implements(np.expm1, module='numpy') +@partial(jit, inline=True) +def expm1(x: ArrayLike, /) -> Array: + return lax.expm1(*promote_args_inexact('expm1', x)) + +@implements(np.log1p, module='numpy') +@partial(jit, inline=True) +def log1p(x: ArrayLike, /) -> Array: + return lax.log1p(*promote_args_inexact('log1p', x)) + +@implements(np.sin, module='numpy') +@partial(jit, inline=True) +def sin(x: ArrayLike, /) -> Array: + return lax.sin(*promote_args_inexact('sin', x)) + +@implements(np.cos, module='numpy') +@partial(jit, inline=True) +def cos(x: ArrayLike, /) -> Array: + return lax.cos(*promote_args_inexact('cos', x)) +@implements(np.tan, module='numpy') +@partial(jit, inline=True) +def tan(x: ArrayLike, /) -> Array: + return lax.tan(*promote_args_inexact('tan', x)) + +@implements(np.arcsin, module='numpy') +@partial(jit, inline=True) +def arcsin(x: ArrayLike, /) -> Array: + return lax.asin(*promote_args_inexact('arcsin', x)) + +@implements(np.arccos, module='numpy') +@partial(jit, inline=True) +def arccos(x: ArrayLike, /) -> Array: + return lax.acos(*promote_args_inexact('arccos', x)) + +@implements(np.arctan, module='numpy') +@partial(jit, inline=True) +def arctan(x: ArrayLike, /) -> Array: + return lax.atan(*promote_args_inexact('arctan', x)) + +@implements(np.sinh, module='numpy') +@partial(jit, inline=True) +def sinh(x: ArrayLike, /) -> Array: + return lax.sinh(*promote_args_inexact('sinh', x)) + +@implements(np.cosh, module='numpy') +@partial(jit, inline=True) +def cosh(x: ArrayLike, /) -> Array: + return lax.cosh(*promote_args_inexact('cosh', x)) + +@implements(np.arcsinh, module='numpy') +@partial(jit, inline=True) +def arcsinh(x: ArrayLike, /) -> Array: + return lax.asinh(*promote_args_inexact('arcsinh', x)) + +@implements(np.arccosh, module='numpy') @jit -def _arccosh(x: ArrayLike, /) -> Array: - # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different - # convention than np.arccosh. - out = lax.acosh(*promote_args_inexact("arccosh", x)) - if dtypes.issubdtype(out.dtype, np.complexfloating): - out = _where(real(out) < 0, lax.neg(out), out) - return out - -fabs = _one_to_one_unop(np.fabs, lax.abs, True) -bitwise_invert = _one_to_one_unop(getattr(np, 'bitwise_invert', np.invert), lax.bitwise_not) -bitwise_invert = _one_to_one_unop(getattr(np, 'bitwise_invert', np.invert), lax.bitwise_not) -bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not) -invert = _one_to_one_unop(np.invert, lax.bitwise_not) -negative = _one_to_one_unop(np.negative, lax.neg) -positive = _one_to_one_unop(np.positive, lambda x: lax.asarray(x)) -floor = _one_to_one_unop(np.floor, lax.floor, True) -ceil = _one_to_one_unop(np.ceil, lax.ceil, True) -exp = _one_to_one_unop(np.exp, lax.exp, True) -log = _one_to_one_unop(np.log, lax.log, True) -expm1 = _one_to_one_unop(np.expm1, lax.expm1, True) -log1p = _one_to_one_unop(np.log1p, lax.log1p, True) -sin = _one_to_one_unop(np.sin, lax.sin, True) -cos = _one_to_one_unop(np.cos, lax.cos, True) -tan = _one_to_one_unop(np.tan, lax.tan, True) -arcsin = _one_to_one_unop(np.arcsin, lax.asin, True) -arccos = _one_to_one_unop(np.arccos, lax.acos, True) -arctan = _one_to_one_unop(np.arctan, lax.atan, True) -sinh = _one_to_one_unop(np.sinh, lax.sinh, True) -cosh = _one_to_one_unop(np.cosh, lax.cosh, True) -arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True) -arccosh = _one_to_one_unop(np.arccosh, _arccosh, True) -tanh = _one_to_one_unop(np.tanh, lax.tanh, True) -arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True) -sign = _one_to_one_unop(np.sign, lax.sign) -sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True) -cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True) - -add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or) -bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and) -bitwise_left_shift = _one_to_one_binop(getattr(np, "bitwise_left_shift", np.left_shift), lax.shift_left, promote_to_numeric=True) -bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or) -bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor) -left_shift = _one_to_one_binop(np.left_shift, lax.shift_left, promote_to_numeric=True) -equal = _one_to_one_binop(np.equal, lax.eq) -multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and) -not_equal = _one_to_one_binop(np.not_equal, lax.ne) -subtract = _one_to_one_binop(np.subtract, lax.sub) -arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True) -minimum = _one_to_one_binop(np.minimum, lax.min) -maximum = _one_to_one_binop(np.maximum, lax.max) -float_power = _one_to_one_binop(np.float_power, lax.pow, True) -nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True) - -greater_equal = _comparison_op(np.greater_equal, lax.ge) -greater = _comparison_op(np.greater, lax.gt) -less_equal = _comparison_op(np.less_equal, lax.le) -less = _comparison_op(np.less, lax.lt) - -logical_and: BinOp = _logical_op(np.logical_and, lax.bitwise_and) -logical_not: UnOp = _logical_op(np.logical_not, lax.bitwise_not) -logical_or: BinOp = _logical_op(np.logical_or, lax.bitwise_or) -logical_xor: BinOp = _logical_op(np.logical_xor, lax.bitwise_xor) +def arccosh(x: ArrayLike, /) -> Array: + # Note: arccosh is multi-valued for complex input, and lax.acosh + # uses a different convention than np.arccosh. + result = lax.acosh(*promote_args_inexact("arccosh", x)) + if dtypes.issubdtype(result.dtype, np.complexfloating): + result = _where(real(result) < 0, lax.neg(result), result) + return result + +@implements(np.tanh, module='numpy') +@partial(jit, inline=True) +def tanh(x: ArrayLike, /) -> Array: + return lax.tanh(*promote_args_inexact('tanh', x)) + +@implements(np.arctanh, module='numpy') +@partial(jit, inline=True) +def arctanh(x: ArrayLike, /) -> Array: + return lax.atanh(*promote_args_inexact('arctanh', x)) + +@implements(np.sqrt, module='numpy') +@partial(jit, inline=True) +def sqrt(x: ArrayLike, /) -> Array: + return lax.sqrt(*promote_args_inexact('sqrt', x)) + +@implements(np.cbrt, module='numpy') +@partial(jit, inline=True) +def cbrt(x: ArrayLike, /) -> Array: + return lax.cbrt(*promote_args_inexact('cbrt', x)) + +@implements(np.add, module='numpy') +@partial(jit, inline=True) +def add(x: ArrayLike, y: ArrayLike, /) -> Array: + x, y = promote_args("add", x, y) + return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) + +@implements(np.multiply, module='numpy') +@partial(jit, inline=True) +def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: + x, y = promote_args("multiply", x, y) + return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) + +@implements(np.bitwise_and, module='numpy') +@partial(jit, inline=True) +def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.bitwise_and(*promote_args("bitwise_and", x, y)) + +@implements(np.bitwise_or, module='numpy') +@partial(jit, inline=True) +def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.bitwise_or(*promote_args("bitwise_or", x, y)) + +@implements(np.bitwise_xor, module='numpy') +@partial(jit, inline=True) +def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) + +@implements(np.left_shift, module='numpy') +@partial(jit, inline=True) +def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.shift_left(*promote_args_numeric("left_shift", x, y)) + +@implements(getattr(np, "bitwise_left_shift", np.left_shift), module='numpy') +@partial(jit, inline=True) +def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.shift_left(*promote_args_numeric("bitwise_left_shift", x, y)) + +@implements(np.equal, module='numpy') +@partial(jit, inline=True) +def equal(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.eq(*promote_args("equal", x, y)) + +@implements(np.not_equal, module='numpy') +@partial(jit, inline=True) +def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.ne(*promote_args("not_equal", x, y)) + +@implements(np.subtract, module='numpy') +@partial(jit, inline=True) +def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.sub(*promote_args("subtract", x, y)) + +@implements(np.arctan2, module='numpy') +@partial(jit, inline=True) +def arctan2(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.atan2(*promote_args_inexact("arctan2", x, y)) + +@implements(np.minimum, module='numpy') +@partial(jit, inline=True) +def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.min(*promote_args("minimum", x, y)) + +@implements(np.maximum, module='numpy') +@partial(jit, inline=True) +def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.max(*promote_args("maximum", x, y)) + +@implements(np.float_power, module='numpy') +@partial(jit, inline=True) +def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.pow(*promote_args_inexact("float_power", x, y)) + +@implements(np.nextafter, module='numpy') +@partial(jit, inline=True) +def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.nextafter(*promote_args_inexact("nextafter", x, y)) + +# Logical ops +@implements(np.logical_and, module='numpy') +@partial(jit, inline=True) +def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) + +@implements(np.logical_or, module='numpy') +@partial(jit, inline=True) +def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) + +@implements(np.logical_xor, module='numpy') +@partial(jit, inline=True) +def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) + +@implements(np.logical_not, module='numpy') +@partial(jit, inline=True) +def logical_not(x: ArrayLike, /) -> Array: + return lax.bitwise_not(*map(_to_bool, promote_args("logical_not", x))) + +# Comparison ops +def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array], + x: Array, y: Array): + if dtypes.issubdtype(x.dtype, np.complexfloating): + return lax.select(lax.eq(x.real, y.real), + lax_op(x.imag, y.imag), + lax_op(x.real, y.real)) + return lax_op(x, y) + +@implements(np.greater_equal, module='numpy') +@partial(jit, inline=True) +def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: + return _complex_comparison(lax.ge, *promote_args("greater_equal", x, y)) + +@implements(np.greater, module='numpy') +@partial(jit, inline=True) +def greater(x: ArrayLike, y: ArrayLike, /) -> Array: + return _complex_comparison(lax.gt, *promote_args("greater", x, y)) + +@implements(np.less_equal, module='numpy') +@partial(jit, inline=True) +def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: + return _complex_comparison(lax.le, *promote_args("less_equal", x, y)) + +@implements(np.less, module='numpy') +@partial(jit, inline=True) +def less(x: ArrayLike, y: ArrayLike, /) -> Array: + return _complex_comparison(lax.lt, *promote_args("less", x, y)) # Array API aliases -# TODO(jakevdp): directly reference np_fun when minimum numpy version is 2.0 -acos = _one_to_one_unop(getattr(np, "acos", np.arccos), lax.acos, True) -acosh = _one_to_one_unop(getattr(np, "acosh", np.arccosh), _arccosh, True) -asin = _one_to_one_unop(getattr(np, "asin", np.arcsin), lax.asin, True) -asinh = _one_to_one_unop(getattr(np, "asinh", np.arcsinh), lax.asinh, True) -atan = _one_to_one_unop(getattr(np, "atan", np.arctan), lax.atan, True) -atanh = _one_to_one_unop(getattr(np, "atanh", np.arctanh), lax.atanh, True) -atan2 = _one_to_one_binop(getattr(np, "atan2", np.arctan2), lax.atan2, True) +@partial(jit, inline=True) +def acos(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arccos`""" + return arccos(*promote_args('acos', x)) +@partial(jit, inline=True) +def acosh(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arccosh`""" + return arccosh(*promote_args('acosh', x)) + +@partial(jit, inline=True) +def asin(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arcsin`""" + return arcsin(*promote_args('asin', x)) + +@partial(jit, inline=True) +def asinh(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arcsinh`""" + return arcsinh(*promote_args('asinh', x)) + +@partial(jit, inline=True) +def atan(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arctan`""" + return arctan(*promote_args('atan', x)) + +@partial(jit, inline=True) +def atanh(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arctanh`""" + return arctanh(*promote_args('atanh', x)) + +@partial(jit, inline=True) +def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.arctan2`""" + return arctan2(*promote_args('atan2', x, y)) -@implements(getattr(np, 'bitwise_count', None), module='numpy') @jit def bitwise_count(x: ArrayLike, /) -> Array: + r"""Counts the number of 1 bits in the binary representation of the absolute value + of each element of ``x``. + + LAX-backend implementation of :func:`numpy.bitwise_count`. + + Args: + x: Input array, only accepts integer subtypes + + Returns: + An array-like object containing the binary 1 bit counts of the absolute value of + each element in ``x``, with the same shape as ``x`` of dtype uint8. + + Examples: + >>> x1 = jnp.array([64, 32, 31, 20]) + >>> # 64 = 0b1000000, 32 = 0b100000, 31 = 0b11111, 20 = 0b10100 + >>> jnp.bitwise_count(x1) + Array([1, 1, 5, 2], dtype=uint8) + + >>> x2 = jnp.array([-16, -7, 7]) + >>> # |-16| = 0b10000, |-7| = 0b111, 7 = 0b111 + >>> jnp.bitwise_count(x2) + Array([1, 3, 3], dtype=uint8) + + >>> x3 = jnp.array([[2, -7],[-9, 7]]) + >>> # 2 = 0b10, |-7| = 0b111, |-9| = 0b1001, 7 = 0b111 + >>> jnp.bitwise_count(x3) + Array([[1, 3], + [2, 3]], dtype=uint8) + """ x, = promote_args_numeric("bitwise_count", x) # Following numpy we take the absolute value and return uint8. return lax.population_count(abs(x)).astype('uint8') -@implements(np.right_shift, module='numpy') @partial(jit, inline=True) def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: + r"""Right shift the bits of ``x1`` to the amount specified in ``x2``. + + LAX-backend implementation of :func:`numpy.right_shift`. + + Args: + x1: Input array, only accepts unsigned integer subtypes + x2: The amount of bits to shift each element in ``x1`` to the right, only accepts + integer subtypes + + Returns: + An array-like object containing the right shifted elements of ``x1`` by the + amount specified in ``x2``, with the same shape as the broadcasted shape of + ``x1`` and ``x2``. + + Note: + If ``x1.shape != x2.shape``, they must be compatible for broadcasting to a + shared shape, this shared shape will also be the shape of the output. Right shifting + a scalar x1 by scalar x2 is equivalent to ``x1 // 2**x2``. + + Examples: + >>> def print_binary(x): + ... return [bin(int(val)) for val in x] + + >>> x1 = jnp.array([1, 2, 4, 8]) + >>> print_binary(x1) + ['0b1', '0b10', '0b100', '0b1000'] + >>> x2 = 1 + >>> result = jnp.right_shift(x1, x2) + >>> result + Array([0, 1, 2, 4], dtype=int32) + >>> print_binary(result) + ['0b0', '0b1', '0b10', '0b100'] + + >>> x1 = 16 + >>> print_binary([x1]) + ['0b10000'] + >>> x2 = jnp.array([1, 2, 3, 4]) + >>> result = jnp.right_shift(x1, x2) + >>> result + Array([8, 4, 2, 1], dtype=int32) + >>> print_binary(result) + ['0b1000', '0b100', '0b10', '0b1'] + """ x1, x2 = promote_args_numeric(np.right_shift.__name__, x1, x2) lax_fn = lax.shift_right_logical if \ np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic @@ -241,18 +456,78 @@ def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic return lax_fn(x1, x2) -@implements(np.absolute, module='numpy') + @partial(jit, inline=True) def absolute(x: ArrayLike, /) -> Array: + r"""Calculate the absolute value element-wise. + + LAX-backend implementation of :func:`numpy.absolute`. + + This is the same function as :func:`jax.numpy.abs`. + + Args: + x: Input array + + Returns: + An array-like object containing the absolute value of each element in ``x``, + with the same shape as ``x``. For complex valued input, :math:`a + ib`, + the absolute value is :math:`\sqrt{a^2+b^2}`. + + Examples: + >>> x1 = jnp.array([5, -2, 0, 12]) + >>> jnp.absolute(x1) + Array([ 5, 2, 0, 12], dtype=int32) + + >>> x2 = jnp.array([[ 8, -3, 1],[ 0, 9, -6]]) + >>> jnp.absolute(x2) + Array([[8, 3, 1], + [0, 9, 6]], dtype=int32) + + >>> x3 = jnp.array([8 + 15j, 3 - 4j, -5 + 0j]) + >>> jnp.absolute(x3) + Array([17., 5., 5.], dtype=float32) + """ check_arraylike('absolute', x) dt = dtypes.dtype(x) return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) -abs = implements(np.abs, module='numpy')(absolute) -@implements(np.rint, module='numpy') +@partial(jit, inline=True) +def abs(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.absolute`.""" + return absolute(x) + + @jit def rint(x: ArrayLike, /) -> Array: + """Rounds the elements of x to the nearest integer + + LAX-backend implementation of :func:`numpy.rint`. + + Args: + x: Input array + + Returns: + An array-like object containing the rounded elements of ``x``. Always promotes + to inexact. + + Note: + If an element of x is exactly half way, e.g. ``0.5`` or ``1.5``, rint will round + to the nearest even integer. + + Example: + >>> x1 = jnp.array([5, 4, 7]) + >>> jnp.rint(x1) + Array([5., 4., 7.], dtype=float32) + + >>> x2 = jnp.array([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5]) + >>> jnp.rint(x2) + Array([-2., -2., -0., 0., 2., 2., 4., 4.], dtype=float32) + + >>> x3 = jnp.array([-2.5+3.5j, 4.5-0.5j]) + >>> jnp.rint(x3) + Array([-2.+4.j, 4.-0.j], dtype=complex64) + """ check_arraylike('rint', x) dtype = dtypes.dtype(x) if dtype == bool or dtypes.issubdtype(dtype, np.integer): @@ -262,9 +537,39 @@ def rint(x: ArrayLike, /) -> Array: return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) -@implements(np.copysign, module='numpy') @jit def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Copies the sign of each element in ``x2`` to the corresponding element in ``x1``. + + LAX-backend implementation of :func:`numpy.copysign`. + + Args: + x1: Input array + x2: The array whose elements will be used to determine the sign, must be + broadcast-compatible with ``x1`` + + Returns: + An array object containing the potentially changed elements of ``x1``, always promotes + to inexact dtype, and has a shape of ``jnp.broadcast_shapes(x1.shape, x2.shape)`` + + Examples: + >>> x1 = jnp.array([5, 2, 0]) + >>> x2 = -1 + >>> jnp.copysign(x1, x2) + Array([-5., -2., -0.], dtype=float32) + + >>> x1 = jnp.array([6, 8, 0]) + >>> x2 = 2 + >>> jnp.copysign(x1, x2) + Array([6., 8., 0.], dtype=float32) + + >>> x1 = jnp.array([2, -3]) + >>> x2 = jnp.array([[1],[-4], [5]]) + >>> jnp.copysign(x1, x2) + Array([[ 2., 3.], + [-2., -3.], + [ 2., 3.]], dtype=float32) + """ x1, x2 = promote_args_inexact("copysign", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): raise TypeError("copysign does not support complex-valued inputs") @@ -280,9 +585,43 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: divide = true_divide -@implements(np.floor_divide, module='numpy') @jit def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Calculates the floor division of x1 by x2 element-wise + + LAX-backend implementation of :func:`numpy.floor_divide`. + + Args: + x1: Input array, the dividend + x2: Input array, the divisor + + Returns: + An array-like object containing each of the quotients rounded down + to the nearest integer towards negative infinity. This is equivalent + to ``x1 // x2`` in Python. + + Examples: + >>> x1 = jnp.array([10, 20, 30]) + >>> x2 = jnp.array([3, 4, 7]) + >>> jnp.floor_divide(x1, x2) + Array([3, 5, 4], dtype=int32) + + >>> x1 = jnp.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]) + >>> x2 = 3 + >>> jnp.floor_divide(x1, x2) + Array([-2, -2, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=int32) + + >>> x1 = jnp.array([6, 6, 6], dtype=jnp.int32) + >>> x2 = jnp.array([2.0, 2.5, 3.0], dtype=jnp.float32) + >>> jnp.floor_divide(x1, x2) + Array([3., 2., 2.], dtype=float32) + + Note: + ``x1 // x2`` is equivalent to ``jnp.floor_divide(x1, x2)`` for arrays ``x1`` and ``x2`` + + See Also: + :func:`jnp.divide` and :func:`jnp.true_divide` for floating point division + """ x1, x2 = promote_args_numeric("floor_divide", x1, x2) dtype = dtypes.dtype(x1) if dtypes.issubdtype(dtype, np.unsignedinteger): @@ -704,14 +1043,14 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: return lax.full_like(x, False, dtype=np.bool_) -isposinf: UnOp = implements(np.isposinf, skip_params=['out'])( - lambda x, /, out=None: _isposneginf(np.inf, x, out) -) +@implements(np.isposinf, module='numpy') +def isposinf(x, /, out=None): + return _isposneginf(np.inf, x, out) -isneginf: UnOp = implements(np.isneginf, skip_params=['out'])( - lambda x, /, out=None: _isposneginf(-np.inf, x, out) -) +@implements(np.isposinf, module='numpy') +def isneginf(x, /, out=None): + return _isposneginf(-np.inf, x, out) @implements(np.isnan, module='numpy') diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index fb3b7e4e9dc9..21b96deea3c6 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import re import textwrap -from typing import Any, Callable, NamedTuple, TypeVar +from typing import Any, NamedTuple, TypeVar import warnings diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index 3fe99131e6da..2c517467e287 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -13,10 +13,10 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Collection, Sequence +from collections.abc import Callable, Collection, Sequence import functools import re -from typing import Any, Callable +from typing import Any from jax._src import api from jax import lax diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 667c2707d68e..2bcfe96ad2f0 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -16,9 +16,8 @@ from __future__ import annotations -from collections.abc import Sequence -import sys -from typing import Callable, Union +from collections.abc import Callable, Sequence +from typing import Union import warnings import numpy as np @@ -36,11 +35,8 @@ from jax._src.typing import Array, ArrayLike -if sys.version_info >= (3, 10): - from types import EllipsisType - SingleIndex = int | slice | Sequence[int] | Array | EllipsisType | None -else: - SingleIndex = Union[int, slice, Sequence[int], Array, None] +from types import EllipsisType +SingleIndex = int | slice | Sequence[int] | Array | EllipsisType | None Index = Union[SingleIndex, tuple[SingleIndex, ...]] Scalar = Union[complex, float, int, np.number] diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index 85ea5a13dba6..c0fa02131bc8 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -27,13 +27,13 @@ package( py_library( name = "pallas", - srcs = glob( - include = ["**/*.py"], - exclude = [ - "triton/*.py", - "mosaic/*.py", - ], - ), + srcs = [ + "__init__.py", + "core.py", + "pallas_call.py", + "primitives.py", + "utils.py", + ], deps = [ "//jax", "//jax:ad_util", @@ -46,21 +46,3 @@ py_library( "//jax/_src/lib", ] + py_deps("numpy"), ) - -py_library( - name = "gpu", - visibility = [], - deps = [ - ":pallas", - "//jax/_src/pallas/triton", - ], -) - -py_library( - name = "tpu", - visibility = [], - deps = [ - ":pallas", - "//jax/_src/pallas/mosaic", - ], -) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 5e044a0a65ee..38866b082da6 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -15,14 +15,15 @@ """Module for pallas-core functionality.""" from __future__ import annotations +from collections.abc import Callable, Iterator, Sequence import copy -from collections.abc import Sequence import contextlib import dataclasses import functools -from typing import Any, Callable, Union -from collections.abc import Iterator +import threading +from typing import Any, Union +import jax from jax._src import api_util from jax._src import core as jax_core from jax._src import linear_util as lu @@ -33,13 +34,16 @@ from jax._src.state import discharge as state_discharge import jax.numpy as jnp -# TODO(sharadmv): enable type checking -# mypy: ignore-errors + +class DynamicGridDim: + pass +dynamic_grid_dim = DynamicGridDim() + partial = functools.partial -Grid = tuple[Union[int, jax_core.Array, None], ...] # None indicates that the bound is dynamic. -DynamicGrid = tuple[Union[int, jax_core.Array], ...] +Grid = tuple[Union[int, jax_core.Array], ...] StaticGrid = tuple[int, ...] +GridMappingGrid = tuple[Union[int, DynamicGridDim], ...] split_list = util.split_list map, unsafe_map = util.safe_map, map @@ -87,25 +91,60 @@ def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type): jax_core.raise_to_shaped_mappings[AbstractMemoryRef] = _ref_raise_to_shaped +@dataclasses.dataclass(frozen=True) +class PallasGridContext: + grid: GridMappingGrid + mapped_dims: tuple[int, ...] + + def size(self, axis: int) -> int | DynamicGridDim: + valid_grid = tuple( + s for i, s in enumerate(self.grid) if i not in self.mapped_dims + ) + try: + size = valid_grid[axis] + except IndexError as e: + raise ValueError( + f"Axis {axis} is out of bounds for grid {self.grid}" + ) from e + return size + + @dataclasses.dataclass -class GridEnv: - axis_index: Any - axis_size: int +class PallasTracingEnv(threading.local): + grid_context: PallasGridContext | None = None +_pallas_tracing_env = PallasTracingEnv() + -_grid_env_stack: list[tuple[GridEnv, ...]] = [] +def axis_frame() -> PallasGridContext: + # This is like jax_core.axis_frame, except there should only ever be one + # active PallasGridAxisName for a particular main_trace because we cannot + # nest pallas_calls. + env = _pallas_tracing_env + assert env.grid_context is not None + return env.grid_context + + +@dataclasses.dataclass(frozen=True) +class GridAxis: + index: jax.Array + size: int + +# Stores the kernel execution position and the size along grid axes. +GridEnv = Sequence[GridAxis] + +_grid_env_stack: list[GridEnv] = [] @contextlib.contextmanager -def grid_env(env: tuple[tuple[Any, int], ...]) -> Iterator[None]: - _grid_env_stack.append(tuple(GridEnv(axis_index, axis_size) - for axis_index, axis_size in env)) +def grid_env(env: GridEnv) -> Iterator[None]: + _grid_env_stack.append(env) try: yield finally: _grid_env_stack.pop() -def current_grid_env() -> tuple[GridEnv, ...] | None: +def current_grid_env() -> GridEnv | None: if not _grid_env_stack: return None return _grid_env_stack[-1] @@ -130,22 +169,12 @@ class Blocked: IndexingMode = Union[Blocked, Unblocked] -@dataclasses.dataclass(init=False, unsafe_hash=True) +@dataclasses.dataclass(unsafe_hash=True) class BlockSpec: - index_map: Callable[..., Any] | None - block_shape: tuple[int | None, ...] | None - memory_space: Any - indexing_mode: IndexingMode - - def __init__(self, index_map: Callable[..., Any] | None = None, - block_shape: tuple[int | None, ...] | None = None, - memory_space: Any = None, indexing_mode: IndexingMode = blocked): - self.index_map = index_map - if block_shape is not None and not isinstance(block_shape, tuple): - block_shape = tuple(block_shape) - self.block_shape = block_shape - self.memory_space = memory_space - self.indexing_mode = indexing_mode + index_map: Callable[..., Any] | None = None + block_shape: tuple[int | None, ...] | None = None + memory_space: Any | None = None + indexing_mode: IndexingMode = blocked def compute_index(self, *args): assert self.index_map is not None @@ -156,6 +185,10 @@ def compute_index(self, *args): return out +# A PyTree of BlockSpec | NoBlockSpec. +BlockSpecTree = Any + + @dataclasses.dataclass(frozen=True) class BlockMapping: block_shape: tuple[Mapped | int, ...] @@ -183,25 +216,43 @@ def compute_start_indices(self, loop_idx, *args): replace = dataclasses.replace +@contextlib.contextmanager +def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]): + assert all(i is dynamic_grid_dim or isinstance(i, int) for i in grid) + old_grid_context = _pallas_tracing_env.grid_context + try: + _pallas_tracing_env.grid_context = PallasGridContext(grid, mapped_dims) + yield + finally: + _pallas_tracing_env.grid_context = old_grid_context + + @dataclasses.dataclass(frozen=True) class GridMapping: - grid: Grid + grid: GridMappingGrid block_mappings: tuple[BlockMapping | None, ...] - mapped_dims: tuple[int, ...] - num_index_operands: int - num_scratch_operands: int + mapped_dims: tuple[int, ...] = () + num_index_operands: int = 0 + num_scratch_operands: int = 0 + # Number of constants hoisted to operands by ``_hoist_consts_to_refs``. + num_constant_operands: int = 0 replace = dataclasses.replace @property def num_dynamic_grid_bounds(self): - return sum(b is None for b in self.grid) + return sum(b is dynamic_grid_dim for b in self.grid) @property def static_grid(self) -> StaticGrid: if self.num_dynamic_grid_bounds: raise ValueError("Expected a grid with fully static bounds") - return self.grid # typing: ignore + return self.grid # type: ignore + + @contextlib.contextmanager + def trace_env(self): + with tracing_grid_env(self.grid, self.mapped_dims): + yield def _preprocess_grid(grid: Grid | int | None) -> Grid: @@ -213,9 +264,13 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid: def _convert_block_spec_to_block_mapping( - in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None, - aval: jax_core.ShapedArray, in_tree: Any, - ) -> BlockSpec | None: + in_avals: Sequence[jax_core.ShapedArray], + block_spec: BlockSpec, + aval: jax_core.ShapedArray, + in_tree: Any, + grid: GridMappingGrid, + mapped_dims: tuple[int, ...], +) -> BlockMapping | None: if block_spec is no_block_spec: return None if block_spec.index_map is None: @@ -227,11 +282,13 @@ def _convert_block_spec_to_block_mapping( block_shape = tuple( mapped if s is None else s for s in block_shape) flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + with tracing_grid_env(grid, mapped_dims): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) return BlockMapping( block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode ) + def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None ) -> state.AbstractRef: if block_shape is None: @@ -272,6 +329,7 @@ class NoBlockSpec: pass no_block_spec = NoBlockSpec() + @dataclasses.dataclass(init=False, unsafe_hash=True) class GridSpec: grid: Grid @@ -283,12 +341,8 @@ class GridSpec: def __init__( self, grid: Grid | None = None, - in_specs: BlockSpec - | Sequence[BlockSpec | NoBlockSpec] - | NoBlockSpec = no_block_spec, - out_specs: BlockSpec - | Sequence[BlockSpec | NoBlockSpec] - | NoBlockSpec = no_block_spec, + in_specs: BlockSpecTree = no_block_spec, + out_specs: BlockSpecTree = no_block_spec, ): # Be more lenient for in/out_specs if isinstance(in_specs, list): @@ -332,6 +386,10 @@ def _get_in_out_specs(self, in_avals, in_tree, out_avals, out_tree): def get_grid_mapping( self, in_avals, in_tree, out_avals, out_tree ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: + assert all(i is None or isinstance(i, int) for i in self.grid) + grid_mapping_grid = tuple( + dynamic_grid_dim if d is None else d for d in self.grid + ) flat_in_specs, flat_out_specs = self._get_in_out_specs( in_avals, in_tree, out_avals, out_tree) in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals( @@ -341,25 +399,45 @@ def get_grid_mapping( # Create args, kwargs pytree def grid_tree = tree_util.tree_structure((tuple(grid_avals), {})) in_block_mappings = map( - partial(_convert_block_spec_to_block_mapping, grid_avals, - in_tree=grid_tree), in_specs, in_ref_avals) + partial( + _convert_block_spec_to_block_mapping, + grid_avals, + in_tree=grid_tree, + grid=grid_mapping_grid, + mapped_dims=(), + ), + in_specs, + in_ref_avals, + ) out_block_mappings = map( - partial(_convert_block_spec_to_block_mapping, grid_avals, - in_tree=grid_tree), out_specs, out_ref_avals) + partial( + _convert_block_spec_to_block_mapping, + grid_avals, + in_tree=grid_tree, + grid=grid_mapping_grid, + mapped_dims=(), + ), + out_specs, + out_ref_avals, + ) grid_mapping = GridMapping( - self.grid, (*in_block_mappings, *out_block_mappings), (), - num_index_operands=0, num_scratch_operands=0) + grid_mapping_grid, (*in_block_mappings, *out_block_mappings) # type: ignore + ) jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals) jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals) if not isinstance(jaxpr_out_avals, (tuple, list)): jaxpr_out_avals = (jaxpr_out_avals,) return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping - def unzip_dynamic_grid_bounds(self) -> tuple[GridSpec, tuple[Any, ...]]: - static_grid = tuple(d if isinstance(d, int) else None for d in self.grid) + def unzip_dynamic_grid_bounds( + self, + ) -> tuple[GridSpec, tuple[Any, ...]]: + static_grid = tuple( + d if isinstance(d, int) else None for d in self.grid + ) dynamic_bounds = tuple(d for d in self.grid if not isinstance(d, int)) # We can't use dataclasses.replace, because our fields are incompatible # with __init__'s signature. static_self = copy.copy(self) - static_self.grid = static_grid + static_self.grid = static_grid # type: ignore return static_self, dynamic_bounds diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index b16a819ab914..4c849dfba267 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -15,11 +15,7 @@ # Package for Mosaic-specific Pallas extensions load("@rules_python//python:defs.bzl", "py_library") -load( - "//jaxlib:jax.bzl", - "py_deps", - "py_library_providing_imports_info", -) +load("//jaxlib:jax.bzl", "py_deps") package( default_applicable_licenses = [], @@ -28,20 +24,6 @@ package( ], ) -py_library_providing_imports_info( - name = "mosaic", - srcs = ["__init__.py"], - lib_rule = py_library, - deps = [ - ":core", - ":kernel_regeneration_util", - ":lowering", - ":pallas_call_registration", - ":pipeline", - ":primitives", - ], -) - py_library( name = "core", srcs = ["core.py"], @@ -95,14 +77,6 @@ py_library( ] + py_deps("numpy"), ) -py_library( - name = "kernel_regeneration_util", - srcs = ["kernel_regeneration_util.py"], - deps = [ - "//jax:mlir", - ], -) - py_library( name = "pipeline", srcs = ["pipeline.py"], @@ -113,5 +87,15 @@ py_library( "//jax:api_util", "//jax:util", "//jax/_src/pallas", - ], + ] + py_deps("numpy"), +) + +py_library( + name = "random", + srcs = ["random.py"], + deps = [ + ":primitives", + "//jax", + "//jax:typing", + ] + py_deps("numpy"), ) diff --git a/jax/_src/pallas/mosaic/__init__.py b/jax/_src/pallas/mosaic/__init__.py index 2ab64462b102..38d13f42da99 100644 --- a/jax/_src/pallas/mosaic/__init__.py +++ b/jax/_src/pallas/mosaic/__init__.py @@ -11,42 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Module for Mosaic lowering of Pallas call.""" - -from jax._src.pallas.mosaic import core -from jax._src.pallas.mosaic.core import dma_semaphore -from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec -from jax._src.pallas.mosaic.core import semaphore -from jax._src.pallas.mosaic.core import SemaphoreType -from jax._src.pallas.mosaic.core import TPUMemorySpace -from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata -from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata -from jax._src.pallas.mosaic.lowering import LoweringException -from jax._src.pallas.mosaic.pipeline import BufferedRef -from jax._src.pallas.mosaic.pipeline import emit_pipeline -from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations -from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule -from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations -from jax._src.pallas.mosaic.primitives import async_copy -from jax._src.pallas.mosaic.primitives import async_remote_copy -from jax._src.pallas.mosaic.primitives import bitcast -from jax._src.pallas.mosaic.primitives import delay -from jax._src.pallas.mosaic.primitives import device_id -from jax._src.pallas.mosaic.primitives import DeviceIdType -from jax._src.pallas.mosaic.primitives import get_barrier_semaphore -from jax._src.pallas.mosaic.primitives import make_async_copy -from jax._src.pallas.mosaic.primitives import make_async_remote_copy -from jax._src.pallas.mosaic.primitives import repeat -from jax._src.pallas.mosaic.primitives import roll -from jax._src.pallas.mosaic.primitives import run_scoped -from jax._src.pallas.mosaic.primitives import semaphore_read -from jax._src.pallas.mosaic.primitives import semaphore_signal -from jax._src.pallas.mosaic.primitives import semaphore_wait -from jax._src.pallas.mosaic.primitives import prng_seed -from jax._src.pallas.mosaic.primitives import prng_random_bits - -ANY = TPUMemorySpace.ANY -CMEM = TPUMemorySpace.CMEM -SMEM = TPUMemorySpace.SMEM -VMEM = TPUMemorySpace.VMEM diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 76ad43d596f8..f4a794792253 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -19,7 +19,7 @@ import dataclasses import enum import functools -from typing import Any, Union +from typing import Any from jax._src import core as jax_core from jax._src import dtypes @@ -28,20 +28,17 @@ import jax.numpy as jnp from jax._src.pallas import core as pallas_core -# TODO(sharadmv): enable type checking -# mypy: ignore-errors - map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip partial = functools.partial Grid = pallas_core.Grid BlockSpec = pallas_core.BlockSpec +BlockSpecTree = pallas_core.BlockSpecTree GridMapping = pallas_core.GridMapping NoBlockSpec = pallas_core.NoBlockSpec AbstractMemoryRef = pallas_core.AbstractMemoryRef no_block_spec = pallas_core.no_block_spec -_preprocess_grid = pallas_core._preprocess_grid _convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping split_list = util.split_list @@ -65,8 +62,14 @@ class semaphore(semaphore_dtype): pass class dma_semaphore(semaphore_dtype): pass class barrier_semaphore(semaphore_dtype): pass +class AbstractSemaphoreTyRules: + @staticmethod + def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), jnp.dtype('int32')) + class AbstractSemaphoreTy(dtypes.ExtendedDType): name: str + _rules = AbstractSemaphoreTyRules def __repr__(self) -> str: return self.name @@ -75,7 +78,7 @@ def __eq__(self, other): return self.__class__ == other.__class__ def __hash__(self) -> int: - return hash((self.__class__)) + return hash(self.__class__) # TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy @@ -97,6 +100,7 @@ class SemaphoreType(enum.Enum): BARRIER = "barrier" def __call__(self, shape: tuple[int, ...]): + dtype: Any if self == SemaphoreType.DMA: dtype = DmaSemaphoreTy() elif self == SemaphoreType.BARRIER: @@ -105,7 +109,7 @@ def __call__(self, shape: tuple[int, ...]): dtype = SemaphoreTy() return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) - def get_aval(self) -> "AbstractMemoryRef": + def get_aval(self) -> AbstractMemoryRef: return self(()).get_aval() @dataclasses.dataclass(frozen=True) @@ -143,9 +147,6 @@ def _make_aval(obj: object) -> jax_core.AbstractValue: "Only VMEM and SemaphoreType are supported.") -BlockSpecTree = Union[BlockSpec, NoBlockSpec, Sequence["BlockSpecTree"]] - - @dataclasses.dataclass(init=False, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): grid: Grid @@ -171,6 +172,10 @@ def __init__( def get_grid_mapping( self, in_avals, in_tree, out_avals, out_tree ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: + assert all(i is None or isinstance(i, int) for i in self.grid) + grid_mapping_grid = tuple( + pallas_core.dynamic_grid_dim if d is None else d for d in self.grid + ) all_avals = tree_util.tree_unflatten(in_tree, in_avals) flat_scratch_shapes, scratch_tree = tree_util.tree_flatten( self.scratch_shapes) @@ -196,15 +201,29 @@ def get_grid_mapping( ((*grid_avals, *scalar_avals), {}) ) in_block_mappings = map( - partial(_convert_block_spec_to_block_mapping, - (*grid_avals, *scalar_ref_avals), - in_tree=index_map_in_tree), in_specs, in_ref_avals) + partial( + _convert_block_spec_to_block_mapping, + (*grid_avals, *scalar_ref_avals), + in_tree=index_map_in_tree, + grid=grid_mapping_grid, + mapped_dims=(), + ), + in_specs, + in_ref_avals, + ) out_block_mappings = map( - partial(_convert_block_spec_to_block_mapping, - (*grid_avals, *scalar_ref_avals), - in_tree=index_map_in_tree), out_specs, out_ref_avals) + partial( + _convert_block_spec_to_block_mapping, + (*grid_avals, *scalar_ref_avals), + in_tree=index_map_in_tree, + grid=grid_mapping_grid, + mapped_dims=(), + ), + out_specs, + out_ref_avals, + ) grid_mapping = GridMapping( - grid=self.grid, + grid=grid_mapping_grid, # type: ignore block_mappings=(*in_block_mappings, *out_block_mappings), mapped_dims=(), num_index_operands=num_flat_scalar_prefetch, diff --git a/jax/_src/pallas/mosaic/kernel_regeneration_util.py b/jax/_src/pallas/mosaic/kernel_regeneration_util.py deleted file mode 100644 index 087c1fe3ab9a..000000000000 --- a/jax/_src/pallas/mosaic/kernel_regeneration_util.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Helpers to encode and decode Mosaic kernel regeneration metadata.""" - -import base64 -import json -from typing import Any -from jaxlib.mlir import ir - - -def encode_kernel_regeneration_metadata( - metadata: dict[str, Any] -) -> dict[str, bytes]: - """Serializes the given kernel regeneration metadata. - - This function hides the serialization details from the end user. - - Args: - metadata: dictionary with user-defined data to be serialized in the backend - config. - - Returns: - A dict that can be passed to pallas_call via - compiler_params=dict(mosaic=...)). - - Raises: - TypeError: when the input metadata is not serializable in json format. - """ - serialized_metadata = bytes(json.dumps(metadata), encoding="utf-8") - return dict(kernel_regeneration_metadata=serialized_metadata) - - -def extract_kernel_regeneration_metadata(op: ir.Operation) -> dict[str, Any]: - """Extract kernel regeneration metadata from the given Operation. - - This function hides the serialization details from the end user. - - Args: - op: the tpu custom_call mlir Operation that contains the kernel metadata. - - Returns: - The decoded metadata in the form of a dict. This corresponds to the dict - in input to the 'encode' function. - """ - kernel_regeneration_metadata = ir.StringAttr( - op.attributes["kernel_regeneration_metadata"] - ).value - return json.loads(base64.b64decode(kernel_regeneration_metadata)) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 18f7e35365d0..3ebbfdb51b5b 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -15,11 +15,11 @@ """Module for lowering JAX to Mosaic-compatible MLIR dialects.""" from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools import string -from typing import Any, Callable +from typing import Any import jax from jax import core as jax_core @@ -28,9 +28,11 @@ from jax._src import ad_util from jax._src import custom_derivatives from jax._src import debugging +from jax._src import dtypes from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import pjit +from jax._src import prng from jax._src import source_info_util from jax._src import state from jax._src.interpreters import mlir @@ -76,6 +78,12 @@ map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin +UNSIGNED_TO_SIGNED = { + np.dtype('uint8'): np.dtype('int8'), + np.dtype('uint16'): np.dtype('int16'), + np.dtype('uint32'): np.dtype('int32'), + np.dtype('uint64'): np.dtype('int64'), +} @dataclasses.dataclass class MeshContext: @@ -123,7 +131,13 @@ def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type: return ir.Type.parse("!tpu.semaphore") else: raise NotImplementedError - return mlir.dtype_to_ir_type(dtype) + # TODO(justinfu): Remove after mosaic supports unsigned types. + # This conversion makes mosaic interpret all unsigned types as signed types. + type = mlir.dtype_to_ir_type(dtype) + if isinstance(type, ir.IntegerType): + return ir.IntegerType.get_signless(type.width) + else: + return type def aval_to_ir_type(aval, shape=None, memory_space: TPUMemorySpace | None = None): if isinstance(aval, tpu_core.AbstractSemaphore): @@ -137,6 +151,15 @@ def aval_to_ir_type(aval, shape=None, memory_space: TPUMemorySpace | None = None raise ValueError(f"Cannot allocate {aval.sem_type}.") memspace = _memory_space_to_tpu_memspace(TPUMemorySpace.SEMAPHORE) return ir.MemRefType.get((), sem_type, memory_space=memspace) + if dtypes.issubdtype(aval.dtype, dtypes.prng_key): + shape = aval.dtype._impl.key_shape + if memory_space is None: + memory_space = TPUMemorySpace.SMEM + if memory_space != TPUMemorySpace.SMEM: + raise ValueError(f"PRNG keys must be stored in SMEM. Got {memory_space}") + memspace = _memory_space_to_tpu_memspace(memory_space) + return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)), + memory_space=memspace) if isinstance(aval, state.AbstractRef): if shape is None: shape = aval.shape @@ -428,7 +451,9 @@ def lower_jaxpr_to_module( m.body.append(mlir_func) sym_tab.insert(mlir_func) func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params) - static_grid = [MLIR_DYNAMIC if b is None else b for b in grid] + static_grid = [ + MLIR_DYNAMIC if b is pl_core.dynamic_grid_dim else b for b in grid + ] func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid) func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get( @@ -613,16 +638,18 @@ def read_env(atom: jax_core.Atom): return atom.val if isinstance(atom, jax_core.Literal) else env[atom] def write_env(var: jax_core.Var, val): - assert isinstance(val, ir.Value), type(val) + is_valid_type = isinstance(val, (ir.Value, KeyScalarBundle)) + assert is_valid_type, type(val) env[var] = val for invar, bs in zip(jaxpr.invars, ctx.block_shapes): block_shape_env[invar] = bs map(write_env, jaxpr.invars, args) + initial_name_stack = [scope.name for scope in ctx.name_stack.stack] current_name_stack: list[str] = [] # TODO(justinfu): Handle transform scopes. - current_name_stack.extend([scope.name for scope in ctx.name_stack.stack]) + current_name_stack.extend(initial_name_stack) for eqn in jaxpr.eqns: invals = map(read_env, eqn.invars) source_info = eqn.source_info.replace( @@ -679,6 +706,14 @@ def write_env(var: jax_core.Var, val): map(write_env, eqn.outvars, ans) else: write_env(eqn.outvars[0], ans) + + # Drain the name stack at the end of a jaxpr and insert trace_stop ops. + popped, pushed = _compute_name_stack_updates( + current_name_stack, initial_name_stack) + for _ in popped: + tpu.TraceStopOp() + assert len(pushed) == 0 + outvals = map(read_env, jaxpr.outvars) outvals = [ ir_constant(x) if isinstance(var, jax_core.Literal) else x @@ -690,6 +725,8 @@ def write_env(var: jax_core.Var, val): def _ensure_mlir_value(val, aval): if isinstance(val, ir.Value): return val + if isinstance(val, KeyScalarBundle): + return val elif isinstance(val, (np.generic, np.ndarray, int, float)): return ir_constant(val, _dtype_to_ir_type(aval.dtype)) else: @@ -885,6 +922,21 @@ def _index_ref(ref, ref_aval, ref_block_shape, indexers): ref_block_shape) return ref, ref_block_shape +@dataclasses.dataclass(frozen=True) +class KeyScalarBundle: + """A container class for PRNG key data. + + We pass around keys as a KeyScalarBundle in the lowering pass rather than + as a vector, since we want the key data to live in scalar registers rather + than vector registers. This special dataclass exists so we can return + multiple scalar values from load_op, because the load_op primitive does + not allow multiple results. + + Attributes: + scalars: A list of OpResults representing scalar key data during the + lowering pass. + """ + scalars: list[ir.OpResult] def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): ref, indexers, mask, _ = args_tree.unflatten(args_flat) @@ -903,6 +955,12 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space" ref_aval, *_ = ctx.avals_in (aval_out,) = ctx.avals_out + if isinstance(aval_out.dtype, prng.KeyTy): + if not is_smem_load: + raise ValueError("PRNG keys must be loaded from SMEM. Did you set " + "the memory space to TPUMemorySpace.SMEM in the " + "BlockSpec for the PRNG key input?") + return _prng_key_load_lowering_rule(ctx, *args_flat, args_tree=args_tree) if not is_smem_load and not ref_block_shape: raise NotImplementedError( "Indexing into a ()-shaped Ref not yet supported on TPU.") @@ -934,6 +992,37 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): _dtype_to_ir_type(aval_out.dtype)) return vector.ShapeCastOp(vec_type, load_val).result +def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle: + """Lowering rule for loading PRNG keys from SMEM. + + PRNG key loads are currently lowered as a list of scalar loads from SMEM, + rather than a single vector load. + We store these scalars in a bundle type called KeyScalarBundle, which has + special case handling for functions that consume the key such as set_seed. + """ + ref, _, _, _ = args_tree.unflatten(args_flat) + (aval_out,) = ctx.avals_out + assert isinstance(aval_out.dtype, prng.KeyTy) + ref_block_shape = aval_out.dtype._impl.key_shape + + if len(ref_block_shape) != 2: + raise NotImplementedError("Seed key_data must be 2D.") + if tuple(ref_block_shape) != (1, 1): + raise NotImplementedError( + f"Seed key_data of shape != (1, 1) not supported. Got: {ref_block_shape}") + + load_ops = [] + for i in range(ref_block_shape[0]): + idx = NDIndexer(indices=(0, i), shape=ref_block_shape, + int_indexer_shape=tuple()) + starts, _, _, _, _ = _indexer_to_start_size_stride( + idx, + ref_block_shape, + cast_to_index=True, + ) + load_ops.append(memref.LoadOp(ref, starts).result) + return KeyScalarBundle(scalars=load_ops) + lowering_rules[primitives.load_p] = _load_lowering_rule skip_mlir_conversions.add(primitives.load_p) @@ -1035,11 +1124,7 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): out_type = aval_to_ir_type(ctx.avals_out[0]) if jnp.issubdtype(x_aval.dtype, jnp.floating): - # TODO(apaszke): Remove in 03/2024. - if hasattr(vector.CombiningKind, "MAXIMUMF"): - kind = vector.CombiningKind.MAXIMUMF - else: - kind = vector.CombiningKind.MAXF + kind = vector.CombiningKind.MAXIMUMF val = ir.FloatAttr.get(ir.F32Type.get(), float("-inf")) identity = ir.DenseElementsAttr.get_splat(out_type, val) elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger): @@ -1129,7 +1214,15 @@ def _dot_general_lowering_rule( (aval_out,) = ctx.avals_out out_type = aval_to_ir_type(aval_out) val_type = out_type.element_type - if any(cls.isinstance(val_type) for cls in [ir.BF16Type, ir.F32Type]): + if any( + cls.isinstance(val_type) + for cls in [ + ir.BF16Type, + ir.F32Type, + ir.Float8E5M2Type, + ir.Float8E4M3FNType, + ] + ): val = ir.FloatAttr.get(val_type, 0.0) elif ir.IntegerType.isinstance(val_type): val = ir.IntegerAttr.get(val_type, 0) @@ -1210,14 +1303,14 @@ def _convert_helper(x, *, to_dtype): if jnp.issubdtype(from_dtype, jnp.dtype("bool")): x = x.astype(jnp.int32) return _convert_helper(x, to_dtype=to_dtype) - if jnp.issubdtype(from_dtype, jnp.integer): + if jnp.issubdtype(from_dtype, jnp.signedinteger): if from_dtype.itemsize < 4: x = x.astype(jnp.int32) if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4: x = x.astype(jnp.float32) return x.astype(to_dtype) if jnp.issubdtype(from_dtype, jnp.floating): - if jnp.issubdtype(to_dtype, jnp.integer): + if jnp.issubdtype(to_dtype, jnp.signedinteger): if from_dtype.itemsize < 4: x = x.astype(jnp.float32) if to_dtype.itemsize < 4: @@ -1238,6 +1331,11 @@ def _convert_element_type_lowering_rule( out_aval = ctx.avals_out[0] old_dtype = ctx.avals_in[0].dtype out_type = aval_to_ir_type(out_aval) + + # TODO(justinfu): Remove after mosaic supports unsigned types. + # This conversion makes mosaic interpret all unsigned types as signed types. + if np.issubdtype(new_dtype, jnp.unsignedinteger): + new_dtype = UNSIGNED_TO_SIGNED[new_dtype] if old_dtype == new_dtype: return x if jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype( @@ -2091,9 +2189,11 @@ def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis): def _roll_lowering_rule( - ctx: LoweringRuleContext, x, *, shift, axis, stride, stride_axis + ctx: LoweringRuleContext, x, shift, *, axis, stride, stride_axis ): - return tpu.RotateOp( + (out_aval,) = ctx.avals_out + return tpu.DynamicRotateOp( + aval_to_ir_type(out_aval), x, shift, axis, @@ -2154,9 +2254,16 @@ def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty): (out_aval,) = ctx.avals_out return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result - lowering_rules[tpu_primitives.bitcast_p] = _bitcast_lowering_rule +def _bitcast_convert_type_lowering_rule( + ctx: LoweringRuleContext, x, *, new_dtype): + (in_aval, ) = ctx.avals_in + (out_aval,) = ctx.avals_out + if in_aval.dtype.itemsize != new_dtype.itemsize: + raise NotImplementedError("Changing bitwidths not supported.") + return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result +lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value: if isinstance(aval, pl_core.AbstractMemoryRef): @@ -2364,6 +2471,16 @@ def _debug_print_rule( def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds): del ctx + # In the KeyScalarBundle case we unpack the bundle and set the seed with + # the list of scalars. + if len(seeds) == 1 and isinstance(seeds[0], KeyScalarBundle): + return tpu.PRNGSeed32Op(seeds[0].scalars).results + # For integer seeds, we can set the seed directly as PRNGSeed32Op natively + # takes in a list of integers as input. + all_integers = all(isinstance(seed.type, ir.IntegerType) for seed in seeds) + if not all_integers: + seed_types = [seed.type for seed in seeds] + raise ValueError(f"All seed data must be scalar integers. Got {seed_types}") return tpu.PRNGSeed32Op(seeds).results lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule @@ -2376,3 +2493,58 @@ def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape): out_type = aval_to_ir_type(out_aval) return tpu.PRNGRandomBitsOp(out_type).result lowering_rules[tpu_primitives.prng_random_bits_p] = _prng_random_bits_lowering_rule + + +def random_seed_lowering(ctx, seeds, *, impl): + seed_lowering = lower_fun( + impl.seed, multiple_results=False) + return seed_lowering(ctx, seeds) +lowering_rules[prng.random_seed_p] = random_seed_lowering + + +def random_bits_lowering(ctx, keys, *, bit_width, shape): + assert bit_width == 32, "Only 32-bit PRNG supported." + aval, = ctx.avals_in + impl = aval.dtype._impl + bits_lowering = lower_fun( + impl.random_bits, multiple_results=False) + return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape) +lowering_rules[prng.random_bits_p] = random_bits_lowering + + +def random_fold_in_lowering(ctx, keys, msgs): + keys_aval, _ = ctx.avals_in + impl = keys_aval.dtype._impl + fold_in_lowering = lower_fun( + impl.fold_in, multiple_results=False) + return fold_in_lowering(ctx, keys, msgs) +lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering + + +def random_unwrap_lowering(ctx, key): + del ctx, key + raise NotImplementedError("key_data not implemented.") +lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering + + +def random_wrap_lowering(ctx, key_data, *, impl): + del ctx, impl + if isinstance(key_data.type, ir.VectorType): + # If the key data lives in vregs, need to unpack it to sregs. + key_data_list = [] + key_data_shape = key_data.type.shape + if len(key_data_shape) != 2: + raise NotImplementedError("Seed key_data must be 2D.") + if tuple(key_data_shape) != (1, 1): + raise NotImplementedError( + "Seed key_data of shape != (1, 1) not supported. " + f"Got: {key_data_shape}") + for i in range(key_data_shape[1]): + key_data_list.append(vector.ExtractOp(key_data, [], [0, i])) + return KeyScalarBundle(scalars=key_data_list) + if isinstance(key_data, KeyScalarBundle): + return key_data + else: + raise NotImplementedError(f"key_data wrap {type(key_data)}") + +lowering_rules[prng.random_wrap_p] = random_wrap_lowering diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 1af9853593fb..8f307e560bf0 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -21,6 +21,7 @@ import jax from jax import core as jax_core +from jax._src import core as jax_src_core from jax._src import sharding_impls from jax._src.interpreters import mlir from jax._src.lib.mlir import ir @@ -35,7 +36,6 @@ def pallas_call_tpu_lowering_rule( ctx: mlir.LoweringRuleContext, *in_nodes, jaxpr: jax_core.Jaxpr, name: str, - which_linear: tuple[bool, ...], grid_mapping: core.GridMapping, input_output_aliases: tuple[tuple[int, int], ...], in_shapes: tuple[jax.ShapeDtypeStruct, ...], @@ -48,7 +48,6 @@ def pallas_call_tpu_lowering_rule( return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes, in_shapes=in_shapes, - which_linear=which_linear, interpret=interpret, debug=debug, input_output_aliases=input_output_aliases, grid_mapping=grid_mapping, @@ -78,9 +77,6 @@ def pallas_call_tpu_lowering_rule( mlir_ctx.load_all_available_dialects() tpu.register_dialect(mlir_ctx) dimension_semantics = mosaic_params.get("dimension_semantics", None) - kernel_regeneration_metadata = mosaic_params.get( - "kernel_regeneration_metadata" - ) mosaic_module, extra_args = lowering.lower_jaxpr_to_module( mlir_ctx, grid_mapping, in_shapes, out_shapes, jaxpr, dimension_semantics=dimension_semantics, mesh=mesh) @@ -93,20 +89,29 @@ def pallas_call_tpu_lowering_rule( for a in input_output_aliases ) out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes] + + # Replace in_avals to physical avals. + # This step is required for mapping logical types to physical types. + # (e.g. PRNG key -> uint32[2]) + physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in] + ctx = ctx.replace(avals_in=physical_avals) + def _lower_fun(*args): # Dynamic grid bounds have to go at the front. dynamic_grid_args, args = args[:num_dyn_bounds], args[num_dyn_bounds:], return mosaic.as_tpu_kernel( mosaic_module, out_avals, - backend=ctx.module_context.backend, + backend="tpu", kernel_name=name, - kernel_regeneration_metadata=kernel_regeneration_metadata, - cost_estimate=mosaic_params.get("cost_estimate", None), - vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes", None), - flags=mosaic_params.get("flags", None), - allow_input_fusion=mosaic_params.get("allow_input_fusion", None), + cost_estimate=mosaic_params.get("cost_estimate"), + vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"), + flags=mosaic_params.get("flags"), + allow_input_fusion=mosaic_params.get("allow_input_fusion"), input_output_aliases=input_output_aliases, + internal_scratch_in_bytes=mosaic_params.get( + "internal_scratch_in_bytes" + ), )( *dynamic_grid_args, *extra_args, diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 77b224557281..0d778a60c711 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -13,22 +13,26 @@ # limitations under the License. """Module for emitting custom TPU pipelines within a Pallas call.""" +from __future__ import annotations +from collections.abc import Sequence import dataclasses import enum import functools import itertools import operator -from typing import Optional, Union, Any, Sequence +from typing import Union, Any import jax from jax import lax from jax import tree_util +from jax._src import util as jax_util from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import primitives as tpu_primitives from jax.experimental import pallas as pl import jax.numpy as jnp +import numpy as np SMEM = tpu_core.TPUMemorySpace.SMEM @@ -44,6 +48,9 @@ PipelineRefs = Union[Sequence[REF], Any] +# TODO(sharadmv): make this a parameter and make it queryable from the Device. +_TILING = (8, 128) + def _broadcast_pytree_to(from_pytree, to_pytree): """Broadcast a prefix pytree to a given full tree.""" proxy = object() @@ -63,14 +70,73 @@ def add_leaves(i, x): return tree_util.tree_unflatten(treedef, broadcast_leaves) +def _get_tpu_generation() -> int: + kind = jax.devices()[0].device_kind + if kind.endswith(' lite'): + kind = kind[:-len(' lite')] + assert kind[:5] == "TPU v", kind + return int(kind[5]) + +def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]: + # For a n-dimensional shape, returns (8, 128) for the last 2 dimensions + # and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and + # (2, 3, 128, 128) -> (1, 1, 8, 128). + if len(shape) < 2: + raise ValueError(f"Shape must have at least 2 dimensions: {shape=}") + leading_dims, final_dims = shape[:-2], shape[-2:] + # We want to find the minimum power of 2 that fits the second-minor dimension + # of shape, with maximum value 8. + second_minor, _ = final_dims + packing = 4 // dtype.itemsize + max_tiling = _TILING[0] + second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing + while second_minor_tiling < min(second_minor, max_tiling): + second_minor_tiling *= 2 + return (*(1,) * len(leading_dims), second_minor_tiling, _TILING[1]) + + def _mod(a, n): """"Calculates a mod n for positive and negative a with |a| <= n.""" return lax.rem(a + n, n) -def _make_ds(idx, size): +def _round_up_to_nearest_multiple(s: int, multiple: int) -> int: + if s % multiple == 0: + return s + # Subtract off the remainder, then add multiple + return s - s % multiple + multiple + + +def _make_ds( + idx: jax.Array | int, size: jax.Array | int +) -> pl.Slice: """Make a DMA slice with mosaic size hints.""" - return pl.ds(pl.multiple_of(idx * size, size), size) + out = pl.ds(idx * size, size) + assert isinstance(out, pl.Slice) + return out + + +def _make_block_slice( + block_index: jax.Array, block_size: int, size: int, tiling: int +) -> pl.Slice | slice: + # Computes a slice given a block index and block size. In the default case, + # we return slice(block_index * block_size, (block_index + 1) * block_size). + # However, if the total size of the ref does not divide block size and we are + # selecting the last block, we need to pick the lowest tiling size multiple + # that contains the block. + if size % block_size == 0: + return _make_ds(block_index, block_size) + if block_size % tiling != 0: + raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") + num_blocks = pl.cdiv(size, block_size) + is_last = block_index == num_blocks - 1 + rounded_size = jnp.where( + is_last, + _round_up_to_nearest_multiple(size % block_size, tiling), + block_size, + ) + rounded_size = pl.multiple_of(rounded_size, tiling) + return pl.ds(block_index * block_size, rounded_size) def _tuples_differ(xs, ys): @@ -87,15 +153,16 @@ def _grid_size(grid): return size -def _get_indices(step, grid): +def _get_indices(step, grid, offsets): """Get indices for a given step and grid.""" extended_grid = grid + (1,) strides = tuple( itertools.accumulate(extended_grid[::-1], func=operator.mul))[::-1] - return tuple( + indices = tuple( lax.div(lax.rem(step, a), b) for a, b in zip(strides[:-1], strides[1:]) ) + return tuple(a + b for a, b in zip(indices, offsets, strict=True)) class BufferType(enum.Enum): @@ -135,12 +202,12 @@ class BufferedRef: spec: pl.BlockSpec # static metadata dtype: Any # static metadata buffer_type: BufferType # static metadata - vmem_ref: Optional[REF] - accum_ref: Optional[REF] - current_slot: Optional[ArrayRef] - next_slot: Optional[ArrayRef] - sem_recv: Optional[SemaphoreType] - sem_send: Optional[SemaphoreType] + vmem_ref: REF | None + accum_ref: REF | None + current_slot: ArrayRef | None + next_slot: ArrayRef | None + sem_recv: SemaphoreType | None + sem_send: SemaphoreType | None def tree_flatten(self): return ((self.vmem_ref, self.accum_ref, self.current_slot, @@ -152,7 +219,7 @@ def tree_unflatten(cls, meta, data): return cls(*meta, *data) @classmethod - def create(cls, spec, dtype, buffer_type) -> 'BufferedRef': + def create(cls, spec, dtype, buffer_type) -> BufferedRef: """Create a BufferedRef. Args: @@ -259,49 +326,110 @@ def swap_slots(self): if self.memory_space == VMEM: return self.current_slot[0] = self.next_slot[0] + def get_dma_slice(self, src_shape, src_dtype, grid_indices): + # We need to handle blocks that might go OOB in the src array. An in bounds + # block looks like this (for array shape (600, 600) and block shape + # (256, 256)): + # + # +--------------+------------------| + # | Block (0,0) | | + # | (256, 256) | | + # +--------------+ | + # | A (600, 600) | + # | | + # +---------------------------------+ + # + # For in-bounds blocks, we don't need to do anything special. + # An out-of-bounds block looks like this: + # + # +--------------+------------------| + # | | + # | | + # + | + # | A (600, 600) | + # +--------------+ | + # | Block (2,0) | | + # + --------------------------------| + # | XXXXXXXXXX | + # +--------------+ + # where the X's indicate where the block is out of bounds. + # + # When we have an out of bounds block like this, we need to truncate it to + # a tile boundary (tiles are (8, 128) along the two minormost dimensions). + # In this case, we'll have a block that is indexing the + # 512:768 elements of A along the first dimension. We need to convert 768 + # into 600 (600 % 8 == 0), so our indexing will look like this: + + # +--------------+------------------| + # | | + # | | + # + | + # | A (600, 600) | + # +--------------+ | + # | Block (2,0) | | + # + --------------------------------| + # where it is now a (88, 256) sized block. + # + # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block + # for the last iteration on that dimension, we will pick the next highest + # tile multiple, i.e. (96, 256). + if len(src_shape) < 2: + raise NotImplementedError("Must use >1D values.") + + tiling = _make_tiling(src_shape, src_dtype) + block_shape = tuple(1 if b is None else b for b in self.block_shape) + block_indices = self.compute_index(*grid_indices) + return jax.tree.map( + _make_block_slice, block_indices, block_shape, src_shape, tiling + ) + def copy_in(self, src_ref, grid_indices): """Starts copy of HBM dma slice into the current slot.""" assert self.is_input if self.memory_space == VMEM: return - dma_slice = self.compute_slice(grid_indices) next_slot = lax.rem(self.current_slot[0] + 1, 2) self.next_slot[0] = next_slot + src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) + dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) tpu_primitives.make_async_copy( - src_ref.at[dma_slice], - self.vmem_ref.at[next_slot], + src_ref.at[src_slice], + self.vmem_ref.at[next_slot].at[dst_slice], self.sem_recv).start() def copy_out(self, dst_ref, grid_indices): """Starts copy of HBM dma slice from the current slot.""" assert self.is_output if self.memory_space == VMEM: return - dma_slice = self.compute_slice(grid_indices) slot = self.current_slot[0] self.next_slot[0] = lax.rem(slot + 1, 2) + dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) + src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.vmem_ref.at[slot], - dst_ref.at[dma_slice], + self.vmem_ref.at[slot].at[src_slice], + dst_ref.at[dst_slice], self.sem_send).start() def wait_in(self, src_ref, grid_indices): """Waits for input copy to finish.""" assert self.is_input if self.memory_space == VMEM: return - dma_slice = self.compute_slice(grid_indices) + src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) + dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) tpu_primitives.make_async_copy( - src_ref.at[dma_slice], # nb: doesn't matter - self.vmem_ref.at[self.current_slot[0]], # only dst shape is important + src_ref.at[src_slice], # nb: doesn't matter + self.vmem_ref.at[self.current_slot[0]].at[dst_slice], # only dst shape is important self.sem_recv).wait() def wait_out(self, dst_ref, grid_indices): """Waits for output copy to finish.""" assert self.is_output if self.memory_space == VMEM: return - dma_slice = self.compute_slice(grid_indices) prev_slot = lax.rem(self.current_slot[0] + 1, 2) + dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) + src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.vmem_ref.at[prev_slot], # nb: doesn't matter - dst_ref.at[dma_slice], # only dst shape is important + self.vmem_ref.at[prev_slot].at[src_slice], # nb: doesn't matter + dst_ref.at[dst_slice], # only dst shape is important self.sem_send).wait() # Accumulator methods @@ -350,8 +478,9 @@ class Scheduler: """Sequences input and output copies and waits for a pipeline.""" def __init__(self, - step, - grid, + step: jax.Array, + grid: tuple[int | jax.Array, ...], + grid_offsets: tuple[int | jax.Array, ...], first_cycle=None, last_cycle=None, init_accumulators=None, @@ -361,6 +490,7 @@ def __init__(self, Args: step: inner step number. grid: pallas grid for BufferedRefs. + grid_offsets: offsets for grid indices (used for megacore). first_cycle: whether this is the first invocation of the pipeline. last_cycle: whether this is the last invocation of the pipeline. init_accumulators: do we zero-initialize accumulator state for this @@ -388,12 +518,17 @@ def __init__(self, self.next_step = _mod(step + 1, self.num_steps) # Derived grid indices for present, previous, and next steps. - self.indices = _get_indices(step, grid) - self.prev_indices = _get_indices(self.prev_step, self.grid) - self.next_indices = _get_indices(self.next_step, self.grid) + self.indices = _get_indices(step, grid, grid_offsets) + self.prev_indices = _get_indices( + self.prev_step, grid, grid_offsets + ) + self.next_indices = _get_indices( + self.next_step, grid, grid_offsets + ) def grid_env(self): - return pallas_core.grid_env(zip(self.indices, self.grid)) + return pallas_core.grid_env( + list(map(pallas_core.GridAxis, self.indices, self.grid))) def has_changed(self, buffered_ref): indices = buffered_ref.compute_index(*self.indices) @@ -627,13 +762,111 @@ def make_output_bref(out_spec, out_ref, accumulate): return (*in_brefs, *out_brefs) +class GridDimensionSemantics: + pass +PARALLEL = GridDimensionSemantics() +ARBITRARY = GridDimensionSemantics() + + +def _partition_grid( + grid: tuple[int | jax.Array, ...], + core_axis: int | None, + dimension_semantics: tuple[GridDimensionSemantics, ...] | None, +) -> tuple[tuple[int | jax.Array, ...], tuple[int | jax.Array, ...]]: + if core_axis is None: + # We aren't partitioning the grid + return grid, (0,) * len(grid) + num_cores = pl.num_programs(core_axis) + # Check that num_cores is statically known + if not isinstance(num_cores, int): + raise NotImplementedError( + f"Cannot partition grid over dynamic number of cores: {core_axis=}" + ) + if num_cores == 1: + # We aren't partitioning the grid + return grid, (0,) * len(grid) + + # If dimension_semantics aren't provided, we assume it is all arbitrary. + if dimension_semantics is None: + dimension_semantics = (ARBITRARY,) * len(grid) + if len(dimension_semantics) != len(grid): + raise ValueError("dimension_semantics must be the same length as grid.") + + parallel_dimensions = {i for i, d in enumerate(dimension_semantics) + if d == PARALLEL} + # If there are no parallel dimensions, we can't partition the grid + if not parallel_dimensions: + # TODO(sharadmv): enable running kernel on just one core + raise NotImplementedError( + "Cannot partition over cores without parallel grid dimensions:" + f" {dimension_semantics=}" + ) + if all(not isinstance(grid[i], int) for i in parallel_dimensions): + raise NotImplementedError( + f"Cannot partition cores over only dynamic grid dimensions: {grid=}" + ) + # Try to find a divisible dimension to partition the grid on + divisible_dimensions = { + i for i in parallel_dimensions + if isinstance(grid[i], int) and grid[i] % num_cores == 0 + } + if divisible_dimensions: + first_divisible_dimension, *_ = ( + i for i in range(len(dimension_semantics)) if i in divisible_dimensions + ) + partitioned_dim_size = grid[first_divisible_dimension] // num_cores + partitioned_dim_offset = pl.program_id(core_axis) * partitioned_dim_size + new_grid = jax_util.tuple_update( + grid, first_divisible_dimension, partitioned_dim_size + ) + offsets = jax_util.tuple_update( + (0,) * len(grid), first_divisible_dimension, partitioned_dim_offset + ) + else: + # No divisible dimensions, so we can't evenly partition the grid. Let's pick + # the largest dimension and try to divide it as evenly as possible. + # TODO(sharadmv): take the product of many nondivisible dimensions to + # potentially divide it more evenly + largest_parallel_dimension = max(grid[i] for i in parallel_dimensions + if isinstance(grid[i], int)) # type: ignore + partition_dimension, *_ = ( + i + for i, d in enumerate(grid) + if isinstance(d, int) and d == largest_parallel_dimension + ) + base_num_iters, rem = divmod(grid[partition_dimension], num_cores) + assert rem > 0, rem + # We have some remainder iterations that we need to assign somewhere. We + # know that rem < num_cores, so we can assign one extra iteration to each + # core except for the last (num_cores - rem). + core_index = pl.program_id(core_axis) + num_iters = jnp.where(core_index < rem, base_num_iters + 1, + base_num_iters) + new_grid = jax_util.tuple_update(grid, partition_dimension, num_iters) + # Ordinarily, we would compute the offset as: + # grid_offset = pl.program_id(core_axis) * num_iters + # However, since we have some cores that don't have an extra iteration, we + # need to adjust the offset by `rem`. + grid_offset = jnp.where( + core_index < rem, + core_index * num_iters, + core_index * base_num_iters + rem, + ) + offsets = jax_util.tuple_update( + (0,) * len(grid), partition_dimension, grid_offset + ) + return new_grid, offsets + + def emit_pipeline( body, *, - grid, + grid: tuple[int | jax.Array, ...], in_specs=None, out_specs=None, should_accumulate_out=False, + core_axis: int | None = None, + dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None ): """Creates a function to emit a manual pallas pipeline. @@ -652,7 +885,18 @@ def emit_pipeline( out_specs: output pallas block specs should_accumulate_out: booleans to indicate which outputs should be treated as accumulators. + core_axis: optional int, indicates whether or not to partition the grid + along the core axis. + dimension_semantics: optional tuple of GridDimensionSemantics (e.g. PARALLEL + or ARBITRARY). """ + if any(not isinstance(d, (int, jax.Array)) for d in grid): + grid_types = tuple(type(d) for d in grid) + raise ValueError( + f"Grid must consist of Python integers and JAX Arrays: {grid_types}" + ) + grid, grid_offsets = _partition_grid(grid, core_axis, dimension_semantics) + num_steps = _grid_size(grid) if not isinstance(in_specs, (list, tuple)): in_specs = (in_specs,) @@ -736,6 +980,7 @@ def loop_body(step, _): scheduler = Scheduler( step, grid, + grid_offsets=grid_offsets, first_cycle=first_cycle, last_cycle=last_cycle, init_accumulators=init_accumulators) diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index a065f0386344..f4c24e4e5e16 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -15,9 +15,10 @@ """Module for Pallas:TPU-specific JAX primitives and functions.""" from __future__ import annotations +from collections.abc import Callable import dataclasses import enum -from typing import Any, Callable +from typing import Any import jax from jax._src import api_util @@ -35,9 +36,12 @@ from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pl_core from jax._src.pallas.mosaic import core as tpu_core +from jax._src.state import discharge as state_discharge from jax._src.typing import DTypeLike import jax.numpy as jnp +Slice = indexing.Slice + map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip @@ -105,13 +109,13 @@ def _bitcast(x): def roll( x, - shift: int, + shift, axis: int, *, stride: int | None = None, stride_axis: int | None = None, ): - if shift < 0: + if isinstance(shift, int) and shift < 0: raise ValueError("shift must be non-negative.") if axis < 0 or axis >= len(x.shape): raise ValueError("axis is out of range.") @@ -125,19 +129,20 @@ def roll( if axis == stride_axis: raise ValueError("expected axis and stride_axis are different.") return roll_p.bind( - x, shift=shift, axis=axis, stride=stride, stride_axis=stride_axis + x, shift, axis=axis, stride=stride, stride_axis=stride_axis ) @roll_p.def_abstract_eval -def _roll_abstract_eval(x, **_): +def _roll_abstract_eval(x, shift, **_): + del shift return jax_core.raise_to_shaped(x) def _roll_lowering_rule( - ctx: mlir.LoweringRuleContext, x, *, shift, axis, stride, stride_axis + ctx: mlir.LoweringRuleContext, x, shift, *, axis, stride, stride_axis ): - def _roll(x): + def _roll(x, shift): if stride is None: return jnp.roll(x, shift, axis) outputs = [ @@ -146,7 +151,7 @@ def _roll(x): ] return jnp.concatenate(outputs, stride_axis) - return mlir.lower_fun(_roll, multiple_results=False)(ctx, x) + return mlir.lower_fun(_roll, multiple_results=False)(ctx, x, shift) mlir.register_lowering(roll_p, _roll_lowering_rule) @@ -460,6 +465,117 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn +def dma_start_discharge_rule(in_avals, out_avals, + *args, tree, device_id_type): + ( + src_ref, + src_indexers, + dst_ref, + dst_indexers, + dst_sem, + dst_sem_indexers, + src_sem, + src_sem_indexers, + device_id, + ) = tree_util.tree_unflatten(tree, args) + del out_avals, dst_sem, dst_sem_indexers + is_remote = src_sem is not None and device_id is not None + if is_remote: + if device_id_type == DeviceIdType.MESH: + raise NotImplementedError("Mesh device_id_type not supported.") + else: + assert src_sem is None + assert src_sem_indexers is None + assert device_id is None + + def _find_slice_start_size(indexer): + num_scalar_idxs = 0 + # TODO(b/329733289): support strided load/store in interpret mode. + for s in indexer.indices: + if isinstance(s, Slice) and s.stride > 1: + raise NotImplementedError("Strides not supported in discharge" + " rule of dma_start.") + if not isinstance(s, Slice): + num_scalar_idxs += 1 + indices = indexer.indices + scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices] + slice_starts = [s.start if isinstance(s, Slice) else s for s in indices] + slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices) + return scalar_dims, slice_starts, slice_sizes, num_scalar_idxs + + num_src_index_vals = 0 + if src_indexers: + if len(src_indexers) != 1: + raise NotImplementedError("Multiple indexers not supported.") + idx = src_indexers[0] + if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): + (_, slice_starts, + slice_sizes, num_scalar_idxs) = _find_slice_start_size(idx) + num_src_index_vals += num_scalar_idxs + updates = jax.lax.dynamic_slice( + src_ref, slice_starts, slice_sizes=slice_sizes) + else: + updates = src_ref[idx.indices] + else: + updates = src_ref + + if is_remote: + # Note that this code only works in SPMD mode. If not all devices execute + # the DMA then the devices that do will hang. + # TODO(justinfu): Verify that code only works in SPMD mode. + axis_env = jax_core.thread_local_state.trace_state.axis_env + axis_names = tuple(frame.name for frame in axis_env) + nonempty_axis_names = tuple(name for name in axis_names if name is not None) + if len(nonempty_axis_names) > 1: + raise NotImplementedError("Sharding with more than one named axis not " + "implemented in dma_start_p.") + shard_axis = nonempty_axis_names[0] + my_axis = jax.lax.axis_index(shard_axis) + # Update dst_ref from the perspective of the current device as the + # receiver. + who_copy_to_me = jax.lax.all_gather(device_id, shard_axis) == my_axis + # TODO(justinfu): Add a checkify for verifying there is at most one source. + # TODO(justinfu): Handle the case where no other device is copying to + # this device. + index = jnp.argmax(who_copy_to_me, axis=0) + global_updates = jax.lax.all_gather(updates, shard_axis) + updates = jax.lax.dynamic_index_in_dim( + global_updates, index, axis=0, keepdims=False) + + num_dst_index_vals = 0 + if dst_indexers: + if len(dst_indexers) != 1: + raise NotImplementedError("Multiple indexers not supported.") + idx = dst_indexers[0] + if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): + (_, slice_starts, slice_sizes, + num_scalar_idxs) = _find_slice_start_size(idx) + num_dst_index_vals += num_scalar_idxs + if updates.shape != slice_sizes: + raise ValueError("DMA src and dst slices must have same shape. " + f"Got src={updates.shape}, dst={slice_sizes}") + new_dst = jax.lax.dynamic_update_slice( + dst_ref, updates, slice_starts) + else: + new_dst = dst_ref.at[idx.indices].set(updates) + else: + new_dst = updates + + # TODO(b/345505876): Implement semaphore counting. + new_avals = (None,) # src_aval + new_avals += (None,) * num_src_index_vals + new_avals += (new_dst,) # dst_aval + new_avals += (None,) * num_dst_index_vals + new_avals += (None,) # dst_sem_aval + if is_remote: + new_avals += (None, None) # src_sem_aval, device_id + assert (len(new_avals) == + len(in_avals)), f"{len(new_avals), new_avals} != {len(in_avals)}" + return new_avals, [] + +state_discharge.register_discharge_rule(dma_start_p)(dma_start_discharge_rule) + + dma_wait_p = jax_core.Primitive('dma_wait') dma_wait_p.multiple_results = True @@ -485,6 +601,13 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn +def dma_wait_discharge_rule(in_avals, out_avals, + *args, tree, device_id_type): + del out_avals, args, tree, device_id_type + # TODO(justinfu): Implement semaphore counting. + return (None,) * len(in_avals), [] +state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule) + def _get_ref_and_indexers(ref): if isinstance(ref, state.RefView): return ref.ref, ref.indexers @@ -604,7 +727,7 @@ def _(*_): return [] -def prng_seed(*seeds: tuple[int | jax.Array, ...]) -> None: +def prng_seed(*seeds: int | jax.Array) -> None: """Sets the seed for PRNG. Args: diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py new file mode 100644 index 000000000000..cc864c56f4e3 --- /dev/null +++ b/jax/_src/pallas/mosaic/random.py @@ -0,0 +1,219 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable + +import jax +import numpy as np +from jax import numpy as jnp +from jax import random as jax_api_random +from jax._src import blocked_sampler +from jax._src import typing +from jax._src.pallas.mosaic.primitives import prng_seed +from jax._src.pallas.mosaic.primitives import prng_random_bits +from jax._src.pallas import primitives +from jax._src import prng as jax_prng + + +Shape = jax_prng.Shape +SampleFnType = blocked_sampler.SampleFn +KeylessSampleFnType = Callable[..., jax.Array] + +set_seed = prng_seed + +FOLD_IN_ROUNDS = 128 +SUPPORTED_CONVERSION_KEYS = ["rbg", "unsafe_rbg", "pallas_tpu"] + + +def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray: + """Helper function for converting non-Pallas PRNG keys into Pallas keys.""" + + # Only allow conversion from RBG -> Pallas keys. + # There is no technical reason why we cannot support Threefry here, but + # this reduces the chance of unintended behavior where the pallas PRNG + # produces different random bits than Threefry. RBG has fewer guarantees + # so users of RBG should be more aware of the consequences. + if key._impl.name not in SUPPORTED_CONVERSION_KEYS: + raise ValueError(f"Unsupported key type: {key._impl.name}" + f"Supported keys are: {SUPPORTED_CONVERSION_KEYS}") + + key_data = jax_api_random.key_data(key) + pallas_key_size = np.prod(tpu_key_impl.key_shape) + if key_data.size < pallas_key_size: + raise ValueError(f"Key data must be at least {pallas_key_size} bytes.") + pallas_key_data = jnp.ravel(key_data)[:pallas_key_size] + pallas_key_data = jnp.reshape(pallas_key_data, tpu_key_impl.key_shape) + return jax_api_random.wrap_key_data(pallas_key_data, impl="pallas_tpu") + +def _seed_func(seed: jnp.int32): + seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32) + return (seed_data + seed).astype(jnp.uint32) + +def _random_bits(key: typing.Array, bit_width: int, shape: Shape): + if bit_width != 32: + raise ValueError("Bit width must be 32") + prng_seed(key) + return prng_random_bits(shape) + +def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array): + # Roughly, we compute the new key as follows: + # new_key = random_bits(data)[..., 127] ^ random_bits(old_key)[..., 127] + # Because the TPU generates random numbers in (8, 128) blocks at once, we + # can generate that many values without additional cost which will reduce + # correlation between the old and new keys. + key_shape = tpu_key_impl.key_shape + + prng_seed(data) + data_bits = prng_random_bits( + key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) + prng_seed(key) + key_bits = prng_random_bits( + key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) + + mixed = key_bits[..., FOLD_IN_ROUNDS-1] ^ data_bits[..., FOLD_IN_ROUNDS-1] + assert mixed.shape == key_shape + return jax.random.wrap_key_data(mixed, impl="pallas_tpu") + +def _split(key: typing.Array, shape: Shape): + del key, shape + raise NotImplementedError() + +tpu_key_impl = jax_prng.PRNGImpl( + # Pallas currently only supports 2D+ windows, so set the key_shape + # to be 2D to have better compatibility with setting BlockSpecs. + key_shape=(1, 1), + seed=_seed_func, + split=_split, + random_bits=_random_bits, + fold_in=_fold_in, + name="pallas_tpu", + tag="pl" +) +jax_prng.register_prng(tpu_key_impl) + +# Implementation of the stateful Pallas PRNG API. +# Users should set the seed using the `set_seed` function, +# and call the appropriate stateful sampling functions. +# The actual key impl should never be used. The impl +# serves as internal boilerplate code because JAX's existing +# random functions expect a key as an argument, and +# the keys are only generated as part of unused arguments. + +def _pl_stateful_seed_func(seed: jnp.int32): + del seed + # Unused. Return the correct shape and dtype. + return jnp.empty((), dtype=jnp.int32) + +def _pl_stateful_random_bits(key: typing.Array, bit_width: int, shape: Shape): + del key + assert bit_width == 32, "Bit width must be 32" + return prng_random_bits(shape) + +def _pl_stateful_fold_in(key: typing.Array, data: typing.Array): + del key, data + raise NotImplementedError() + +def _pl_stateful_split(key: typing.Array, shape: Shape): + del key, shape + raise NotImplementedError() + + +tpu_internal_stateful_impl = jax_prng.PRNGImpl( + key_shape=(), + seed=_pl_stateful_seed_func, + split=_pl_stateful_split, + random_bits=_pl_stateful_random_bits, + fold_in=_pl_stateful_fold_in, + name="_pallas_internal_stateful", + tag="_pl_stateful" +) +jax_prng.register_prng(tpu_internal_stateful_impl) + +def _make_stateful_sampler(sampler: SampleFnType) -> KeylessSampleFnType: + """Converts a jax.random sampling function to a stateful version. + + Args: + sampler: A sampling function that consumes a key and returns + random samples. + + Returns: + A stateful sampling function with the key argument removed. + """ + def new_sampler(*args, **kwargs): + # Pass in a placeholder key into the sampling function. + # The key is ignored by the stateful random_bits function, but all jax + # sampling functions expect a key as input so we must pass one in here. + placeholder_key = jax_api_random.key(0, impl=tpu_internal_stateful_impl) + return sampler(placeholder_key, *args, **kwargs) + # Remove key argument from docstring. + if sampler.__doc__: + doc_lines = filter( + lambda line: "key:" not in line, sampler.__doc__.split("\n")) + new_sampler.__doc__ = "\n".join(doc_lines) + return new_sampler + +bits = _make_stateful_sampler(jax_api_random.bits) # type: ignore +uniform = _make_stateful_sampler(jax_api_random.uniform) # type: ignore +bernoulli = _make_stateful_sampler(jax_api_random.bernoulli) # type: ignore + + +def sample_block(sampler_fn: SampleFnType, + global_key: jax_prng.PRNGKeyArray, + block_size: Shape, + tile_size: Shape, + total_size: Shape, + block_index: tuple[typing.ArrayLike, ...] | None = None, + **kwargs) -> jax.Array: + """Samples a block of random values with invariance guarantees. + + `sample_block` allows the sampling of identical blocks of random values + across kernels with different block shapes and iteration orders. Each call + to `sample_block` returns a `block_size`-shaped array of random samples + corresponding to the `block_index`. + + `tile_size` should be chosen such that it is a divisor to all block sizes + one needs to be invariant to. The larger the `tile_size`, the more + efficient the sampling process wil be and therefore the best choice is + typically the greatest common divisor between all possible block sizes. + + Args: + sampler_fn: A sampling function that consumes a key and returns + random samples. + global_key: The global key to use for sampling. + block_size: The shape of an individual block. + tile_size: The shape of a `tile`, which is the smallest unit at + which samples are generated. This should be selected to be a divisor + of all block sizes one needs to be invariant to. + total_size: The total size of the array to sample. + block_index: The index denoting which block to generate keys for. Defaults + to the program_id for each block axis. + **kwargs: Additional arguments to pass to the sampler_fn. + + Returns: + A `block_size` shaped array of samples for the current block corresponding + to `block_index`. + """ + if len(block_size) != len(tile_size): + raise ValueError(f"block_size ({len(block_size)}) and tile_size " + f"({len(tile_size)}) must have the same length.") + + if block_index is None: + num_axes = len(block_size) + block_index = tuple( + primitives.program_id(axis) for axis in range(num_axes)) + + keys = blocked_sampler.blocked_fold_in( + global_key, total_size, block_size, tile_size, block_index) + return blocked_sampler.sample_block( + sampler_fn, keys, block_size, tile_size, **kwargs) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 76ab118f8dcf..6f39f2686d4f 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -14,13 +14,17 @@ # Package for Mosaic-specific Pallas extensions -load("//jaxlib:jax.bzl", "pytype_strict_library") load("@rules_python//python:defs.bzl", "py_library") +load( + "//jaxlib:jax.bzl", + "py_deps", + "pytype_strict_library", +) package( default_applicable_licenses = [], default_visibility = [ - "//third_party/py/jax:internal", + "//:__subpackages__", ], ) @@ -37,10 +41,10 @@ pytype_strict_library( srcs = ["pallas_call_registration.py"], deps = [ ":lowering", - "//third_party/py/jax", - "//third_party/py/jax:mlir", - "//third_party/py/jax:mosaic_gpu", - "//third_party/py/jax/_src/pallas", + "//jax", + "//jax:mlir", + "//jax:mosaic_gpu", + "//jax/_src/pallas", ], ) @@ -48,13 +52,12 @@ pytype_strict_library( name = "lowering", srcs = ["lowering.py"], deps = [ - "//third_party/py/jax", - "//third_party/py/jax:core", - "//third_party/py/jax:mlir", - "//third_party/py/jax:mosaic_gpu", - "//third_party/py/jax:util", - "//third_party/py/jax/_src/lib", - "//third_party/py/jax/_src/pallas", - "//third_party/py/numpy", - ], + "//jax", + "//jax:core", + "//jax:mlir", + "//jax:mosaic_gpu", + "//jax:util", + "//jax/_src/lib", + "//jax/_src/pallas", + ] + py_deps("numpy"), ) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index a2ab59267a05..5b4db68f2552 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -151,7 +151,7 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): # TODO(slebedev): Consider enforcing this in the mgpu.BarrierArray. [barrier] = mgpu.BarrierArray(1, arrival_count=1) - with mgpu.once(): + with mgpu.single_thread(): nvgpu_dialect.mbarrier_arrive_expect_tx( barrier.barrier_array.value, _index( @@ -329,6 +329,7 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): @register_lowering_rule(lax.integer_pow_p) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): + x = _ensure_fa(x, *ctx.avals_in) if y == 2: return x * x return NotImplementedError diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 16633fccabe9..740f0c31ebb7 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -35,7 +35,6 @@ def pallas_call_lowering( name: str, in_shapes: tuple[jax.ShapeDtypeStruct, ...], out_shapes: tuple[jax.ShapeDtypeStruct, ...], - which_linear: tuple[bool, ...], interpret: bool, debug: bool, input_output_aliases: tuple[tuple[int, int], ...], @@ -50,7 +49,6 @@ def pallas_call_lowering( name=name, out_shapes=out_shapes, in_shapes=in_shapes, - which_linear=which_linear, interpret=interpret, debug=debug, input_output_aliases=input_output_aliases, diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d3f301c4f971..0748e78a2db9 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -15,43 +15,50 @@ """Module for calling pallas functions from JAX.""" from __future__ import annotations -from collections.abc import Sequence -import itertools +from collections.abc import Callable, Sequence from functools import partial, reduce -from typing import Any, Callable +import itertools +from typing import Any import jax from jax import api_util -from jax import tree_util from jax import lax +from jax import tree_util +from jax._src import ad_util +from jax._src import checkify from jax._src import config +from jax._src import core as jax_core +from jax._src import effects +from jax._src import linear_util as lu from jax._src import state from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla -from jax._src import ad_util -from jax._src import core as jax_core -from jax._src.state import primitives as sp -from jax._src import linear_util as lu +from jax._src.pallas import core as pallas_core +from jax._src.pallas.primitives import uninitialized_value from jax._src.state import discharge as state_discharge +from jax._src.state import primitives as sp from jax._src.util import ( - split_list, safe_map, safe_zip, weakref_lru_cache, - tuple_insert, partition_list, merge_lists) + safe_map, + safe_zip, + split_list, + tuple_insert, + weakref_lru_cache, +) import jax.numpy as jnp import numpy as np -from jax._src.pallas import core as pallas_core - map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip Grid = pallas_core.Grid -BlockSpec = pallas_core.BlockSpec GridSpec = pallas_core.GridSpec BlockMapping = pallas_core.BlockMapping GridMapping = pallas_core.GridMapping +BlockSpec = pallas_core.BlockSpec +BlockSpecTree = pallas_core.BlockSpecTree NoBlockSpec = pallas_core.NoBlockSpec no_block_spec = pallas_core.no_block_spec @@ -104,16 +111,48 @@ def _pad_values_to_block_dimension(value, ) if padded_shape != value.shape: pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape)) - pad_value = _uninitialized_value(shape=(), dtype=value.dtype) + pad_value = uninitialized_value(shape=(), dtype=value.dtype) value = jnp.pad(value, pad_width, constant_values=pad_value) return value -def _uninitialized_value(shape, dtype): - if jnp.issubdtype(dtype, jnp.floating): - return jnp.full(shape, jnp.nan, dtype) - elif jnp.issubdtype(dtype, jnp.integer): - return jnp.full(shape, jnp.iinfo(dtype).min, dtype) - raise NotImplementedError(dtype) +def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]: + scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals) + return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals) + +def _initialize_output_vals( + out_shapes, input_args, input_output_aliases) -> Sequence[jax.Array]: + oi_map = {v: k for k, v in input_output_aliases} + output_vals = [] + for i, out_shape in enumerate(out_shapes): + if i in oi_map: + output_vals.append(input_args[oi_map[i]]) + else: + output_vals.append(uninitialized_value(out_shape.shape, out_shape.dtype)) + return output_vals + +def _logical_to_interpret_mode_dtype(dtype): + """Converts logical dtypes into JAX dtypes for interpret mode. + + This function is used to convert device-specific dtypes that have no + corresponding equivalent in JAX/XLA into a type that can be executed + by the XLA interpreter (e.g. TPU semaphores -> int32). + """ + if (hasattr(dtype, "_rules") and + hasattr(dtype._rules, "pallas_interpret_element_aval")): + return dtype._rules.pallas_interpret_element_aval(dtype).dtype + return dtype + +def _logical_aval_to_interpret_mode_aval(aval): + """Logical to interpret mode aval conversion.""" + if isinstance(aval, pallas_core.AbstractMemoryRef): + inner_aval = _logical_aval_to_interpret_mode_aval(aval.inner_aval) + return aval.update(inner_aval=inner_aval) + if isinstance(aval, jax_core.ShapedArray): + inner_dtype = _logical_to_interpret_mode_dtype(aval.dtype) + return jax_core.ShapedArray(aval.shape, + inner_dtype, + weak_type=aval.weak_type, named_shape=aval.named_shape) + return aval def _get_next_indices(grid, indices): next_indices = [] @@ -124,7 +163,7 @@ def _get_next_indices(grid, indices): next_indices.append(jnp.where(carry, 0, i)) return tuple(reversed(next_indices)) -def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, +def _pallas_call_impl(*args, jaxpr, name, out_shapes, interpret, debug: bool, in_shapes, input_output_aliases: tuple[tuple[int, int], ...], @@ -139,21 +178,16 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, # will do. dynamic_grid_args_iter = iter(dynamic_grid_args) grid = tuple( - a if a is not None else next(dynamic_grid_args_iter) + a if a is not pallas_core.dynamic_grid_dim + else next(dynamic_grid_args_iter) for a in grid_mapping.grid ) assert next(dynamic_grid_args_iter, None) is None - discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ()) + with grid_mapping.trace_env(): + discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ()) if debug: print(discharged_jaxpr) - oi_map = {v: k for k, v in input_output_aliases} - out = [] - for i, out_shape in enumerate(out_shapes): - if i in oi_map: - out.append(args[oi_map[i]]) - else: - # TODO(sharadmv): use unitialized values for outputs - out.append(jnp.zeros(out_shape.shape, out_shape.dtype)) + out = _initialize_output_vals(out_shapes, args, input_output_aliases) scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore # invars: [*scalar_prefetch, *inputs, *outputs, *scratch] num_invars = len(jaxpr.invars) @@ -166,12 +200,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, jaxpr.invars, [grid_mapping.num_index_operands, num_inputs_outputs] ) scratch_avals = [v.aval for v in scratch_invars] - if not all( - hasattr(a, "shape") and hasattr(a, "dtype") for a in scratch_avals - ): - raise NotImplementedError(f"Cannot initialize scratch: {scratch_avals}") - scratch_values = [_uninitialized_value(a.shape, a.dtype) - for a in scratch_avals] + scratch_values = _initialize_scratch_vals(scratch_avals) carry = [] for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings): @@ -218,7 +247,7 @@ def cond(carry): def body(carry): i, loop_idx, *carry = carry local_grid_env = tuple( - (idx, b) + pallas_core.GridAxis(idx, b) for dim, (idx, b) in enumerate(zip(loop_idx, grid)) if dim not in grid_mapping.mapped_dims ) @@ -259,13 +288,13 @@ def body(carry): if input_output_aliases: raise NotImplementedError("Padding with aliasing not supported.") pad_low, pad_high = zip(*padding) - limit_indices = [s - p for s, p in zip(out.shape, pad_high)] + limit_indices = [s - p for s, p in zip(o.shape, pad_high)] o = lax.slice(o, pad_low, limit_indices) out_nopad.append(o) return out_nopad return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name, in_shapes=in_shapes, - out_shapes=out_shapes, which_linear=which_linear, + out_shapes=out_shapes, grid_mapping=grid_mapping, interpret=interpret, debug=debug, input_output_aliases=input_output_aliases, @@ -276,7 +305,7 @@ def _pallas_call_abstract_eval(*avals, out_shapes, **_): return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes) pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) -def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, +def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, input_output_aliases: tuple[tuple[int, int], ...], in_shapes, out_shapes, grid_mapping, debug, interpret, compiler_params: Any): if grid_mapping.num_dynamic_grid_bounds: @@ -302,8 +331,14 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)] ) invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs) - # TODO(sharadmv): Fix state effect tracking after invar switch. - jvp_jaxpr = jvp_jaxpr.replace(invars=invars) + effs = [] + for eff in jvp_jaxpr.effects: + if isinstance(eff, effects.JaxprInputEffect): + eff = eff.replace( + input_index=invars.index(jvp_jaxpr.invars[eff.input_index]) + ) + effs.append(eff) + jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs) if debug: print(jvp_jaxpr) in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)]) @@ -316,7 +351,6 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, in_shapes=(*in_shapes, *in_shapes), out_shapes=(*out_shapes, *out_shapes), grid_mapping=grid_mapping.replace(block_mappings=jvp_bms), - which_linear=which_linear + (True,) * len(tangents), interpret=interpret, debug=debug, input_output_aliases=(), @@ -326,7 +360,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, return out_primals, out_tangents ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule -def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray, +def _batch_block_mapping(grid_mapping: GridMapping, aval: jax_core.ShapedArray, dim: int | batching.NotMapped, block_mapping: BlockMapping | None) -> BlockMapping: def _block_map_function(new_idx, *args): @@ -341,11 +375,12 @@ def _block_map_function(new_idx, *args): return tuple(indices) i32_aval = jax_core.ShapedArray((), jnp.int32) if block_mapping is None: - idx_avals = [i32_aval] * (len(grid) + 1) + idx_avals = [i32_aval] * (len(grid_mapping.grid) + 1) else: idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals] - block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(_block_map_function), idx_avals) + with grid_mapping.trace_env(): + block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(_block_map_function), idx_avals) shape = aval.shape if block_mapping is None else block_mapping.block_shape if dim is batching.not_mapped: new_block_shape = shape @@ -361,17 +396,149 @@ def _block_map_function(new_idx, *args): return block_mapping.replace(block_shape=new_block_shape, index_map_jaxpr=jaxpr) -def _pallas_call_batching_rule(args, dims, *, - jaxpr: jax_core.Jaxpr, - name: str, - in_shapes: tuple[jax.ShapeDtypeStruct, ...], - out_shapes: tuple[jax.ShapeDtypeStruct, ...], - grid_mapping: GridMapping, - input_output_aliases: tuple[tuple[int, int], ...], - debug: bool, - interpret: bool, - which_linear: tuple[bool, ...], - compiler_params: Any): + +def _broadcast_input_output_aliases( + args: Sequence[jax.Array], + dims: Sequence[int | batching.NotMapped], + *, + input_output_aliases: tuple[tuple[int, int], ...], + axis_size: int, +) -> tuple[tuple[jax.Array, ...], tuple[int | batching.NotMapped, ...]]: + """Broadcast input/output operands. + + When we have input/output aliasing, since the output will be mapped, we need + to make sure to broadcast the input across that dimension if it is not + mapped. If the input is mapped, but on a different axis, we tranpose the input + to match the output. + """ + + args_ = list(args) + dims_ = list(dims) + for input_index, _ in input_output_aliases: + dim = dims_[input_index] + dims_[input_index] = 0 + if dim is batching.not_mapped: + args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) + elif dim != 0: + # TODO(cjfj): Change output batching axis instead? + args_[input_index] = jnp.moveaxis(args[input_index], dim, 0) + + return tuple(args_), tuple(dims_) + + +def _batch_with_explicit_loop( + args: Sequence[jax.Array], + dims: Sequence[int | batching.NotMapped], + *, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: tuple[jax.ShapeDtypeStruct, ...], + out_shapes: tuple[jax.ShapeDtypeStruct, ...], + grid_mapping: GridMapping, + input_output_aliases: tuple[tuple[int, int], ...], + debug: bool, + interpret: bool, + compiler_params: Any, +): + """Batch the pallas_call by calling it in loop over the batch size. + + This function provides a fallback implementation of batching a pallas_call + for the cases in which adding a batch dimension to the pallas grid is not + supported. This is currently the case when the batched dimension corresponds + to a dynamic axis or a scalar prefetch argument. + + This implementation builds a HLO loop that dynamic_slices the inputs according + to the current iteration index and dynamic_updates an (initially empty) output + allocation. + """ + + if not dims: + raise NotImplementedError("vmapping pallas_call with no arguments.") + + (axis_size,) = { + arg.shape[dim] + for arg, dim in zip(args, dims) + if dim is not batching.not_mapped + } + + args, dims = _broadcast_input_output_aliases( + args, + dims, + input_output_aliases=input_output_aliases, + axis_size=axis_size, + ) + + # The output arrays are completelly overwritten, so we can just initialize + # empty arrays. + initial_state = [ + jnp.empty( + tuple_insert(out_shape.shape, 0, axis_size), dtype=out_shape.dtype + ) + for out_shape in out_shapes + ] + + def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: + batch_args = [] + + for arg, dim in zip(args, dims): + # If the argument is mapped, extract a slice of size 1 in the mapped + # dimension at the current index. + if dim is batching.not_mapped: + batch_args.append(arg) + else: + batch_args.append( + jnp.squeeze( + jax.lax.dynamic_slice_in_dim( + operand=arg, + start_index=batch_index, + slice_size=1, + axis=dim, + ), + axis=dim, + ) + ) + + batch_out = pallas_call_p.bind( + *batch_args, + jaxpr=jaxpr, + name=name, + in_shapes=in_shapes, + out_shapes=out_shapes, + grid_mapping=grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, + ) + for i, batch_out_array in enumerate(batch_out): + state[i] = jax.lax.dynamic_update_index_in_dim( + state[i], + batch_out_array, + batch_index, + axis=0, + ) + + return state + + result = jax.lax.fori_loop(0, axis_size, body, initial_state, unroll=False) + + return result, (0,) * len(result) + + +def _pallas_call_batching_rule( + args, + dims, + *, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: tuple[jax.ShapeDtypeStruct, ...], + out_shapes: tuple[jax.ShapeDtypeStruct, ...], + grid_mapping: GridMapping, + input_output_aliases: tuple[tuple[int, int], ...], + debug: bool, + interpret: bool, + compiler_params: Any, +): def _maybe_squeeze_out_bdim( x: jax.Array, bdim: int | batching.NotMapped @@ -380,6 +547,27 @@ def _maybe_squeeze_out_bdim( return x return jnp.squeeze(x, axis=bdim) + axis_size, = {x.shape[d] for x, d in zip(args, dims) + if d is not batching.not_mapped} + if axis_size == 1: + # Why are we even vmapping? + args = map(_maybe_squeeze_out_bdim, args, dims) + out = pallas_call_p.bind( + *args, + jaxpr=jaxpr, + name=name, + in_shapes=in_shapes, + out_shapes=out_shapes, + grid_mapping=grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, + ) + return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out) + + # The first num_dynamic_grid_bounds arguments are size-1 arrays that store + # the size of the dynamic bounds. dynamic_grid_args, args = split_list( args, [grid_mapping.num_dynamic_grid_bounds] ) @@ -391,10 +579,23 @@ def _maybe_squeeze_out_bdim( for arg, bdim in zip(dynamic_grid_args, dynamic_grid_dims) ): dynamic_grid_args = safe_map( - _maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims) + _maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims + ) elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims): - raise NotImplementedError( - f"Batched dynamic grid bounds unsupported: {dynamic_grid_dims}" + # TODO(amagni, sharadmv): Explore possibility of batching dynamic grid + # bounds. + return _batch_with_explicit_loop( + args=dynamic_grid_args + args, + dims=dynamic_grid_dims + dims, + jaxpr=jaxpr, + name=name, + in_shapes=in_shapes, + out_shapes=out_shapes, + grid_mapping=grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, ) else: pass # No dynamic grid dimensions @@ -415,12 +616,24 @@ def _maybe_squeeze_out_bdim( args = (*scalar_args, *args) dims = (*scalar_bdims, *bdims) else: - # TODO(sharadmv,apaszke): enable batching over prefetched scalar args - raise NotImplementedError + # TODO(amagni,sharadmv,apaszke): enable efficient batching over + # prefetched scalar args. + return _batch_with_explicit_loop( + args=scalar_args + args, + dims=scalar_bdims + bdims, + jaxpr=jaxpr, + name=name, + in_shapes=in_shapes, + out_shapes=out_shapes, + grid_mapping=grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, + ) + if not dims: raise NotImplementedError("vmapping pallas_call with no arguments.") - axis_size, = {x.shape[d] for x, d in zip(args, dims) - if d is not batching.not_mapped} block_mappings = grid_mapping.block_mappings avals = [v.aval for v in jaxpr.invars] # How should we pick output dimensions? This actually matters because XLA @@ -430,18 +643,9 @@ def _maybe_squeeze_out_bdim( # TODO(sharadmv): explore inferring better output dimensions via a heuristic # TODO(sharadmv): explore a long term solution to output dim inference - # When we have input/output aliasing, since the output will be mapped, we need - # to make sure to broadcast the input across that dimension if it is not - # mapped. - dims_ = list(dims) - args_ = list(args) - for input_index, _ in input_output_aliases: - dim = dims_[input_index] - if dim is batching.not_mapped: - dims_[input_index] = 0 - args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) - args = tuple(args_) - dims = tuple(dims_) + args, dims = _broadcast_input_output_aliases( + args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size + ) all_dims = list(dims) + [0] * len(out_shapes) @@ -453,7 +657,7 @@ def _maybe_squeeze_out_bdim( # operands (the last in the list). avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)] batched_block_mappings = map( - partial(_batch_block_mapping, grid_mapping.grid), + partial(_batch_block_mapping, grid_mapping), avals_to_batch, all_dims[num_index_operands:], block_mappings, @@ -479,7 +683,6 @@ def _maybe_squeeze_out_bdim( name=f"batched_{name}", in_shapes=batched_in_shapes, out_shapes=batched_out_shapes, - which_linear=which_linear, grid_mapping=batched_grid_mapping, input_output_aliases=input_output_aliases, debug=debug, @@ -487,42 +690,229 @@ def _maybe_squeeze_out_bdim( compiler_params=compiler_params, ) return out, (0,) * len(out) + + batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr: - all_const_avals = [var.aval for var in jaxpr.constvars] - is_const_ref = [isinstance(var.aval, state.AbstractRef) for var in - jaxpr.constvars] - const_avals, const_ref_avals = partition_list(is_const_ref, all_const_avals) - const_avals = map(state.AbstractRef, const_avals) - merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals) - arg_avals = [var.aval for var in jaxpr.invars] - in_avals = [*merged_const_avals, *arg_avals] - num_consts = len(merged_const_avals) + """Hoists the constants in the given jaxpr into invars. + + Args: + jaxpr: The jaxpr. + + Returns: + A new jaxpr where the constants were hoisted into invars as ``Ref``s. + The invars for the constants are added *before* any existing invars. + """ + if not jaxpr.constvars: + return jaxpr # Nothing to hoist. + + is_const_ref = [ + isinstance(var.aval, state.AbstractRef) for var in jaxpr.constvars + ] + const_avals = [ + var.aval if is_ref else state.AbstractRef(var.aval) + for is_ref, var in zip(is_const_ref, jaxpr.constvars) + ] + in_avals = const_avals + [var.aval for var in jaxpr.invars] def _hoist(*consts_args): - all_consts, args = split_list(consts_args, [num_consts]) - consts, const_refs = partition_list(is_const_ref, all_consts) + all_consts, args = split_list(consts_args, [len(const_avals)]) # We immediately read the const values out of the `Ref`s. - consts = map(lambda x: sp.ref_get(x, ()), consts) - all_consts = merge_lists(is_const_ref, consts, const_refs) + all_consts = [ + c if is_ref else sp.ref_get(c, ()) + for is_ref, c in zip(is_const_ref, all_consts) + ] return jax_core.eval_jaxpr(jaxpr, all_consts, *args) + hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_hoist), in_avals) assert not consts, "All consts should have been converted to refs" return hoisted_jaxpr + +def checkify_pallas_kernel_body_jaxpr( + body_jaxpr: jax_core.ClosedJaxpr, + enabled_errors, + error: checkify.Error, + grid_mapping: GridMapping) -> tuple[ + jax_core.ClosedJaxpr, tree_util.PyTreeDef, set[checkify.ErrorEffect]]: + err_vals, err_tree = tree_util.tree_flatten(error) + err_vals = map(checkify.get_shaped_aval, err_vals) + flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals] + + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + checked_jaxpr, out_tree, error_effects = checkify.jaxpr_to_checkify_jaxpr( + body_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) + return checked_jaxpr, out_tree, error_effects + +def pallas_call_checkify_rule(error: checkify.Error, + enabled_errors, + *args: jax_core.Value, + jaxpr: jax_core.Jaxpr, + interpret: bool, + input_output_aliases: tuple[tuple[int, int], ...], + grid_mapping: GridMapping, + out_shapes, + **kwargs): + # TODO(b/346651778): Support TPU/GPU checkify. + if not interpret: + raise NotImplementedError( + "Checkify for pallas_call only supports interpret mode.") + # We implement the checkify rule in 4 steps: + # 1) First, trace the kernel body to get the expected error shapes. + # 2) Checkify the kernel body to obtain a jaxpr with errors as inputs + # and outputs. + # 3) Create a new kernel which stores the errors in output memrefs instead of + # returning them, since pallas kernels do not return outputs. + # 4) Create block specs for the error state and call pallas_call with + # the new kernel. + dynamic_grid_bounds, scalars, args = split_list( # type: ignore + args, [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands] + ) + num_scalars = len(scalars) + num_invars = len(jaxpr.invars) + num_inputs_outputs = ( + num_invars + - grid_mapping.num_index_operands + - grid_mapping.num_scratch_operands + ) + num_kernel_inputs = len(args) + num_scratch = num_invars - num_inputs_outputs + num_kernel_outputs = num_invars - num_scratch - num_kernel_inputs + + # Trace the jaxpr to get an initial error value so the kernel jaxpr has all of + # the required inputs. + closed_jaxpr = pe.close_jaxpr(jaxpr) + _jaxpr, _, error_effects = checkify_pallas_kernel_body_jaxpr( + closed_jaxpr, enabled_errors, error, grid_mapping) + error = error._add_placeholder_effects(error_effects) + err_vals, err_tree = jax.tree.flatten(error) + shaped_err_avals = map(checkify.get_shaped_aval, err_vals) + + # Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have + # all enabled errors removed, but have the error as inputs and return values. + input_avals = [v.aval for v in jaxpr.invars] + num_err_vals = len(err_vals) + shaped_input_avals = tuple(jax_core.raise_to_shaped(x) for x in input_avals) + checkify_in_avals = [*shaped_err_avals, + *shaped_input_avals] + closed_kernel_jaxpr = pe.close_jaxpr(jaxpr) + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + checked_jaxpr, out_tree, _ = checkify.jaxpr_to_checkify_jaxpr( + closed_kernel_jaxpr, enabled_errors, err_tree, *checkify_in_avals) + + # Create a new kernel to remove the error as an return value and instead + # write them to a memref. This is because pallas kernels are expected + # to have no return values but instead write their outputs to a ref. + def checked_kernel_fn(*args): + (scalars, _, inputs, out_error_refs, outputs, scratch + ) = split_list( + args, + [num_scalars, num_err_vals, + num_kernel_inputs, num_err_vals, num_kernel_outputs]) + input_error_vals = [err_ref[...] for err_ref in out_error_refs] + # We need to re-order the inputs here. A checkified jaxpr always expects + # errors before other arguments. + jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch] + assert len(checked_jaxpr.jaxpr.invars) == len(jaxpr_args) + result_flat = jax.core.eval_jaxpr( + checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args) + output_errors, _ = split_list(result_flat, [num_err_vals]) + # Store new errors back in the error refs. + for out_ref, error in zip(out_error_refs, output_errors): + out_ref[...] = error + return [] + + # Trace the new checked_kernel_fn with Memref inputs so that + # we can replace the old kernel jaxpr with the new checked jaxpr in + # pallas_call. + # TODO(justinfu): Place errors in scalar memory for non-interpret mode. + error_mem_space = None + error_memref_aval = [pallas_core.AbstractMemoryRef( + err_val, error_mem_space) for err_val in shaped_err_avals] + shaped_scalar_avals, input_aval, output_aval, scratch_aval = split_list( + shaped_input_avals, [num_scalars, num_kernel_inputs, num_kernel_outputs]) + retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval, + *error_memref_aval, *output_aval, *scratch_aval] + jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals) + wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(checked_kernel_fn), jaxpr_in_tree) + debug = pe.debug_info( + checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas") + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + wrapped_kernel_with_err, jaxpr_flat_avals, debug) + + # Prepare pallas_call inputs. We need to create new block specs + # for the new error inputs and outputs. + scalar_avals = map(checkify.get_shaped_aval, scalars) + error_block_specs = [no_block_spec] * num_err_vals + grid_avals = [ + jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid) + # TODO(justinfu): Place these in device-specific scalar memory. + scalar_ref_avals = [ + pallas_core.AbstractMemoryRef( + jax_core.ShapedArray(aval.shape, aval.dtype), None) + for aval in scalar_avals] + grid_tree = tree_util.tree_structure(((*grid_avals, *scalar_avals), {})) + error_block_mappings = map( + partial( + pallas_core._convert_block_spec_to_block_mapping, + (*grid_avals, *scalar_ref_avals), + in_tree=grid_tree, + grid=grid_mapping.grid, + mapped_dims=grid_mapping.mapped_dims), + error_block_specs, error_memref_aval) + input_block_mappings, output_block_mappings = split_list( + grid_mapping.block_mappings, [num_kernel_inputs,]) + grid_mapping_with_error = grid_mapping.replace( + block_mappings=(*error_block_mappings, *input_block_mappings, + *error_block_mappings, *output_block_mappings) + ) + error_out_shapes = tuple( + jax.ShapeDtypeStruct(e.shape, e.dtype) for e in shaped_err_avals) + # Bump all input_output_aliases by num_err_vals to make room for error + # TODO(justinfu): Don't bump scalars here. + input_output_aliases = tuple( + (i+num_err_vals, o+num_err_vals) for (i, o) in input_output_aliases) + input_output_aliases_with_error = tuple( + (i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases + + new_vals_in = [*scalars, *err_vals, *args] + result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in, + jaxpr=final_jaxpr, + interpret=interpret, + grid_mapping=grid_mapping_with_error, + input_output_aliases=input_output_aliases_with_error, + out_shapes=error_out_shapes + out_shapes, + **kwargs) + errors, results = split_list(result, [num_err_vals]) + new_error, _ = jax.tree.unflatten(out_tree, errors) + return new_error, results +checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule + @weakref_lru_cache def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals, - flat_out_avals, in_tree, out_tree): + flat_out_avals, in_tree, out_tree, interpret: bool): avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree, flat_out_avals, out_tree) + if interpret: + avals = jax.tree_util.tree_map(_logical_aval_to_interpret_mode_aval, avals) jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals) wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(fun), jaxpr_in_tree) debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call") - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, debug) - jaxpr = _hoist_consts_to_refs(jaxpr) + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, + jaxpr_flat_avals, debug) + if consts: + jaxpr = _hoist_consts_to_refs(jaxpr) + # Pad ``block_mappings`` to account for the hoisted constants. + grid_mapping = grid_mapping.replace( + block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)), + num_constant_operands=len(consts), + ) return grid_mapping, jaxpr, consts, out_tree_thunk() def _extract_function_name(f: Callable, name: str | None) -> str: @@ -531,7 +921,7 @@ def _extract_function_name(f: Callable, name: str | None) -> str: return name -_PALLAS_USE_MOSAIC_GPU = config.DEFINE_bool( +_PALLAS_USE_MOSAIC_GPU = config.bool_flag( "jax_pallas_use_mosaic_gpu", default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False), help=( @@ -558,38 +948,48 @@ def _pallas_call_lowering( impl = partial(_pallas_call_impl, **params, interpret=True) return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes) - try: - [platform] = ctx.module_context.platforms - except ValueError: - raise ValueError( - "Can only lower pallas_call on a single platform." - ) from None - - if platform == "cpu": + def cpu_lowering(ctx: mlir.LoweringRuleContext, + *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], + **params): raise ValueError("Only interpret mode is supported on CPU backend.") - elif platform == "cuda" or platform == "rocm": + + def tpu_lowering(ctx: mlir.LoweringRuleContext, + *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], + **params): + try: + from jax._src.pallas.mosaic import pallas_call_registration + except ImportError: + raise _unsupported_lowering_error("tpu") + else: + return pallas_call_registration.pallas_call_tpu_lowering_rule( + ctx, *in_nodes, **params + ) + + def gpu_lowering(ctx: mlir.LoweringRuleContext, + *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], + **params): try: if _PALLAS_USE_MOSAIC_GPU.value: from jax._src.pallas.mosaic_gpu import pallas_call_registration else: from jax._src.pallas.triton import pallas_call_registration # type: ignore except ImportError: - pass + raise _unsupported_lowering_error("gpu") else: return pallas_call_registration.pallas_call_lowering( - ctx, *in_nodes, interpret=interpret, **params - ) - elif platform == "tpu": - try: - from jax._src.pallas.mosaic import pallas_call_registration # type: ignore - except ImportError: - pass - else: - return pallas_call_registration.pallas_call_tpu_lowering_rule( - ctx, *in_nodes, interpret=interpret, **params + ctx, *in_nodes, **params ) - raise _unsupported_lowering_error(platform) + return mlir.lower_per_platform(ctx, "pallas_call", + dict(cpu=cpu_lowering, + tpu=tpu_lowering, + cuda=gpu_lowering, + rocm=gpu_lowering), + None, # default_rule + effects.no_effects, + *in_nodes, + interpret=interpret, + **params) mlir.register_lowering(pallas_call_p, _pallas_call_lowering) @@ -602,14 +1002,13 @@ def pallas_call( grid_spec: GridSpec | None = None, debug: bool = False, grid: Grid | None = None, - in_specs: Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec, - out_specs: BlockSpec | NoBlockSpec - | Sequence[BlockSpec | NoBlockSpec] = no_block_spec, + in_specs: BlockSpecTree = no_block_spec, + out_specs: BlockSpecTree = no_block_spec, input_output_aliases: dict[int, int] = {}, interpret: bool = False, name: str | None = None, compiler_params: dict[str, Any] | None = None, -): +) -> Callable[..., Any]: name = _extract_function_name(f, name) if compiler_params is None: compiler_params = {} @@ -632,11 +1031,10 @@ def wrapped(*args): for v in flat_out_shapes) grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr( f, grid_spec, flat_in_avals, flat_out_avals, in_tree, - out_tree) - which_linear = (False,) * len(flat_args) + out_tree, interpret=interpret) out_flat = pallas_call_p.bind( *dynamic_grid_bounds, *consts, *flat_args, - jaxpr=jaxpr, name=name, which_linear=which_linear, + jaxpr=jaxpr, name=name, in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype) for a in flat_args), out_shapes=tuple(flat_out_shapes), debug=debug, diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index a8264d858026..4abc5ced1af0 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -39,10 +39,6 @@ from jax.interpreters import mlir import jax.numpy as jnp - -# TODO(sharadmv): enable type checking -# mypy: ignore-errors - partial = functools.partial Slice = indexing.Slice NDIndexer = indexing.NDIndexer @@ -53,41 +49,42 @@ program_id_p = jax_core.Primitive("program_id") def program_id(axis: int) -> jax.Array: + """Returns the kernel execution position along the given axis of the grid.""" return program_id_p.bind(axis=axis) def program_id_bind(*, axis: int): grid_env = pallas_core.current_grid_env() if grid_env: - return grid_env[axis].axis_index + return grid_env[axis].index + frame = pallas_core.axis_frame() + # Query the size of the axis to make sure its a valid axis (and error + # otherwise). + _ = frame.size(axis) return jax_core.Primitive.bind(program_id_p, axis=axis) program_id_p.def_custom_bind(program_id_bind) -def _program_id_impl(*, axis: int): - grid_env = pallas_core.current_grid_env() - return grid_env[axis].axis_index -program_id_p.def_impl(_program_id_impl) - def _program_id_abstract_eval(**_): return jax_core.ShapedArray((), jnp.int32) program_id_p.def_abstract_eval(_program_id_abstract_eval) - num_programs_p = jax_core.Primitive("num_programs") -def num_programs(axis: int) -> jax.Array: +def num_programs(axis: int) -> int | jax.Array: + """Returns the size of the grid along the given axis.""" return num_programs_p.bind(axis=axis) @num_programs_p.def_custom_bind def _num_programs_bind(*, axis: int): + # We might be using a local grid env grid_env = pallas_core.current_grid_env() if grid_env: - return jnp.asarray(grid_env[axis].axis_size, dtype=jnp.int32) - return jax_core.Primitive.bind(num_programs_p, axis=axis) - -@num_programs_p.def_impl -def _num_programs_impl(*, axis: int): - grid_env = pallas_core.current_grid_env() - return jnp.asarray(grid_env[axis].axis_size, dtype=jnp.int32) + return grid_env[axis].size + # Otherwise, we look up the size of the grid in the axis env + frame = pallas_core.axis_frame() + size = frame.size(axis) + if size is pallas_core.dynamic_grid_dim: + return jax_core.Primitive.bind(num_programs_p, axis=axis) + return size @num_programs_p.def_abstract_eval def _num_programs_abstract_eval(**_): @@ -298,6 +295,41 @@ def _load_jvp(primals, tangents, args_tree, **params): ad.primitive_jvps[load_p] = _load_jvp +def uninitialized_value(shape, dtype): + if jnp.issubdtype(dtype, jnp.floating): + return jnp.full(shape, jnp.nan, dtype) + elif jnp.issubdtype(dtype, jnp.integer): + return jnp.full(shape, jnp.iinfo(dtype).min, dtype) + elif jnp.issubdtype(dtype, jnp.bool): + return jnp.full(shape, False, dtype) + raise NotImplementedError(dtype) + +def _pad_values_to_avoid_dynamic_slice_oob_shift(value, + slice_sizes, unpad=False): + """ + DynamicSlice and DynamicUpdateSlice adjust the start index in cases where the + requested slice overruns the bounds of the array. This pads the array with + uninitialised values such that the requested slice will never overrun. + + For example, if arr is [1.,2.,3.,4.] and a slice of size 4, start index 2 is + requested then the result will be [3.,4.,NaN,NaN] after padding, rather than + [1.,2.,3.,4.] from the unpadded array + + unpad=True performs the inverse operation + """ + + padding_config = tuple((0, slice_size, 0) for slice_size in slice_sizes) + if unpad: + padding_config = tuple((-low, -high, -interior) + for (low, high, interior) in padding_config) + padding_value = uninitialized_value(shape=(), dtype=value.dtype) + value = lax.pad(value, + padding_config=padding_config, + padding_value=padding_value) + return value + +_unpad_values_to_avoid_dynamic_slice_oob_shift = partial( + _pad_values_to_avoid_dynamic_slice_oob_shift, unpad=True) def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): del out_avals # Unused. @@ -315,6 +347,10 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices] slice_starts = [s.start if isinstance(s, Slice) else s for s in indices] slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices) + # fixes an inconstency with lax.dynamic_slice where if the slice goes out + # of bounds, it will instead move the start_index backwards so the slice + # will fit in memory. + ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes) out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes) out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims) out = out_ones[out_indexer] @@ -424,6 +460,10 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): ] slice_starts = [s.start if isinstance(s, Slice) else s for s in indices] slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices) + # Fixes an inconsistency with lax.dynamic_update_slice where if the slice + # goes out of bounds, it will instead move the start_index backwards so the + # slice will fit in memory. + ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes) out = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes) out = jnp.squeeze(out, scalar_dims) if mask is not None: @@ -432,6 +472,7 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): val = jnp.where(mask, val, out_) val = jnp.expand_dims(val, scalar_dims) x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts) + x_new = _unpad_values_to_avoid_dynamic_slice_oob_shift(x_new, slice_sizes) elif all(not isinstance(s, Slice) for s in idx.indices): out = ref[idx.indices] if mask is not None: @@ -525,7 +566,7 @@ def debug_print(fmt: str, *args: jax.ArrayLike): """ # fmt: skip has_placeholders = False if fmt: - _, field_name, *_ = next(string.Formatter().parse(fmt)) + _, field_name, *_ = next(iter(string.Formatter().parse(fmt))) has_placeholders = field_name is not None return debug_print_p.bind(*args, fmt=fmt, has_placeholders=has_placeholders) diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index d8b4c6ed3497..370cbb713ac5 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -17,7 +17,6 @@ load( "//jaxlib:jax.bzl", "py_deps", - "py_library_providing_imports_info", "pytype_strict_library", ) @@ -28,18 +27,6 @@ package( ], ) -py_library_providing_imports_info( - name = "triton", - srcs = ["__init__.py"], - lib_rule = pytype_strict_library, - deps = [ - ":lowering", - ":pallas_call_registration", - ":primitives", - "//jax/_src/lib", - ], -) - pytype_strict_library( name = "primitives", srcs = ["primitives.py"], diff --git a/jax/_src/pallas/triton/__init__.py b/jax/_src/pallas/triton/__init__.py index adade4e8a72c..38d13f42da99 100644 --- a/jax/_src/pallas/triton/__init__.py +++ b/jax/_src/pallas/triton/__init__.py @@ -11,8 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Triton-specific Pallas APIs.""" - -from jax._src.pallas.triton.primitives import approx_tanh -from jax._src.pallas.triton.primitives import elementwise_inline_asm diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 92ca3cd63b75..c270e8084f42 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -16,12 +16,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools import math import operator -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar import jax from jax import lax @@ -78,6 +78,7 @@ class ModuleContext: grid_mapping: GridMapping program_ids: Sequence[ir.Value] traceback_caches: mlir.TracebackCaches = dataclasses.field(repr=False) + platform: str @dataclasses.dataclass @@ -136,6 +137,10 @@ def _bcast_to(a: ir.Value, shape: tuple[int, ...]) -> ir.Value: a_type = ir.RankedTensorType(a.type) if a_type.shape == [*shape]: return a + if a_type.rank != len(shape) or not all( + a_type.shape[i] in (dim, 1) for i, dim in enumerate(shape) + ): + raise ValueError(f"Cannot broadcast from {a_type.shape} to {[*shape]}") return tt_dialect.broadcast( ir.RankedTensorType.get(shape, a_type.element_type, a_type.encoding), a ) @@ -249,9 +254,8 @@ def lower_jaxpr_to_triton_module( in_shapes, grid_mapping: GridMapping, name: str, - cuda_options: Any, + platform: str ) -> LoweringResult: - # TODO(slebedev): Use cuda_options= during lowering. jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True) with _new_ir_context(), ir.Location.unknown(): module = ir.Module.create() @@ -282,7 +286,7 @@ def lower_jaxpr_to_triton_module( if i not in grid_mapping.mapped_dims ] ctx = ModuleContext( - name, grid_mapping, local_program_ids, mlir.TracebackCaches() + name, grid_mapping, local_program_ids, mlir.TracebackCaches(), platform ) if grid_mapping.num_index_operands: raise NotImplementedError( @@ -306,7 +310,9 @@ def lower_jaxpr_to_triton_module( if block_mapping is not None else None for shape_dtype, block_mapping, start_idx in zip( - in_shapes, grid_mapping.block_mappings, start_indices + (*in_shapes, *[()] * grid_mapping.num_constant_operands), + grid_mapping.block_mappings, + start_indices, ) ] () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *entry.arguments) @@ -334,11 +340,12 @@ def read_block_info_env(atom: jax_core.Atom): def write_env(var: jax_core.Var, val): env[var] = val - if block_infos is None: - block_infos = [None] * len(jaxpr.invars) - for invar, block_info in zip(jaxpr.invars, block_infos): - block_info_env[invar] = block_info + if block_infos is not None: + for invar, block_info in zip(jaxpr.invars, block_infos): + block_info_env[invar] = block_info + map(write_env, jaxpr.invars, args) + for eqn in jaxpr.eqns: invals = map(read_env, eqn.invars) if eqn.primitive not in triton_lowering_rules: @@ -369,6 +376,7 @@ def write_env(var: jax_core.Var, val): map(write_env, eqn.outvars, outvals) else: write_env(eqn.outvars[0], outvals) + return map(read_env, jaxpr.outvars) @@ -393,7 +401,6 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis): raise ValueError(f"axis must be in [0, 3), but got: {axis}") return tt_dialect.get_num_programs(axis) - def _atomic_rmw( op: tt_dialect.RMWOp, ptr: ir.Value, @@ -561,10 +568,11 @@ class _Fallback: def _make_dispatch_table( - name: str, table: Sequence[_Extern | _Fallback] + name: str, **tables: Sequence[_Extern | _Fallback] ) -> Callable[..., ir.Value]: def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: + table = tables[ctx.context.platform] h = next((e for e in table if e.matches(ctx.avals_in)), None) if h is None: arg_aval_dtypes = tuple(aval.dtype.name for aval in ctx.avals_in) @@ -586,12 +594,18 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: _abs_dispatch_table = _make_dispatch_table( "abs", - [ + cuda=[ _Extern(["int32"], "__nv_abs", "int32"), _Extern(["int64"], "__nv_llabs", "int64"), _Extern(["float32"], "__nv_fabsf", "float32"), _Extern(["float64"], "__nv_fabs", "float64"), ], + rocm=[ + _Fallback(["int32"], lambda ctx, x: math_dialect.absi(x)), + _Fallback(["int64"], lambda ctx, x: math_dialect.absi(x)), + _Fallback(["float32"], lambda ctx, x: math_dialect.absf(x)), + _Fallback(["float64"], lambda ctx, x: math_dialect.absf(x)), + ], ) @@ -613,208 +627,332 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): lax.neg_p: lambda ctx, x: _minus(x), lax.ceil_p: _make_dispatch_table( "ceil", - [ + cuda=[ _Extern(["float32"], "__nv_ceilf", "float32"), _Extern(["float64"], "__nv_ceil", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_ceil_f32", "float32"), + _Extern(["float64"], "__ocml_ceil_f64", "float64"), + ], ), lax.floor_p: _make_dispatch_table( "floor", - [ + cuda=[ _Extern(["float32"], "__nv_floorf", "float32"), _Extern(["float64"], "__nv_floor", "float64"), _Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)), _Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)), ], + rocm=[ + _Extern(["float32"], "__ocml_floor_f32", "float32"), + _Extern(["float64"], "__ocml_floor_f64", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)), + ], ), lax.exp_p: _make_dispatch_table( "exp", - [ + cuda=[ _Extern(["float32"], "__nv_expf", "float32"), _Extern(["float64"], "__nv_exp", "float64"), _Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)), _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)), ], + rocm=[ + _Fallback(["float32"], lambda ctx, x: math_dialect.exp(x)), + _Fallback(["float64"], lambda ctx, x: math_dialect.exp(x)), + _Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)), + ], ), lax.exp2_p: _make_dispatch_table( "exp2", - [ + cuda=[ _Extern(["float32"], "__nv_exp2f", "float32"), _Extern(["float64"], "__nv_exp2", "float64"), _Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)), _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)), ], + rocm=[ + _Extern(["float32"], "__ocml_exp2_f32", "float32"), + _Extern(["float64"], "__ocml_exp2_f64", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)), + ], ), lax.expm1_p: _make_dispatch_table( "expm1", - [ + cuda=[ _Extern(["float32"], "__nv_expm1f", "float32"), _Extern(["float64"], "__nv_expm1", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_expm1_f32", "float32"), + _Extern(["float64"], "__ocml_expm1_f64", "float64"), + ], ), lax.log_p: _make_dispatch_table( "log", - [ + cuda=[ _Extern(["float32"], "__nv_logf", "float32"), _Extern(["float64"], "__nv_log", "float64"), _Fallback(["float16"], lambda ctx, x: math_dialect.log(x)), _Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)), ], + rocm=[ + _Extern(["float32"], "__ocml_log_f32", "float32"), + _Extern(["float64"], "__ocml_log_f64", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.log(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)), + ], ), lax.log1p_p: _make_dispatch_table( "log1p", - [ + cuda=[ _Extern(["float32"], "__nv_log1pf", "float32"), _Extern(["float64"], "__nv_log1p", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_log1p_f32", "float32"), + _Extern(["float64"], "__ocml_log1p_f64", "float64"), + ], ), lax.sqrt_p: _make_dispatch_table( "sqrt", - [ + cuda=[ _Extern(["float32"], "__nv_sqrtf", "float32"), _Extern(["float64"], "__nv_sqrt", "float64"), _Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)), _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)), ], + rocm=[ + _Extern(["float32"], "__ocml_sqrt_f32", "float32"), + _Extern(["float64"], "__ocml_sqrt_f64", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)), + ], ), lax.pow_p: _make_dispatch_table( "pow", - [ + cuda=[ _Extern(["float32", "int32"], "__nv_powif", "float32"), _Extern(["float64", "int32"], "__nv_powi", "float64"), _Extern(["float32", "float32"], "__nv_powf", "float32"), _Extern(["float64", "float64"], "__nv_pow", "float64"), ], + rocm=[ + _Extern(["float32", "int32"], "__ocml_pown_f32", "float32"), + _Extern(["float64", "int32"], "__ocml_pown_f64", "float64"), + _Extern(["float32", "float32"], "__ocml_pow_f32", "float32"), + _Extern(["float64", "float64"], "__ocml_pow_f64", "float64"), + ], ), lax.cbrt_p: _make_dispatch_table( "cbrt", - [ + cuda=[ _Extern(["float32"], "__nv_cbrtf", "float32"), _Extern(["float64"], "__nv_cbrt", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_cbrt_f32", "float32"), + _Extern(["float64"], "__ocml_cbrt_f64", "float64"), + ], ), lax.rsqrt_p: _make_dispatch_table( "rsqrt", - [ + cuda=[ _Extern(["float32"], "__nv_rsqrtf", "float32"), _Extern(["float64"], "__nv_rsqrt", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_rsqrt_f32", "float32"), + _Extern(["float64"], "__ocml_rsqrt_f64", "float64"), + ], ), lax.sin_p: _make_dispatch_table( "sin", - [ + cuda=[ _Extern(["float32"], "__nv_sinf", "float32"), _Extern(["float64"], "__nv_sin", "float64"), _Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)), _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)), ], + rocm=[ + _Extern(["float32"], "__ocml_sin_f32", "float32"), + _Extern(["float64"], "__ocml_sin_f64", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)), + ], ), lax.cos_p: _make_dispatch_table( "cos", - [ + cuda=[ _Extern(["float32"], "__nv_cosf", "float32"), _Extern(["float64"], "__nv_cos", "float64"), _Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)), _Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)), ], + rocm=[ + _Extern(["float32"], "__ocml_cos_f32", "float32"), + _Extern(["float64"], "__ocml_cos_f64", "float64"), + _Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)), + _Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)), + ], ), lax.tan_p: _make_dispatch_table( "tan", - [ + cuda=[ _Extern(["float32"], "__nv_tanf", "float32"), _Extern(["float64"], "__nv_tan", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_tan_f32", "float32"), + _Extern(["float64"], "__ocml_tan_f64", "float64"), + ], ), lax.asin_p: _make_dispatch_table( "asin", - [ + cuda=[ _Extern(["float32"], "__nv_asinf", "float32"), _Extern(["float64"], "__nv_asin", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_asin_f32", "float32"), + _Extern(["float64"], "__ocml_asin_f64", "float64"), + ], ), lax.acos_p: _make_dispatch_table( "acos", - [ + cuda=[ _Extern(["float32"], "__nv_acosf", "float32"), _Extern(["float64"], "__nv_acos", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_acos_f32", "float32"), + _Extern(["float64"], "__ocml_acos_f64", "float64"), + ], ), lax.atan_p: _make_dispatch_table( "atan", - [ + cuda=[ _Extern(["float32"], "__nv_atanf", "float32"), _Extern(["float64"], "__nv_atan", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_atan_f32", "float32"), + _Extern(["float64"], "__ocml_atan_f64", "float64"), + ], ), lax.atan2_p: _make_dispatch_table( "atan2", - [ + cuda=[ _Extern(["float32", "float32"], "__nv_atan2f", "float32"), _Extern(["float64", "float64"], "__nv_atan2", "float64"), ], + rocm=[ + _Extern(["float32", "float32"], "__ocml_atan2_f32", "float32"), + _Extern(["float64", "float64"], "__ocml_atan2_f64", "float64"), + ], ), lax.sinh_p: _make_dispatch_table( "sinh", - [ + cuda=[ _Extern(["float32"], "__nv_sinhf", "float32"), _Extern(["float64"], "__nv_sinh", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_sinh_f32", "float32"), + _Extern(["float64"], "__ocml_sinh_f64", "float64"), + ], ), lax.cosh_p: _make_dispatch_table( "cosh", - [ + cuda=[ _Extern(["float32"], "__nv_coshf", "float32"), _Extern(["float64"], "__nv_cosh", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_cosh_f32", "float32"), + _Extern(["float64"], "__ocml_cosh_f64", "float64"), + ], ), lax.tanh_p: _make_dispatch_table( "tanh", - [ + cuda=[ _Extern(["float32"], "__nv_tanhf", "float32"), _Extern(["float64"], "__nv_tanh", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_tanh_f32", "float32"), + _Extern(["float64"], "__ocml_tanh_f64", "float64"), + ], ), lax.asinh_p: _make_dispatch_table( "asinh", - [ + cuda=[ _Extern(["float32"], "__nv_asinhf", "float32"), _Extern(["float64"], "__nv_asinh", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_asinh_f32", "float32"), + _Extern(["float64"], "__ocml_asinh_f64", "float64"), + ], ), lax.acosh_p: _make_dispatch_table( "acosh", - [ + cuda=[ _Extern(["float32"], "__nv_acoshf", "float32"), _Extern(["float64"], "__nv_acosh", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_acosh_f32", "float32"), + _Extern(["float64"], "__ocml_acosh_f64", "float64"), + ], ), lax.atanh_p: _make_dispatch_table( "atanh", - [ + cuda=[ _Extern(["float32"], "__nv_atanhf", "float32"), _Extern(["float64"], "__nv_atanh", "float64"), ], + rocm=[ + _Extern(["float32"], "__ocml_atanh_f32", "float32"), + _Extern(["float64"], "__ocml_atanh_f64", "float64"), + ], ), lax.population_count_p: _make_dispatch_table( "population_count", - [ + cuda=[ _Extern(["int32"], "__nv_popc", "int32"), _Extern(["int64"], "__nv_popcll", "int32"), ], + rocm=[ + _Fallback(["int32"], lambda ctx, x: math_dialect.ctpop(x)), + _Fallback(["int64"], lambda ctx, x: math_dialect.ctpop(x)), + ], ), lax.clz_p: _make_dispatch_table( "clz", - [ + cuda=[ _Extern(["int32"], "__nv_clz", "int32"), _Extern(["int64"], "__nv_clzll", "int32"), ], + rocm=[ + _Fallback(["int32"], lambda ctx, x: math_dialect.ctlz(x)), + _Fallback(["int64"], lambda ctx, x: math_dialect.ctlz(x)), + ], ), lax.nextafter_p: _make_dispatch_table( "nextafter", - [ + cuda=[ _Extern(["float32", "float32"], "__nv_nextafterf", "float32"), _Extern(["float64", "float64"], "__nv_nextafter", "float64"), ], + rocm=[ + _Extern(["float32", "float32"], "__ocml_nextafter_f32", "float32"), + _Extern(["float64", "float64"], "__ocml_nextafter_f64", "float64"), + ], ), }) @@ -1018,7 +1156,7 @@ def debug_print_lowering_rule( "pl.debug_print() does not support placeholders when lowering to Triton" ) - tt_dialect.print_(f" {fmt}", hex=False, args=args) + tt_dialect.print_(f" {fmt} ", hex=False, args=args) return () @@ -1067,14 +1205,32 @@ def _broadcast_to_rule(ctx: LoweringRuleContext, x, shape: Sequence[int]): return _bcast_to(_ensure_ir_value(x, x_aval), shape) -def _integer_pow(a, *, y): - if y == 2: - return a * a - if y == 3: - return a * a * a - if y == -2: - return 1.0 / (a * a) - return jax.lax.pow(a, y) +@register_lowering(lax.integer_pow_p) +def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): + if y == 0: + return _full(x.type, 1) + + is_reciprocal = y < 0 + if is_reciprocal: + y = -y + + acc = None + while y > 0: + y, mod = divmod(y, 2) + if mod: + acc = x if acc is None else _mul(acc, x) + if y > 0: + x = _mul(x, x) + assert acc is not None + + [x_aval] = ctx.avals_in + [out_aval] = ctx.avals_out + acc = _cast(acc, x_aval.dtype, out_aval.dtype) + if is_reciprocal: + signed = jnp.issubdtype(out_aval.dtype, jnp.signedinteger) + return _truediv(_full(acc.type, 1), acc, signed=signed) + else: + return acc def lower_fun( @@ -1094,7 +1250,6 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params): _JAX_FN_MAPPING = { lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), - lax.integer_pow_p: _integer_pow, lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), } @@ -1501,14 +1656,22 @@ def _compute_pointers_from_indices( else: index = next(indexer_iter) if isinstance(index, primitives.Slice): - # Handle slices with static and dynamic indices and static sizes - if isinstance(index.start, int): - ptr_dim_offset = _make_range(index.start, index.start + index.size) - else: + if index.is_dynamic_start: + # Compute the offset as start + range(0, size). ptr_dim_offset = _add( _bcast_to(index.start, [index.size]), _ir_cast(_make_range(0, index.size), index.start.type, signed=False), ) + elif index.stride > 1: + # Compute the offset as start + range(0, size) * stride. + iota = _make_range(0, index.size) + ptr_dim_offset = _add( + _bcast_to(_i32_constant(index.start), [index.size]), + _mul(iota, _full(iota.type, index.stride)), + ) + else: + ptr_dim_offset = _make_range(index.start, index.start + index.size) + # We need to add broadcastable dimensions for the advanced int indexing # and for previous slices num_left_expand_dims = len(int_indexer_shape) + other_shape_idx diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index f3584e6f59fe..e6d521692ec2 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -14,9 +14,6 @@ """Module registering a lowering rule for pallas_call on GPU.""" -# TODO(sharadmv): Enable type checking. -# mypy: ignore-errors - from __future__ import annotations import io @@ -36,38 +33,59 @@ def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]: grid = (grid,) elif len(grid) > 3: raise ValueError("`grid` should have three or fewer dimensions.") - return tuple(grid) + (1,) * (3 - len(grid)) + return tuple(grid) + (1,) * (3 - len(grid)) # type: ignore def avals_to_layouts(avals): return [list(reversed(range(aval.ndim))) for aval in avals] -def _pallas_call_ttir_lowering( +def pallas_call_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, jaxpr: jax_core.Jaxpr, name: str, in_shapes: tuple[jax.ShapeDtypeStruct, ...], out_shapes: tuple[jax.ShapeDtypeStruct, ...], + interpret: bool, debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, - triton_params: dict[str, Any] | None = None, - num_warps: int, - num_stages: int, + compiler_params: dict[str, Any], ): - # TODO(sharadmv): Handle multiple devices with different capabilities. - d, *_ = jax.local_devices(backend="gpu") - cuda_options = dict( - compute_capability=d.compute_capability, - num_warps=num_warps, - num_stages=num_stages, - debug=debug, - ) + if interpret: + return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( + ctx, + *in_nodes, + jaxpr=jaxpr, + name=name, + out_shapes=out_shapes, + in_shapes=in_shapes, + interpret=interpret, + debug=debug, + input_output_aliases=input_output_aliases, + grid_mapping=grid_mapping, + compiler_params=compiler_params, + ) + + if grid_mapping.num_dynamic_grid_bounds: + raise NotImplementedError( + "dynamic grid bounds not supported in the Triton backend" + ) + triton_params = compiler_params.get("triton", compiler_params) + num_warps = triton_params.pop("num_warps", 4) + [lowering_platform] = ctx.platforms or ctx.module_context.platforms + if lowering_platform == "rocm": + num_stages = triton_params.pop("num_stages", 1) + else: + num_stages = triton_params.pop("num_stages", 3) + + if debug: + print(jaxpr) + print(grid_mapping) lowering_result = lowering.lower_jaxpr_to_triton_module( - jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, cuda_options + jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, lowering_platform ) module_op = lowering_result.module.operation if debug: @@ -82,7 +100,7 @@ def _pallas_call_ttir_lowering( module_op.write_bytecode(buf) backend_config = dict( name=ir.StringAttr.get(name), - ir=ir.StringAttr.get(buf.getvalue()), + ir=ir.StringAttr.get(buf.getvalue()), # type: ignore num_stages=mlir.i32_attr(num_stages), num_warps=mlir.i32_attr(num_warps), grid_x=mlir.i32_attr(grid_x), @@ -105,66 +123,3 @@ def _pallas_call_ttir_lowering( result_layouts=avals_to_layouts(ctx.avals_out), operand_output_aliases=dict(input_output_aliases), ).results - - -def pallas_call_lowering( - ctx: mlir.LoweringRuleContext, - *in_nodes, - jaxpr: jax_core.Jaxpr, - name: str, - in_shapes: tuple[jax.ShapeDtypeStruct, ...], - out_shapes: tuple[jax.ShapeDtypeStruct, ...], - which_linear: tuple[bool, ...], - interpret: bool, - debug: bool, - input_output_aliases: tuple[tuple[int, int], ...], - grid_mapping: pallas_core.GridMapping, - compiler_params: dict[str, Any], -): - if interpret: - return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( - ctx, - *in_nodes, - jaxpr=jaxpr, - name=name, - out_shapes=out_shapes, - in_shapes=in_shapes, - which_linear=which_linear, - interpret=interpret, - debug=debug, - input_output_aliases=input_output_aliases, - grid_mapping=grid_mapping, - compiler_params=compiler_params, - ) - - if grid_mapping.num_dynamic_grid_bounds: - raise NotImplementedError( - "dynamic grid bounds not supported in the Triton backend" - ) - triton_params = compiler_params.get("triton", compiler_params) - num_warps = triton_params.pop("num_warps", 4) - if len(ctx.module_context.platforms) > 1: - raise NotImplementedError("multi-platform lowering for Pallas kernels") - if ctx.module_context.platforms[0] == "rocm": - num_stages = triton_params.pop("num_stages", 1) - else: - num_stages = triton_params.pop("num_stages", 3) - - if debug: - print(jaxpr) - print(grid_mapping) - - return _pallas_call_ttir_lowering( - ctx, - *in_nodes, - jaxpr=jaxpr, - name=name, - in_shapes=in_shapes, - out_shapes=out_shapes, - debug=debug, - input_output_aliases=input_output_aliases, - grid_mapping=grid_mapping, - triton_params=triton_params, - num_warps=num_warps, - num_stages=num_stages, - ) diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index fae408275c21..41466be0822d 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -15,7 +15,9 @@ """Pallas utility functions.""" from __future__ import annotations +from typing import overload +import jax from jax import lax from jax._src import core as jax_core from jax._src.util import split_list @@ -32,9 +34,26 @@ def _wrapped(f): lax.cond(condition, f, lambda: None) return _wrapped - +@overload def cdiv(a: int, b: int) -> int: - return (a + b - 1) // b + ... + +@overload +def cdiv(a: int, b: jax.Array) -> jax.Array: + ... + +@overload +def cdiv(a: jax.Array, b: int) -> jax.Array: + ... + +@overload +def cdiv(a: jax.Array, b: jax.Array) -> jax.Array: + ... + +def cdiv(a: int | jax.Array, b: int | jax.Array) -> int | jax.Array: + if isinstance(a, int) and isinstance(b, int): + return (a + b - 1) // b + return lax.div(a + b - 1, b) def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]: diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 05200fd347a7..18e7d18d931d 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -14,7 +14,7 @@ class _UnconstrainedPartitionSingleton: - def __str__(self): + def __repr__(self): return "UNCONSTRAINED" diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 43bc77836613..204c288d6993 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -14,15 +14,16 @@ from __future__ import annotations -from collections.abc import Sequence, Iterable +from collections import defaultdict +from collections.abc import Callable, Sequence, Iterable import dataclasses -from functools import partial, lru_cache +from functools import partial import inspect import itertools as it import logging import operator as op import weakref -from typing import Callable, NamedTuple, Any, Union, Optional +from typing import NamedTuple, Any, Union, cast import threading import warnings @@ -38,6 +39,7 @@ from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import op_shardings +from jax._src import profiler from jax._src import sharding_impls from jax._src import source_info_util from jax._src import stages @@ -60,9 +62,12 @@ from jax._src.interpreters import pxla from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version +from jax._src import sharding from jax._src.sharding_impls import ( - NamedSharding, XLACompatibleSharding, GSPMDSharding, + NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified, is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding) @@ -70,13 +75,13 @@ from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef from jax._src.traceback_util import api_boundary from jax._src.tree_util import ( - tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, + tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_leaves, treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr, PyTreeDef, none_leaf_registry as none_lr) from jax._src.util import ( HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log, split_list, weakref_lru_cache, - merge_lists, flatten, unflatten, subs_list) + merge_lists, flatten, unflatten, subs_list, fun_name) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -139,7 +144,6 @@ class PjitInfo(NamedTuple): In other words, this structure contains arguments to jit()/pjit(), preprocessed and validated. """ - fun: Callable fun_sourceinfo: str | None fun_signature: inspect.Signature | None # Shardings, as specified by the user. These can either be UNSPECIFIED or they @@ -165,35 +169,40 @@ class PjitInfo(NamedTuple): has_explicit_sharding: bool use_resource_env: bool # False for jit, True for pjit + # Hash and compare PjitInfo by identity when used as a cache key. + def __hash__(self): + return id(self) + + def __eq__(self, other): + return self is other + -def _python_pjit_helper(jit_info, *args, **kwargs): - (args_flat, _, params, _, out_tree, _, arg_names, - attrs_tracked) = _infer_params(jit_info, args, kwargs) +def _python_pjit_helper(fun, jit_info, *args, **kwargs): + p, args_flat = _infer_params(fun, jit_info, args, kwargs) for arg in args_flat: dispatch.check_arg(arg) - if attrs_tracked: - init_states = _get_states(attrs_tracked) + if p.attrs_tracked: + init_states = _get_states(p.attrs_tracked) args_flat = [*init_states, *args_flat] try: - out_flat = pjit_p.bind(*args_flat, **params) + out_flat = pjit_p.bind(*args_flat, **p.params) except pxla.DeviceAssignmentMismatchError as e: fails, = e.args - api_name = 'jit' if params['resource_env'] is None else 'pjit' - fun = jit_info.fun + api_name = 'jit' if p.params['resource_env'] is None else 'pjit' fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) msg = _device_assignment_mismatch_error( - fun_name, fails, args_flat, api_name, arg_names) + fun_name, fails, args_flat, api_name, p.arg_names) raise ValueError(msg) from None except xla.InvalidInputException as e: - arg_names = [''] * len(args_flat) if arg_names is None else arg_names + arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names # Run canonicalization again to figure out which arg failed. - if params['jaxpr'].consts: + if p.params['jaxpr'].consts: raise TypeError(e.args[0]) from e else: - for arg, name, aval in zip(args_flat, arg_names, params['jaxpr'].in_avals): + for arg, name, aval in zip(args_flat, arg_names, p.in_avals): try: xla.canonicalize_dtype(arg) except xla.InvalidInputException as _: @@ -203,28 +212,40 @@ def _python_pjit_helper(jit_info, *args, **kwargs): f' {type(arg)} is not a valid JAX type.') from e raise AssertionError("Unreachable") from e - if attrs_tracked: - final_states, out_flat = split_list(out_flat, [len(attrs_tracked)]) - _set_states(attrs_tracked, final_states) + if p.attrs_tracked: + num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in p.attrs_tracked) + final_states, out_flat = split_list(out_flat, [num_states_out]) + _set_states(p.attrs_tracked, final_states) - outs = tree_unflatten(out_tree, out_flat) - return outs, out_flat, out_tree, args_flat, params['jaxpr'], attrs_tracked + outs = tree_unflatten(p.out_tree, out_flat) + return outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'], p.attrs_tracked def _set_states(attrs_tracked, vals): from jax.experimental.attrs import jax_setattr - for ((obj, attr), val) in zip(attrs_tracked, vals): + valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]]) + for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): + val = tree_unflatten(treedef, leaves) jax_setattr(obj, attr, val) def _get_states(attrs_tracked): from jax.experimental.attrs import jax_getattr - return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked] - + vals = [] + for treedef, _, (obj, attr) in attrs_tracked: + tree = jax_getattr(obj, attr) + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + return vals + +def _need_to_rebuild_with_fdo(pgle_profiler): + return (pgle_profiler is not None and pgle_profiler.is_enabled() + and not pgle_profiler.is_fdo_consumed()) def _get_fastpath_data( executable, out_tree, args_flat, out_flat, attrs_tracked, effects, - consts, abstracted_axes, -) -> Optional[pxla.MeshExecutableFastpathData]: + consts, abstracted_axes, pgle_profiler +) -> pxla.MeshExecutableFastpathData | None: out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) use_fastpath = ( @@ -245,6 +266,7 @@ def _get_fastpath_data( and not (config.debug_key_reuse.value and any( hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key) for arg in (*args_flat, *out_flat, *consts))) + and not _need_to_rebuild_with_fdo(pgle_profiler) ) if use_fastpath: @@ -253,7 +275,7 @@ def _get_fastpath_data( kept_var_bitvec = [i in executable._kept_var_idx for i in range(len(args_flat))] in_shardings = [ - a.dtype._rules.physical_sharding(a, s) + sharding_impls.physical_sharding(a, s) if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) else s for s, a in zip(executable._in_shardings, executable.in_avals) @@ -271,6 +293,7 @@ def _get_fastpath_data( class _MostRecentPjitCallExecutable(threading.local): def __init__(self): self.weak_key_dict = weakref.WeakKeyDictionary() + self.weak_pgle_profiler_dict = weakref.WeakKeyDictionary() _most_recent_pjit_call_executable = _MostRecentPjitCallExecutable() @@ -279,9 +302,15 @@ def _read_most_recent_pjit_call_executable(jaxpr): return _most_recent_pjit_call_executable.weak_key_dict.get(jaxpr, None) +def _read_pgle_profiler(jaxpr): + return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get( + jaxpr, None + ) + def _cpp_pjit_evict_fn(self): self._clear_cache() _create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error + _infer_params_cached.cache_clear() # The entries are doubled here from the default 4096 because _pjit_call_impl @@ -297,24 +326,26 @@ def _get_cpp_global_cache(pjit_has_explicit_sharding): return _cpp_pjit_cache -def _cpp_pjit(jit_info: PjitInfo): +def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @api_boundary def cache_miss(*args, **kwargs): outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper( - jit_info, *args, **kwargs) + fun, jit_info, *args, **kwargs) executable = _read_most_recent_pjit_call_executable(jaxpr) + pgle_profiler = _read_pgle_profiler(jaxpr) maybe_fastpath_data = _get_fastpath_data( executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects, - jaxpr.consts, jit_info.abstracted_axes) - return outs, maybe_fastpath_data + jaxpr.consts, jit_info.abstracted_axes, + pgle_profiler) + + return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) - fun = jit_info.fun cpp_pjit_f = xc._xla.pjit( - getattr(fun, "__name__", ""), + fun_name(fun), fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, jit_info.donate_argnums, tree_util.dispatch_registry, - pxla.shard_arg, + lambda x, sharding: pxla.shard_args([sharding], [x])[0], _get_cpp_global_cache(jit_info.has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) @@ -423,7 +454,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, in_shardings, out_shardings, device, backend) return PjitInfo( - fun=fun, fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature, user_specified_in_shardings=user_specified_in_shardings, @@ -444,45 +474,45 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, use_resource_env=use_resource_env) -def _make_jit_wrapper(jit_info: PjitInfo): +def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo): @api_boundary def lower(*args, **kwargs): - lowering_parameters = kwargs.pop( - '_experimental_lowering_parameters', mlir.LoweringParameters()) - - (args_flat, flat_global_in_avals, params, in_tree, out_tree, - donated_invars, arg_names, _) = _infer_params(jit_info, args, kwargs) + traced = trace(*args, **kwargs) try: - lowering = _resolve_and_lower( - args_flat, **params, lowering_parameters=lowering_parameters) + return traced.lower() except pxla.DeviceAssignmentMismatchError as e: fails, = e.args - api_name = 'jit' if params['resource_env'] is None else 'pjit' - fun = jit_info.fun fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) msg = _device_assignment_mismatch_error( - fun_name, fails, args_flat, api_name, arg_names) + fun_name, fails, traced._args_flat, 'jit', traced._arg_names) raise ValueError(msg) from None - donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d) - return stages.Lowered.from_flat_info( - lowering, in_tree, flat_global_in_avals, donate_argnums, - out_tree) - @api_boundary def eval_shape(*args, **kwargs): - _, _, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs) - out_s = [None if is_unspecified(s) else s for s in params['out_shardings']] + p, _ = _infer_params(fun, jit_info, args, kwargs) + out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']] # TODO(yashkatariya): Add `Layout` to SDS. out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s) - for x, s in zip(params['jaxpr'].out_avals, out_s)] - return tree_unflatten(out_tree, out) + for x, s in zip(p.params['jaxpr'].out_avals, out_s)] + return tree_unflatten(p.out_tree, out) - wrapped = _cpp_pjit(jit_info) + @api_boundary + def trace(*args, **kwargs) -> stages.Traced: + p, args_flat = _infer_params(fun, jit_info, args, kwargs) + donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) + args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) + lower_callable = partial(_resolve_and_lower, args_flat, **p.params, + pgle_profiler=None) + return stages.Traced( + p.params['jaxpr'], args_info, p.params["name"],p.out_tree, + lower_callable, args_flat, p.arg_names, p.num_consts) + + wrapped = _cpp_pjit(fun, jit_info) wrapped.lower = lower wrapped.eval_shape = eval_shape + wrapped.trace = trace return wrapped @@ -499,82 +529,90 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes, keep_unused, inline, use_resource_env) - return _make_jit_wrapper(jit_info) + return _make_jit_wrapper(fun, jit_info) + +class PjitParams(NamedTuple): + consts: list[Any] # Only jaxpr constants, we can't keep other arguments alive + params: dict[str, Any] + in_avals: tuple[core.AbstractValue, ...] + in_tree: PyTreeDef + out_tree: PyTreeDef + donated_invars: tuple[bool, ...] + arg_names: tuple[str, ...] | None + num_consts: int + attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] -def _infer_params(jit_info, args, kwargs): - (fun, fun_sourceinfo, fun_signature, user_specified_in_shardings, - in_shardings_treedef, in_shardings_leaves, out_shardings_treedef, - out_shardings_leaves, in_layouts_treedef, in_layouts_leaves, - out_layouts_treedef, out_layouts_leaves, static_argnums, static_argnames, - donate_argnums, donate_argnames, device, backend, keep_unused, inline, - abstracted_axes, _, use_resource_env) = jit_info +def _infer_params_impl( + fun: Callable, + ji: PjitInfo, + pjit_mesh: mesh_lib.Mesh | None, + resource_env: mesh_lib.ResourceEnv | None, + args: tuple[Any, ...], + kwargs: dict[str, Any], + in_avals: tuple[core.AbstractValue, ...] | None, +) -> tuple[PjitParams, list[Any]]: have_kwargs = bool(kwargs) - if have_kwargs and user_specified_in_shardings: + if have_kwargs and ji.user_specified_in_shardings: raise ValueError( "pjit does not support kwargs when in_shardings is specified.") - if use_resource_env: - # We need to fetch the mesh from inside the wrapped function, because - # meshes are dynamically scoped (i.e., with a context manager). - resource_env = mesh_lib.thread_resources.env - pjit_mesh = resource_env.physical_mesh + if pjit_mesh is not None: jit_name = 'pjit' - if (backend or device) and not pjit_mesh.empty: + if (ji.backend or ji.device) and not pjit_mesh.empty: raise ValueError( "Mesh context manager should not be used with jit when backend or " "device is also specified as an argument to jit.") else: - resource_env = None - pjit_mesh = None jit_name = 'jit' + axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs) - axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs) - - dbg = debug_info(jit_name, fun_sourceinfo, fun_signature, args, kwargs, - static_argnums, static_argnames) + dbg = debug_info(jit_name, ji.fun_sourceinfo, ji.fun_signature, args, kwargs, + ji.static_argnums, ji.static_argnames) f = lu.wrap_init(fun) f, res_paths = result_paths(f) - f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=True) + f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True) del args - f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs) + f, dyn_kwargs = argnames_partial_except(f, ji.static_argnames, kwargs) explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs)) flat_fun, out_tree = flatten_fun(f, in_tree) flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args) - if (donate_argnums or donate_argnames) and not config.debug_nans.value: - donated_invars = donation_vector(donate_argnums, donate_argnames, in_tree) + if (ji.donate_argnums or ji.donate_argnames) and not config.debug_nans.value: + donated_invars = donation_vector(ji.donate_argnums, ji.donate_argnames, in_tree) else: donated_invars = (False,) * len(explicit_args) - del donate_argnums, donate_argnames # If backend or device is set as an arg on jit, then resolve them to # in_shardings and out_shardings as if user passed in in_shardings # and out_shardings. - device_or_backend_set = bool(backend or device) + device_or_backend_set = bool(ji.backend or ji.device) if device_or_backend_set: - sharding = _create_sharding_with_device_backend(device, backend) + sharding = _create_sharding_with_device_backend(ji.device, ji.backend) leaves, treedef = tree_flatten(sharding) in_shardings_leaves = out_shardings_leaves = tuple(leaves) in_shardings_treedef = out_shardings_treedef = treedef else: in_shardings_leaves = tuple( _create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name) - for x in in_shardings_leaves) + for x in ji.in_shardings_leaves) + in_shardings_treedef = ji.in_shardings_treedef out_shardings_leaves = tuple( _create_sharding_for_array(pjit_mesh, x, 'out_shardings', jit_name) - for x in out_shardings_leaves) + for x in ji.out_shardings_leaves) + out_shardings_treedef = ji.out_shardings_treedef assert None not in in_shardings_leaves assert None not in out_shardings_leaves + in_type: core.InputType | tuple[core.AbstractValue, ...] if config.dynamic_shapes.value: in_type = pe.infer_lambda_input_type(axes_specs, explicit_args) in_avals = tuple(a for a, e in in_type if e) - else: + elif in_avals is None: avals = [] for i, a in enumerate(explicit_args): try: @@ -587,32 +625,41 @@ def _infer_params(jit_info, args, kwargs): f"computation, whose {arg_path}." ) from e in_type = in_avals = tuple(avals) + else: + in_type = in_avals in_shardings_flat, in_layouts_flat = _process_in_axis_resources( in_shardings_treedef, in_shardings_leaves, - in_layouts_treedef, in_layouts_leaves, + ji.in_layouts_treedef, ji.in_layouts_leaves, in_avals, in_tree, dbg, device_or_backend_set, have_kwargs) - jaxpr, consts, out_shardings_flat, out_layouts_flat, attrs_tracked = _pjit_jaxpr( - flat_fun, out_shardings_treedef, out_shardings_leaves, - out_layouts_treedef, out_layouts_leaves, in_type, dbg, - device_or_backend_set, HashableFunction(out_tree, closure=()), - HashableFunction(res_paths, closure=()), inline) + attr_token = _attr_token(flat_fun, in_type) + jaxpr, consts, out_type, attrs_tracked = _create_pjit_jaxpr( + flat_fun, in_type, attr_token, dbg, + HashableFunction(res_paths, closure=()), + IgnoreKey(ji.inline)) + _attr_update(flat_fun, in_type, attr_token, attrs_tracked) + out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( + out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef, + ji.out_layouts_leaves, HashableFunction(out_tree, closure=()), + tuple(out_type), jaxpr.jaxpr.debug_info, device_or_backend_set) assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat) if config.dynamic_shapes.value: - implicit_args = _extract_implicit_args(in_type, explicit_args) + implicit_args = _extract_implicit_args( + cast(core.InputType, in_type), explicit_args) else: implicit_args = [] args_flat = [*implicit_args, *explicit_args] - num_extra_args = len(implicit_args) + len(attrs_tracked) + len(consts) + num_states_in = sum(init_tree.num_leaves for init_tree, _, _ in attrs_tracked) + num_extra_args = len(implicit_args) + num_states_in + len(consts) in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat in_layouts_flat = (None,) * num_extra_args + in_layouts_flat donated_invars = (False,) * num_extra_args + donated_invars assert (len(in_shardings_flat) == len(in_layouts_flat) == - len(donated_invars) == len(attrs_tracked) + len(consts) + len(args_flat)) + len(donated_invars) == num_states_in + len(consts) + len(args_flat)) params = dict( jaxpr=jaxpr, @@ -622,12 +669,83 @@ def _infer_params(jit_info, args, kwargs): out_layouts=out_layouts_flat, resource_env=resource_env, donated_invars=donated_invars, - name=getattr(flat_fun, '__name__', ''), - keep_unused=keep_unused, - inline=inline, + name=fun_name(flat_fun), + keep_unused=ji.keep_unused, + inline=ji.inline, ) - return (consts + args_flat, in_type, params, in_tree, out_tree(), - donated_invars, dbg.arg_names if dbg else None, attrs_tracked) + return PjitParams(consts, params, in_avals, in_tree, out_tree(), + donated_invars, dbg.arg_names if dbg else None, len(consts), + attrs_tracked), args_flat + + + +class InferParamsCacheEntry: + """Mutable value object for _infer_params_cached.""" + __slots__ = ['pjit_params'] + + pjit_params: PjitParams | None + + def __init__(self): + self.pjit_params = None + + +# We use an outer cache that is keyed on the signature of the arguments, but +# when populating a cache entry using _infer_params_impl, we need to provide +# actual arguments. In principle we could refactor _infer_params_impl to look +# only at an argument signature instead of args/kwargs in those cases that we +# cache, but this was a more minimal change. +@util.weakref_lru_cache +def _infer_params_cached( + fun: Callable, + jit_info: PjitInfo, + signature: jax_jit.ArgumentSignature, + in_avals: tuple[core.AbstractValue, ...], + pjit_mesh: mesh_lib.Mesh | None, + resource_env: mesh_lib.ResourceEnv | None, +) -> InferParamsCacheEntry: + return InferParamsCacheEntry() + + +def _infer_params( + fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> tuple[PjitParams, list[Any]]: + if ji.use_resource_env: + # We need to fetch the mesh from inside the wrapped function, because + # meshes are dynamically scoped (i.e., with a context manager). + resource_env = mesh_lib.thread_resources.env + pjit_mesh = resource_env.physical_mesh + else: + resource_env = None + pjit_mesh = None + + skip_cache = xla_extension_version < 273 or config.dynamic_shapes.value + if not skip_cache: + signature, dynargs = jax_jit.parse_arguments( + args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, + ji.static_argnames, tree_util.default_registry) + try: + avals = tuple(shaped_abstractify(a) for a in dynargs) + except (OverflowError, TypeError): + # If we see something we don't understand, use the slow path. + skip_cache = True + + if skip_cache: + p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, args, + kwargs, in_avals=None) + return p, p.consts + args_flat + + entry = _infer_params_cached( + fun, ji, signature, avals, pjit_mesh, resource_env) + if entry.pjit_params is None: + p, args_flat = _infer_params_impl( + fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals) + if p.attrs_tracked: + # If there are attrs_tracked, don't use the cache. + return p, p.consts + args_flat + else: + entry.pjit_params = p + return entry.pjit_params, entry.pjit_params.consts + dynargs + def _extract_implicit_args( in_type: Sequence[tuple[core.AbstractValue, bool]], @@ -675,6 +793,9 @@ def eval_shape(self, *args, **kwargs): """See ``jax.eval_shape``.""" raise NotImplementedError + def trace(self, *args, **kwargs) -> stages.Traced: + raise NotImplementedError + # in_shardings and out_shardings can't be None as the default value # because `None` means that the input is fully replicated. @@ -754,7 +875,7 @@ def pjit( The valid resource assignment specifications are: - - :py:class:`XLACompatibleSharding`, which will decide how the value + - :py:class:`Sharding`, which will decide how the value will be partitioned. With this, using a mesh context manager is not required. - :py:obj:`None` is a special case whose semantics are: @@ -876,10 +997,10 @@ def hashable_pytree(pytree): def _create_sharding_for_array(mesh, x, name, api_name): if x is None and (mesh is None or mesh.empty): return UNSPECIFIED - if isinstance(x, XLACompatibleSharding) or is_unspecified_or_auto(x): + if isinstance(x, sharding.Sharding) or is_unspecified_or_auto(x): return x if mesh is None: - msg = ('jax.jit only supports `XLACompatibleSharding`s being passed to' + msg = ('jax.jit only supports `Sharding`s being passed to' f' {name}. Looks like you are passing either `PartitionSpec` or `None`' f' which is not allowed in jax.jit.\n') if name == 'in_shardings': @@ -895,7 +1016,7 @@ def _create_sharding_for_array(mesh, x, name, api_name): raise RuntimeError( f'{api_name} requires a non-empty mesh if you are passing' f' `PartitionSpec`s or `None` to {name}! Is a mesh defined at the call' - f' site? Alternatively, provide `XLACompatibleSharding`s to {name} and' + f' site? Alternatively, provide `Sharding`s to {name} and' ' then the mesh context manager is not required.') # A nice user error is raised in prepare_axis_resources. assert x is None or isinstance(x, ParsedPartitionSpec), x @@ -974,7 +1095,7 @@ class PytreeLeaf: def __repr__(self): return "pytree leaf" -@lru_cache(maxsize=4096) +@util.cache(max_size=4096, trace_context_in_key=False) def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, in_layouts_treedef, in_layouts_leaves, in_avals, in_tree, debug_info, @@ -1013,7 +1134,7 @@ def explain_tracing_cache_miss( if config.check_tracer_leaks.value: return def unpack(key): - transforms, (), _, (in_type, debug_info, _, inline), *_, ctx = key + transforms, (), _, (in_type, _, debug_info, _, inline), *_, ctx = key # TODO(dougalm,mattjj): enable cache miss explanation with attrs _, (_, (in_tree,)), *_ = transforms return in_tree, in_type, debug_info, inline.val, ctx @@ -1059,7 +1180,9 @@ def unpack(key): f" {', '.join(map(repr, kwarg_keys))}") dont_match = [set(t[1].node_data()[1]) for t in args_kwargs_trees # type: ignore if t != [args_tree, kwargs_tree]] - close_kwargs = min(dont_match, key=set(kwarg_keys).symmetric_difference) + close_kwargs = min( + dont_match, key=set(kwarg_keys).symmetric_difference, default=None + ) if not close_kwargs: p(" closest seen is passing no keyword args") else: @@ -1137,7 +1260,15 @@ def unpack(key): return done() @partial(lu.cache, explain=explain_tracing_cache_miss) -def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline): +def _create_pjit_jaxpr( + fun: lu.WrappedFun, + in_type: core.InputType | Sequence[core.AbstractValue], + attr_data: int, + debug_info: lu.TracingDebugInfo, + out_paths: Callable, + ignored_inline: IgnoreKey +) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: del ignored_inline # just for explain_cache_miss with dispatch.log_elapsed_time( "Finished tracing + transforming {fun_name} for pjit in {elapsed_time} sec", @@ -1145,11 +1276,12 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline): pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for) if config.dynamic_shapes.value: jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2( - lu.annotate(fun, in_type), debug_info=pe_debug) + lu.annotate(fun, cast(core.InputType, in_type)), debug_info=pe_debug) attrs_tracked = [] else: jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( fun, in_type, debug_info=pe_debug) + # assert attr_data is sentinel or attr_data matches attrs_tracked # TODO(dougalm,mattjj): enable debug info with attrs_tracked if not config.dynamic_shapes.value and not attrs_tracked: @@ -1169,13 +1301,13 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline): return closed_jaxpr, final_consts, global_out_avals, attrs_tracked -@lru_cache(maxsize=4096) +@util.cache(max_size=4096, trace_context_in_key=False) def _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, out_layouts_leaves, out_tree, out_type, debug_info, device_or_backend_set): orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves) if (is_unspecified(orig_out_shardings) or - isinstance(orig_out_shardings, XLACompatibleSharding)): + isinstance(orig_out_shardings, sharding.Sharding)): out_shardings_flat = (orig_out_shardings,) * len(out_type) else: out_shardings_flat = flatten_axis_resources( @@ -1197,17 +1329,43 @@ def _check_and_canonicalize_out_shardings( return out_shardings_flat, out_layouts_flat -def _pjit_jaxpr(fun, out_shardings_treedef, out_shardings_leaves, - out_layouts_treedef, out_layouts_leaves, in_type, debug_info, - device_or_backend_set, out_tree, result_paths, inline): - jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr( - fun, in_type, debug_info, result_paths, IgnoreKey(inline)) - canonicalized_out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( - out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, - out_layouts_leaves, out_tree, tuple(out_type), - jaxpr.jaxpr.debug_info, device_or_backend_set) - return (jaxpr, final_consts, canonicalized_out_shardings_flat, - out_layouts_flat, attrs_tracked) +AttrRecord = tuple[object, str, PyTreeDef, list[core.AbstractValue]] +_seen_attrs = weakref.WeakKeyDictionary() # type: ignore + +def seen_attrs_get( + fun: lu.WrappedFun, + in_type: core.InputType | tuple[core.AbstractValue, ...] +) -> list: + cache = _seen_attrs.setdefault(fun.f, defaultdict(list)) + assert fun.in_type is None or fun.in_type == in_type + return cache[(fun.transforms, fun.params, in_type)] + +def _attr_token( + fun: lu.WrappedFun, + in_type: core.InputType | tuple[core.AbstractValue, ...] +) -> int: + from jax.experimental.attrs import jax_getattr + cases = seen_attrs_get(fun, in_type) + for i, records in enumerate(cases): + for obj, attr, treedef, avals in records: + val = jax_getattr(obj, attr) + vals, treedef_ = tree_flatten(val) + avals_ = map(shaped_abstractify, vals) + if treedef != treedef_ or avals != avals_: break + else: + return i + return len(cases) + +def _attr_update(fun, in_type, i, attrs_tracked): + from jax.experimental.attrs import jax_getattr + leaves = lambda obj, attr: tree_leaves(jax_getattr(obj, attr)) + records = [(obj, attr, init_tree, map(shaped_abstractify, leaves(obj, attr))) + for init_tree, _, (obj, attr) in attrs_tracked] + cases = seen_attrs_get(fun, in_type) + if i == len(cases): + cases.append(records) + else: + assert i < len(cases) and cases[i] == records @dataclasses.dataclass(frozen=True) @@ -1229,10 +1387,10 @@ def pjit_check_aval_sharding( name_str = f' with pytree key path {name}' if name else '' shape = aval.shape try: - # Sharding interfaces can implement `is_compatible_aval` as an optional + # Sharding interfaces can implement `check_compatible_aval` as an optional # method to raise a more meaningful error. - if hasattr(s, 'is_compatible_aval'): - s.is_compatible_aval(shape) + if hasattr(s, 'check_compatible_aval'): + s.check_compatible_aval(shape) else: s._to_xla_hlo_sharding(len(shape)) except ValueError as e: @@ -1241,7 +1399,7 @@ def pjit_check_aval_sharding( f'annotation {s}: {e}') # Use the `OpSharding` proto to find out how many ways each dimension of # the aval is sharded. This approach will work across all - # XLACompatibleSharding. + # Sharding. hlo_sharding = s._to_xla_hlo_sharding(len(shape)) assert hlo_sharding is not None num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded(hlo_sharding) @@ -1291,11 +1449,21 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): # `Layout(None, sharding)` if (committed and not is_pmap_sharding and arg_layout is not None and arg_layout != jit_in_l): + extra_msg = '' + if isinstance(jit_in_l, AutoLayout): + extra_msg = ( + ' The layout given to `jax.jit` is `DeviceLocalLayout.AUTO` but' + ' the corresponding argument passed is a `jax.Array` with a' + ' concrete layout. Consider passing a `jax.ShapeDtypeStruct`' + ' instead of `jax.Array` as an argument to the jitted function ' + ' when using `DeviceLocalLayout.AUTO`.' + ) raise ValueError('Layout passed to jit does not match the layout ' 'on the respective arg. ' f'Got pjit layout: {jit_in_l},\n' f'arg layout: {arg_layout} for ' - f'arg shape: {shaped_abstractify(arg).str_short()}') + f'arg shape: {shaped_abstractify(arg).str_short()}.' + f'{extra_msg}') resolved_in_layouts.append(jit_in_l) return tuple(resolved_in_layouts) @@ -1320,9 +1488,6 @@ def _resolve_in_shardings( # not allow None as the sharding. if arg_s is None: continue - if not isinstance(arg_s, XLACompatibleSharding): - raise ValueError(f'One of the argument to pjit got sharding {arg_s} ' - 'which is not a subclass of XLACompatibleSharding.') # Don't consider PmapSharding inputs as committed. They will get resharded # unconditionally. if isinstance(arg_s, PmapSharding): @@ -1410,7 +1575,7 @@ def _resolve_in_shardings( def _resolve_and_lower( args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, - lowering_parameters): + lowering_platforms, lowering_parameters, pgle_profiler): in_shardings = _resolve_in_shardings( args, in_shardings, out_shardings, resource_env.physical_mesh if resource_env is not None else None) @@ -1419,20 +1584,45 @@ def _resolve_and_lower( lowered = _pjit_lower( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, - lowering_parameters=lowering_parameters) + lowering_platforms=lowering_platforms, + lowering_parameters=lowering_parameters, + pgle_profiler=pgle_profiler) return lowered - def _pjit_call_impl_python( *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline): global _most_recent_pjit_call_executable + compile_options = None + pgle_profiler = None + pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict + if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: + if jaxpr not in pgle_profiler_dict: + pgle_profiler_dict[jaxpr] = profiler.PGLEProfiler( + config.pgle_profiling_runs.value, + config.pgle_aggregation_percentile.value) + + pgle_profiler = pgle_profiler_dict[jaxpr] + # The method below will return FDO profile when module was profiled + # config.jax_pgle_profiling_runs amount of times, otherwise the result will + # be None. + fdo_profile = pgle_profiler.consume_fdo_profile() + if fdo_profile is not None: + compile_options = {'fdo_profile': fdo_profile} + + # TODO(patrios): Do not pass mutable profile session through cached lowering + # chain. Instead we need to move profilers dictionary to pxla module and use + # module as key. Right now we can't do that since there is no way to evict _pjit_lower_cached cache for in PGLE mode. compiled = _resolve_and_lower( - args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, - in_layouts=in_layouts, out_layouts=out_layouts, resource_env=resource_env, + args, jaxpr=jaxpr, in_shardings=in_shardings, + out_shardings=out_shardings, in_layouts=in_layouts, + out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, - inline=inline, lowering_parameters=mlir.LoweringParameters()).compile() + inline=inline, lowering_platforms=None, + lowering_parameters=mlir.LoweringParameters(), + pgle_profiler=pgle_profiler + ).compile(compile_options) _most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled # This check is expensive so only do it if enable_checks is on. @@ -1508,10 +1698,11 @@ def call_impl_cache_miss(*args_, **kwargs_): out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, inline=inline) + pgle_profiler = _read_pgle_profiler(jaxpr) fastpath_data = _get_fastpath_data( compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, - jaxpr.consts, None) - return out_flat, fastpath_data + jaxpr.consts, None, pgle_profiler) + return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, @@ -1522,7 +1713,7 @@ def call_impl_cache_miss(*args_, **kwargs_): return xc._xla.pjit( name, f, call_impl_cache_miss, [], [], donated_argnums, tree_util.dispatch_registry, - pxla.shard_arg, + lambda x, sharding: pxla.shard_args([sharding], [x])[0], _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) @@ -1545,7 +1736,9 @@ def _pjit_lower_cached( keep_unused: bool, inline: bool, *, - lowering_parameters: mlir.LoweringParameters): + lowering_platforms: tuple[str, ...] | None, + lowering_parameters: mlir.LoweringParameters, + pgle_profiler: profiler.PGLEProfiler | None): if resource_env is not None: pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit") @@ -1564,6 +1757,7 @@ def _pjit_lower_cached( jaxpr, api_name, name, mesh, in_shardings, out_shardings, donated_invars, True, jaxpr.in_avals, tiling_method=None, + lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters) else: return pxla.lower_sharding_computation( @@ -1572,7 +1766,9 @@ def _pjit_lower_cached( keep_unused=keep_unused, inline=inline, devices_from_context=( None if mesh is None or mesh.empty else list(mesh.devices.flat)), - lowering_parameters=lowering_parameters) + lowering_platforms=lowering_platforms, + lowering_parameters=lowering_parameters, + pgle_profiler=pgle_profiler) def pjit_staging_rule(trace, *args, **params): @@ -1796,7 +1992,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name, pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None) def _pjit_batcher_for_sharding( - s: XLACompatibleSharding | UnspecifiedValue, + s: sharding.Sharding | UnspecifiedValue, dim: int, val: tuple[str, ...], mesh, ndim: int): if is_unspecified(s): return s @@ -2364,18 +2560,35 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size, axis_name, main_type, vals_in, dims_in, sharding, resource_env, unconstrained_dims): + if spmd_axis_name is not None and isinstance(sharding, NamedSharding): + used = {n for ns in sharding.spec + for n in (ns if isinstance(ns, tuple) else (ns,))} + if set(spmd_axis_name) & used: + raise ValueError(f"vmap spmd_axis_name {spmd_axis_name} cannot appear in " + "with_sharding_constraint spec, but got spec " + f"{sharding.spec}") x, = vals_in d, = dims_in # None means unconstrained in ParsedPartitionSpec new_parts = (axis_name,) if insert_axis else ( None if spmd_axis_name is None else spmd_axis_name) unconstrained_dims = {ud + (d <= ud) for ud in unconstrained_dims} + if new_parts is None: unconstrained_dims.add(d) + + vmapped_sharding = _pjit_batcher_for_sharding( + sharding, d, new_parts, resource_env.physical_mesh, x.ndim) + if unconstrained_dims and isinstance(vmapped_sharding, NamedSharding): + new_spec = list(vmapped_sharding.spec) + [None] * (x.ndim - len(vmapped_sharding.spec)) + for u in unconstrained_dims: + new_spec[u] = PartitionSpec.UNCONSTRAINED + vmapped_sharding = NamedSharding( + vmapped_sharding.mesh, PartitionSpec(*new_spec)) + y = sharding_constraint_p.bind( x, - sharding=_pjit_batcher_for_sharding( - sharding, d, new_parts, resource_env.physical_mesh, x.ndim), + sharding=vmapped_sharding, resource_env=resource_env, unconstrained_dims=unconstrained_dims) return y, d diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index 0614bb8a8d9b..5c1e7e1198e8 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -42,7 +42,7 @@ colorama = None -_PPRINT_USE_COLOR = config.DEFINE_bool( +_PPRINT_USE_COLOR = config.bool_flag( 'jax_pprint_use_color', config.bool_env('JAX_PPRINT_USE_COLOR', True), help='Enable jaxpr pretty-printing with colorful syntax highlighting.' diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 57038e4630b7..d585b312fafc 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence from functools import partial, reduce import math import operator as op -from typing import Any, Callable, NamedTuple +from typing import Any, NamedTuple import numpy as np @@ -33,11 +33,9 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import pretty_printer as pp -from jax._src import sharding_specs from jax._src import source_info_util from jax._src import tree_util as tree_util_internal from jax._src import typing -from jax._src import op_shardings from jax._src.api import jit, vmap from jax._src.dtypes import float0 from jax._src.interpreters import ad @@ -53,9 +51,8 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.array_methods import ( _array_operators, _set_array_base_attributes, _IndexUpdateHelper) -from jax._src.partition_spec import PartitionSpec from jax._src.sharding_impls import ( - NamedSharding, PmapSharding, GSPMDSharding, XLACompatibleSharding) + NamedSharding, PmapSharding, physical_sharding, logical_sharding) from jax._src.typing import Array from jax._src.util import safe_map, safe_zip @@ -235,8 +232,7 @@ def global_shards(self) -> list[Shard]: @property def sharding(self): - phys_sharding = self._base_array.sharding - return KeyTyRules.logical_sharding(self.aval, phys_sharding) + return logical_sharding(self.aval, self._base_array.sharding) def _is_scalar(self): base_ndim = len(self._impl.key_shape) @@ -324,53 +320,6 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape): base_ndim = len(impl.key_shape) return base_arr_shape[:-base_ndim] -def make_key_array_phys_sharding(aval, sharding): - if dispatch.is_single_device_sharding(sharding): - return sharding - elif isinstance(sharding, PmapSharding): - key_shape = aval.dtype._impl.key_shape - trailing_sharding = [sharding_specs.NoSharding()] * len(key_shape) - phys_sharding_spec = sharding_specs.ShardingSpec( - sharding=(*sharding.sharding_spec.sharding, *trailing_sharding), - mesh_mapping=sharding.sharding_spec.mesh_mapping) - return PmapSharding(devices=sharding.devices, - sharding_spec=phys_sharding_spec) - elif isinstance(sharding, NamedSharding): - key_shape = aval.dtype._impl.key_shape - trailing_spec = [None] * len(key_shape) - return NamedSharding( - sharding.mesh, - PartitionSpec(*sharding.spec, *trailing_spec)) - else: - hlos = sharding._to_xla_hlo_sharding(aval.ndim) - return GSPMDSharding( - sharding._device_assignment, physical_hlo_sharding(aval, hlos)) - - -def get_logical_gspmd_sharding(aval, phys_sharding): - key_shape = aval.dtype._impl.key_shape - phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( - aval.ndim + len(key_shape)) - partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( - phys_hlo_sharding) - suffix = [] if num_replicas == 1 else [num_replicas] - # Create logical sharding by cutting off the replicated trailing dims. - logical_op_sharding = phys_hlo_sharding.to_proto().clone() - tad = partitions[:-len(key_shape)] + suffix - logical_op_sharding.tile_assignment_dimensions = tad - return GSPMDSharding(phys_sharding._device_assignment, - xc.HloSharding.from_proto(logical_op_sharding)) - - -def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: - key_shape = aval.dtype._impl.key_shape - new_op_sharding = hlo_sharding.to_proto().clone() - partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(hlo_sharding) - suffix = [] if num_replicas == 1 else [num_replicas] - tad = partitions + [1] * len(key_shape) + suffix - new_op_sharding.tile_assignment_dimensions = tad - return xc.HloSharding.from_proto(new_op_sharding) - class KeyTyRules: @@ -393,32 +342,6 @@ def physical_element_aval(dtype) -> core.ShapedArray: def physical_const(val) -> Array: return val._base_array - @staticmethod - def physical_sharding( - aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding: - return make_key_array_phys_sharding(aval, sharding) - - @staticmethod - def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding: - # The trailing dims should always be replicated. - aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval) - - if dispatch.is_single_device_sharding(phys_sharding): - return phys_sharding - elif isinstance(phys_sharding, PmapSharding): - key_shape = aval.dtype._impl.key_shape - logical_sharding_spec = sharding_specs.ShardingSpec( - sharding=phys_sharding.sharding_spec.sharding[:-len(key_shape)], - mesh_mapping=phys_sharding.sharding_spec.mesh_mapping) - return PmapSharding(devices=phys_sharding.devices, - sharding_spec=logical_sharding_spec) - elif isinstance(phys_sharding, NamedSharding): - logical_gs = get_logical_gspmd_sharding(aval, phys_sharding) - return pxla._gspmd_to_named_sharding_via_mesh( - logical_gs, phys_sharding.mesh) - else: - return get_logical_gspmd_sharding(aval, phys_sharding) - @staticmethod def result_handler(sticky_device, aval): def handler(_, buf): @@ -434,7 +357,7 @@ def local_sharded_result_handler(aval, sharding, indices): # set up a grounded sharding (with a grounded sharding spec) if isinstance(sharding, (PmapSharding, NamedSharding)): - phys_sharding = make_key_array_phys_sharding(aval, sharding) + phys_sharding = physical_sharding(aval, sharding) else: assert False, f'impossible sharding {sharding} in local sharded result handler' @@ -456,7 +379,7 @@ def global_sharded_result_handler(aval, out_sharding, committed): phys_aval = core.physical_aval(aval) phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] - phys_sharding = make_key_array_phys_sharding(aval, out_sharding) + phys_sharding = physical_sharding(aval, out_sharding) phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) def handler(bufs): return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs)) @@ -468,7 +391,7 @@ def make_sharded_array(aval, sharding, arrays, committed): phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] phys_arrays = [random_unwrap(arr) for arr in arrays] - phys_sharding = make_key_array_phys_sharding(aval, sharding) + phys_sharding = physical_sharding(aval, sharding) phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) phys_result = phys_handler(phys_arrays) return PRNGKeyArray(aval.dtype._impl, phys_result) @@ -477,8 +400,9 @@ def make_sharded_array(aval, sharding, arrays, committed): def device_put_sharded(vals, aval, sharding, devices): physical_aval = core.physical_aval(aval) physical_buffers = tree_util.tree_map(random_unwrap, vals) - physical_sharding = make_key_array_phys_sharding(aval, sharding) - physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices)) + phys_sharding = physical_sharding(aval, sharding) + physical_result = pxla.batched_device_put(physical_aval, phys_sharding, + physical_buffers, list(devices)) return random_wrap(physical_result, impl=aval.dtype._impl) @staticmethod @@ -486,37 +410,11 @@ def device_put_replicated(val, aval, sharding, devices): physical_aval = core.physical_aval(aval) assert len(xla.aval_to_xla_shapes(physical_aval)) == 1 physical_buf = random_unwrap(val) - physical_sharding = make_key_array_phys_sharding(aval, sharding) - physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices) + phys_sharding = physical_sharding(aval, sharding) + physical_result = pxla.batched_device_put( + physical_aval, phys_sharding, [physical_buf] * len(devices), devices) return random_wrap(physical_result, impl=aval.dtype._impl) - @staticmethod - def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval): - if isinstance(sharding, PmapSharding): - return - phys_aval = core.physical_aval(aval) - hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim) - partitions, _ = op_shardings.get_num_ways_dim_sharded(hlo_s) - num_trailing_dims = phys_aval.ndim - aval.ndim - if not all(i == 1 for i in partitions[-num_trailing_dims:]): - raise AssertionError( - "The trailing dims of extended dtypes should be replicated. Got" - f" sharding: {sharding}, partitions: {partitions}, " - f"num_trailing_dims: {num_trailing_dims}") - - @staticmethod - def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: - # Set the sharding of extended dtypes to be UNCONSTRAINED - # (i.e. XLA will choose) on aval.shape. - # For the trailing dims i.e. the dimension of key_shape on the base_array, - # the sharding is set to be REPLICATED always. - # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), - # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). - # The below custom call achieves the sharding like above example. - return mlir.wrap_with_sharding_op( - ctx, val, aval, xc.HloSharding.replicate().to_proto(), - unspecified_dims=set(range(aval.ndim))) - @staticmethod def tangent_dtype(_): return dtypes.float0 @@ -569,10 +467,11 @@ def __hash__(self) -> int: xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x -def key_array_shard_arg_handler(x: PRNGKeyArray, sharding): - arr = x._base_array - phys_sharding = make_key_array_phys_sharding(x.aval, sharding) - return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding) +def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings): + arrs = [x._base_array for x in xs] + phys_shardings = [physical_sharding(x.aval, sharding) + for x, sharding in zip(xs, shardings)] + return pxla.shard_args(phys_shardings, arrs) pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index f802ea7974cb..cad4826ba801 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections.abc import Callable from contextlib import contextmanager from functools import wraps import glob @@ -24,7 +25,7 @@ import os import socketserver import threading -from typing import Callable, Union +from typing import Any from jax._src import traceback_util traceback_util.register_exclusion(__file__) @@ -210,7 +211,7 @@ def stop_trace(): _profile_state.reset() -def stop_and_get_fdo_profile() -> Union[bytes, str]: +def stop_and_get_fdo_profile() -> bytes | str: """Stops the currently-running profiler trace and export fdo_profile. Currently, this is only supported for GPU. @@ -380,3 +381,62 @@ def save_device_memory_profile(filename, backend: str | None = None) -> None: profile = device_memory_profile(backend) with open(filename, "wb") as f: f.write(profile) + + +# Allows to run model with profiler given amount of times. After required amount +# of retries achived client can collect FDO data. +class PGLEProfiler: + + def __init__(self, retries: int, percentile: int): + self.retries: int = retries + self.percentile: int = percentile + self.collected_fdo: str | None = None + self.called_times: int = 0 + self.fdo_profiles: list[Any] = [] + self.current_session: xla_client.profiler.ProfilerSession | None = None + + def consume_fdo_profile(self) -> str | None: + if self.collected_fdo is not None: + return self.collected_fdo + + if not self.is_enabled() or self.called_times != self.retries: + return None + + self.collected_fdo = xla_client.profiler.aggregate_profiled_instructions( + self.fdo_profiles, self.percentile + ) + return self.collected_fdo + + def is_fdo_consumed(self): + return self.collected_fdo is not None + + def disable(self): + self.retries = 0 + + def is_enabled(self): + return self.retries > 0 + + def is_running(self): + return self.current_session is not None + + @classmethod + @contextmanager + def trace(cls, runner: PGLEProfiler | None): + if (runner is None or runner.is_running() + or not runner.is_enabled() or runner.is_fdo_consumed()): + yield + else: + options = xla_client.profiler.ProfileOptions() + options.enable_hlo_proto = True + runner.current_session = xla_client.profiler.ProfilerSession(options) + + try: + yield + finally: + xspace = runner.current_session.stop() + runner.fdo_profiles.append( + xla_client.profiler.get_fdo_profile(xspace) + ) + runner.current_session = None + + runner.called_times += 1 diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index c7cb4ee20e1b..bc18226f07a7 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -104,6 +104,15 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) + # TODO(reedwm): Upcasting int4 to int8 will no longer be neccessary once + # ml_dtypes has a stable release with commit + # https://github.com/jax-ml/ml_dtypes/commit/348fd3704306cae97f617c38045cee6bc416bf10. + # Remove these checks once JAX depends on a version on ml_dtypes with that + # commit. + if x.dtype == _dtypes.int4: + return x.astype(np.int8) + if x.dtype == _dtypes.uint4: + return x.astype(np.uint8) return x a = maybe_upcast(a) diff --git a/jax/_src/random.py b/jax/_src/random.py index f0ea398e77e2..6a0a3c0f9932 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -403,8 +403,10 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array: finfo = jnp.finfo(dtype) nbits, nmant = finfo.bits, finfo.nmant - if nbits not in (16, 32, 64): - raise TypeError(f"uniform only accepts 16-, 32-, or 64-bit dtypes, got {dtype}.") + if nbits not in (8, 16, 32, 64): + raise TypeError( + f"uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot {dtype}." + ) rng_bits = nbits if nmant < 8: @@ -1860,7 +1862,7 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array: def rademacher(key: KeyArrayLike, - shape: Shape, + shape: Shape = (), dtype: DTypeLikeInt = int) -> Array: r"""Sample from a Rademacher distribution. @@ -1873,7 +1875,7 @@ def rademacher(key: KeyArrayLike, Args: key: a PRNG key. - shape: The shape of the returned samples. + shape: The shape of the returned samples. Default (). dtype: The type used for samples. Returns: @@ -2354,7 +2356,6 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array: return tri - def lognormal(key: KeyArrayLike, sigma: RealArray = np.float32(1), shape: Shape | None = None, @@ -2619,7 +2620,7 @@ def clone(key): Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`) this function operates as an identity. - Example: + Examples: >>> import jax >>> key = jax.random.key(0) diff --git a/jax/_src/scipy/fft.py b/jax/_src/scipy/fft.py index 63f1d1a552d7..f1d907cf3f3b 100644 --- a/jax/_src/scipy/fft.py +++ b/jax/_src/scipy/fft.py @@ -64,6 +64,37 @@ def dct(x: Array, type: int = 2, n: int | None = None, - :func:`jax.scipy.fft.dctn`: multidimensional DCT - :func:`jax.scipy.fft.idct`: inverse DCT - :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT + + Examples: + >>> x = jax.random.normal(jax.random.key(0), (3, 3)) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dct(x)) + [[-0.58 -0.33 -1.08] + [-0.88 -1.01 -1.79] + [-1.06 -2.43 1.24]] + + When ``n`` smaller than ``x.shape[axis]`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dct(x, n=2)) + [[-0.22 -0.9 ] + [-0.57 -1.68] + [-2.52 -0.11]] + + When ``n`` smaller than ``x.shape[axis]`` and ``axis=0`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dct(x, n=2, axis=0)) + [[-2.22 1.43 -0.67] + [ 0.52 -0.26 -0.04]] + + When ``n`` larger than ``x.shape[axis]`` and ``axis=1`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dct(x, n=4, axis=1)) + [[-0.58 -0.35 -0.64 -1.11] + [-0.88 -0.9 -1.46 -1.68] + [-1.06 -2.25 -1.15 1.93]] """ if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') @@ -125,6 +156,43 @@ def dctn(x: Array, type: int = 2, - :func:`jax.scipy.fft.dct`: one-dimensional DCT - :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT - :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT + + Examples: + + ``jax.scipy.fft.dctn`` computes the transform along both the axes by default + when ``axes`` argument is ``None``. + + >>> x = jax.random.normal(jax.random.key(0), (3, 3)) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dctn(x)) + [[-5.04 -7.54 -3.26] + [ 0.83 3.64 -4.03] + [ 0.12 -0.73 3.74]] + + When ``s=[2]``, dimension of the transform along ``axis 0`` will be ``2`` + and dimension along ``axis 1`` will be same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dctn(x, s=[2])) + [[-2.92 -2.68 -5.74] + [ 0.42 0.97 1. ]] + + When ``s=[2]`` and ``axes=[1]``, dimension of the transform along ``axis 1`` will + be ``2`` and dimension along ``axis 0`` will be same as that of input. + Also when ``axes=[1]``, transform will be computed only along ``axis 1``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dctn(x, s=[2], axes=[1])) + [[-0.22 -0.9 ] + [-0.57 -1.68] + [-2.52 -0.11]] + + When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.dctn(x, s=[2, 4])) + [[-2.92 -2.49 -4.21 -5.57] + [ 0.42 0.79 1.16 0.8 ]] """ if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') @@ -171,6 +239,46 @@ def idct(x: Array, type: int = 2, n: int | None = None, - :func:`jax.scipy.fft.dct`: DCT - :func:`jax.scipy.fft.dctn`: multidimensional DCT - :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT + + Examples: + + >>> x = jax.random.normal(jax.random.key(0), (3, 3)) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idct(x)) + [[-0.02 -0. -0.17] + [-0.02 -0.07 -0.28] + [-0.16 -0.36 0.18]] + + When ``n`` smaller than ``x.shape[axis]`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idct(x, n=2)) + [[ 0. -0.19] + [-0.03 -0.34] + [-0.38 0.04]] + + When ``n`` smaller than ``x.shape[axis]`` and ``axis=0`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idct(x, n=2, axis=0)) + [[-0.35 0.23 -0.1 ] + [ 0.17 -0.09 0.01]] + + When ``n`` larger than ``x.shape[axis]`` and ``axis=0`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idct(x, n=4, axis=0)) + [[-0.34 0.03 0.07] + [ 0. 0.18 -0.17] + [ 0.14 0.09 -0.14] + [ 0. -0.18 0.14]] + + ``jax.scipy.fft.idct`` can be used to reconstruct ``x`` from the result + of ``jax.scipy.fft.dct`` + + >>> x_dct = jax.scipy.fft.dct(x) + >>> jnp.allclose(x, jax.scipy.fft.idct(x_dct)) + Array(True, dtype=bool) """ if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') @@ -223,6 +331,50 @@ def idctn(x: Array, type: int = 2, - :func:`jax.scipy.fft.dct`: one-dimensional DCT - :func:`jax.scipy.fft.dctn`: multidimensional DCT - :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT + + Examples: + + ``jax.scipy.fft.idctn`` computes the transform along both the axes by default + when ``axes`` argument is ``None``. + + >>> x = jax.random.normal(jax.random.key(0), (3, 3)) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idctn(x)) + [[-0.03 -0.08 -0.08] + [ 0.05 0.12 -0.09] + [-0.02 -0.04 0.08]] + + When ``s=[2]``, dimension of the transform along ``axis 0`` will be ``2`` + and dimension along ``axis 1`` will be the same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idctn(x, s=[2])) + [[-0.01 -0.03 -0.14] + [ 0. 0.03 0.06]] + + When ``s=[2]`` and ``axes=[1]``, dimension of the transform along ``axis 1`` will + be ``2`` and dimension along ``axis 0`` will be same as that of input. + Also when ``axes=[1]``, transform will be computed only along ``axis 1``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idctn(x, s=[2], axes=[1])) + [[ 0. -0.19] + [-0.03 -0.34] + [-0.38 0.04]] + + When ``s=[2, 4]``, shape of the transform will be ``(2, 4)`` + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jax.scipy.fft.idctn(x, s=[2, 4])) + [[-0.01 -0.01 -0.05 -0.11] + [ 0. 0.01 0.03 0.04]] + + ``jax.scipy.fft.idctn`` can be used to reconstruct ``x`` from the result + of ``jax.scipy.fft.dctn`` + + >>> x_dctn = jax.scipy.fft.dctn(x) + >>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn)) + Array(True, dtype=bool) """ if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index d064f28a5d81..5458d71dedf4 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -78,7 +78,7 @@ def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, - :func:`jax.scipy.linalg.cho_factor` - :func:`jax.scipy.linalg.cho_solve` - Example: + Examples: A small real Hermitian positive-definite matrix: >>> x = jnp.array([[2., 1.], @@ -130,7 +130,7 @@ def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, - :func:`jax.scipy.linalg.cholesky` - :func:`jax.scipy.linalg.cho_solve` - Example: + Examples: A small real Hermitian positive-definite matrix: >>> x = jnp.array([[2., 1.], @@ -174,7 +174,7 @@ def cho_solve(c_and_lower: tuple[ArrayLike, bool], b: ArrayLike, Args: c_and_lower: ``(c, lower)``, where ``c`` is an array of shape ``(..., N, N)`` representing the lower or upper cholesky decomposition of the matrix, and - ``lower`` is a boolean specifying whethe this is the lower or upper decomposition. + ``lower`` is a boolean specifying whether this is the lower or upper decomposition. b: right-hand-side of linear system. Must have shape ``(..., N)`` overwrite_a: unused by JAX check_finite: unused by JAX @@ -186,7 +186,7 @@ def cho_solve(c_and_lower: tuple[ArrayLike, bool], b: ArrayLike, - :func:`jax.scipy.linalg.cholesky` - :func:`jax.scipy.linalg.cho_factor` - Example: + Examples: A small real Hermitian positive-definite matrix: >>> x = jnp.array([[2., 1.], @@ -288,7 +288,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, - :func:`jax.numpy.linalg.svd`: NumPy-style SVD API - :func:`jax.lax.linalg.svd`: XLA-style SVD API - Example: + Examples: Consider the SVD of a small real-valued array: >>> x = jnp.array([[1., 2., 3.], @@ -415,7 +415,7 @@ def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: """Compute eigenvalues and eigenvectors for a Hermitian matrix - JAX implementation of :func:`jax.scipy.linalg.eigh`. + JAX implementation of :func:`scipy.linalg.eigh`. Args: a: Hermitian input array of shape ``(..., N, N)`` @@ -513,10 +513,10 @@ def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]: transformation matrix. See also: - - :func:`jax.scipy.linalg.rsf2csf`: conver real Schur form to complex Schur form. + - :func:`jax.scipy.linalg.rsf2csf`: convert real Schur form to complex Schur form. - :func:`jax.lax.linalg.schur`: XLA-style API for Schur decomposition. - Example: + Examples: A Schur decomposition of a 3x3 matrix: >>> a = jnp.array([[1., 2., 3.], @@ -570,7 +570,7 @@ def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> A - :func:`jax.numpy.linalg.inv`: NumPy-style API for matrix inverse - :func:`jax.scipy.linalg.solve`: direct linear solver - Example: + Examples: Compute the inverse of a 3x3 matrix >>> a = jnp.array([[1., 2., 3.], @@ -630,7 +630,7 @@ def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True - :func:`jax.scipy.linalg.lu` - :func:`jax.scipy.linalg.lu_solve` - Example: + Examples: Solving a small linear system via LU factorization: >>> a = jnp.array([[2., 1.], @@ -686,7 +686,7 @@ def lu_solve(lu_and_piv: tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0, - :func:`jax.scipy.linalg.lu` - :func:`jax.scipy.linalg.lu_factor` - Example: + Examples: Solving a small linear system via LU factorization: >>> a = jnp.array([[2., 1.], @@ -784,7 +784,7 @@ def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, - :func:`jax.lax.linalg.lu`: XLA-style API for LU decomposition. - :func:`jax.scipy.linalg.lu_solve`: LU-based linear solver. - Example: + Examples: An LU decomposition of a 3x3 matrix: >>> a = jnp.array([[1., 2., 3.], @@ -888,7 +888,7 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = " - ``"economic"``: return `Q` of shape ``(M, K)`` and `R` of shape ``(K, N)``, where K = min(M, N). - pivoting: Not implemened in JAX. + pivoting: Not implemented in JAX. overwrite_a: unused in JAX lwork: unused in JAX check_finite: unused in JAX @@ -906,7 +906,7 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = " See also: - :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API - - :func:`jax.lax.linalg.qr`: XLA-style QR decompositon API + - :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API Examples: Compute the QR decomposition of a matrix: @@ -999,7 +999,7 @@ def solve(a: ArrayLike, b: ArrayLike, lower: bool = False, - :func:`jax.numpy.linalg.solve`: NumPy-style API for solving linear systems. - :func:`jax.lax.custom_linear_solve`: matrix-free linear solver. - Example: + Examples: A simple 3x3 linear system: >>> A = jnp.array([[1., 2., 3.], @@ -1083,7 +1083,7 @@ def solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str = 0, lower: bo See also: :func:`jax.scipy.linalg.solve`: Solve a general linear system. - Example: + Examples: A simple 3x3 triangular linear system: >>> A = jnp.array([[1., 2., 3.], @@ -1105,7 +1105,7 @@ def solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str = 0, lower: bo >>> x Array([10. , -4. , -3.4], dtype=float32) - Confiriming that the result solves the system: + Confirming that the result solves the system: >>> jnp.allclose(A.T @ x, b) Array(True, dtype=bool) @@ -1137,6 +1137,32 @@ def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 1 See Also: :func:`jax.scipy.linalg.expm_frechet` + + Examples: + + ``expm`` is the matrix exponential, and has similar properties to the more + familiar scalar exponential. For scalars ``a`` and ``b``, :math:`e^{a + b} + = e^a e^b`. However, for matrices, this property only holds when ``A`` and + ``B`` commute (``AB = BA``). In this case, ``expm(A+B) = expm(A) @ expm(B)`` + + >>> A = jnp.array([[2, 0], + ... [0, 1]]) + >>> B = jnp.array([[3, 0], + ... [0, 4]]) + >>> jnp.allclose(jax.scipy.linalg.expm(A+B), + ... jax.scipy.linalg.expm(A) @ jax.scipy.linalg.expm(B), + ... rtol=0.0001) + Array(True, dtype=bool) + + If a matrix ``X`` is invertible, then + ``expm(X @ A @ inv(X)) = X @ expm(A) @ inv(X)`` + + >>> X = jnp.array([[3, 1], + ... [2, 5]]) + >>> X_inv = jax.scipy.linalg.inv(X) + >>> jnp.allclose(jax.scipy.linalg.expm(X @ A @ X_inv), + ... X @ jax.scipy.linalg.expm(A) @ X_inv) + Array(True, dtype=bool) """ A, = promote_dtypes_inexact(A) @@ -1358,7 +1384,7 @@ def block_diag(*arrs: ArrayLike) -> Array: 2D block-diagonal array constructed by placing the input arrays along the diagonal. - Example: + Examples: >>> A = jnp.ones((1, 1)) >>> B = jnp.ones((2, 2)) >>> C = jnp.ones((3, 3)) @@ -1642,6 +1668,36 @@ def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: float ``posdef`` is either :math:`n \times n` or :math:`m \times m` depending on whether ``side`` is ``"right"`` or ``"left"``, respectively. + Examples: + + Polar decomposition of a 3x3 matrix: + + >>> a = jnp.array([[1., 2., 3.], + ... [5., 4., 2.], + ... [3., 2., 1.]]) + >>> U, P = jax.scipy.linalg.polar(a) + + U is a Unitary Matrix: + + >>> jnp.round(U.T @ U) + Array([[ 1., -0., -0.], + [-0., 1., 0.], + [-0., 0., 1.]], dtype=float32) + + P is positive-semidefinite Matrix: + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(P) + [[4.79 3.25 1.23] + [3.25 3.06 2.01] + [1.23 2.01 2.91]] + + The original matrix can be reconstructed by multiplying the U and P: + + >>> a_reconstructed = U @ P + >>> jnp.allclose(a, a_reconstructed) + Array(True, dtype=bool) + .. _QDWH: https://epubs.siam.org/doi/abs/10.1137/090774999 """ arr = jnp.asarray(a) @@ -1730,7 +1786,7 @@ def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array: See Also: :func:`jax.scipy.linalg.expm` - Example: + Examples: >>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) @@ -1780,7 +1836,7 @@ def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Arra See Also: :func:`jax.scipy.linalg.schur`: Schur decomposition - Example: + Examples: >>> A = jnp.array([[0., 3., 3.], ... [0., 1., 2.], ... [2., 0., 1.]]) @@ -1803,7 +1859,7 @@ def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Arra [ 0. -0.88 -0.35] [ 0. 2.37 -0.88]] - By contrast, the complex form is truely upper-triangular: + By contrast, the complex form is truly upper-triangular: >>> with jnp.printoptions(precision=2, suppress=True): ... print(Tc) @@ -1905,7 +1961,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, - ``H`` has shape ``(..., N, N)`` and is the Hessenberg form of ``a`` - ``Q`` has shape ``(..., N, N)`` and is the associated unitary matrix - Example: + Examples: Computing the Hessenberg form of a 4x4 matrix >>> a = jnp.array([[1., 2., 3., 4.], diff --git a/jax/_src/scipy/ndimage.py b/jax/_src/scipy/ndimage.py index 4445a6130d06..d81008308b94 100644 --- a/jax/_src/scipy/ndimage.py +++ b/jax/_src/scipy/ndimage.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import itertools import operator -from typing import Callable from jax._src import api from jax._src import util @@ -151,8 +150,13 @@ def map_coordinates( * 1: Linear mode: Points outside the boundaries of the input are filled according to the given mode. - JAX supports one of ``('constant', 'nearest', 'mirror', 'wrap', 'reflect')``. - Default is 'constant'. + JAX supports one of ``('constant', 'nearest', 'mirror', 'wrap', 'reflect')``. Note the + ``'wrap'`` mode in JAX behaves as ``'grid-wrap'`` mode in SciPy, and ``'constant'`` + mode in JAX behaves as ``'grid-constant'`` mode in SciPy. This discrepancy was caused + by a former bug in those modes in SciPy (https://github.com/scipy/scipy/issues/2640), + which was first fixed in JAX by changing the behavior of the existing modes, and later + on fixed in SciPy, by adding modes with new names, rather than fixing the existing + ones, for backwards compatibility reasons. Default is 'constant'. cval: Value used for points outside the boundaries of the input if ``mode='constant'`` Default is 0.0. diff --git a/jax/_src/scipy/optimize/_lbfgs.py b/jax/_src/scipy/optimize/_lbfgs.py index 4df9647debe5..aa82ab4fd0c8 100644 --- a/jax/_src/scipy/optimize/_lbfgs.py +++ b/jax/_src/scipy/optimize/_lbfgs.py @@ -15,8 +15,9 @@ from __future__ import annotations -from typing import Callable, NamedTuple +from collections.abc import Callable from functools import partial +from typing import NamedTuple import jax import jax.numpy as jnp diff --git a/jax/_src/scipy/optimize/bfgs.py b/jax/_src/scipy/optimize/bfgs.py index b6fd9f9dda17..657b7610e6e1 100644 --- a/jax/_src/scipy/optimize/bfgs.py +++ b/jax/_src/scipy/optimize/bfgs.py @@ -15,8 +15,9 @@ from __future__ import annotations +from collections.abc import Callable from functools import partial -from typing import Callable, NamedTuple +from typing import NamedTuple import jax import jax.numpy as jnp diff --git a/jax/_src/scipy/optimize/line_search.py b/jax/_src/scipy/optimize/line_search.py index 078d23d97a96..189009693cdd 100644 --- a/jax/_src/scipy/optimize/line_search.py +++ b/jax/_src/scipy/optimize/line_search.py @@ -118,7 +118,7 @@ def body(state): # This will cause the line search to stop, and since the Wolfe conditions # are not satisfied the minimization should stop too. - threshold = jnp.where((jnp.finfo(dalpha).bits < 64), 1e-5, 1e-10) + threshold = jnp.where((jnp.finfo(dalpha.dtype).bits < 64), 1e-5, 1e-10) state = state._replace(failed=state.failed | (dalpha <= threshold)) # Cubmin is sometimes nan, though in this case the bounds check will fail. diff --git a/jax/_src/scipy/optimize/minimize.py b/jax/_src/scipy/optimize/minimize.py index 830f1228424a..4fc006be6df0 100644 --- a/jax/_src/scipy/optimize/minimize.py +++ b/jax/_src/scipy/optimize/minimize.py @@ -14,8 +14,8 @@ from __future__ import annotations -from collections.abc import Mapping -from typing import Any, Callable +from collections.abc import Callable, Mapping +from typing import Any import jax from jax._src.scipy.optimize.bfgs import minimize_bfgs diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index ae093a3ac3c1..1282650ae1e5 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -14,11 +14,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import math import operator -from typing import Callable import warnings import numpy as np @@ -188,7 +187,7 @@ def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto', precision: PrecisionLike = None) -> Array: """Convolution of two N-dimensional arrays. - JAX implementation of :func:`jax.scipy.signal.convolve`. + JAX implementation of :func:`scipy.signal.convolve`. Args: in1: left-hand input to the convolution. @@ -253,7 +252,7 @@ def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill fillvalue: float = 0, precision: PrecisionLike = None) -> Array: """Convolution of two 2-dimensional arrays. - JAX implementation of :func:`jax.scipy.signal.convolve2d`. + JAX implementation of :func:`scipy.signal.convolve2d`. Args: in1: left-hand input to the convolution. Must have ``in1.ndim == 2``. @@ -284,6 +283,37 @@ def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill - :func:`jax.numpy.convolve`: 1D convolution - :func:`jax.scipy.signal.convolve`: ND convolution - :func:`jax.scipy.signal.correlate`: ND correlation + + Examples: + A few 2D convolution examples: + + >>> x = jnp.array([[1, 2], + ... [3, 4]]) + >>> y = jnp.array([[2, 1, 1], + ... [4, 3, 4], + ... [1, 3, 2]]) + + Full 2D convolution uses implicit zero-padding at the edges: + + >>> jax.scipy.signal.convolve2d(x, y, mode='full') + Array([[ 2., 5., 3., 2.], + [10., 22., 17., 12.], + [13., 30., 32., 20.], + [ 3., 13., 18., 8.]], dtype=float32) + + Specifying ``mode = 'same'`` returns a centered 2D convolution of the same size + as the first input: + + >>> jax.scipy.signal.convolve2d(x, y, mode='same') + Array([[22., 17.], + [30., 32.]], dtype=float32) + + Specifying ``mode = 'valid'`` returns only the portion of 2D convolution + where the two arrays fully overlap: + + >>> jax.scipy.signal.convolve2d(x, y, mode='valid') + Array([[22., 17.], + [30., 32.]], dtype=float32) """ if boundary != 'fill' or fillvalue != 0: raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0") @@ -296,7 +326,7 @@ def correlate(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto', precision: PrecisionLike = None) -> Array: """Cross-correlation of two N-dimensional arrays. - JAX implementation of :func:`jax.scipy.signal.correlate`. + JAX implementation of :func:`scipy.signal.correlate`. Args: in1: left-hand input to the cross-correlation. @@ -325,6 +355,29 @@ def correlate(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto', - :func:`jax.numpy.correlate`: 1D cross-correlation - :func:`jax.scipy.signal.correlate2d`: 2D cross-correlation - :func:`jax.scipy.signal.convolve`: ND convolution + + Examples: + A few 1D correlation examples: + + >>> x = jnp.array([1, 2, 3, 2, 1]) + >>> y = jnp.array([1, 3, 2]) + + Full 1D correlation uses implicit zero-padding at the edges: + + >>> jax.scipy.signal.correlate(x, y, mode='full') + Array([ 2., 7., 13., 15., 11., 5., 1.], dtype=float32) + + Specifying ``mode = 'same'`` returns a centered 1D correlation of the same + size as the first input: + + >>> jax.scipy.signal.correlate(x, y, mode='same') + Array([ 7., 13., 15., 11., 5.], dtype=float32) + + Specifying ``mode = 'valid'`` returns only the portion of 1D correlation + where the two arrays fully overlap: + + >>> jax.scipy.signal.correlate(x, y, mode='valid') + Array([13., 15., 11.], dtype=float32) """ return convolve(in1, jnp.flip(in2.conj()), mode, precision=precision, method=method) @@ -333,7 +386,7 @@ def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fil fillvalue: float = 0, precision: PrecisionLike = None) -> Array: """Cross-correlation of two 2-dimensional arrays. - JAX implementation of :func:`jax.scipy.signal.correlate2d`. + JAX implementation of :func:`scipy.signal.correlate2d`. Args: in1: left-hand input to the cross-correlation. Must have ``in1.ndim == 2``. @@ -364,6 +417,38 @@ def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fil - :func:`jax.numpy.correlate`: 1D cross-correlation - :func:`jax.scipy.signal.correlate`: ND cross-correlation - :func:`jax.scipy.signal.convolve`: ND convolution + + Examples: + A few 2D correlation examples: + + >>> x = jnp.array([[2, 1, 3], + ... [1, 3, 1], + ... [4, 1, 2]]) + >>> y = jnp.array([[1, 3], + ... [4, 2]]) + + Full 2D correlation uses implicit zero-padding at the edges: + + >>> jax.scipy.signal.correlate2d(x, y, mode='full') + Array([[ 4., 10., 10., 12.], + [ 8., 15., 24., 7.], + [11., 28., 14., 9.], + [12., 7., 7., 2.]], dtype=float32) + + Specifying ``mode = 'same'`` returns a centered 2D correlation of the same + size as the first input: + + >>> jax.scipy.signal.correlate2d(x, y, mode='same') + Array([[15., 24., 7.], + [28., 14., 9.], + [ 7., 7., 2.]], dtype=float32) + + Specifying ``mode = 'valid'`` returns only the portion of 2D correlation + where the two arrays fully overlap: + + >>> jax.scipy.signal.correlate2d(x, y, mode='valid') + Array([[15., 24.], + [28., 14.]], dtype=float32) """ if boundary != 'fill' or fillvalue != 0: raise NotImplementedError("correlate2d() only supports boundary='fill', fillvalue=0") @@ -415,7 +500,7 @@ def detrend(data: ArrayLike, axis: int = -1, type: str = 'linear', bp: int = 0, Returns: The detrended data array. - Example: + Examples: A simple detrend operation in one dimension: >>> data = jnp.array([1., 4., 8., 8., 9.]) @@ -996,7 +1081,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', See Also: :func:`jax.scipy.signal.stft`: short-time Fourier transform. - Example: + Examples: Demonstrate that this gives the inverse of :func:`~jax.scipy.signal.stft`: >>> x = jnp.array([1., 2., 3., 2., 1., 0., 1., 2.]) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 7c1cb34a8971..70f3ccd2ef80 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -16,7 +16,7 @@ from functools import partial import operator -from typing import cast, Any +from typing import cast, overload, Any import numpy as np @@ -28,12 +28,15 @@ from jax._src import core from jax._src import custom_derivatives +from jax._src import deprecations from jax._src import dtypes from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact from jax._src.ops import special as ops_special from jax._src.third_party.scipy.betaln import betaln as _betaln_impl from jax._src.typing import Array, ArrayLike +from jax._src.nn.functions import softmax as nn_softmax +from jax._src.nn.functions import log_softmax as nn_log_softmax def gammaln(x: ArrayLike) -> Array: @@ -186,8 +189,16 @@ def factorial(n: ArrayLike, exact: bool = False) -> Array: n, = promote_args_inexact("factorial", n) return jnp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1))) +@overload +def beta(a: ArrayLike, b: ArrayLike) -> Array: ... -def beta(x: ArrayLike, y: ArrayLike) -> Array: +@overload +def beta(a: ArrayLike, *, y: ArrayLike) -> Array: ... + +@overload +def beta(*, x: ArrayLike, y: ArrayLike) -> Array: ... + +def beta(*args, **kwds): r"""The beta function JAX implementation of :obj:`scipy.special.beta`. @@ -209,9 +220,27 @@ def beta(x: ArrayLike, y: ArrayLike) -> Array: - :func:`jax.scipy.special.gamma` - :func:`jax.scipy.special.betaln` """ - x, y = promote_args_inexact("beta", x, y) - sign = gammasgn(x) * gammasgn(y) * gammasgn(x + y) - return sign * lax.exp(betaln(x, y)) + # TODO(jakevdp): deprecation warning added 2024-06-10; finalize after 2024-09-10 + if 'x' in kwds: + msg = "The `x` parameter of jax.scipy.special.beta is deprecated, use `a` instead." + deprecations.warn('jax-scipy-beta-args', msg, stacklevel=2) + if 'a' in kwds: + raise TypeError("beta() got both parameter 'a' and parameter 'x'.") + kwds['a'] = kwds.pop('x') + if 'y' in kwds: + msg = "The `y` parameter of jax.scipy.special.beta is deprecated, use `b` instead." + deprecations.warn('jax-scipy-beta-args', msg, stacklevel=2) + if 'b' in kwds: + raise TypeError("beta() got both parameter 'b' and parameter 'y'.") + kwds['b'] = kwds.pop('y') + if extra := kwds.keys() - {'a', 'b'}: + raise TypeError(f"beta() got unexpected keyword arguments {list(extra)}") + return _beta(*args, **kwds) + +def _beta(a, b): + a, b = promote_args_inexact("beta", a, b) + sign = gammasgn(a) * gammasgn(b) * gammasgn(a + b) + return sign * lax.exp(betaln(a, b)) def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: @@ -616,24 +645,7 @@ def kl_div( - :func:`jax.scipy.special.rel_entr` """ p, q = promote_args_inexact("kl_div", p, q) - zero = _lax_const(p, 0.0) - both_gt_zero_mask = lax.bitwise_and(lax.gt(p, zero), lax.gt(q, zero)) - one_zero_mask = lax.bitwise_and(lax.eq(p, zero), lax.ge(q, zero)) - - safe_p = jnp.where(both_gt_zero_mask, p, 1) - safe_q = jnp.where(both_gt_zero_mask, q, 1) - - log_val = lax.sub( - lax.add( - lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q)), - safe_q, - ), - safe_p, - ) - result = jnp.where( - both_gt_zero_mask, log_val, jnp.where(one_zero_mask, q, np.inf) - ) - return result + return rel_entr(p, q) - p + q def rel_entr( @@ -2375,7 +2387,6 @@ def poch(z: ArrayLike, m: ArrayLike) -> Array: Notes: The JAX version supports only real-valued inputs. """ - # Factorial definition when m is close to an integer, otherwise gamma definition. z, m = promote_args_inexact("poch", z, m) return jnp.where(m == 0., jnp.array(1, dtype=z.dtype), gamma(z + m) / gamma(z)) @@ -2412,6 +2423,8 @@ def _hyp1f1_serie(a, b, x): https://doi.org/10.48550/arXiv.1407.7786 """ + precision = jnp.finfo(x.dtype).eps + def body(state): serie, k, term = state serie += term @@ -2423,7 +2436,7 @@ def body(state): def cond(state): serie, k, term = state - return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-8) + return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) init = 1, 1, a / b * x @@ -2437,6 +2450,8 @@ def _hyp1f1_asymptotic(a, b, x): https://doi.org/10.48550/arXiv.1407.7786 """ + precision = jnp.finfo(x.dtype).eps + def body(state): serie, k, term = state serie += term @@ -2448,7 +2463,7 @@ def body(state): def cond(state): serie, k, term = state - return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-8) + return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) init = 1, 1, (b - a) * (1 - a) / x serie = lax.while_loop(cond, body, init)[0] @@ -2464,6 +2479,8 @@ def _hyp1f1_a_derivative(a, b, x): https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/01/ """ + precision = jnp.finfo(x.dtype).eps + def body(state): serie, k, term = state serie += term * (digamma(a + k) - digamma(a)) @@ -2475,7 +2492,7 @@ def body(state): def cond(state): serie, k, term = state - return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-15) + return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) init = 0, 1, a / b * x @@ -2490,6 +2507,8 @@ def _hyp1f1_b_derivative(a, b, x): https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/02/ """ + precision = jnp.finfo(x.dtype).eps + def body(state): serie, k, term = state serie += term * (digamma(b) - digamma(b + k)) @@ -2501,7 +2520,7 @@ def body(state): def cond(state): serie, k, term = state - return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-15) + return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) init = 0, 1, a / b * x @@ -2565,3 +2584,72 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: lambda b_dot, primal_out, a, b, x: _hyp1f1_b_derivative(a, b, x) * b_dot, lambda x_dot, primal_out, a, b, x: _hyp1f1_x_derivative(a, b, x) * x_dot ) + + +def softmax(x: ArrayLike, + /, + *, + axis: int | tuple[int, ...] | None = None, + ) -> Array: + r"""Softmax function. + + JAX implementation of :func:`scipy.special.softmax`. + + Computes the function which rescales elements to the range :math:`[0, 1]` + such that the elements along :code:`axis` sum to :math:`1`. + + .. math :: + \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} + + Args: + x : input array + axis: the axis or axes along which the softmax should be computed. The + softmax output summed across these dimensions should sum to :math:`1`. + + Returns: + An array of the same shape as ``x``. + + Note: + If any input values are ``+inf``, the result will be all ``NaN``: this + reflects the fact that ``inf / inf`` is not well-defined in the context of + floating-point math. + + See also: + :func:`log_softmax` + """ + return nn_softmax(x, axis=axis) + + +def log_softmax(x: ArrayLike, + /, + *, + axis: int | tuple[int, ...] | None = None, + ) -> Array: + r"""Log-Softmax function. + + JAX implementation of :func:`scipy.special.log_softmax` + + Computes the logarithm of the :code:`softmax` function, which rescales + elements to the range :math:`[-\infty, 0)`. + + .. math :: + \mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} + \right) + + Args: + x : input array + axis: the axis or axes along which the :code:`log_softmax` should be + computed. + + Returns: + An array of the same shape as ``x`` + + Note: + If any input values are ``+inf``, the result will be all ``NaN``: this + reflects the fact that ``inf / inf`` is not well-defined in the context of + floating-point math. + + See also: + :func:`softmax` + """ + return nn_log_softmax(x, axis=axis) diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index 916e6939cd28..08d1c0b6b538 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -46,6 +46,40 @@ def mode(a: ArrayLike, axis: int | None = 0, nan_policy: str = "propagate", keep Returns: A tuple of arrays, ``(mode, count)``. ``mode`` is the array of modal values, and ``count`` is the number of times each value appears in the input array. + + Examples: + >>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3]) + >>> mode, count = jax.scipy.stats.mode(x) + >>> mode, count + (Array(4, dtype=int32), Array(3, dtype=int32)) + + For multi dimensional arrays, ``jax.scipy.stats.mode`` computes the ``mode`` + and the corresponding ``count`` along ``axis=0``: + + >>> x1 = jnp.array([[1, 2, 1, 3, 2, 1], + ... [3, 1, 3, 2, 1, 3], + ... [1, 2, 2, 3, 1, 2]]) + >>> mode, count = jax.scipy.stats.mode(x1) + >>> mode, count + (Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32)) + + If ``axis=1``, ``mode`` and ``count`` will be computed along ``axis 1``. + + >>> mode, count = jax.scipy.stats.mode(x1, axis=1) + >>> mode, count + (Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32)) + + By default, ``jax.scipy.stats.mode`` reduces the dimension of the result. + To keep the dimensions same as that of the input array, the argument + ``keepdims`` must be set to ``True``. + + >>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True) + >>> mode, count + (Array([[1], + [3], + [2]], dtype=int32), Array([[3], + [3], + [3]], dtype=int32)) """ check_arraylike("mode", a) x = jnp.atleast_1d(a) @@ -142,7 +176,7 @@ def rankdata( if nan_policy not in ["propagate", "omit", "raise"]: raise ValueError( f"Illegal nan_policy value {nan_policy!r}; expected one of " - "{'propoagate', 'omit', 'raise'}" + "{'propagate', 'omit', 'raise'}" ) if nan_policy == "omit": raise NotImplementedError( @@ -201,15 +235,69 @@ def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "pr Returns: array + + Examples: + >>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x) + Array(0.41, dtype=float32) + + For multi dimensional arrays, ``sem`` computes standard error of mean along + ``axis=0``: + + >>> x1 = jnp.array([[1, 2, 1, 3, 2, 1], + ... [3, 1, 3, 2, 1, 3], + ... [1, 2, 2, 3, 1, 2]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x1) + Array([0.67, 0.33, 0.58, 0.33, 0.33, 0.58], dtype=float32) + + If ``axis=1``, standard error of mean will be computed along ``axis 1``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x1, axis=1) + Array([0.33, 0.4 , 0.31], dtype=float32) + + If ``axis=None``, standard error of mean will be computed along all the axes. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x1, axis=None) + Array(0.2, dtype=float32) + + By default, ``sem`` reduces the dimension of the result. To keep the + dimensions same as that of the input array, the argument ``keepdims`` must + be set to ``True``. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x1, axis=1, keepdims=True) + Array([[0.33], + [0.4 ], + [0.31]], dtype=float32) + + Since, by default, ``nan_policy='propagate'``, ``sem`` propagates the ``nan`` + values in the result. + + >>> nan = jnp.nan + >>> x2 = jnp.array([[1, 2, 3, nan, 4, 2], + ... [4, 5, 4, 3, nan, 1], + ... [7, nan, 8, 7, 9, nan]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x2) + Array([1.73, nan, 1.53, nan, nan, nan], dtype=float32) + + If ``nan_policy='omit```, ``sem`` omits the ``nan`` values and computes the error + for the remainging values along the specified axis. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jax.scipy.stats.sem(x2, nan_policy='omit') + Array([1.73, 1.5 , 1.53, 2. , 2.5 , 0.5 ], dtype=float32) """ b, = promote_args_inexact("sem", a) - if axis is None: - b = b.ravel() - axis = 0 if nan_policy == "propagate": - return b.std(axis, ddof=ddof) / jnp.sqrt(b.shape[axis]).astype(b.dtype) + size = b.size if axis is None else b.shape[axis] + return b.std(axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(size).astype(b.dtype) elif nan_policy == "omit": - count = (~jnp.isnan(b)).sum(axis) - return jnp.nanstd(b, axis, ddof=ddof) / jnp.sqrt(count).astype(b.dtype) + count = (~jnp.isnan(b)).sum(axis, keepdims=keepdims) + return jnp.nanstd(b, axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(count).astype(b.dtype) else: raise ValueError(f"{nan_policy} is not supported") diff --git a/jax/_src/scipy/stats/bernoulli.py b/jax/_src/scipy/stats/bernoulli.py index 3dfc25f06aea..96e4a68b7697 100644 --- a/jax/_src/scipy/stats/bernoulli.py +++ b/jax/_src/scipy/stats/bernoulli.py @@ -25,7 +25,7 @@ def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: JAX implementation of :obj:`scipy.stats.bernoulli` ``logpmf`` - The Bernoulli probablility mass function is defined as + The Bernoulli probability mass function is defined as .. math:: @@ -62,7 +62,7 @@ def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: JAX implementation of :obj:`scipy.stats.bernoulli` ``pmf`` - The Bernoulli probablility mass function is defined as + The Bernoulli probability mass function is defined as .. math:: diff --git a/jax/_src/scipy/stats/multivariate_normal.py b/jax/_src/scipy/stats/multivariate_normal.py index 968ad57c7a5b..8ba34703aada 100644 --- a/jax/_src/scipy/stats/multivariate_normal.py +++ b/jax/_src/scipy/stats/multivariate_normal.py @@ -33,7 +33,7 @@ def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right) - where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covarance matrix (``cov``), and + where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covariance matrix (``cov``), and :math:`k` is the rank of :math:`\Sigma`. Args: @@ -83,7 +83,7 @@ def pdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike) -> Array: f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right) - where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covarance matrix (``cov``), and + where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covariance matrix (``cov``), and :math:`k` is the rank of :math:`\Sigma`. Args: diff --git a/jax/_src/scipy/stats/norm.py b/jax/_src/scipy/stats/norm.py index 54d94b832b21..b222e187f255 100644 --- a/jax/_src/scipy/stats/norm.py +++ b/jax/_src/scipy/stats/norm.py @@ -175,7 +175,7 @@ def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: scale: arraylike, distribution scale parameter Returns: - array of ppdf values. + array of ppf values. See Also: - :func:`jax.scipy.stats.norm.cdf` diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index 7c294d66fb6e..14dbbba6e975 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -17,9 +17,12 @@ from collections.abc import Mapping, Sequence import functools -from jax._src import util +from jax._src.util import safe_zip, use_cpp_class, cache from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc +from jax._src.op_shardings import ( + are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated, + op_sharding_to_indices) Shape = tuple[int, ...] Device = xc.Device @@ -27,7 +30,7 @@ XLADeviceAssignment = Sequence[Device] -@functools.lru_cache(maxsize=4096) +@cache(max_size=4096, trace_context_in_key=False) def _addressable_devices_indices_map( sharding: Sharding, global_shape: Shape) -> Mapping[Device, Index | None]: global_map = sharding.devices_indices_map(global_shape) @@ -39,8 +42,41 @@ def _addressable_devices_indices_map( return {d: ind for d, ind in global_map.items() if d.process_index == d.client.process_index()} - -@util.use_cpp_class(xc.Sharding) +@cache(max_size=4096, trace_context_in_key=False) +def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]: + s.shard_shape(global_shape) # raises a good error message + hlo_sharding = s._to_xla_hlo_sharding(len(global_shape)) + indices = op_sharding_to_indices(hlo_sharding, global_shape, + len(s._device_assignment)) + return dict(safe_zip(s._device_assignment, indices)) + + +@cache(max_size=4096, trace_context_in_key=False) +def _common_shard_shape(self, global_shape: Shape) -> Shape: + hlo_sharding = self._to_xla_hlo_sharding(len(global_shape)) + if is_op_sharding_replicated(hlo_sharding): + return global_shape + partitions, _ = get_num_ways_dim_sharded(hlo_sharding) + assert len(partitions) == len(global_shape), (len(partitions), len(global_shape)) + out = [] + for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)): + try: + quotient, remainder = divmod(s, p) + except TypeError: + # TODO Figure out how to partition dynamic shapes + raise NotImplementedError + if remainder != 0: + raise ValueError( + f"Sharding {self} implies that array axis {dim} is partitioned " + f"{p} times, but the dimension size is {s} " + f"(full shape: {global_shape}, " + f"per-dimension tiling factors: {partitions} should evenly divide " + "the shape)") + out.append(quotient) + return tuple(out) + + +@use_cpp_class(xc.Sharding) class Sharding: """Describes how a :class:`jax.Array` is laid out across devices. """ @@ -55,35 +91,6 @@ def device_set(self) -> set[Device]: """ raise NotImplementedError('Subclasses should implement this method.') - def devices_indices_map( - self, global_shape: Shape) -> Mapping[Device, Index | None]: - """Returns a mapping from devices to the array slices each contains. - - The mapping includes all global devices, i.e., including - non-addressable devices from other processes. - """ - raise NotImplementedError('Subclasses should implement this method.') - - def shard_shape(self, global_shape: Shape) -> Shape: - """Returns the shape of the data on each device. - - The shard shape returned by this function is calculated from - ``global_shape`` and the properties of the sharding. - """ - raise NotImplementedError('Subclasses should implement this method.') - - def is_equivalent_to(self, other: Sharding, ndim: int) -> bool: - """Returns ``True`` if two shardings are equivalent. - - Two shardings are equivalent if they place the same logical array shards on - the same devices. - - For example, a :class:`NamedSharding` may be equivalent - to a :class:`PositionalSharding` if both place the same shards of the array - on the same devices. - """ - raise NotImplementedError('Subclasses should implement this method.') - @property def is_fully_replicated(self) -> bool: """Is this sharding fully replicated? @@ -112,6 +119,14 @@ def with_memory_kind(self, kind: str) -> Sharding: """Returns a new Sharding instance with the specified memory kind.""" raise NotImplementedError('Subclasses should implement this method') + @property + def _device_assignment(self) -> XLADeviceAssignment: + raise NotImplementedError('Subclasses should implement this method.') + + def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: + raise NotImplementedError('Subclasses should implement this method.') + + ############################################################################# # Default implementations below that all subclasses will inherit. @@ -134,3 +149,49 @@ def addressable_devices_indices_map( ``device_indices_map`` that applies to the addressable devices. """ return _addressable_devices_indices_map(self, global_shape) + + def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: + """Returns a mapping from devices to the array slices each contains. + + The mapping includes all global devices, i.e., including + non-addressable devices from other processes. + """ + return common_devices_indices_map(self, global_shape) + + @functools.cached_property + def _addressable_device_assignment(self) -> XLADeviceAssignment: + if self.is_fully_addressable: + return self._device_assignment + if hasattr(self, '_internal_device_list'): + return tuple(self._internal_device_list.addressable_device_list) + return tuple(d for d in self._device_assignment + if d.process_index == d.client.process_index()) + + def shard_shape(self, global_shape: Shape) -> Shape: + """Returns the shape of the data on each device. + + The shard shape returned by this function is calculated from + ``global_shape`` and the properties of the sharding. + """ + return _common_shard_shape(self, global_shape) + + def is_equivalent_to(self: Sharding, other: Sharding, ndim: int) -> bool: + """Returns ``True`` if two shardings are equivalent. + + Two shardings are equivalent if they place the same logical array shards on + the same devices. + + For example, a :class:`NamedSharding` may be equivalent + to a :class:`PositionalSharding` if both place the same shards of the array + on the same devices. + """ + try: + return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim), + other._to_xla_hlo_sharding(ndim)) + and self._internal_device_list == other._internal_device_list and # type: ignore + self.memory_kind == other.memory_kind) + # NotImplementedError is raised by PmapSharding because it can't lower + # to OpSharding. So if `other` is a PmapSharding, default to a strict + # equality check. + except NotImplementedError: + return self == other diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 25a02facb670..acceeef86a4a 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -30,10 +30,10 @@ from jax._src import tree_util from jax._src import util from jax._src import xla_bridge +from jax._src import core from jax._src.lib import xla_client as xc -from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, - is_op_sharding_replicated, - op_sharding_to_indices) # pyformat: disable +from jax._src.op_shardings import ( + are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) from jax._src.partition_spec import PartitionSpec from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method import numpy as np @@ -43,97 +43,15 @@ Device = xc.Device Index = tuple[slice, ...] XLADeviceAssignment = tuple[Device, ...] - +# TODO(yashkatariya): Remove this after 3 months of deprecation. +XLACompatibleSharding = sharding.Sharding @dataclasses.dataclass(frozen=True) class TransferToMemoryKind: memory_kind: str -@functools.lru_cache(maxsize=4096) -def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]: - hlo_sharding = s._to_xla_hlo_sharding(len(global_shape)) - gspmd_sharding = GSPMDSharding(s._device_assignment, hlo_sharding) - return gspmd_sharding.devices_indices_map(global_shape) - - -@functools.lru_cache(maxsize=4096) -def _common_shard_shape(self, global_shape: Shape) -> Shape: - hlo_sharding = self._to_xla_hlo_sharding(len(global_shape)) - if is_op_sharding_replicated(hlo_sharding): - return global_shape - partitions, _ = get_num_ways_dim_sharded(hlo_sharding) - assert len(partitions) == len(global_shape), (len(partitions), len(global_shape)) - out = [] - for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)): - try: - quotient, remainder = divmod(s, p) - except TypeError: - # TODO Figure out how to partition dynamic shapes - raise NotImplementedError - if remainder != 0: - raise ValueError( - f"Sharding {self} implies that array axis {dim} is partitioned " - f"{p} times, but the dimension size is {s} " - f"(full shape: {global_shape}, " - f"per-dimension tiling factors: {partitions} should evenly divide " - "the shape)") - out.append(quotient) - return tuple(out) - - -# Shardings that inherit from XLACompatibleSharding should implement the -# `_device_assignment` property and `_to_xla_hlo_sharding` method. -@use_cpp_class(xc.XLACompatibleSharding) -class XLACompatibleSharding(sharding.Sharding): - """A :class:`Sharding` that describes shardings expressible to XLA. - - Subclasses of :class:`XLACompatibleSharding` work with - all JAX APIs and transformations that use XLA. - """ - - # Abstract methods below that subclasses should implement. - - @property - def _device_assignment(self) -> XLADeviceAssignment: - raise NotImplementedError('Subclasses should implement this method.') - - def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: - raise NotImplementedError('Subclasses should implement this method.') - - ############################################################################# - # Default implementations below that all subclasses will inherit. - - def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: - return common_devices_indices_map(self, global_shape) - - @functools.cached_property - def _addressable_device_assignment(self) -> XLADeviceAssignment: - if self.is_fully_addressable: - return self._device_assignment - if hasattr(self, '_internal_device_list'): - return tuple(self._internal_device_list.addressable_device_list) - return tuple(d for d in self._device_assignment - if d.process_index == d.client.process_index()) - - def shard_shape(self, global_shape: Shape) -> Shape: - return _common_shard_shape(self, global_shape) - - def is_equivalent_to(self: XLACompatibleSharding, # type: ignore - other: XLACompatibleSharding, ndim: int) -> bool: - try: - return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim), - other._to_xla_hlo_sharding(ndim)) - and self._internal_device_list == other._internal_device_list and # type: ignore - self.memory_kind == other.memory_kind) - # NotImplementedError is raised by PmapSharding because it can't lower - # to OpSharding. So if `other` is a PmapSharding, default to a strict - # equality check. - except NotImplementedError: - return self == other - - -@functools.lru_cache +@util.cache(max_size=128, trace_context_in_key=False) def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes): try: for p in parsed_pspec: @@ -156,7 +74,7 @@ def hashed_index(x) -> int: return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x)) -@functools.lru_cache(maxsize=4096) +@util.cache(max_size=4096, trace_context_in_key=False) def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]: try: device_indices_map_fn = sharding.devices_indices_map @@ -176,7 +94,7 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int] return out -@functools.lru_cache(maxsize=4096) +@util.cache(max_size=4096, trace_context_in_key=False) def named_sharding_to_xla_hlo_sharding( self, num_dimensions: int) -> xc.HloSharding: mesh_shape = self.mesh.shape @@ -239,7 +157,7 @@ def named_sharding_to_xla_hlo_sharding( @use_cpp_class(xc.NamedSharding) -class NamedSharding(XLACompatibleSharding): +class NamedSharding(sharding.Sharding): r"""A :class:`NamedSharding` expresses sharding using named axes. A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and @@ -265,7 +183,7 @@ class NamedSharding(XLACompatibleSharding): mesh: A :class:`jax.sharding.Mesh` object. spec: A :class:`jax.sharding.PartitionSpec` object. - Example: + Examples: >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P @@ -322,7 +240,7 @@ def __eq__(self, other): return False return self.mesh is other.mesh or self.mesh == other.mesh - def is_compatible_aval(self, aval_shape: Shape): + def check_compatible_aval(self, aval_shape: Shape) -> None: assert self._parsed_pspec is not None if len(aval_shape) < len(self._parsed_pspec): extra_msg = (' For scalars the PartitionSpec should be P()' @@ -378,19 +296,19 @@ def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) -@functools.lru_cache +@util.cache(max_size=128, trace_context_in_key=False) def get_replicated_hlo_sharding(): return xc.HloSharding.replicate() @use_cpp_class(xc.SingleDeviceSharding) -class SingleDeviceSharding(XLACompatibleSharding): +class SingleDeviceSharding(sharding.Sharding): """A :class:`Sharding` that places its data on a single device. Args: device: A single :py:class:`Device`. - Example: + Examples: >>> single_device_sharding = jax.sharding.SingleDeviceSharding( ... jax.devices()[0]) @@ -454,7 +372,7 @@ def is_fully_addressable(self) -> bool: return True -@functools.lru_cache(maxsize=4096) +@util.cache(max_size=4096, trace_context_in_key=False) def pmap_sharding_devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Index]: self.shard_shape(global_shape) # raises a good error message @@ -463,7 +381,7 @@ def pmap_sharding_devices_indices_map( @use_cpp_class(xc.PmapSharding) -class PmapSharding(XLACompatibleSharding): +class PmapSharding(sharding.Sharding): """Describes a sharding used by :func:`jax.pmap`.""" devices: np.ndarray sharding_spec: sharding_specs.ShardingSpec @@ -642,7 +560,7 @@ def _op_sharding_to_pos_sharding( return p -@functools.lru_cache(maxsize=4096) +@util.cache(max_size=4096, trace_context_in_key=False) def _positional_sharding_to_xla_hlo_sharding( self, num_dimensions: int) -> xc.HloSharding: if self.shape == (1,) * self.ndim: @@ -664,7 +582,7 @@ def _positional_sharding_to_xla_hlo_sharding( return xc.HloSharding.from_proto(pbuf) -class PositionalSharding(XLACompatibleSharding): +class PositionalSharding(sharding.Sharding): _devices: tuple[xc.Device, ...] _memory_kind: str | None _ids: np.ndarray # dtype DeviceIdSet @@ -716,6 +634,13 @@ def replicate(self, axis=None, keepdims=True) -> PositionalSharding: new_ids = self._ids.sum(axis=axis, keepdims=keepdims) # union return self._remake(self._devices, new_ids) + def check_compatible_aval(self, aval_shape: Shape) -> None: + if len(aval_shape) != len(self.shape) and not self.is_fully_replicated: + raise ValueError( + f"Sharding {self} is only valid for values of rank " + f"{len(self.shape)}, but was applied to a value of rank " + f"{len(aval_shape)}") + @classmethod def _remake( cls, devices: tuple[xc.Device, ...], ids: np.ndarray, @@ -764,7 +689,7 @@ def with_memory_kind(self, kind: str) -> PositionalSharding: def is_fully_replicated(self) -> bool: return self.shape == (1,) * self.ndim - # XLACompatibleSharding interface + # sharding.Sharding interface @property def _device_assignment(self) -> XLADeviceAssignment: @@ -807,17 +732,8 @@ def __eq__(self, other) -> bool: self._ids == other._ids) -@functools.lru_cache(maxsize=4096) -def gspmd_sharding_devices_indices_map( - self, global_shape: Shape) -> Mapping[Device, Index]: - self.shard_shape(global_shape) # raises a good error message - indices = op_sharding_to_indices(self._hlo_sharding, global_shape, - len(self._devices)) - return dict(safe_zip(self._devices, indices)) - - @use_cpp_class(xc.GSPMDSharding) -class GSPMDSharding(XLACompatibleSharding): +class GSPMDSharding(sharding.Sharding): _devices: tuple[Device, ...] _hlo_sharding: xc.HloSharding _memory_kind: str | None @@ -865,7 +781,7 @@ def __repr__(self): mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}' return f'GSPMDSharding({self._hlo_sharding!r}{mem})' - def is_compatible_aval(self, aval_shape: Shape): + def check_compatible_aval(self, aval_shape: Shape) -> None: num_ways_dim_sharded, _ = get_num_ways_dim_sharded(self._hlo_sharding) if len(aval_shape) < len(num_ways_dim_sharded): raise ValueError( @@ -884,9 +800,6 @@ def memory_kind(self) -> str | None: def with_memory_kind(self, kind: str) -> GSPMDSharding: return GSPMDSharding(self._devices, self._hlo_sharding, memory_kind=kind) - def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: - return gspmd_sharding_devices_indices_map(self, global_shape) - @property def _device_assignment(self) -> XLADeviceAssignment: return self._devices @@ -1140,9 +1053,6 @@ def prepare_axis_resources(axis_resources, if isinstance(entry, PmapSharding): raise ValueError(f'One of {what} got sharding {entry} which is not ' 'allowed.') - if not isinstance(entry, XLACompatibleSharding): - raise ValueError(f'One of {what} got sharding {entry} which is not a ' - 'subclass of XLACompatibleSharding.') new_entries.append(entry) else: new_entries.append(ParsedPartitionSpec.from_user_input( @@ -1156,7 +1066,7 @@ def _check_unique_resources(axis_resources, arg_name): for arg_axis_resources in axis_resources: if not arg_axis_resources: continue if (is_unspecified_or_auto(arg_axis_resources) or - isinstance(arg_axis_resources, XLACompatibleSharding)): + isinstance(arg_axis_resources, sharding.Sharding)): continue constrained_dims = [d for d in arg_axis_resources if d is not None] resource_counts = collections.Counter( @@ -1193,11 +1103,6 @@ class SPMDAxisContext: def axis_env(self): # All collectives that touch axis_env should remember to set use_global_device_ids # when this context is enabled! - if self.manual_axes != frozenset(self.mesh.axis_names): - raise NotImplementedError( - "Collectives in manually partitioned computations are only supported " - "when all mesh axes are partitioned manually (no partial automatic sharding). " - "Make sure that you mention all mesh axes in axis_resources!") return self.unsafe_axis_env @property @@ -1386,6 +1291,192 @@ def _slice_as_tuple(s: slice): return (s.start, s.stop) +class NonUniformShardingError(ValueError): + """Raised when sharding is not uniform across processes.""" + + +def get_process_index_and_count( + tensor_sharding: sharding.Sharding, + dim: int, + ndims: int, +) -> tuple[int, int]: + """Get current process index and number of unique processes for given dimension. + + This function facilitates mapping of process-level data to individual + devices. Each process can use its index to obtain the data corresponding + to that index. If process level data is sharded on multiple dimensions + this function can be used to build the cross product of indices in + each sharded axis. Processes that need to load the same data will have + the same index. For shardings whose per-process data is not distributed + on a grid, the number of distinct shards will be such that it is possible to + build the target shape while maintaining a "cube" shape of local-process data. + + For example, in case of 4 hosts with sharding distributed like so: + + 1234 + 2143 + + For dim 0 (rows): all processes need to access all rows, so we return (0, 1) + For dim 1 (cols): + process 1 and 2 returns index 0 out of 2 (need cols 0 and 1), + process 3 and 4 returns index 1 out of 2 (need cols 2 and 3). + + On the other hand, for a sharding like: + + 1212 + 3434 + + Dim 0 (rows): process 1 and 2 returns (0, 2), process 3 and 4 returns (1, 2) + Dim 1 (cols): process 1 and 3 returns (0, 2), process 2 and 4 returns (1, 2) + + Note: This function requires sharding to be process uniform in dimension + `dim`: + each process has the same number of addressable indices in that + dimension and all index sets across processes are either disjoint or the same. + + For sharding to be process uniform the addressable shards doesn't need to + form contiguous subtensor, or even a sparse grid and in case of + interleaved high-dimensional tensor it is possible for sharding to be + process uniform only in some dimensions but not others. + + For example: + 1111 and 12 and 1212 and 1212 + 2222 21 2121 1212 + + are all sharding uniform, in both dimensions. However + + 1122 + 2121 + 1121 + 1222 + + is uniform in dimension 0 (both hosts access all rows), but + is not uniform in dimension 1 (host 1 accesses columns: 0, 1, and 3), + while host 2 accesses (0, 1, 2, 3). + + Returns: + A tuple of (index, num_distinct_shards) for the given dimension. + It is guaranteed that `index` will cover 0 to `num_distinct_shards - 1`, + across all processes. + + Raises: + NonUniformShardingError: if the sharding is not process uniform in dimension + `dim`. + """ + # TODO(sandler, yashkatariya): Consider making this function public. + + if ( + tensor_sharding.is_fully_addressable + or tensor_sharding.is_fully_replicated + ): + return (0, 1) + num_devices = len(tensor_sharding.device_set) + # Get device to indices map, we don't care about the concrete + # global shape here, only to get the distribution of shards across the tensor + # using (num_devices, num_devices, ...) This is a universal shape that is + # compatible with any mesh with num_devices. + device_map = tensor_sharding.devices_indices_map((num_devices,) * ndims) + + # Get the slices for 'dim' for all devices. + global_slice = {k: v[dim] for k, v in device_map.items()} + + # Contains mapping from process_index to a set of slices for that process. + process_to_slice = collections.defaultdict(set) + # Contains global set of slices across all processes. + all_slices = set() + + # Compute the set of slices for each process and the global set of slices. + for d, v in global_slice.items(): + key = (v.start, v.stop) + process_to_slice[d.process_index].add(key) + all_slices.add(key) + + # Get the set of slices for the current process which we will use to compute + # the index of the current process. + current_pid = next(iter(tensor_sharding.addressable_devices)).process_index + addressable_slices = frozenset(process_to_slice[current_pid]) + + # Verify that all processes have the same number of slices. + slices_per_process = len(addressable_slices) + if any(len(x) != slices_per_process for x in process_to_slice.values()): + raise NonUniformShardingError( + f'{tensor_sharding=} is non-uniform on {dim=} as some processes have ' + 'different number of slices.' + ) + unique_processes = list({frozenset(x) for x in process_to_slice.values()}) + + # After removing duplicate processes all unique slices should + # cover the dimension exactly once. If they don' it means that + # the sharding is not uniform. + if sum(len(h) for h in unique_processes) != len(all_slices): + raise NonUniformShardingError( + f'{tensor_sharding=} is non-uniform on {dim=}' + ) + return (unique_processes.index(addressable_slices), len(unique_processes)) + + +def local_to_global_shape( + sharding: sharding.Sharding, + local_shape: Shape, +) -> tuple[int | None, ...]: + """Computes the global shape given the per process if possible. + + The returned shape will have the size of the global tensor in that dimension + or None, if it is not computable. The latter can happen when sharding + is not uniform along that dimension, e.g. different hosts require + different shapes, or if different processes have partial data overlap. + + If at most one dimension is sharded the shape is always computable. + Generally, global shape is computable for most practical meshes (including + topology aware such as meshes returned by mesh_utils.create_device_mesh) + + Some examples: Suppose mesh is {'a': 2, 'b': 2, 'c': 2} with 2 devices + per host, 4 hosts total. For different specs we get: + - P(): + global_shape = local_shape + + - P(('a', 'b', 'c'), None): + global_shape = (4 * local_shape[0], local_shape[1]) + Note: per device shape is (local_shape[0] / 2, local_shape[1]) + + - P(('a', 'b'), None) + global_shape = (4 * local_shape[0], local_shape[1]) + # NB: the same global shape as above, since sharding along 'c' dimension + # happens to be within process, and thus doesn't affect the global shape. + # The underlying difference will be in the per *device* shape, which + # would be (local_shape[0], local_shape[1]) in this case. + + - P(None, ('a', 'c')) + global_shape = (local_shape[0], 2 * local_shape[1]) + # Per device shape is (local_shape[0], local_shape[1] / 2) + - P(('a', 'c'), 'b'): + global_shape = (2 * local_shape[0], 2 * local_shape[1]) + # Per device shape is (local_shape[0] / 2, local_shape[1]) + - If devices in the Mesh are randomly permuted: For any partition spec + which shards more than 1 axis: e.g. P('a', ('b', 'c')): + global_shape = (None, None) + + Args: + local_shape: global shape of the tensor. + + Returns: + global_shape with Nones in non-uniform dimensions. + """ + + global_shape : list[int | None] = [None] * len(local_shape) + for i, local_dim in enumerate(local_shape): + try: + _, shard_count = get_process_index_and_count( + sharding, i, ndims=len(local_shape) + ) + global_shape[i] = local_dim * shard_count + except NonUniformShardingError: + global_shape[i] = None + continue + + return tuple(global_shape) + + def num_addressable_indices( tensor_sharding: sharding.Sharding, dim: int, @@ -1430,3 +1521,111 @@ def num_addressable_indices( }) shard_size = tensor_sharding.shard_shape(global_shape)[dim] return shard_size * num_unique_slices + + +def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + new_op_sharding = hlo_sharding.to_proto().clone() + partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + tad = partitions + [1] * elt_aval.ndim + suffix + new_op_sharding.tile_assignment_dimensions = tad + return xc.HloSharding.from_proto(new_op_sharding) + +def is_single_device_sharding(sharding: sharding.Sharding) -> bool: + # Special case PmapSharding here because PmapSharding maps away an axis + # and needs to be handled separately.test_pjit_single_device_sharding_add + return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) + +def make_key_array_phys_sharding(aval, sharding): + if is_single_device_sharding(sharding): + return sharding + elif isinstance(sharding, PmapSharding): + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + trailing_sharding = [sharding_specs.NoSharding()] * elt_aval.ndim + phys_sharding_spec = sharding_specs.ShardingSpec( + sharding=(*sharding.sharding_spec.sharding, *trailing_sharding), + mesh_mapping=sharding.sharding_spec.mesh_mapping) + return PmapSharding(devices=sharding.devices, + sharding_spec=phys_sharding_spec) + elif isinstance(sharding, NamedSharding): + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + trailing_spec = [None] * elt_aval.ndim + return NamedSharding( + sharding.mesh, + PartitionSpec(*sharding.spec, *trailing_spec)) + else: + hlos = sharding._to_xla_hlo_sharding(aval.ndim) + return GSPMDSharding( + sharding._device_assignment, physical_hlo_sharding(aval, hlos)) + + +def physical_sharding( + aval, sharding: sharding.Sharding) -> sharding.Sharding: + return make_key_array_phys_sharding(aval, sharding) + + +def get_logical_gspmd_sharding(aval, phys_sharding): + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( + aval.ndim + elt_aval.ndim) + partitions, num_replicas = get_num_ways_dim_sharded(phys_hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + # Create logical sharding by cutting off the replicated trailing dims. + logical_op_sharding = phys_hlo_sharding.to_proto().clone() + tad = partitions[:-elt_aval.ndim] + suffix + logical_op_sharding.tile_assignment_dimensions = tad + return GSPMDSharding(phys_sharding._device_assignment, + xc.HloSharding.from_proto(logical_op_sharding)) + +def check_replicated_trailing_dims(sharding: sharding.Sharding, aval): + if isinstance(sharding, PmapSharding): + return + phys_aval = core.physical_aval(aval) + hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim) + partitions, _ = get_num_ways_dim_sharded(hlo_s) + num_trailing_dims = phys_aval.ndim - aval.ndim + if not all(i == 1 for i in partitions[-num_trailing_dims:]): + raise AssertionError( + "The trailing dims of extended dtypes should be replicated. Got" + f" sharding: {sharding}, partitions: {partitions}, " + f"num_trailing_dims: {num_trailing_dims}") + +def logical_sharding(aval, phys_sharding) -> sharding.Sharding: + # The trailing dims should always be replicated. + check_replicated_trailing_dims(phys_sharding, aval) + + if is_single_device_sharding(phys_sharding): + return phys_sharding + elif isinstance(phys_sharding, PmapSharding): + elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + logical_sharding_spec = sharding_specs.ShardingSpec( + sharding=phys_sharding.sharding_spec.sharding[:-elt_aval.ndim], + mesh_mapping=phys_sharding.sharding_spec.mesh_mapping) + return PmapSharding(devices=phys_sharding.devices, + sharding_spec=logical_sharding_spec) + elif isinstance(phys_sharding, NamedSharding): + logical_gs = get_logical_gspmd_sharding(aval, phys_sharding) + return _gspmd_to_named_sharding_via_mesh( + logical_gs, phys_sharding.mesh) + else: + return get_logical_gspmd_sharding(aval, phys_sharding) + + +@util.cache() +def create_mesh_pspec_sharding( + mesh: mesh_lib.Mesh, pspec: PartitionSpec | None, parsed_pspec=None, + memory_kind: str | None = None) -> NamedSharding: + if pspec is None: + pspec, parsed_pspec = PartitionSpec(), None + return NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec, + memory_kind=memory_kind) + + +def _gspmd_to_named_sharding_via_mesh( + out_s: GSPMDSharding, mesh: mesh_lib.Mesh) -> NamedSharding: + parsed_pspec = parse_flatten_op_sharding( + out_s._hlo_sharding, mesh)[0] + return create_mesh_pspec_sharding( + mesh, parsed_pspec.get_partition_spec(), parsed_pspec, + out_s.memory_kind) diff --git a/jax/_src/sourcemap.py b/jax/_src/sourcemap.py index 276a39e26444..b54f2193ff26 100644 --- a/jax/_src/sourcemap.py +++ b/jax/_src/sourcemap.py @@ -18,9 +18,10 @@ from __future__ import annotations +from collections.abc import Iterable, Sequence from dataclasses import dataclass import json -from typing import Iterable, Sequence, Union +from typing import Union # A Segment encodes how parts in the generated source relate to the original source. # Each segment is made up of 1, 4 or 5 variable-length fields. For their semantics see diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 683f2fcdceb3..874ef8834557 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,9 +30,10 @@ """ from __future__ import annotations +import functools from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, NamedTuple, Protocol, Union +from typing import Any, NamedTuple, Protocol, Union, runtime_checkable import warnings import jax @@ -44,6 +45,7 @@ from jax._src import tree_util from jax._src.tree_util import tree_unflatten, keystr from jax._src import util +from jax._src.sharding_impls import is_unspecified_or_auto from jax._src.layout import Layout from jax._src.interpreters import mlir from jax._src.lib.mlir import ir @@ -71,7 +73,7 @@ def call(self, *args_flat) -> Sequence[Any]: # TODO(frostig): improve annotation (sequences of arrays/buffers) raise NotImplementedError - def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: + def input_shardings(self) -> Sequence[jax.sharding.Sharding]: """Flat sequence of input shardings. May raise ``NotImplementedError`` if unavailable, e.g. based on backend, @@ -79,7 +81,7 @@ def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: """ raise NotImplementedError - def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: + def output_shardings(self) -> Sequence[jax.sharding.Sharding]: """Flat sequence of output shardings. May raise ``NotImplementedError`` if unavailable, e.g. based on backend, @@ -217,11 +219,11 @@ def xla_extension_executable(self) -> xc.LoadedExecutable: def call(self, *args_flat) -> Sequence[Any]: raise NotImplementedError("must override") - def input_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: + def input_shardings(self) -> Sequence[jax.sharding.Sharding]: raise NotImplementedError( "compiled executable carries no input sharding information") - def output_shardings(self) -> Sequence[jax.sharding.XLACompatibleSharding]: + def output_shardings(self) -> Sequence[jax.sharding.Sharding]: raise NotImplementedError( "compiled executable carries no output sharding information") @@ -313,7 +315,7 @@ class XlaLowering(Lowering): def hlo(self) -> xc.XlaComputation: """Return an HLO representation of this computation.""" hlo = self.stablehlo() - m: Union[str, bytes] + m: str | bytes m = mlir.module_to_bytecode(hlo) return xla_extension.mlir.mlir_module_to_xla_computation( m, use_tuple_args=self.compile_args["tuple_args"]) @@ -368,11 +370,26 @@ def cost_analysis(self) -> dict[str, float]: # -- Public-facing API, plus helpers -@dataclass +@dataclass(frozen=True) class ArgInfo: - aval: core.AbstractValue + _aval: core.AbstractValue donated: bool + @property + def shape(self): + return self._aval.shape # pytype: disable=attribute-error + + @property + def dtype(self): + return self._aval.dtype # pytype: disable=attribute-error + + +@dataclass(frozen=True) +class OutInfo: + shape: tuple[int, ...] + dtype: jax.typing.DTypeLike + sharding: jax.sharding.Sharding | None = None + class Stage: args_info: Any # PyTree of ArgInfo @@ -385,7 +402,7 @@ def in_tree(self) -> tree_util.PyTreeDef: @property def in_avals(self): """Tree of input avals.""" - return tree_util.tree_map(lambda x: x.aval, self.args_info) + return tree_util.tree_map(lambda x: x._aval, self.args_info) @property def donate_argnums(self): @@ -409,6 +426,37 @@ class CompiledCallParams(NamedTuple): out_tree: tree_util.PyTreeDef +class Traced(Stage): + __slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable", + "_args_flat", "_arg_names", "_num_consts"] + + def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree, + lower_callable, args_flat=None, arg_names=None, + num_consts: int = 0): + self.jaxpr = jaxpr + self.args_info = args_info + self.fun_name = fun_name + self._out_tree = out_tree + self._lower_callable = lower_callable + self._args_flat = args_flat + self._arg_names = arg_names + self._num_consts = num_consts + + @property + def out_info(self): + return self._out_tree.unflatten( + [OutInfo(o.shape, o.dtype) for o in self.jaxpr.out_avals]) + + def lower(self, lowering_platforms: tuple[str, ...] | None = None, + _private_parameters: mlir.LoweringParameters | None = None): + if _private_parameters is None: + _private_parameters = mlir.LoweringParameters() + new_callable = functools.partial( + self._lower_callable, lowering_platforms=lowering_platforms, + lowering_parameters=_private_parameters) + return Lowered(new_callable(), self.args_info, self._out_tree) + + class Compiled(Stage): """Compiled representation of a function specialized to types/values. @@ -496,12 +544,17 @@ def runtime_executable(self) -> Any | None: return self._executable.runtime_executable() @property - def input_shardings(self): # PyTree[sharding.XLACompatibleSharding] + def input_shardings(self): # PyTree[sharding.Sharding] shardings_flat = self._executable.input_shardings() + # Some input shardings got DCE'd + if self.in_tree.num_leaves > len(shardings_flat): + iter_shardings_flat = iter(shardings_flat) + shardings_flat = [next(iter_shardings_flat) if i in self._executable._kept_var_idx + else None for i in range(self.in_tree.num_leaves)] return tree_util.tree_unflatten(self.in_tree, shardings_flat) # pytype: disable=attribute-error @property - def output_shardings(self): # PyTree[sharding.XLACompatibleSharding] + def output_shardings(self): # PyTree[sharding.Sharding] shardings_flat = self._executable.output_shardings() return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error @@ -601,11 +654,10 @@ class Lowered(Stage): querying properties of lowered computations across JAX's various lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.). """ - __slots__ = ["args_info", "out_tree", "_lowering", "_no_kwargs"] - + __slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs"] + _lowering: XlaLowering args_info: Any # PyTree of ArgInfo out_tree: tree_util.PyTreeDef - _lowering: XlaLowering _no_kwargs: bool def __init__( @@ -614,10 +666,11 @@ def __init__( args_info, # PyTree of ArgInfo out_tree: tree_util.PyTreeDef, no_kwargs: bool = False): + self._lowering = lowering - self._no_kwargs = no_kwargs self.args_info = args_info self.out_tree = out_tree + self._no_kwargs = no_kwargs @classmethod def from_flat_info(cls, @@ -642,6 +695,14 @@ def from_flat_info(cls, out_tree, no_kwargs=no_kwargs) + @property + def out_info(self): # PyTree of OutInfo + out_avals = self._lowering.compile_args["global_out_avals"] + out_shardings = self._lowering.compile_args["out_shardings"] + return self.out_tree.unflatten( + [OutInfo(o.shape, o.dtype, None if is_unspecified_or_auto(s) else s) + for o, s in zip(out_avals, out_shardings)]) + def compile( self, compiler_options: CompilerOptions | None = None) -> Compiled: """Compile, returning a corresponding ``Compiled`` instance.""" @@ -701,8 +762,9 @@ def cost_analysis(self) -> Any | None: return None +@runtime_checkable class Wrapped(Protocol): - """A function ready to be specialized, lowered, and compiled. + """A function ready to be traced, lowered, and compiled. This protocol reflects the output of functions such as ``jax.jit``. Calling it results in JIT (just-in-time) lowering, @@ -714,6 +776,17 @@ def __call__(self, *args, **kwargs): """Executes the wrapped function, lowering and compiling as needed.""" raise NotImplementedError + def trace(self, *args, **kwargs) -> Traced: + """Trace this function explicitly for the given arguments. + + A traced function is staged out of Python and translated to a jaxpr. It is + ready for lowering but not yet lowered. + + Returns: + A ``Traced`` instance representing the tracing. + """ + raise NotImplementedError + def lower(self, *args, **kwargs) -> Lowered: """Lower this function explicitly for the given arguments. diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 09488c8bb165..f3a3e61a2ace 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -14,11 +14,11 @@ """Module for discharging state primitives.""" from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses from functools import partial import operator -from typing import Any, Callable, Protocol +from typing import Any, Protocol import numpy as np diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 7ed5043854a0..acf1c7216240 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -17,7 +17,7 @@ from __future__ import annotations import dataclasses -from typing import Any, Union, List +from typing import Any, Union from jax._src import core from jax._src import tree_util @@ -30,7 +30,12 @@ @tree_util.register_pytree_node_class @dataclasses.dataclass class Slice: - """Represents a slice with a dynamic start index and a fixed size.""" + """A slice with a start index and a size. + + Both start index and size can either be static, i.e. known at tracing + and compilation time, or dynamic. + """ + start: int | Array size: int | Array stride: int = 1 @@ -60,9 +65,9 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children) -> Slice: - start, size = [ + start, size = ( a if a is not None else b for a, b in zip(children, aux_data[:2]) - ] + ) return cls(start, size, aux_data[2]) @classmethod @@ -78,7 +83,15 @@ def dslice( size: int | Array | None = None, stride: int | None = None, ) -> slice | Slice: - """Constructs a `Slice` from a start and a size.""" + """Constructs a ``Slice`` from a start index and a size. + + The semantics of ``dslice`` mirror those of the builtin ``slice`` type: + + * ``dslice(None)`` is ``:`` + * ``dslice(j)`` is ``:j`` + * ``dslice(i, j)`` is ``i:i+j`` + * ``dslice(i, j, stride)`` is ``i:i+j:stride`` + """ if start is None: return slice(None) if stride is None: @@ -123,6 +136,7 @@ class NDIndexer: indices: tuple[DimIndexer, ...] shape: tuple[int, ...] int_indexer_shape: tuple[int, ...] + # Off by default to avoid doing validation during pytree operations. validate: bool = False def __post_init__(self): @@ -181,53 +195,60 @@ def tree_flatten(self): def tree_unflatten(cls, data, flat_idx): idx_tree, shape, int_indexer_shape = data indices = tree_util.tree_unflatten(idx_tree, flat_idx) - return NDIndexer(tuple(indices), shape, int_indexer_shape) + return cls(tuple(indices), shape, int_indexer_shape) @classmethod def from_indices_shape(cls, indices, shape) -> NDIndexer: if not isinstance(indices, tuple): + # TODO(slebedev): Consider requiring `indices` to be a Sequence. indices = (indices,) - if len(indices) == 1 and indices[0] is ...: - indices = (slice(None),) * len(shape) - if any(idx is ... for idx in indices): - new_indices : List[Any] = [] - num_ellipsis = sum(1 for idx in indices if idx is ...) + + indices = list(indices) + if num_ellipsis := sum(idx is ... for idx in indices): if num_ellipsis > 1: raise ValueError("Only one ellipsis is supported.") - for idx in indices: - if idx is ...: - expand = (slice(None),) * (len(shape) - len(indices) + 1) - new_indices.extend(expand) - else: - new_indices.append(idx) - indices = tuple(new_indices) + # Expand ... so that `indices` has the same length as `shape`. + ip = indices.index(...) + indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1) if len(indices) > len(shape): + indices = tuple(indices) raise ValueError("`indices` must not be longer than `shape`: " f"{indices=}, {shape=}") - # Pad out indices with slice(None) - indices = [*indices, *[slice(None)] * (len(shape) - len(indices))] - # Convert all `slice`s to `Slice`s - indices = tuple(Slice.from_slice(i, s) if isinstance(i, slice) - else i for i, s in zip(indices, shape)) + elif len(indices) < len(shape): + # Pad `indices` to have the same length as `shape`. + indices.extend([slice(None)] * (len(shape) - len(indices))) + + # Promote all builtin `slice`s to `Slice`. + indices = tuple( + Slice.from_slice(i, s) if isinstance(i, slice) else i + for i, s in zip(indices, shape)) + is_int_indexing = [not isinstance(i, Slice) for i in indices] - other_indexers, int_indexers = partition_list(is_int_indexing, indices) - indexer_shapes = [core.get_aval(i).shape for i in int_indexers] - if indexer_shapes: + if any(is_int_indexing): + other_indexers, int_indexers = partition_list(is_int_indexing, indices) + indexer_shapes = tuple(core.get_aval(i).shape for i in int_indexers) try: - bcast_shape = np.broadcast_shapes(*indexer_shapes) + int_indexer_shape = np.broadcast_shapes(*indexer_shapes) except ValueError as e: # Raise a nicer error than the NumPy one. - raise ValueError("Cannot broadcast shapes for indexing: " - f"{tuple(a for a in indexer_shapes)}") from e + raise ValueError( + f"Cannot broadcast shapes for indexing: {indexer_shapes}") from e + + # Here we use the `broadcast_to` primitive instead of composing lax + # primitives together because it is easier to lower in targets like + # Triton/Mosaic. + # + # The local import avoids a circular dependency between primitives + # and this module. + from jax._src.state import primitives as sp # pytype: disable=import-error + int_indexers = [ + sp.broadcast_to(i, int_indexer_shape) for i in int_indexers + ] + indices = tuple(merge_lists(is_int_indexing, other_indexers, int_indexers)) else: - bcast_shape = () - # Here we use the `broadcast_to` primitive instead of composing lax - # primitives together because it is easier to lower in targets like - # Triton/Mosaic. - from jax._src.state import primitives as sp # pytype: disable=import-error - int_indexers = [sp.broadcast_to(i, bcast_shape) for i in int_indexers] - indices = merge_lists(is_int_indexing, other_indexers, int_indexers) - return NDIndexer(tuple(indices), shape, bcast_shape, validate=True) + int_indexer_shape = () + + return cls(indices, shape, int_indexer_shape, validate=True) def get_indexer_shape(self) -> tuple[int | Array, ...]: _, slice_indexers, _ = unpack_ndindexer(self) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c688fe61f0dc..edd769aff5c6 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -11,21 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# pyformat: disable from __future__ import annotations -from collections.abc import Generator, Iterable, Mapping, Sequence +import collections +from collections.abc import Callable, Generator, Iterable, Sequence from contextlib import ExitStack, contextmanager import datetime import functools from functools import partial import inspect -import io import math import os import re +import sys import tempfile import textwrap -from typing import Any, Callable +from typing import Any import unittest import warnings import zlib @@ -63,18 +66,18 @@ # jax.test_util. Functionality appearing here is for internal use only, and # may be changed or removed at any time and without any deprecation cycle. -_TEST_DUT = config.DEFINE_string( +_TEST_DUT = config.string_flag( 'jax_test_dut', '', help= 'Describes the device under test in case special consideration is required.' ) -NUM_GENERATED_CASES = config.DEFINE_integer( +NUM_GENERATED_CASES = config.int_flag( 'jax_num_generated_cases', int(os.getenv('JAX_NUM_GENERATED_CASES', '10')), help='Number of generated cases to test') -_MAX_CASES_SAMPLING_RETRIES = config.DEFINE_integer( +_MAX_CASES_SAMPLING_RETRIES = config.int_flag( 'max_cases_sampling_retries', int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')), 'Number of times a failed test sample should be retried. ' @@ -82,23 +85,23 @@ 'sampling process is terminated.' ) -_SKIP_SLOW_TESTS = config.DEFINE_bool( +_SKIP_SLOW_TESTS = config.bool_flag( 'jax_skip_slow_tests', config.bool_env('JAX_SKIP_SLOW_TESTS', False), help='Skip tests marked as slow (> 5 sec).' ) -_TEST_TARGETS = config.DEFINE_string( +_TEST_TARGETS = config.string_flag( 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), 'Regular expression specifying which tests to run, called via re.search on ' 'the test name. If empty or unspecified, run all tests.' ) -_EXCLUDE_TEST_TARGETS = config.DEFINE_string( +_EXCLUDE_TEST_TARGETS = config.string_flag( 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), 'Regular expression specifying which tests NOT to run, called via re.search ' 'on the test name. If empty or unspecified, run all tests.' ) -TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.DEFINE_bool( +TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag( 'jax_test_with_persistent_compilation_cache', config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False), help='If enabled, the persistent compilation cache will be enabled for all ' @@ -177,12 +180,39 @@ def check_eq(xs, ys, err_msg=''): tree_all(tree_map(assert_close, xs, ys)) +# TODO(yashkatariya): Make this context manager check for deprecation message +# in OSS. +@contextmanager +def unaccelerate_getattr_deprecation(module, name): + message, prev_attr = module._deprecations[name] + module._deprecations[name] = (message, getattr(module, f"_deprecated_{name}")) + try: + yield + finally: + module._deprecations[name] = (message, prev_attr) + @contextmanager -def capture_stdout() -> Generator[Callable[[], str], None, None]: - with unittest.mock.patch('sys.stdout', new_callable=io.StringIO) as fp: - def _read() -> str: - return fp.getvalue() - yield _read +def capture_stdout() -> Generator[Callable[[], str | None], None, None]: + """Context manager to capture all stdout output.""" + + # The encoding should also work on windows, the default doesn't necessarily. + with tempfile.NamedTemporaryFile(mode="w+", delete=True, encoding='utf-8') as f: + original_stdout = os.dup(sys.stdout.fileno()) + os.dup2(f.fileno(), sys.stdout.fileno()) + + # if get_stdout returns not it means we are not done capturing + # stdout. it should only be used after the context has exited. + captured = None + get_stdout: Callable[[], str | None] = lambda: captured + + try: + yield get_stdout + finally: + # Python also has its own buffers, make sure everything is flushed. + sys.stdout.flush() + f.seek(0) + captured = f.read() + os.dup2(original_stdout, sys.stdout.fileno()) @contextmanager @@ -226,18 +256,18 @@ def count_primitive_compiles(): @contextmanager def count_device_put_fast_path_hit(): - original_fn = xc.copy_array_to_devices_with_sharding + original_fn = xc.batched_copy_array_to_devices_with_sharding count = [0] - def copy_array_to_devices_with_sharding_and_count(*args, **kwargs): + def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs): count[0] += 1 return original_fn(*args, **kwargs) - xc.copy_array_to_devices_with_sharding = copy_array_to_devices_with_sharding_and_count + xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count try: yield count finally: - xc.copy_array_to_devices_with_sharding = original_fn + xc.batched_copy_array_to_devices_with_sharding = original_fn @contextmanager @@ -255,6 +285,20 @@ def pjit_lower_and_count(*args, **kwargs): finally: pjit_lib._pjit_lower = original_pjit_lower +@contextmanager +def count_cached_compilation_cache_miss(): + original_cached_compilation = pxla._cached_compilation + count = [0] + + def cached_compilation_and_count(*args, **kwargs): + count[0] += 1 + return original_cached_compilation(*args, **kwargs) + + pxla._cached_compilation = cached_compilation_and_count + try: + yield count + finally: + pxla._cached_compilation = original_cached_compilation @contextmanager def count_jit_tracing_cache_miss(): @@ -272,6 +316,21 @@ def create_pjit_jaxpr_and_count(*args): finally: pjit_lib._create_pjit_jaxpr = original_create_pjit_jaxpr +@contextmanager +def count_jit_infer_params_cache_miss(): + original_infer_params_impl = pjit_lib._infer_params_impl + count = collections.defaultdict(int) + + def infer_params_impl_and_count(fun, *args, **kw): + count[fun] += 1 + return original_infer_params_impl(fun, *args, **kw) + + pjit_lib._infer_params_impl = infer_params_impl_and_count + try: + yield count + finally: + pjit_lib._infer_params_impl = original_infer_params_impl + @contextmanager def count_aot_jit_cpp_cache_miss(): @@ -356,7 +415,7 @@ def supported_dtypes(): return types def is_device_rocm(): - return xla_bridge.get_backend().platform_version.startswith('rocm') + return 'rocm' in xla_bridge.get_backend().platform_version def is_device_cuda(): return 'cuda' in xla_bridge.get_backend().platform_version @@ -474,8 +533,13 @@ def device_supports_buffer_donation(): ) +@contextmanager def set_host_platform_device_count(nr_devices: int): - """Returns a closure that undoes the operation.""" + """Context manager to set host platform device count if not specified by user. + + This should only be used by tests at the top level in setUpModule(); it will + not work correctly if applied to individual test cases. + """ prev_xla_flags = os.getenv("XLA_FLAGS") flags_str = prev_xla_flags or "" # Don't override user-specified device count, or other XLA flags. @@ -484,13 +548,14 @@ def set_host_platform_device_count(nr_devices: int): f" --xla_force_host_platform_device_count={nr_devices}") # Clear any cached backends so new CPU backend will pick up the env var. xla_bridge.get_backend.cache_clear() - def undo(): + try: + yield + finally: if prev_xla_flags is None: del os.environ["XLA_FLAGS"] else: os.environ["XLA_FLAGS"] = prev_xla_flags xla_bridge.get_backend.cache_clear() - return undo def skip_on_flag(flag_name, skip_value): @@ -519,6 +584,18 @@ def wrap(func_or_class): return wrap +def is_running_under_pytest(): + return "pytest" in sys.modules + + +def skip_under_pytest(reason: str): + """A decorator for test methods to skip the test when run under pytest.""" + reason = "Running under pytest: " + reason + def skip(test_method): + return unittest.skipIf(is_running_under_pytest(), reason)(test_method) + return skip + + def format_test_name_suffix(opname, shapes, dtypes): arg_descriptions = (format_shape_dtype_string(shape, dtype) for shape, dtype in zip(shapes, dtypes)) @@ -981,6 +1058,38 @@ def wrapper(*args, **kw): return fun(*args, **kw) return wrapper +@contextmanager +def global_config_context(**kwds): + original_config = {} + try: + for key, value in kwds.items(): + original_config[key] = config._read(key) + config.update(key, value) + yield + finally: + for key, value in original_config.items(): + config.update(key, value) + + +class NotPresent: + def __repr__(self): + return "" + + +@contextmanager +def assert_global_configs_unchanged(): + starting_config = jax.config.values.copy() + yield + ending_config = jax.config.values + + if starting_config == ending_config: + return + differing = {k: (starting_config.get(k, NotPresent()), ending_config.get(k, NotPresent())) + for k in (starting_config.keys() | ending_config.keys()) + if (k not in starting_config or k not in ending_config + or starting_config[k] != ending_config[k])} + raise AssertionError(f"Test changed global config values. Differing values are: {differing}") + class JaxTestCase(parameterized.TestCase): """Base class for JAX tests including numerical checks and boilerplate.""" @@ -1001,26 +1110,20 @@ class JaxTestCase(parameterized.TestCase): def setUp(self): super().setUp() - self._original_config = {} - for key, value in self._default_config.items(): - self._original_config[key] = config._read(key) - config.update(key, value) + self.enter_context(assert_global_configs_unchanged()) # We use the adler32 hash for two reasons. # a) it is deterministic run to run, unlike hash() which is randomized. # b) it returns values in int32 range, which RandomState requires. self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode())) - def tearDown(self): - for key, value in self._original_config.items(): - config.update(key, value) - super().tearDown() - @classmethod def setUpClass(cls): + cls._compilation_cache_exit_stack = ExitStack() + stack = cls._compilation_cache_exit_stack + stack.enter_context(global_config_context(**cls._default_config)) + if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: - cls._compilation_cache_exit_stack = ExitStack() - stack = cls._compilation_cache_exit_stack stack.enter_context(config.enable_compilation_cache(True)) stack.enter_context(config.raise_persistent_cache_errors(True)) stack.enter_context(config.persistent_cache_min_compile_time_secs(0)) @@ -1032,8 +1135,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: - cls._compilation_cache_exit_stack.close() + cls._compilation_cache_exit_stack.close() def rng(self): return self._rng @@ -1415,7 +1517,7 @@ def register_event_duration_listener(callback): def set_env(**kwargs): """Context manager to temporarily set/unset one or more environment variables. - Example: + Examples: >>> import os >>> os.environ['my_var'] = 'original' @@ -1906,101 +2008,3 @@ def worker(ctx, s, e, r, v): return worker(ctx, scale, exact, reference, value) else: assert 0 # unreachable - - -def get_process_index_and_count( - tensor_sharding: jax.sharding.Sharding, - dim: int, - global_shape: tuple[int, ...], -) -> tuple[int, int]: - """Returns current process index and total count for the given dimension. - - This function facilitates mapping of process-level data to individual - devices. Each process can use its index to obtain the data corresponding - to that index. If process level data is sharded on multiple dimensions - this function can be used to build the cross product of indices in - each sharded axis. Processes that need to load the same data will have - the same index. For shardings whose per-process data is not distributed - on a grid, the number of distinct shards will be such that it is possible to - build the target shape while maintaining a "cube" shape of local-process data. - - For example, in case of 4 hosts with sharding distributed like so: - - 1234 - 2143 - - For dim 0 (rows): all processes need to access all rows, so we return (0, 1) - For dim 1 (cols): - process 1 and 2 returns index 0 out of 2 (need cols 0 and 1), - process 3 and 4 returns index 1 out of 2 (need cols 2 and 3). - - On the other hand, for a sharding like: - - 1212 - 3434 - - Dim 0 (rows): process 1 and 2 returns (0, 2), process 3 and 4 returns (1, 2) - Dim 1 (cols): process 1 and 3 returns (0, 2), process 2 and 4 returns (1, 2) - - Note: This function requires sharding to be process uniform in dimension `dim`: - each process has the same number of addressable indices in that - dimension and all index sets across processes are either disjoint or the same. - - For sharding to be process uniform the addressable shards doesn't need to - form contiguous subtensor, or even a sparse grid and in case of - interleaved high-dimensional tensor it is possible for sharding to be - process uniform only in some dimensions but not others. - - For example: - 1111 and 12 and 1212 and 1212 - 2222 21 2121 1212 - - are all sharding uniform, in both dimensions. However - - 1122 - 2121 - 1121 - 1222 - - is uniform in dimension 0 (both hosts access all rows), but - is not uniform in dimension 1 (host 1 accesses columns: 0, 1, and 3), - while host 2 accesses (0, 1, 2, 3). - - Returns: - A tuple of (index, num_distinct_shards) for the given dimension. - It is guaranteed that `index` will cover 0 to `num_distinct_shards - 1`, - across all processes. - - Raises: - ValueError if the sharding is not process uniform in dimension `dim`. - """ - # TODO(sandler, yashkatariya): Consider making this function public. - - if tensor_sharding.is_fully_addressable or tensor_sharding.is_fully_replicated: - return (0, 1) - # NB: For most types of shardings, global_shape is a superfluous argument - # and could be replaced by [d, d, ...., d, d], where d is the number of - # devices. - device_map: Mapping[jax.sharding.Device, jax.sharding.Index] = ( - tensor_sharding.devices_indices_map(global_shape) - ) - - global_slice = {k: v[dim] for k, v in device_map.items()} - process_map: dict[int, set[tuple[int, int]]] = {} - all_slices = set() - - current_pid = next(iter(tensor_sharding.addressable_devices)).process_index - for d, v in global_slice.items(): - key = (v.start, v.stop) - process_map.setdefault(d.process_index, set()).add(key) - all_slices.add(key) - addressable = frozenset(process_map[current_pid]) - slices_per_process = len(addressable) - if any(len(x) != slices_per_process for x in process_map.values()): - raise ValueError(f'{tensor_sharding=} is non-uniform on {dim=}') - unique_processes = list({frozenset(x) for x in process_map.values()}) - - # After removing duplicate processes each slide should appear exactly once. - if sum(len(h) for h in unique_processes) != len(all_slices): - raise ValueError(f'{tensor_sharding=} is non-uniform on {dim=}') - return (unique_processes.index(addressable), len(unique_processes)) diff --git a/jax/_src/third_party/scipy/interpolate.py b/jax/_src/third_party/scipy/interpolate.py index a5bf460a2674..1eb726ea863c 100644 --- a/jax/_src/third_party/scipy/interpolate.py +++ b/jax/_src/third_party/scipy/interpolate.py @@ -45,7 +45,7 @@ class RegularGridInterpolator: Returns: interpolator: callable interpolation object. - Example: + Examples: >>> points = (jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) >>> values = jnp.array([[10, 20, 30], [40, 50, 60], [70, 80, 90]]) >>> interpolate = RegularGridInterpolator(points, values, method='linear') diff --git a/jax/_src/third_party/scipy/linalg.py b/jax/_src/third_party/scipy/linalg.py index 03093c280b2a..dce4df1fb817 100644 --- a/jax/_src/third_party/scipy/linalg.py +++ b/jax/_src/third_party/scipy/linalg.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from collections.abc import Callable from jax import jit, lax import jax.numpy as jnp @@ -65,7 +65,7 @@ def funm(A: ArrayLike, func: Callable[[Array], Array], close to zero, the SciPy function may return a real-valued array, whereas the JAX implementation will return a complex-valued array. - Example: + Examples: Applying an arbitrary matrix function: >>> A = jnp.array([[1., 2.], [3., 4.]]) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 3699e7ee4209..14721cea7682 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -19,12 +19,13 @@ import base64 import collections.abc +from collections.abc import Callable, Sequence import dataclasses import functools import io import os import time -from typing import Any, Callable +from typing import Any import jax from jax import core @@ -33,13 +34,10 @@ from jax._src.interpreters import mlir from jax._src.lib import tpu from jax._src.lib import xla_client -from jax._src.lib.mlir.dialects import hlo from jax.interpreters import xla from jaxlib.mlir import ir from jaxlib.mlir.dialects import mhlo -from jaxlib.mlir.dialects import stablehlo from jaxlib.mlir.passmanager import PassManager -import numpy as np try: from absl import flags @@ -47,7 +45,7 @@ except ImportError: FLAGS = {} -_MOSAIC_USE_PYTHON_PIPELINE = config.define_bool_state( +_MOSAIC_USE_PYTHON_PIPELINE = config.bool_state( name="mosaic_use_python_pipeline", default=False, help=( @@ -57,7 +55,7 @@ ), ) -_MOSAIC_ALLOW_HLO = config.define_bool_state( +_MOSAIC_ALLOW_HLO = config.bool_state( name="jax_mosaic_allow_hlo", default=False, help="Allow hlo dialects in Mosaic", @@ -96,6 +94,7 @@ class CustomCallBackendConfig: flags: dict[str, bool | int | float] | None allow_input_fusion: list[bool] | None serialization_format: int | None + internal_scratch_in_bytes: int | None # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. @@ -135,6 +134,9 @@ def to_json(self) -> bytes: if i + 1 != len(self.allow_input_fusion): config.write(b",") config.write(b"]") + if self.internal_scratch_in_bytes is not None: + config.write(b', "internal_scratch_in_bytes": ') + config.write(str(self.internal_scratch_in_bytes).encode("ascii")) config.write(b"}") # End of custom_call_config. if self.device_type is not None: config.write(b', "device_type": ') @@ -181,13 +183,8 @@ def _tpu_custom_call_abstract_eval(*_, out_avals, **__): return out_avals -def _aval_to_layout(aval): - arange = np.arange(aval.ndim, dtype=np.dtype(np.int64))[::-1].copy() - return ir.DenseIntElementsAttr.get(arange, type=ir.IndexType.get()) - - -def _avals_to_layouts(avals): - return ir.ArrayAttr.get([_aval_to_layout(a) for a in avals]) +def _avals_to_layouts(avals) -> Sequence[Sequence[int]]: + return [tuple(range(a.ndim - 1, -1, -1)) for a in avals] def _tpu_custom_call_lowering( @@ -195,18 +192,11 @@ def _tpu_custom_call_lowering( *in_nodes, # pylint: disable=missing-function-docstring config: CustomCallBackendConfig, kernel_name: str | None, - kernel_regeneration_metadata: bytes | None, out_avals: Any, input_output_aliases: tuple[tuple[int, int], ...], ) -> ...: i32_type = ir.IntegerType.get_signless(32) - multiple_results = len(out_avals) > 1 - if multiple_results: - result_type = ir.TupleType.get_tuple( - [mlir.aval_to_ir_type(aval) for aval in out_avals] - ) - else: - result_type = mlir.aval_to_ir_type(out_avals[0]) + result_types = [mlir.aval_to_ir_type(aval) for aval in out_avals] axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.SPMDAxisContext): if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names): @@ -224,45 +214,23 @@ def _tpu_custom_call_lowering( raise NotImplementedError( "Replica lowering for Mosaic kernels not implemented." ) - call = stablehlo.CustomCallOp( - [result_type], - in_nodes, - call_target_name=ir.StringAttr.get(b"tpu_custom_call"), - has_side_effect=ir.BoolAttr.get(False), - backend_config=ir.StringAttr.get(config.to_json()), - api_version=ir.IntegerAttr.get(i32_type, 1), - called_computations=ir.ArrayAttr.get([]), + extra_attributes = {} + # Add kernel_name and kernel_metadata as attributes to the custom call op. + # This is because we do not want to pollute the backend_config with this + # information. + if kernel_name is not None: + extra_attributes = dict(kernel_name=ir.StringAttr.get(kernel_name)) + call = mlir.custom_call( + "tpu_custom_call", + result_types=result_types, + operands=in_nodes, + backend_config=config.to_json(), + api_version=1, + operand_output_aliases=dict(input_output_aliases), operand_layouts=_avals_to_layouts(ctx.avals_in), result_layouts=_avals_to_layouts(ctx.avals_out), - output_operand_aliases=ir.ArrayAttr.get([ - hlo.OutputOperandAlias.get( - # if len(result_types) == 1 then the aliasing refers implicitly to - # the only output. - output_tuple_indices=[output_idx] - if len(out_avals) > 1 - else [], - operand_index=input_idx, - operand_tuple_indices=[], - ) - for input_idx, output_idx in input_output_aliases - ]), - ) - - # Add kernel_name and kernel_regeneration_metadata as attributes to the - # custom call op. This is because we do not want to pollute the backend_config - # with this information. - if kernel_name is not None: - call.attributes["kernel_name"] = ir.StringAttr.get(kernel_name) - if kernel_regeneration_metadata is not None: - call.attributes["kernel_regeneration_metadata"] = ir.StringAttr.get( - base64.b64encode(kernel_regeneration_metadata) - ) - if multiple_results: - results = [stablehlo.get_tuple_element(call, mlir.i32_attr(i)) - for i in range(len(out_avals))] - else: - results = call.results - return results + extra_attributes=extra_attributes) + return call.results mlir.register_lowering(tpu_custom_call_p, _tpu_custom_call_lowering, @@ -376,24 +344,20 @@ def as_tpu_kernel( backend: str | xla_client.Client = "tpu", device_type: str | None = None, kernel_name: str | None = None, - kernel_regeneration_metadata: bytes | None = None, vmem_limit_bytes: int | None = None, flags: dict[str, bool | int | float] | None = None, allow_input_fusion: list[bool] | None = None, input_output_aliases: tuple[tuple[int, int], ...] = (), + internal_scratch_in_bytes: int | None = None, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" # We use jax.jit to make sure we hit the fast compilation cache. - some_tpu = jax.devices(backend)[0] - device_kind = some_tpu.device_kind - if not device_kind.startswith("TPU v"): - raise ValueError(f"Unrecognized TPU device kind: {device_kind}.") + if vmem_limit_bytes is not None and not isinstance(vmem_limit_bytes, int): raise ValueError( "vmem_limit_bytes must be an int: provided with a" f" {type(vmem_limit_bytes)}." ) - hardware_generation = int(device_kind[len("TPU v")]) has_communication, has_custom_barrier = tpu.private_has_communication( module.operation ) @@ -405,6 +369,14 @@ def as_tpu_kernel( module.operation.get_asm(binary=True, enable_debug_info=True) ) if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value: + some_tpu = jax.devices(backend)[0] + device_kind = some_tpu.device_kind + if not device_kind.startswith("TPU v"): + raise ValueError( + f"Unrecognized TPU device kind: {device_kind}. " + "tpu_custom_call cannot be lowered on a machine without TPUs " + "when mosaic_use_python_pipeline=True.") + hardware_generation = int(device_kind[len("TPU v")]) module = _lower_tpu_kernel(module, hardware_generation) needs_hlo_passes = False needs_layout_passes = False @@ -431,12 +403,12 @@ def as_tpu_kernel( has_communication=has_communication, has_custom_barrier=has_custom_barrier, kernel_name=kernel_name, - kernel_regeneration_metadata=kernel_regeneration_metadata, cost_estimate=cost_estimate, vmem_limit_bytes=vmem_limit_bytes, flags=flags, allow_input_fusion=allow_input_fusion, input_output_aliases=input_output_aliases, + internal_scratch_in_bytes=internal_scratch_in_bytes, ) @@ -451,12 +423,12 @@ def _lowered_as_tpu_kernel( has_communication: bool = False, has_custom_barrier: bool = False, kernel_name: str | None = None, - kernel_regeneration_metadata: bytes | None = None, vmem_limit_bytes: int | None = None, flags: dict[str, bool | int | float] | None = None, allow_input_fusion: list[bool] | None = None, input_output_aliases: tuple[tuple[int, int], ...] = (), serialization_format: int | None = 1, + internal_scratch_in_bytes: int | None = None, ): """Turns a low-level MLIR Mosaic kernel into a JAX-compatible function.""" unpack = False @@ -486,13 +458,13 @@ def apply_kernel(*args, collective_id: int | None = None): vmem_limit_bytes, flags, allow_input_fusion, - serialization_format=serialization_format, + serialization_format, + internal_scratch_in_bytes, ) result = tpu_custom_call_p.bind( *args, config=config, kernel_name=kernel_name, - kernel_regeneration_metadata=kernel_regeneration_metadata, out_avals=out_avals, input_output_aliases=input_output_aliases, ) diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index 0a6ef8da4263..d66cbb912a99 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -14,12 +14,13 @@ from __future__ import annotations +from collections.abc import Callable import functools import os import sys import traceback import types -from typing import Any, Callable, TypeVar, cast +from typing import Any, TypeVar, cast from jax._src import config from jax._src import util diff --git a/jax/_src/tree.py b/jax/_src/tree.py index f8358670af7b..49faaa774ef2 100644 --- a/jax/_src/tree.py +++ b/jax/_src/tree.py @@ -13,51 +13,145 @@ # limitations under the License. from __future__ import annotations -import functools -from typing import Any, Callable, Iterable, TypeVar, overload +from collections.abc import Callable, Iterable +from typing import Any, TypeVar, overload from jax._src import tree_util T = TypeVar("T") -def _add_doc(docstr): - def wrapper(fun): - doc = fun.__doc__ - firstline, rest = doc.split('\n', 1) - fun.__doc__ = f'{firstline}\n\n {docstr}\n{rest}' - return fun - return wrapper +def all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool: + """Call all() over the leaves of a tree. + Args: + tree: the pytree to evaluate + is_leaf : an optionally specified function that will be called at each + flattening step. It should return a boolean, which indicates whether the + flattening should traverse the current object, or if it should be stopped + immediately, with the whole subtree being treated as a leaf. -@_add_doc("Alias of :func:`jax.tree_util.tree_all`.") -@functools.wraps(tree_util.tree_all) -def all(tree: Any) -> bool: - return tree_util.tree_all(tree) + Returns: + result: boolean True or False + + Examples: + >>> import jax + >>> jax.tree.all([True, {'a': True, 'b': (True, True)}]) + True + >>> jax.tree.all([False, (True, False)]) + False + + See Also: + - :func:`jax.tree.reduce` + - :func:`jax.tree.leaves` + """ + return tree_util.tree_all(tree, is_leaf=is_leaf) -@_add_doc("Alias of :func:`jax.tree_util.tree_flatten`.") -@functools.wraps(tree_util.tree_flatten) def flatten(tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> tuple[list[tree_util.Leaf], tree_util.PyTreeDef]: + """Flattens a pytree. + + The flattening order (i.e. the order of elements in the output list) + is deterministic, corresponding to a left-to-right depth-first tree + traversal. + + Args: + tree: a pytree to flatten. + is_leaf: an optionally specified function that will be called at each + flattening step. It should return a boolean, with true stopping the + traversal and the whole subtree being treated as a leaf, and false + indicating the flattening should traverse the current object. + + Returns: + A pair where the first element is a list of leaf values and the second + element is a treedef representing the structure of the flattened tree. + + Examples: + >>> import jax + >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) + >>> vals + [1, 2, 3, 4, 5] + >>> treedef + PyTreeDef([*, (*, *), [*, *]]) + + See Also: + - :func:`jax.tree.leaves` + - :func:`jax.tree.structure` + - :func:`jax.tree.unflatten` + """ return tree_util.tree_flatten(tree, is_leaf) -@_add_doc("Alias of :func:`jax.tree_util.tree_leaves`.") -@functools.wraps(tree_util.tree_leaves) def leaves(tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[tree_util.Leaf]: + """Gets the leaves of a pytree. + + Args: + tree: the pytree for which to get the leaves + is_leaf : an optionally specified function that will be called at each + flattening step. It should return a boolean, which indicates whether the + flattening should traverse the current object, or if it should be stopped + immediately, with the whole subtree being treated as a leaf. + + Returns: + leaves: a list of tree leaves. + + Examples: + >>> import jax + >>> jax.tree.leaves([1, (2, 3), [4, 5]]) + [1, 2, 3, 4, 5] + + See Also: + - :func:`jax.tree.flatten` + - :func:`jax.tree.structure` + - :func:`jax.tree.unflatten` + """ return tree_util.tree_leaves(tree, is_leaf) -@_add_doc("Alias of :func:`jax.tree_util.tree_map`.") -@functools.wraps(tree_util.tree_map) def map(f: Callable[..., Any], tree: Any, *rest: Any, is_leaf: Callable[[Any], bool] | None = None) -> Any: + """Maps a multi-input function over pytree args to produce a new pytree. + + Args: + f: function that takes ``1 + len(rest)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree: a pytree to be mapped over, with each leaf providing the first + positional argument to ``f``. + rest: a tuple of pytrees, each of which has the same structure as ``tree`` + or has ``tree`` as a prefix. + is_leaf: an optionally specified function that will be called at each + flattening step. It should return a boolean, which indicates whether the + flattening should traverse the current object, or if it should be stopped + immediately, with the whole subtree being treated as a leaf. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each + leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding + leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in + ``rest``. + + Examples: + + >>> import jax + >>> jax.tree.map(lambda x: x + 1, {"x": 7, "y": 42}) + {'x': 8, 'y': 43} + + If multiple inputs are passed, the structure of the tree is taken from the + first input; subsequent inputs need only have ``tree`` as a prefix: + + >>> jax.tree.map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) + [[5, 7, 9], [6, 1, 2]] + + See Also: + - :func:`jax.tree.leaves` + - :func:`jax.tree.reduce` + """ return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf) @@ -73,32 +167,120 @@ def reduce(function: Callable[[T, Any], T], initializer: T, is_leaf: Callable[[Any], bool] | None = None) -> T: ... -@_add_doc("Alias of :func:`jax.tree_util.tree_reduce`.") -@functools.wraps(tree_util.tree_reduce) def reduce(function: Callable[[T, Any], T], tree: Any, initializer: Any = tree_util.no_initializer, is_leaf: Callable[[Any], bool] | None = None) -> T: + """Call reduce() over the leaves of a tree. + + Args: + function: the reduction function + tree: the pytree to reduce over + initializer: the optional initial value + is_leaf : an optionally specified function that will be called at each + flattening step. It should return a boolean, which indicates whether the + flattening should traverse the current object, or if it should be stopped + immediately, with the whole subtree being treated as a leaf. + + Returns: + result: the reduced value. + + Examples: + >>> import jax + >>> import operator + >>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]]) + 21 + + See Also: + - :func:`jax.tree.leaves` + - :func:`jax.tree.map` + """ return tree_util.tree_reduce(function, tree, initializer, is_leaf=is_leaf) -@_add_doc("Alias of :func:`jax.tree_util.tree_structure`.") -@functools.wraps(tree_util.tree_structure) def structure(tree: Any, is_leaf: None | (Callable[[Any], bool]) = None) -> tree_util.PyTreeDef: + """Gets the treedef for a pytree. + + Args: + tree: the pytree for which to get the leaves + is_leaf : an optionally specified function that will be called at each + flattening step. It should return a boolean, which indicates whether the + flattening should traverse the current object, or if it should be stopped + immediately, with the whole subtree being treated as a leaf. + + Returns: + pytreedef: a PyTreeDef representing the structure of the tree. + + Examples: + >>> import jax + >>> jax.tree.structure([1, (2, 3), [4, 5]]) + PyTreeDef([*, (*, *), [*, *]]) + + See Also: + - :func:`jax.tree.flatten` + - :func:`jax.tree.leaves` + - :func:`jax.tree.unflatten` + """ return tree_util.tree_structure(tree, is_leaf) -@_add_doc("Alias of :func:`jax.tree_util.tree_transpose`.") -@functools.wraps(tree_util.tree_transpose) def transpose(outer_treedef: tree_util.PyTreeDef, inner_treedef: tree_util.PyTreeDef, pytree_to_transpose: Any) -> Any: + """Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). + + Args: + outer_treedef: PyTreeDef representing the outer tree. + inner_treedef: PyTreeDef representing the inner tree. + If None, then it will be inferred from outer_treedef and the structure of + pytree_to_transpose. + pytree_to_transpose: the pytree to be transposed. + + Returns: + transposed_pytree: the transposed pytree. + + Examples: + >>> import jax + >>> tree = [(1, 2, 3), (4, 5, 6)] + >>> inner_structure = jax.tree.structure(('*', '*', '*')) + >>> outer_structure = jax.tree.structure(['*', '*']) + >>> jax.tree.transpose(outer_structure, inner_structure, tree) + ([1, 4], [2, 5], [3, 6]) + + Inferring the inner structure: + + >>> jax.tree.transpose(outer_structure, None, tree) + ([1, 4], [2, 5], [3, 6]) + """ return tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose) -@_add_doc("Alias of :func:`jax.tree_util.tree_unflatten`.") -@functools.wraps(tree_util.tree_unflatten) def unflatten(treedef: tree_util.PyTreeDef, leaves: Iterable[tree_util.Leaf]) -> Any: + """Reconstructs a pytree from the treedef and the leaves. + + The inverse of :func:`tree_flatten`. + + Args: + treedef: the treedef to reconstruct + leaves: the iterable of leaves to use for reconstruction. The iterable must + match the leaves of the treedef. + + Returns: + The reconstructed pytree, containing the ``leaves`` placed in the structure + described by ``treedef``. + + Examples: + >>> import jax + >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) + >>> newvals = [100, 200, 300, 400, 500] + >>> jax.tree.unflatten(treedef, newvals) + [100, (200, 300), [400, 500]] + + See Also: + - :func:`jax.tree.flatten` + - :func:`jax.tree.leaves` + - :func:`jax.tree.structure` + """ return tree_util.tree_unflatten(treedef, leaves) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 857ad46ffd10..32f59b1df36e 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -14,14 +14,14 @@ from __future__ import annotations import collections -from collections.abc import Hashable, Iterable +from collections.abc import Callable, Hashable, Iterable, Sequence from dataclasses import dataclass import difflib import functools from functools import partial import operator as op import textwrap -from typing import Any, Callable, NamedTuple, Sequence, TypeVar, Union, overload +from typing import Any, NamedTuple, TypeVar, Union, overload from jax._src import traceback_util from jax._src.lib import pytree @@ -74,66 +74,13 @@ def tree_flatten(tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> tuple[list[Leaf], PyTreeDef]: - """Flattens a pytree. - - The flattening order (i.e. the order of elements in the output list) - is deterministic, corresponding to a left-to-right depth-first tree - traversal. - - Args: - tree: a pytree to flatten. - is_leaf: an optionally specified function that will be called at each - flattening step. It should return a boolean, with true stopping the - traversal and the whole subtree being treated as a leaf, and false - indicating the flattening should traverse the current object. - - Returns: - A pair where the first element is a list of leaf values and the second - element is a treedef representing the structure of the flattened tree. - - Example: - >>> import jax - >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) - >>> vals - [1, 2, 3, 4, 5] - >>> treedef - PyTreeDef([*, (*, *), [*, *]]) - - See Also: - - :func:`jax.tree.leaves` - - :func:`jax.tree.structure` - - :func:`jax.tree.unflatten` - """ + """Alias of :func:`jax.tree.flatten`.""" return default_registry.flatten(tree, is_leaf) @export def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any: - """Reconstructs a pytree from the treedef and the leaves. - - The inverse of :func:`tree_flatten`. - - Args: - treedef: the treedef to reconstruct - leaves: the iterable of leaves to use for reconstruction. The iterable must - match the leaves of the treedef. - - Returns: - The reconstructed pytree, containing the ``leaves`` placed in the structure - described by ``treedef``. - - Example: - >>> import jax - >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) - >>> newvals = [100, 200, 300, 400, 500] - >>> jax.tree.unflatten(treedef, newvals) - [100, (200, 300), [400, 500]] - - See Also: - - :func:`jax.tree.flatten` - - :func:`jax.tree.leaves` - - :func:`jax.tree.structure` - """ + """Alias of :func:`jax.tree.unflatten`.""" return treedef.unflatten(leaves) @@ -141,28 +88,7 @@ def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any: def tree_leaves(tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[Leaf]: - """Gets the leaves of a pytree. - - Args: - tree: the pytree for which to get the leaves - is_leaf : an optionally specified function that will be called at each - flattening step. It should return a boolean, which indicates whether the - flattening should traverse the current object, or if it should be stopped - immediately, with the whole subtree being treated as a leaf. - - Returns: - leaves: a list of tree leaves. - - Example: - >>> import jax - >>> jax.tree.leaves([1, (2, 3), [4, 5]]) - [1, 2, 3, 4, 5] - - See Also: - - :func:`jax.tree.flatten` - - :func:`jax.tree.structure` - - :func:`jax.tree.unflatten` - """ + """Alias of :func:`jax.tree.leaves`.""" return default_registry.flatten(tree, is_leaf)[0] @@ -170,28 +96,7 @@ def tree_leaves(tree: Any, def tree_structure(tree: Any, is_leaf: None | (Callable[[Any], bool]) = None) -> PyTreeDef: - """Gets the treedef for a pytree. - - Args: - tree: the pytree for which to get the leaves - is_leaf : an optionally specified function that will be called at each - flattening step. It should return a boolean, which indicates whether the - flattening should traverse the current object, or if it should be stopped - immediately, with the whole subtree being treated as a leaf. - - Returns: - pytreedef: a PyTreeDef representing the structure of the tree. - - Example: - >>> import jax - >>> jax.tree.structure([1, (2, 3), [4, 5]]) - PyTreeDef([*, (*, *), [*, *]]) - - See Also: - - :func:`jax.tree.flatten` - - :func:`jax.tree.leaves` - - :func:`jax.tree.unflatten` - """ + """Alias of :func:`jax.tree.structure`.""" return default_registry.flatten(tree, is_leaf)[1] @@ -205,7 +110,7 @@ def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef: Returns: a single treedef representing a tuple of the structures - Example: + Examples: >>> import jax >>> x = [1, 2, 3] >>> y = {'a': 4, 'b': 5} @@ -256,7 +161,7 @@ def treedef_is_leaf(treedef: PyTreeDef) -> bool: Returns: True if treedef is a leaf (i.e. has a single node); False otherwise. - Example: + Examples: >>> import jax >>> tree1 = jax.tree.structure(1) >>> jax.tree_util.treedef_is_leaf(tree1) @@ -288,7 +193,7 @@ def all_leaves(iterable: Iterable[Any], Returns: A boolean indicating if all elements in the input are leaves. - Example: + Examples: >>> import jax >>> tree = {"a": [1, 2, 3]} >>> assert all_leaves(jax.tree_util.tree_leaves(tree)) @@ -331,7 +236,7 @@ def register_pytree_node(nodetype: type[T], - :func:`~jax.tree_util.register_pytree_node_class` - :func:`~jax.tree_util.register_pytree_with_keys_class` - Example: + Examples: First we'll define a custom type: >>> class MyContainer: @@ -401,7 +306,7 @@ def register_pytree_node_class(cls: Typ) -> Typ: - :func:`~jax.tree_util.register_pytree_with_keys` - :func:`~jax.tree_util.register_pytree_with_keys_class` - Example: + Examples: Here we'll define a custom container that will be compatible with :func:`jax.jit` and other JAX transformations: @@ -432,42 +337,7 @@ def tree_map(f: Callable[..., Any], tree: Any, *rest: Any, is_leaf: Callable[[Any], bool] | None = None) -> Any: - """Maps a multi-input function over pytree args to produce a new pytree. - - Args: - f: function that takes ``1 + len(rest)`` arguments, to be applied at the - corresponding leaves of the pytrees. - tree: a pytree to be mapped over, with each leaf providing the first - positional argument to ``f``. - rest: a tuple of pytrees, each of which has the same structure as ``tree`` - or has ``tree`` as a prefix. - is_leaf: an optionally specified function that will be called at each - flattening step. It should return a boolean, which indicates whether the - flattening should traverse the current object, or if it should be stopped - immediately, with the whole subtree being treated as a leaf. - - Returns: - A new pytree with the same structure as ``tree`` but with the value at each - leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding - leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in - ``rest``. - - Examples: - - >>> import jax.tree_util - >>> jax.tree_util.tree_map(lambda x: x + 1, {"x": 7, "y": 42}) - {'x': 8, 'y': 43} - - If multiple inputs are passed, the structure of the tree is taken from the - first input; subsequent inputs need only have ``tree`` as a prefix: - - >>> jax.tree_util.tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) - [[5, 7, 9], [6, 1, 2]] - - See Also: - - :func:`jax.tree.leaves` - - :func:`jax.tree.reduce` - """ + """Alias of :func:`jax.tree.map`.""" leaves, treedef = tree_flatten(tree, is_leaf) all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) @@ -487,7 +357,7 @@ def build_tree(treedef: PyTreeDef, xs: Any) -> Any: See Also: - :func:`jax.tree.unflatten` - Example: + Examples: >>> import jax >>> tree = [(1, 2), {'a': 3, 'b': 4}] >>> treedef = jax.tree.structure(tree) @@ -507,31 +377,7 @@ def build_tree(treedef: PyTreeDef, xs: Any) -> Any: @export def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None, pytree_to_transpose: Any) -> Any: - """Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). - - Args: - outer_treedef: PyTreeDef representing the outer tree. - inner_treedef: PyTreeDef representing the inner tree. - If None, then it will be inferred from outer_treedef and the structure of - pytree_to_transpose. - pytree_to_transpose: the pytree to be transposed. - - Returns: - transposed_pytree: the transposed pytree. - - Examples: - >>> import jax - >>> tree = [(1, 2, 3), (4, 5, 6)] - >>> inner_structure = jax.tree.structure(('*', '*', '*')) - >>> outer_structure = jax.tree.structure(['*', '*']) - >>> jax.tree.transpose(outer_structure, inner_structure, tree) - ([1, 4], [2, 5], [3, 6]) - - Inferring the inner structure: - - >>> jax.tree.transpose(outer_structure, None, tree) - ([1, 4], [2, 5], [3, 6]) - """ + """Alias of :func:`jax.tree.transpose`.""" flat, treedef = tree_flatten(pytree_to_transpose) if inner_treedef is None: inner_treedef = tree_structure(outer_treedef.flatten_up_to(pytree_to_transpose)[0]) @@ -592,30 +438,7 @@ def tree_reduce(function: Callable[[T, Any], T], tree: Any, initializer: Any = no_initializer, is_leaf: Callable[[Any], bool] | None = None) -> T: - """Call reduce() over the leaves of a tree. - - Args: - function: the reduction function - tree: the pytree to reduce over - initializer: the optional initial value - is_leaf : an optionally specified function that will be called at each - flattening step. It should return a boolean, which indicates whether the - flattening should traverse the current object, or if it should be stopped - immediately, with the whole subtree being treated as a leaf. - - Returns: - result: the reduced value. - - Examples: - >>> import jax - >>> import operator - >>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]]) - 21 - - See Also: - - :func:`jax.tree.leaves` - - :func:`jax.tree.map` - """ + """Alias of :func:`jax.tree.reduce`.""" if initializer is no_initializer: return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf)) else: @@ -623,27 +446,9 @@ def tree_reduce(function: Callable[[T, Any], T], @export -def tree_all(tree: Any) -> bool: - """Call all() over the leaves of a tree. - - Args: - tree: the pytree to evaluate - - Returns: - result: boolean True or False - - Examples: - >>> import jax - >>> jax.tree.all([True, {'a': True, 'b': (True, True)}]) - True - >>> jax.tree.all([False, (True, False)]) - False - - See Also: - - :func:`jax.tree_util.tree_reduce` - - :func:`jax.tree_util.tree_leaves` - """ - return all(tree_leaves(tree)) +def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool: + """Alias of :func:`jax.tree.all`.""" + return all(tree_leaves(tree, is_leaf=is_leaf)) register_pytree_node( @@ -786,7 +591,7 @@ def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]: ValueError: If the given pytree is not a built-in or registered container via ``register_pytree_node`` or ``register_pytree_with_keys``. - Example: + Examples: >>> import jax >>> from jax._src.tree_util import flatten_one_level >>> flattened, meta = flatten_one_level({'a': [1, 2], 'b': {'c': 3}}) @@ -940,7 +745,7 @@ def keystr(keys: KeyPath): Returns: A string that joins all string representations of the keys. - Example: + Examples: >>> import jax >>> keys = (0, 1, 'a', 'b') >>> jax.tree_util.keystr(keys) @@ -1017,7 +822,7 @@ def register_pytree_with_keys( This argument is optional and only needed for faster traversal when calling functions without keys like ``tree_map`` and ``tree_flatten``. - Example: + Examples: First we'll define a custom type: >>> class MyContainer: @@ -1086,7 +891,7 @@ class that defines how it could be flattened with keys. - :func:`~jax.tree_util.register_pytree_with_keys` - :func:`~jax.tree_util.register_pytree_node_class` - Example: + Examples: >>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey >>> @register_pytree_with_keys_class ... class Special: @@ -1141,7 +946,7 @@ def register_dataclass( pytree registry. This return value allows ``register_dataclass`` to be partially evaluated and used as a decorator as in the example below. - Example: + Examples: >>> from dataclasses import dataclass >>> from functools import partial >>> diff --git a/jax/_src/util.py b/jax/_src/util.py index eb7b4b80adfb..7aab80b2def3 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -15,13 +15,14 @@ from __future__ import annotations import abc -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence +import dataclasses import functools from functools import partial import itertools as it import logging import operator -from typing import (Any, Callable, Generic, TypeVar, overload, TYPE_CHECKING, cast) +from typing import (Any, Generic, TypeVar, overload, TYPE_CHECKING, cast) import weakref import numpy as np @@ -285,7 +286,18 @@ def merge(new_lhs, new_rhs): return lhs, rhs, merge -def cache(max_size=4096): + +@dataclasses.dataclass(frozen=True) +class _IgnoreKey: + + def __hash__(self): + return hash(self.__class__) + + def __eq__(self, other): + return isinstance(other, _IgnoreKey) + + +def cache(max_size=4096, trace_context_in_key=True): def wrap(f): @functools.lru_cache(max_size) def cached(_, *args, **kwargs): @@ -295,14 +307,24 @@ def cached(_, *args, **kwargs): def wrapper(*args, **kwargs): if config.check_tracer_leaks.value: return f(*args, **kwargs) - else: + elif trace_context_in_key: return cached(config.trace_context(), *args, **kwargs) + else: + return cached(_IgnoreKey(), *args, **kwargs) wrapper.cache_clear = cached.cache_clear wrapper.cache_info = cached.cache_info + cache_clearing_funs.add(wrapper.cache_clear) return wrapper return wrap +cache_clearing_funs = weakref.WeakSet() # type: ignore + +def clear_all_caches(): + global cache_clearing_funs + for clear in cache_clearing_funs: + clear() + memoize = cache(max_size=None) def weakref_lru_cache(call: Callable, maxsize=2048): @@ -360,6 +382,9 @@ def __eq__(self, other): def wrap_name(name, transform_name): return transform_name + '(' + name + ')' +def fun_name(fun: Callable): + return getattr(fun, "__name__", "") + def canonicalize_axis(axis, num_dims) -> int: """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" axis = operator.index(axis) @@ -399,7 +424,7 @@ def wraps( """ def wrapper(fun: T) -> T: try: - name = getattr(wrapped, "__name__", "") + name = fun_name(wrapped) doc = getattr(wrapped, "__doc__", "") or "" fun.__dict__.update(getattr(wrapped, "__dict__", {})) fun.__annotations__ = getattr(wrapped, "__annotations__", {}) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index c68d39abf96f..41fd2c586593 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -20,10 +20,8 @@ """ from __future__ import annotations -from __future__ import annotations - import atexit -from collections.abc import Mapping +from collections.abc import Callable, Mapping import dataclasses from functools import lru_cache, partial import importlib @@ -32,23 +30,20 @@ import os import pkgutil import platform as py_platform -import traceback -import sys import threading -from typing import Any, Callable, Union +import traceback +from typing import Any, Union import warnings from jax._src import config from jax._src import distributed +from jax._src import hardware_utils from jax._src import traceback_util from jax._src import util -from jax._src import hardware_utils -from jax._src.cloud_tpu_init import maybe_import_libtpu +from jax._src.cloud_tpu_init import get_tpu_library_path from jax._src.lib import cuda_versions from jax._src.lib import xla_client from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version -from jax._src.lib import jaxlib logger = logging.getLogger(__name__) @@ -68,46 +63,46 @@ MIN_COMPUTE_CAPABILITY = 52 # TODO(phawkins): Remove jax_xla_backend. -_XLA_BACKEND = config.DEFINE_string( +_XLA_BACKEND = config.string_flag( 'jax_xla_backend', '', 'Deprecated, please use --jax_platforms instead.') -BACKEND_TARGET = config.DEFINE_string( +BACKEND_TARGET = config.string_flag( 'jax_backend_target', os.getenv('JAX_BACKEND_TARGET', '').lower(), 'Either "local" or "rpc:address" to connect to a remote service target.') # TODO(skye): warn when this is used once we test out --jax_platforms a bit -_PLATFORM_NAME = config.DEFINE_string( +_PLATFORM_NAME = config.string_flag( 'jax_platform_name', os.getenv('JAX_PLATFORM_NAME', '').lower(), 'Deprecated, please use --jax_platforms instead.') -CUDA_VISIBLE_DEVICES = config.DEFINE_string( +CUDA_VISIBLE_DEVICES = config.string_flag( 'jax_cuda_visible_devices', 'all', 'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') -_ROCM_VISIBLE_DEVICES = config.DEFINE_string( +_ROCM_VISIBLE_DEVICES = config.string_flag( 'jax_rocm_visible_devices', 'all', 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') -_USE_MOCK_GPU_CLIENT = config.DEFINE_bool( +_USE_MOCK_GPU_CLIENT = config.bool_flag( name="use_mock_gpu_client", default=False, help="If True, use a mock GPU client instead of a real one.", ) -_MOCK_NUM_GPUS = config.DEFINE_integer( +_MOCK_NUM_GPUS = config.int_flag( name="mock_num_gpus", default=1, help="Mock GPU client number of gpus.", ) -_CPU_ENABLE_GLOO_COLLECTIVES = config.DEFINE_bool( +_CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag( name="jax_cpu_enable_gloo_collectives", default=False, help="Deprecated, please use jax_cpu_collectives_implementation instead.", ) -_CPU_COLLECTIVES_IMPLEMENTATION = config.DEFINE_string( +_CPU_COLLECTIVES_IMPLEMENTATION = config.string_flag( name='jax_cpu_collectives_implementation', default='none', help='Cross-process collective implementation used on CPU. Either "none", ' @@ -116,7 +111,7 @@ # TODO(yueshengys): turn default back to True after resolving memory increase # issue. -_CPU_ENABLE_ASYNC_DISPATCH = config.DEFINE_bool( +_CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag( name="jax_cpu_enable_async_dispatch", default=False, help="Only applies to non-parallel computations. If False, run computations" @@ -136,17 +131,6 @@ def _at_fork(): # Backends -def _get_tpu_library_path() -> str | None: - path_from_env = os.getenv("TPU_LIBRARY_PATH") - if path_from_env is not None: - return path_from_env - - libtpu_module = maybe_import_libtpu() - if libtpu_module is not None: - return libtpu_module.get_library_path() - - return None - def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None: def _log_warning(): @@ -161,13 +145,9 @@ def _log_warning(): t.start() try: - if xla_extension_version >= 267: - client = xla_client.make_tpu_client( # type: ignore - _get_tpu_library_path(), - _options_from_jax_configs("tpu")) - else: - client = xla_client.make_tpu_client( - _get_tpu_library_path()) + client = xla_client.make_tpu_client( # type: ignore + get_tpu_library_path(), + _options_from_jax_configs("tpu")) finally: t.cancel() @@ -400,19 +380,16 @@ def _version_check(name: str, _version_check("cuPTI", cuda_versions.cupti_get_version, cuda_versions.cupti_build_version, min_supported_version=18) - # TODO(jakevdp) remove these checks when minimum jaxlib is v0.4.21 - if hasattr(cuda_versions, "cublas_get_version"): - _version_check("cuBLAS", cuda_versions.cublas_get_version, - cuda_versions.cublas_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=120100) - if hasattr(cuda_versions, "cusparse_get_version"): - _version_check("cuSPARSE", cuda_versions.cusparse_get_version, - cuda_versions.cusparse_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=12100) + _version_check("cuBLAS", cuda_versions.cublas_get_version, + cuda_versions.cublas_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=120100) + _version_check("cuSPARSE", cuda_versions.cusparse_get_version, + cuda_versions.cusparse_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=12100) errors = [] debug_results = [] @@ -434,7 +411,7 @@ def _version_check(name: str, def make_gpu_client( - *, platform_name: str, visible_devices_flag: config.FlagHolder[str] + *, platform_name: str, visible_devices_flag: config.Flag[str] ) -> xla_client.Client: visible_devices = visible_devices_flag.value allowed_devices = None @@ -454,11 +431,12 @@ def make_gpu_client( print('Skipped CUDA versions constraints check due to the ' 'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.') - # TODO(micky774): remove this check when minimum jaxlib is v0.4.26 - if jaxlib.version.__version_info__ >= (0, 4, 26): - devices_to_check = (allowed_devices if allowed_devices else - range(cuda_versions.cuda_device_count())) - _check_cuda_compute_capability(devices_to_check) + devices_to_check = ( + allowed_devices + if allowed_devices + else range(cuda_versions.cuda_device_count()) + ) + _check_cuda_compute_capability(devices_to_check) return xla_client.make_gpu_client( distributed_client=distributed.global_state.client, @@ -591,12 +569,7 @@ def discover_pjrt_plugins() -> None: logger.debug("No jax_plugins namespace packages available") # Augment with advertised entrypoints. - if sys.version_info < (3, 10): - # Use the backport library because it provides a forward-compatible - # implementation. - from importlib_metadata import entry_points - else: - from importlib.metadata import entry_points + from importlib.metadata import entry_points for entry_point in entry_points(group="jax_plugins"): logger.debug("Discovered entry-point based JAX plugin: %s", @@ -1228,7 +1201,7 @@ def make_pjrt_topology(platform: str, topology_name='', **kwargs): # TODO(parkers): Get rid of this in favor of a generic way to get topologies. def make_pjrt_tpu_topology(topology_name='', **kwargs): if not xla_client.pjrt_plugin_loaded("tpu"): - library_path = _get_tpu_library_path() + library_path = get_tpu_library_path() if library_path is None: raise RuntimeError( "JAX TPU support not installed; cannot generate TPU topology. See" diff --git a/jax/core.py b/jax/core.py index edc31778fd25..c23d37123d00 100644 --- a/jax/core.py +++ b/jax/core.py @@ -118,18 +118,6 @@ no_effects as no_effects, non_negative_dim as _deprecated_non_negative_dim, outfeed_primitives as outfeed_primitives, - pp_aval as pp_aval, - pp_eqn as pp_eqn, - pp_eqn_rules as pp_eqn_rules, - pp_eqns as pp_eqns, - pp_jaxpr as pp_jaxpr, - pp_jaxpr_eqn_range as pp_jaxpr_eqn_range, - pp_jaxpr_skeleton as pp_jaxpr_skeleton, - pp_jaxprs as pp_jaxprs, - pp_kv_pair as pp_kv_pair, - pp_kv_pairs as pp_kv_pairs, - pp_var as pp_var, - pp_vars as pp_vars, primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, primitive_uses_outfeed as primitive_uses_outfeed, process_env_traces_call as process_env_traces_call, @@ -162,6 +150,19 @@ from jax._src import core as _src_core _deprecations = { + # Added 2024-06-12 + "pp_aval": ("jax.core.pp_aval is deprecated.", _src_core.pp_aval), + "pp_eqn": ("jax.core.pp_eqn is deprecated.", _src_core.pp_eqn), + "pp_eqn_rules": ("jax.core.pp_eqn_rules is deprecated.", _src_core.pp_eqn_rules), + "pp_eqns": ("jax.core.pp_eqns is deprecated.", _src_core.pp_eqns), + "pp_jaxpr": ("jax.core.pp_jaxpr is deprecated.", _src_core.pp_jaxpr), + "pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range is deprecated.", _src_core.pp_jaxpr_eqn_range), + "pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton is deprecated.", _src_core.pp_jaxpr_skeleton), + "pp_jaxprs": ("jax.core.pp_jaxprs is deprecated.", _src_core.pp_jaxprs), + "pp_kv_pair": ("jax.core.pp_kv_pair is deprecated.", _src_core.pp_kv_pair), + "pp_kv_pairs": ("jax.core.pp_kv_pairs is deprecated.", _src_core.pp_kv_pairs), + "pp_var": ("jax.core.pp_var is deprecated.", _src_core.pp_var), + "pp_vars": ("jax.core.pp_vars is deprecated.", _src_core.pp_vars), # Finalized 2024-05-13; remove after 2024-08-13 "DimSize": ( "jax.core.DimSize is deprecated. Use DimSize = int | Any.", @@ -196,6 +197,18 @@ dimension_as_value = _deprecated_dimension_as_value definitely_equal = _deprecated_definitely_equal non_negative_dim = _deprecated_non_negative_dim + pp_aval = _src_core.pp_aval + pp_eqn = _src_core.pp_eqn + pp_eqn_rules = _src_core.pp_eqn_rules + pp_eqns = _src_core.pp_eqns + pp_jaxpr = _src_core.pp_jaxpr + pp_jaxpr_eqn_range = _src_core.pp_jaxpr_eqn_range + pp_jaxpr_skeleton = _src_core.pp_jaxpr_skeleton + pp_jaxprs = _src_core.pp_jaxprs + pp_kv_pair = _src_core.pp_kv_pair + pp_kv_pairs = _src_core.pp_kv_pairs + pp_var = _src_core.pp_var + pp_vars = _src_core.pp_vars symbolic_equal_dim = _deprecated_definitely_equal else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr diff --git a/jax/example_libraries/optimizers.py b/jax/example_libraries/optimizers.py index 8bf57792ffd8..71680ca61b96 100644 --- a/jax/example_libraries/optimizers.py +++ b/jax/example_libraries/optimizers.py @@ -91,7 +91,8 @@ def step(step, opt_state): from __future__ import annotations -from typing import Any, Callable, NamedTuple +from collections.abc import Callable +from typing import Any, NamedTuple from collections import namedtuple import functools diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index f7375a80fa8a..e0d8c4ee67f5 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -48,6 +48,7 @@ add as add, all as all, any as any, + arange as arange, argmax as argmax, argmin as argmin, argsort as argsort, @@ -66,6 +67,7 @@ broadcast_arrays as broadcast_arrays, broadcast_to as broadcast_to, can_cast as can_cast, + ceil as ceil, complex128 as complex128, complex64 as complex64, concat as concat, @@ -82,9 +84,11 @@ exp as exp, expand_dims as expand_dims, expm1 as expm1, + eye as eye, flip as flip, float32 as float32, float64 as float64, + floor as floor, floor_divide as floor_divide, from_dlpack as from_dlpack, full as full, @@ -160,6 +164,7 @@ tile as tile, tril as tril, triu as triu, + trunc as trunc, uint16 as uint16, uint32 as uint32, uint64 as uint64, @@ -180,9 +185,7 @@ ) from jax.experimental.array_api._creation_functions import ( - arange as arange, asarray as asarray, - eye as eye, linspace as linspace, ) @@ -192,11 +195,8 @@ ) from jax.experimental.array_api._elementwise_functions import ( - ceil as ceil, clip as clip, - floor as floor, hypot as hypot, - trunc as trunc, ) from jax.experimental.array_api._statistical_functions import ( diff --git a/jax/experimental/array_api/_creation_functions.py b/jax/experimental/array_api/_creation_functions.py index 99b8e3ed4465..5b9789ed732d 100644 --- a/jax/experimental/array_api/_creation_functions.py +++ b/jax/experimental/array_api/_creation_functions.py @@ -17,15 +17,9 @@ import jax import jax.numpy as jnp -# TODO(micky774): Deprecate after adding device argument to jax.numpy functions -def arange(start, /, stop=None, step=1, *, dtype=None, device=None): - return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device) def asarray(obj, /, *, dtype=None, device=None, copy=None): return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device) -def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None): - return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device) - def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True): return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device) diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py index 5587b9a60f1a..103f8ab7d1ef 100644 --- a/jax/experimental/array_api/_elementwise_functions.py +++ b/jax/experimental/array_api/_elementwise_functions.py @@ -18,15 +18,6 @@ from jax._src.numpy.util import promote_args -# TODO(micky774): Update jnp.ceil to preserve integral dtype -def ceil(x, /): - """Rounds each element x_i of the input array x to the smallest (i.e., closest to -infinity) integer-valued number that is not less than x_i.""" - x, = promote_args("ceil", x) - if isdtype(x.dtype, "integral"): - return x - return jax.numpy.ceil(x) - - # TODO(micky774): Remove when jnp.clip deprecation is completed # (began 2024-4-2) and default behavior is Array API 2023 compliant def clip(x, /, min=None, max=None): @@ -43,15 +34,6 @@ def clip(x, /, min=None, max=None): return jax.numpy.clip(x, min=min, max=max) -# TODO(micky774): Update jnp.floor to preserve integral dtype -def floor(x, /): - """Rounds each element x_i of the input array x to the greatest (i.e., closest to +infinity) integer-valued number that is not greater than x_i.""" - x, = promote_args("floor", x) - if isdtype(x.dtype, "integral"): - return x - return jax.numpy.floor(x) - - # TODO(micky774): Remove when jnp.hypot deprecation is completed # (began 2024-4-14) and default behavior is Array API 2023 compliant def hypot(x1, x2, /): @@ -64,12 +46,3 @@ def hypot(x1, x2, /): "values first, such as by using jnp.real or jnp.imag to take the real " "or imaginary components respectively.") return jax.numpy.hypot(x1, x2) - - -# TODO(micky774): Update jnp.trunc to preserve integral dtype -def trunc(x, /): - """Rounds each element x_i of the input array x to the nearest integer-valued number that is closer to zero than x_i.""" - x, = promote_args("trunc", x) - if isdtype(x.dtype, "integral"): - return x - return jax.numpy.trunc(x) diff --git a/jax/experimental/array_api/_utility_functions.py b/jax/experimental/array_api/_utility_functions.py index c5dac25fd8c6..f75b2e2e29af 100644 --- a/jax/experimental/array_api/_utility_functions.py +++ b/jax/experimental/array_api/_utility_functions.py @@ -15,7 +15,6 @@ from __future__ import annotations import jax -from typing import Tuple from jax._src.sharding import Sharding from jax._src.lib import xla_client as xc from jax._src import dtypes as _dtypes, config @@ -71,7 +70,7 @@ def default_dtypes(self, *, device: xc.Device | Sharding | None = None): def dtypes( self, *, device: xc.Device | Sharding | None = None, - kind: str | Tuple[str, ...] | None = None): + kind: str | tuple[str, ...] | None = None): # Array API supported dtypes are device-independent in JAX del device data_types = self._build_dtype_dict() diff --git a/jax/experimental/array_api/linalg.py b/jax/experimental/array_api/linalg.py index f19955409d5f..6494884135fe 100644 --- a/jax/experimental/array_api/linalg.py +++ b/jax/experimental/array_api/linalg.py @@ -35,8 +35,7 @@ vector_norm as vector_norm, ) -# TODO(micky774): Add trace to jax.numpy.linalg -from jax.numpy import trace as trace +from jax.numpy.linalg import trace as trace from jax.experimental.array_api._linear_algebra_functions import ( matrix_rank as matrix_rank, diff --git a/jax/experimental/array_api/skips.txt b/jax/experimental/array_api/skips.txt index 5eb5fd8e27a5..c865fabcfb55 100644 --- a/jax/experimental/array_api/skips.txt +++ b/jax/experimental/array_api/skips.txt @@ -9,16 +9,28 @@ array_api_tests/test_array_object.py::test_setitem array_api_tests/test_special_cases.py::test_binary array_api_tests/test_special_cases.py::test_unary -# fft test suite is buggy as of 83f0bcdc -array_api_tests/test_fft.py - # Pending implementation update for proper dtype promotion behavior, # see https://github.com/data-apis/array-api-tests/issues/234 array_api_tests/test_statistical_functions.py::test_sum array_api_tests/test_statistical_functions.py::test_prod -array_api_tests/test_linalg.py::test_trace # Pending bugfix, see https://github.com/data-apis/array-api-tests/issues/256 array_api_tests/test_signatures.py::test_func_signature[logical_and] array_api_tests/test_signatures.py::test_func_signature[logical_or] -array_api_tests/test_signatures.py::test_func_signature[logical_xor] \ No newline at end of file +array_api_tests/test_signatures.py::test_func_signature[logical_xor] + +# Returns int32 when int64 is expected +array_api_tests/test_searching_functions.py::test_searchsorted + +# Various info functions not yet defined +# Pending bugfix, see https://github.com/data-apis/array-api-tests/pull/262 +array_api_tests/test_has_names.py::test_has_names[info-capabilities] +array_api_tests/test_has_names.py::test_has_names[info-default_device] +array_api_tests/test_has_names.py::test_has_names[info-default_dtypes] +array_api_tests/test_has_names.py::test_has_names[info-devices] +array_api_tests/test_has_names.py::test_has_names[info-dtypes] +array_api_tests/test_signatures.py::test_func_signature[capabilities] +array_api_tests/test_signatures.py::test_func_signature[default_device] +array_api_tests/test_signatures.py::test_func_signature[default_dtypes] +array_api_tests/test_signatures.py::test_func_signature[devices] +array_api_tests/test_signatures.py::test_func_signature[dtypes] diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 9d947527bb1c..c7aa8b590412 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -17,16 +17,15 @@ import abc import asyncio -from collections.abc import Awaitable, Sequence +from collections.abc import Awaitable, Callable, Sequence from functools import partial import itertools import logging import os import re -import sys import threading import time -from typing import Any, Callable, Optional, Union +from typing import Any import jax from jax._src import array @@ -70,7 +69,7 @@ class BarrierTimeoutException(Exception): async def create_async_array_from_callback( global_shape: array.Shape, - inp_sharding: sharding_impls.XLACompatibleSharding, + inp_sharding: jax.sharding.Sharding, data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]], ): device_to_index_map = inp_sharding.devices_indices_map(global_shape) @@ -130,7 +129,7 @@ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False): return spec -def is_remote_storage(tspec: Union[dict[str, Any], str]) -> bool: +def is_remote_storage(tspec: dict[str, Any] | str) -> bool: """Detect if user is using cloud storages. This can detect common defines and unable to detect some corner cases such as @@ -170,7 +169,7 @@ def __init__(self, num_bytes): self._cv = asyncio.Condition(lock=asyncio.Lock()) async def wait_for_bytes(self, requested_bytes): - if requested_bytes >= self._max_bytes: + if requested_bytes > self._max_bytes: raise ValueError('Requested more bytes than we reserved space for: ' f'{requested_bytes} > {self._max_bytes}') async with self._cv: @@ -190,7 +189,7 @@ async def async_serialize( tensorstore_spec, commit_future=None, context=TS_CONTEXT, - primary_host: Optional[int] = 0, + primary_host: int | None = 0, replica_id: int = 0, ): """Serialize an array using TensorStore. @@ -310,7 +309,7 @@ def estimate_read_memory_footprint(t: ts.TensorStore, async def async_deserialize( - user_in_sharding: sharding_impls.XLACompatibleSharding | Layout, + user_in_sharding: jax.sharding.Sharding | Layout, tensorstore_spec: ts.Spec | dict[str, Any], global_shape: Sequence[int] | None = None, dtype=None, @@ -320,10 +319,10 @@ async def async_deserialize( ): in_sharding = (user_in_sharding.sharding if isinstance(user_in_sharding, Layout) else user_in_sharding) - if not isinstance(in_sharding, sharding_impls.XLACompatibleSharding): + if not isinstance(in_sharding, jax.sharding.Sharding): raise ValueError( 'sharding passed to deserialization should be specified, concrete and' - f' an instance of `jax.XLACompatibleSharding`. Got {in_sharding}') + f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') dll = (user_in_sharding.device_local_layout if isinstance(user_in_sharding, Layout) else None) t = await ts.open( @@ -412,7 +411,7 @@ class GlobalAsyncCheckpointManagerBase(util.StrictABC): is finished, checkpoint for step 2 will need to be blocked. Maintaining a class allows to maintain that state. - Example: + Examples: Below is a simplified training loop: diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index f0bf4f456292..b71c3cac2afd 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -14,6 +14,7 @@ """Tests for serialization and deserialization of GDA.""" import asyncio +import contextlib import math from functools import partial import re @@ -36,17 +37,13 @@ import tensorstore as ts jax.config.parse_flags_with_absl() - -prev_xla_flags = None +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(8) + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() pattern = re.compile(r"\{(.*?):") diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index ac076a4d57ea..8176465c1470 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -14,6 +14,7 @@ from __future__ import annotations +from contextlib import contextmanager from typing import Any from jax._src import core @@ -30,59 +31,64 @@ zip, unsafe_zip = safe_zip, zip JaxVal = Any +Pytree = Any register = api_util.register_class_with_attrs -class GetAttrPrimitive(core.Primitive): - def bind_with_trace(self, trace, args, params): - () = args - return trace.process_getattr(**params) -getattr_p = GetAttrPrimitive('getattr') - -class SetAttrPrimitive(core.Primitive): - def bind_with_trace(self, trace, args, params): - val, = args - return trace.process_setattr(trace.full_raise(val), **params) -setattr_p = SetAttrPrimitive('setattr') +@contextmanager +def top_trace(): + stack = core.thread_local_state.trace_state.trace_stack.stack + main = stack.pop() + try: + trace = main.with_cur_sublevel() + yield trace + finally: + stack.append(main) def jax_getattr(obj: Any, attr: str): - return getattr_p.bind(obj=obj, attr=attr) - -def jax_setattr(obj: Any, attr: str, val: JaxVal): - setattr_p.bind(val, obj=obj, attr=attr) + with top_trace() as trace: + return trace.process_getattr(obj, attr) +def jax_setattr(obj: Any, attr: str, val: Pytree): + with top_trace() as trace: + return trace.process_setattr(obj, attr, val) -def _getattr_impl(_, *, obj, attr): +def _getattr_impl(_, obj, attr): return getattr(obj, attr) core.EvalTrace.process_getattr = _getattr_impl -def _setattr_impl(_, val, *, obj, attr): +def _setattr_impl(_, obj, attr, val): setattr(obj, attr, val) core.EvalTrace.process_setattr = _setattr_impl - def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): frame = trace.main.jaxpr_stack[-1] # type: ignore - if (obj, attr) not in frame.attrs_tracked: - init_val = getattr(obj, attr) - aval = core.raise_to_shaped(core.get_aval(init_val)) + + def new_tracer(x): + aval = core.raise_to_shaped(core.get_aval(x)) tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current()) var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval) - setattr(obj, attr, tracer) - frame.attrs_tracked.append((obj, attr)) - frame.attrs_inits.append(init_val) frame.attrs_vars.append(var) frame.tracers.append(tracer) + return tracer + + if (obj, attr) not in frame.attrs_tracked: + init_val = getattr(obj, attr) + frame.attrs_inits.append(init_val) + init_vals, init_tree = tree_flatten(init_val) + tracers = map(new_tracer, init_vals) + setattr(obj, attr, tree_unflatten(init_tree, tracers)) + frame.attrs_tracked.append((obj, attr)) pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked -def _getattr_staging(trace, *, obj, attr): +def _getattr_staging(trace, obj, attr): trace._ensure_tracked(obj, attr) return getattr(obj, attr) pe.DynamicJaxprTrace.process_getattr = _getattr_staging -def _setattr_staging(trace, tracer, *, obj, attr): +def _setattr_staging(trace, obj, attr, val): trace._ensure_tracked(obj, attr) - setattr(obj, attr, tracer) + setattr(obj, attr, val) pe.DynamicJaxprTrace.process_setattr = _setattr_staging @@ -134,12 +140,19 @@ def jvp_subtrace2(main, primals, tangents): del main.attrs_tracked yield out_primals, out_tangents, tangent_attrs_out -def _setattr_jvp(trace, tracer, *, obj, attr): +def _setattr_jvp(trace, obj, attr, maybe_tracer): + tracer = trace.full_raise(maybe_tracer) + if isinstance(tracer.tangent, ad.Zero): + return setattr(obj, attr, tracer.primal) if (obj, attr) not in trace.main.attrs_tracked: trace.main.attrs_tracked.append((obj, attr)) - setattr(obj, attr, tracer) + return setattr(obj, attr, tracer) ad.JVPTrace.process_setattr = _setattr_jvp +def _getattr_jvp(trace, obj, attr): + return getattr(obj, attr) +ad.JVPTrace.process_getattr = _getattr_jvp + def linearize(f, *primals, attrs: list[tuple[Any, str]] = []): attr_primals = [jax_getattr(o, a) for o, a in attrs] diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index ec2311736f5e..aa138fe88993 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -133,10 +133,8 @@ def _custom_partitioning_propagate_user_sharding(user_sharding, shape, def _to_hlo_sharding(sharding, num_dimensions): - if not isinstance(sharding, jax.sharding.XLACompatibleSharding): - raise ValueError( - "Custom Partitioning rules must return XLACompatibleShardings." - ) + if not isinstance(sharding, jax.sharding.Sharding): + raise ValueError("Custom Partitioning rules must return Sharding.") return sharding._to_xla_hlo_sharding(num_dimensions) @@ -301,7 +299,7 @@ def infer_sharding_from_operands(mesh, arg_shapes, shape): Positional arguments can be specified as static using static_argnums. JAX uses :code:`inspect.signature(fun)` to resolve these positional arguments. - Example: + Examples: As an example, assume we want to enhance the existing ``jax.numpy.fft.fft``. This function computes the discrete Fourier transform of an N-dimensional input along the last dimension, and is batched diff --git a/jax/experimental/export/BUILD b/jax/experimental/export/BUILD index 05c892b38d2d..1246b0d407af 100644 --- a/jax/experimental/export/BUILD +++ b/jax/experimental/export/BUILD @@ -31,11 +31,6 @@ py_library( name = "export", srcs = [ "__init__.py", - "_export.py", - "_serialization.py", - "_shape_poly.py", - "_shape_poly_decision.py", - "serialization_generated.py", ], srcs_version = "PY3", # TODO: b/255503696: enable pytype diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py index ac4fd199e37f..b67354bb4248 100644 --- a/jax/experimental/export/__init__.py +++ b/jax/experimental/export/__init__.py @@ -13,26 +13,61 @@ # limitations under the License. # ============================================================================== -from jax.experimental.export._export import ( - minimum_supported_serialization_version, - maximum_supported_serialization_version, - Exported, - export, - call_exported, # TODO: deprecate - call, - DisabledSafetyCheck, - default_lowering_platform, - - args_specs, # TODO: deprecate +_deprecation_message = ( + "The jax.experimental.export module is deprecated. " + "Use jax.export instead. " + "See the migration guide at https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export." ) -from jax.experimental.export._shape_poly import ( - is_symbolic_dim, - symbolic_shape, - symbolic_args_specs, - SymbolicScope, -) -from jax.experimental.export._serialization import ( - serialize, - deserialize, -) -from jax.experimental.export import _shape_poly_decision + +from jax._src.export import _export as _src_export +from jax._src.export import shape_poly as _src_shape_poly +from jax._src.export import serialization as _src_serialization +# Import only to set the shape poly decision procedure +from jax._src.export import shape_poly_decision +del shape_poly_decision + +# All deprecations added Jun 14, 2024 +_deprecations = { + # Added Jun 13, 2024 + "Exported": (_deprecation_message, _src_export.Exported), + "DisabledSafetyCheck": (_deprecation_message, _src_export.DisabledSafetyCheck), + "export": (_deprecation_message, _src_export.export_back_compat), + "call": (_deprecation_message, _src_export.call), + "call_exported": (_deprecation_message, _src_export.call_exported), + "default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform), + "minimum_supported_serialization_version" : (_deprecation_message, _src_export.minimum_supported_calling_convention_version), + "maximum_supported_serialization_version" : (_deprecation_message, _src_export.maximum_supported_calling_convention_version), + + "serialize": (_deprecation_message, _src_serialization.serialize), + "deserialize": (_deprecation_message, _src_serialization.deserialize), + + "SymbolicScope": (_deprecation_message, _src_shape_poly.SymbolicScope), + "is_symbolic_dim": (_deprecation_message, _src_shape_poly.is_symbolic_dim), + "symbolic_shape": (_deprecation_message, _src_shape_poly.symbolic_shape), + "symbolic_args_specs": (_deprecation_message, _src_shape_poly.symbolic_args_specs), +} + +import typing +if typing.TYPE_CHECKING: + Exported = _src_export.Exported + DisabledSafetyCheck = _src_export.DisabledSafetyCheck + export = _src_export.export_back_compat + call = _src_export.call + call_exported = _src_export.call_exported + default_lowering_platform = _src_export.default_lowering_platform + + serialize = _src_serialization.serialize + deserialize = _src_serialization.deserialize + + SymbolicScope = _src_shape_poly.SymbolicScope + is_symbolic_dim = _src_shape_poly.is_symbolic_dim + symbolic_shape = _src_shape_poly.symbolic_shape + symbolic_args_specs = _src_shape_poly.symbolic_args_specs +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing +del _src_export +del _src_serialization +del _src_shape_poly diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index b56bd7cec49a..deaac9c72c8b 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -503,14 +503,14 @@ def power3_with_cotangents(x): import atexit import enum -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import itertools import logging import math import threading import traceback -from typing import Any, Callable, cast +from typing import Any, cast import jax from jax._src import api @@ -541,12 +541,12 @@ def power3_with_cotangents(x): import numpy as np -_HOST_CALLBACK_INLINE = config.DEFINE_bool( +_HOST_CALLBACK_INLINE = config.bool_flag( 'jax_host_callback_inline', config.bool_env('JAX_HOST_CALLBACK_INLINE', False), help='Inline the host_callback, if not in a staged context.' ) -_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.DEFINE_integer( +_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.int_flag( 'jax_host_callback_max_queue_byte_size', config.int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)), help=('The size in bytes of the buffer used to hold outfeeds from each ' @@ -555,7 +555,7 @@ def power3_with_cotangents(x): 'until the Python callback consume more outfeeds.'), lower_bound=int(16 * 1e6) ) -_HOST_CALLBACK_OUTFEED = config.DEFINE_bool( +_HOST_CALLBACK_OUTFEED = config.bool_flag( 'jax_host_callback_outfeed', config.bool_env('JAX_HOST_CALLBACK_OUTFEED', False), help=( @@ -564,7 +564,7 @@ def power3_with_cotangents(x): 'Has no effect on TPU, since only the outfeed mechanism is implemented.' ) ) -_HOST_CALLBACK_LEGACY = config.DEFINE_bool( +_HOST_CALLBACK_LEGACY = config.bool_flag( 'jax_host_callback_legacy', config.bool_env('JAX_HOST_CALLBACK_LEGACY', True), help=( diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 0ce5bcb170f2..adf43b6b94c0 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -25,10 +25,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools -from typing import Any, Callable, Optional +from typing import Any from absl import logging import jax diff --git a/jax/experimental/jax2tf/examples/mnist_lib.py b/jax/experimental/jax2tf/examples/mnist_lib.py index 1dae2752ffd3..41173c79a5b9 100644 --- a/jax/experimental/jax2tf/examples/mnist_lib.py +++ b/jax/experimental/jax2tf/examples/mnist_lib.py @@ -21,12 +21,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools import logging import re import time -from typing import Any, Callable, Optional +from typing import Any from absl import flags import flax diff --git a/jax/experimental/jax2tf/examples/saved_model_lib.py b/jax/experimental/jax2tf/examples/saved_model_lib.py index f0fa145728fe..8f2f0982fd3d 100644 --- a/jax/experimental/jax2tf/examples/saved_model_lib.py +++ b/jax/experimental/jax2tf/examples/saved_model_lib.py @@ -26,8 +26,8 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sequence +from typing import Any from jax.experimental import jax2tf import tensorflow as tf diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 3305654f2243..5ecde602cdaa 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -16,12 +16,12 @@ from __future__ import annotations import builtins -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses from functools import partial, wraps import math import string -from typing import Any, Callable, Optional +from typing import Any from jax._src import core from jax import lax diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 58f47be2af6b..5f3230599a25 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -15,7 +15,7 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from functools import partial import contextlib import math @@ -23,7 +23,7 @@ import os import re import threading -from typing import Any, Callable, Union +from typing import Any, Union import warnings from absl import logging @@ -36,9 +36,7 @@ from jax import numpy as jnp from jax import tree_util from jax import sharding -from jax.experimental import export -from jax.experimental.export import _export -from jax.experimental.export import _shape_poly +from jax import export from jax.experimental.jax2tf import impl_no_xla from jax.interpreters import xla @@ -61,6 +59,8 @@ from jax._src import source_info_util from jax._src import util from jax._src import shard_alike +from jax._src.export import _export +from jax._src.export import shape_poly from jax._src.interpreters import ad from jax._src.interpreters import mlir from jax._src.lax import control_flow as lax_control_flow @@ -88,7 +88,7 @@ # pylint: enable=g-direct-tensorflow-import NameStack = source_info_util.NameStack -PolyShape = _shape_poly.PolyShape # TODO: deprecate +PolyShape = shape_poly.PolyShape # TODO: deprecate DType = Any DisabledSafetyCheck = export.DisabledSafetyCheck @@ -388,13 +388,13 @@ def jax_arg_spec_from_tf(a: TfVal) -> jax.ShapeDtypeStruct: args_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, args_tf) args_specs = export.symbolic_args_specs( - args_jax_specs, polymorphic_shapes=polymorphic_shapes, - symbolic_constraints=polymorphic_constraints) + args_jax_specs, polymorphic_shapes, + constraints=polymorphic_constraints) # The polymorphic_shapes argument refers to positional arguments only. # We assume None for the kwargs. kwargs_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, kwargs_tf) kwargs_specs = export.symbolic_args_specs( - kwargs_jax_specs, polymorphic_shapes=None) + kwargs_jax_specs, None) combined_args_tf = (args_tf, kwargs_tf) args_flat_tf: Sequence[TfVal] args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf) @@ -514,11 +514,15 @@ def _restore_context(): _thread_local_state.call_tf_concrete_function_list = _prev_func_list self._restore_context = _restore_context - self.exported = export.export( + _exported_device_assignment = [None] + self.exported = _export.export_back_compat( self.fun_jax, lowering_platforms=self.native_serialization_platforms, - disabled_checks=self.native_serialization_disabled_checks + disabled_checks=self.native_serialization_disabled_checks, + _device_assignment_for_internal_jax2tf_use_only=_exported_device_assignment, )(*self.args_specs, **self.kwargs_specs) + assert(_exported_device_assignment[0] is not None) + self.device_assignment = _exported_device_assignment[0] def after_conversion(self): self._restore_context() @@ -531,15 +535,13 @@ def run_fun_tf(self, def get_vjp_fun(self) -> tuple[Callable, Sequence[core.AbstractValue]]: - # TODO(necula): use the actual device assignment from the primal function - device_assignment = jax.devices(jax.default_backend())[:self.exported.nr_devices] return _export._get_vjp_fun(self.fun_jax, in_tree=self.exported.in_tree, in_avals=self.exported.in_avals, - in_shardings=self.exported.in_shardings, + in_shardings_hlo=self.exported.in_shardings_hlo, out_avals=self.exported.out_avals, - out_shardings=self.exported.out_shardings, - device_assignment=device_assignment, + out_shardings_hlo=self.exported.out_shardings_hlo, + device_assignment=self.device_assignment, apply_jit=True) class GraphSerializationImpl(SerializationImpl): @@ -577,9 +579,9 @@ def _restore_context(): (self.args_specs, self.kwargs_specs)) self.args_avals_flat = tuple( map(lambda a: core.raise_to_shaped(core.get_aval(a)), args_specs_flat)) - dim_vars = _shape_poly.all_dim_vars(self.args_avals_flat) + dim_vars = shape_poly.all_dim_vars(self.args_avals_flat) dim_values, _ = _interpret_fun_jax( - partial(_shape_poly.compute_dim_vars_from_arg_shapes, + partial(shape_poly.compute_dim_vars_from_arg_shapes, self.args_avals_flat, args_kwargs_tree=self.in_tree), self.args_flat_tf, self.args_avals_flat, self.name_stack) @@ -608,9 +610,9 @@ def get_vjp_fun(self) -> tuple[Callable, return _export._get_vjp_fun(self.fun_jax, in_tree=self.in_tree, in_avals=self.args_avals_flat, - in_shardings=(None,) * len(self.args_avals_flat), + in_shardings_hlo=(None,) * len(self.args_avals_flat), out_avals=self.outs_avals, - out_shardings=(None,) * len(self.outs_avals), + out_shardings_hlo=(None,) * len(self.outs_avals), device_assignment=None, # Not used when apply_jit = False apply_jit=False) @@ -675,7 +677,7 @@ def eval_polymorphic_shape(fun_jax: Callable, """ def do_eval_polymorphic_shape(*args_specs) -> Any: args_poly_specs = export.symbolic_args_specs( - args_specs, polymorphic_shapes=polymorphic_shapes) + args_specs, polymorphic_shapes) res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs) # TODO(necula): For now we export the polymorphic shapes using `str`. res_polymorphic_shape = tree_util.tree_map(lambda r: str(r.shape), res_poly_spec) @@ -855,7 +857,7 @@ def _convert_value(val, aval): kept_args_avals = [aval for i, aval in enumerate(exported.in_avals) if i in exported.module_kept_var_idx] kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx] - version = exported.mlir_module_serialization_version + version = exported.calling_convention_version try: get_max_supported_version = tfxla.call_module_maximum_supported_version @@ -888,7 +890,7 @@ def _convert_value(val, aval): has_token_input_output=False ) - call_module_attrs["platforms"] = tuple(p.upper() for p in exported.lowering_platforms) + call_module_attrs["platforms"] = tuple(p.upper() for p in exported.platforms) if version >= 6: call_module_attrs["disabled_checks"] = tuple( str(dc) @@ -914,7 +916,7 @@ def _convert_value(val, aval): # See b/255511660. kept_in_shardings = [] for i in exported.module_kept_var_idx: - kept_in_shardings.append(exported.in_shardings[i]) + kept_in_shardings.append(exported.in_shardings_hlo[i]) args_flat_tf = tuple( map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()), @@ -931,7 +933,7 @@ def _convert_value(val, aval): res = list(map(partial(_shard_value, skip_replicated_sharding=tf.executing_eagerly()), - res, exported.out_shardings)) + res, exported.out_shardings_hlo)) res = tuple(map(_convert_value, res, exported.out_avals)) return res @@ -1146,8 +1148,8 @@ def _tfval_to_tensor_jax_dtype(val: TfVal, return tf_val, jax_dtype -def _eval_shape(shape: Sequence[_shape_poly.DimSize], dtype=None) -> Sequence[TfVal]: - # Returns a tuple of _shape_poly.dim_as_value_dtype +def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]: + # Returns a tuple of shape_poly.dim_as_value_dtype # Used only for non-native lowering assert all(map(lambda x: x is not None, shape)), ( f"Argument shape should be a valid JAX shape but got {shape}") @@ -1172,7 +1174,7 @@ def _ensure_tf_shape_if_dynamic(x: TfVal, shape): return tf.ensure_shape(x, shape) -def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[_shape_poly.DimSize]): +def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize]): """Asserts that shape matches x.shape in the known dimensions and has dimension polynomials elsewhere.""" # Ensures that the shape does not contain None; it should contain symbolic expressions. @@ -1545,7 +1547,7 @@ def _add(x: TfVal, y: TfVal) -> TfVal: tf_impl[ad_util.add_jaxvals_p] = _add -tf_impl[dispatch.device_put_p] = lambda x, device=None, src=None: x +tf_impl[dispatch.device_put_p] = lambda *xs, devices=None, srcs=None: xs tf_impl[lax_internal.copy_p] = lambda x: x def _shard_alike(*args: TfVal, **_): @@ -3461,7 +3463,7 @@ def split_to_logical_devices(tensor: TfVal, def _xla_compatible_sharding_to_hlo_sharding( - s: sharding.XLACompatibleSharding, + s: sharding.Sharding, aval: core.ShapedArray) -> xla_client.HloSharding | None: if sharding_impls.is_unspecified(s): return None @@ -3513,8 +3515,8 @@ def _shard_value(val: TfVal, def _pjit(*args: TfVal, jaxpr: core.ClosedJaxpr, - in_shardings: Sequence[sharding.XLACompatibleSharding], - out_shardings: Sequence[sharding.XLACompatibleSharding], + in_shardings: Sequence[sharding.Sharding], + out_shardings: Sequence[sharding.Sharding], in_layouts, out_layouts, resource_env: mesh.ResourceEnv, donated_invars, @@ -3547,7 +3549,7 @@ def _pjit(*args: TfVal, def _pjit_sharding_constraint(arg: TfVal, *, - sharding: sharding.XLACompatibleSharding, + sharding: sharding.Sharding, resource_env: mesh.ResourceEnv, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray, @@ -3567,13 +3569,13 @@ def _dimension_size_jax2tf(op: TfVal, *, dimension, _in_avals, _out_aval): else: return dim_tf -tf_impl_with_avals[_shape_poly.dimension_size_p] = _dimension_size_jax2tf +tf_impl_with_avals[shape_poly.dimension_size_p] = _dimension_size_jax2tf -def _dim_as_value_jax2tf(dim: _shape_poly.DimSize): +def _dim_as_value_jax2tf(dim: shape_poly.DimSize): dim_tf, = _eval_shape((dim,)) return dim_tf -tf_impl[_shape_poly.dim_as_value_p] = _dim_as_value_jax2tf +tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf def _shape_assertion_jax2tf(assert_what, *error_message_inputs, error_message: str): @@ -3583,7 +3585,7 @@ def _shape_assertion_jax2tf(assert_what, *error_message_inputs, message=error_message.format(*error_message_inputs)) return [] -tf_impl[_shape_poly.shape_assertion_p] = _shape_assertion_jax2tf +tf_impl[shape_poly.shape_assertion_p] = _shape_assertion_jax2tf def _reduce_precision(x, *, exponent_bits, mantissa_bits): return tfxla.reduce_precision(x, exponent_bits=exponent_bits, diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index 56091b2c7eae..7f903b70d987 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -20,11 +20,10 @@ from __future__ import annotations import base64 -from collections.abc import Sequence +from collections.abc import Callable, Sequence import io import os import tarfile -from typing import Callable, Optional from absl.testing import absltest import jax diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index d740cf26d733..5740b76038d8 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -13,10 +13,10 @@ # limitations under the License. """Tests for call_tf.""" +from collections.abc import Callable import contextlib from functools import partial import os -from typing import Callable import unittest from absl import logging @@ -25,13 +25,13 @@ import jax from jax import dlpack from jax import dtypes +from jax import export from jax import lax from jax import numpy as jnp +from jax._src import config from jax._src import test_util as jtu -from jax._src import xla_bridge from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax.experimental import export from jax.experimental import jax2tf from jax.experimental.jax2tf.tests import tf_test_util import numpy as np @@ -695,7 +695,6 @@ def cos_tf_sin_jax(x): jax.grad(cos_tf_sin_jax)(x) logging.info(jax.make_jaxpr(cos_tf_sin_jax)(x)) - logging.info(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text()) def test_tf_gather(self): """tf_gather gradient output is tf.IndexSlices.""" @@ -778,7 +777,7 @@ def f_jax(x): lowering_platforms = ("tpu", "cpu", "cuda") - exp = export.export(f_jax, + exp = export.export(jax.jit(f_jax), lowering_platforms=lowering_platforms)(x) for jax_platform in jax_and_tf_platforms: with self.subTest(jax_platform): @@ -787,7 +786,7 @@ def f_jax(x): logging.info("Running harness natively on %s", jax_device) native_res = f_jax(x_device) logging.info("Running exported harness on %s", jax_device) - exported_res = export.call_exported(exp)(x_device) + exported_res = exp.call(x_device) self.assertAllClose(native_res, exported_res) def test_multi_platform_call_tf_graph(self): @@ -1149,17 +1148,6 @@ def setUp(self): _ = tf.add(1, 1) super().setUp() - def override_serialization_version(self, version_override: int): - version = jax.config.jax_serialization_version - if version != version_override: - self.addCleanup(partial(jax.config.update, - "jax_serialization_version", - version_override)) - jax.config.update("jax_serialization_version", version_override) - logging.info( - "Using JAX serialization version %s", - jax.config.jax_serialization_version) - def test_alternate(self): # Alternate sin/cos with sin in TF and cos in JAX f_tf_inner = tf.math.sin @@ -1660,116 +1648,127 @@ def tf_f_2(): _, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt_2, input_args=[]) @jtu.parameterized_filterable( - kwargs=[dict(version=version) for version in [8, 9]] + kwargs=[dict(version=version) for version in [9]] ) def test_call_tf_graph_ordered(self, *, version: int): - self.override_serialization_version(version) - @tf.function - def tf_print(x): - tf.print(x) - - call_tf_print = jax2tf.call_tf( - tf_print, - call_tf_graph=True, - ordered=True, - ) - - x = jnp.array(1.0, dtype=jnp.float32) + with config.jax_export_calling_convention_version(version): + logging.info( + "Using JAX serialization version %s", + jax.config.jax_export_calling_convention_version) - def body(i, x): - call_tf_print(x) - return x + 1 + @tf.function + def tf_print(x): + tf.print(x) - @jax.jit - def f_jax(x): - return jax.lax.fori_loop(0, 4, body, x) + call_tf_print = jax2tf.call_tf( + tf_print, + call_tf_graph=True, + ordered=True, + ) - num_custom_calls = 0 + x = jnp.array(1.0, dtype=jnp.float32) - def _check_mlir_ops(op): - nonlocal num_custom_calls + def body(i, x): + call_tf_print(x) + return x + 1 - if ( - op.operation.name == "stablehlo.custom_call" - and ir.StringAttr(op.attributes["call_target_name"]).value - == "tf.call_tf_function" + @jax.jit + def f_jax(x): + return jax.lax.fori_loop(0, 4, body, x) + + num_custom_calls = 0 + + def _check_mlir_ops(op): + nonlocal num_custom_calls + + if ( + op.operation.name == "stablehlo.custom_call" + and ir.StringAttr(op.attributes["call_target_name"]).value + == "tf.call_tf_function" + ): + num_custom_calls += 1 + + # The custom call op must have `has_token_input_output` attribute. + tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"]) + self.assertTrue( + ir.BoolAttr(tf_backend_config["has_token_input_output"]).value + ) + + # Verify that the first argument/result of the custom call op is a token + # type. This is a calling convention defined by `has_token_input_output`. + self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type)) + self.assertTrue(hlo.TokenType.isinstance(op.results[0].type)) + + stablehlo_module = None + with self.assertRaisesRegex( + ValueError, + "call_tf_graph=True only support exporting by jax2tf.convert currently", ): - num_custom_calls += 1 - - # The custom call op must have `has_token_input_output` attribute. - tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"]) - self.assertTrue( - ir.BoolAttr(tf_backend_config["has_token_input_output"]).value - ) - - # Verify that the first argument/result of the custom call op is a token - # type. This is a calling convention defined by `has_token_input_output`. - self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type)) - self.assertTrue(hlo.TokenType.isinstance(op.results[0].type)) - - stablehlo_module = None - with self.assertRaisesRegex( - ValueError, - "call_tf_graph=True only support exporting by jax2tf.convert currently", - ): - lower = f_jax.lower(x) - self.assertNotEmpty(lower._lowering.compile_args["ordered_effects"]) - stablehlo_module = lower.compiler_ir("stablehlo") - if stablehlo_module: - self._walk_stablehlo_operations(stablehlo_module, _check_mlir_ops) - self.assertEqual(num_custom_calls, 1) - - f_tf = jax2tf.convert( - f_jax, - native_serialization=True, - with_gradient=False, - ) - _, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x]) + lower = f_jax.lower(x) + self.assertNotEmpty(lower._lowering.compile_args["ordered_effects"]) + stablehlo_module = lower.compiler_ir("stablehlo") + if stablehlo_module: + self._walk_stablehlo_operations(stablehlo_module, _check_mlir_ops) + self.assertEqual(num_custom_calls, 1) + + f_tf = jax2tf.convert( + f_jax, + native_serialization=True, + with_gradient=False, + ) + _, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x]) @jtu.parameterized_filterable( kwargs=[dict(poly=poly, version=version) for poly in [True, False] - for version in [8, 9]] + for version in [9]] ) def test_call_tf_ordered_dead_inputs(self, *, poly: bool, version: int): - self.override_serialization_version(version) - def f_jax(x1, x_dead, x3): - return (x1, jax2tf.call_tf(lambda x: tf.math.sin(x), ordered=True, - call_tf_graph=True)(x3)) - if poly: - polymorphic_shapes = ["b", None, None] - else: - polymorphic_shapes = None - f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes) - x1 = np.arange(3, dtype=np.float32) - x_dead = np.arange(4, dtype=np.float32) - x3 = np.arange(5, dtype=np.float32) - self.assertAllClose(f_jax(x1, x_dead, x3), - f_tf(x1, x_dead, x3)) + with config.jax_export_calling_convention_version(version): + logging.info( + "Using JAX serialization version %s", + jax.config.jax_export_calling_convention_version) + def f_jax(x1, x_dead, x3): + return (x1, jax2tf.call_tf(lambda x: tf.math.sin(x), ordered=True, + call_tf_graph=True)(x3)) + if poly: + polymorphic_shapes = ["b", None, None] + else: + polymorphic_shapes = None + f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes) + x1 = np.arange(3, dtype=np.float32) + x_dead = np.arange(4, dtype=np.float32) + x3 = np.arange(5, dtype=np.float32) + self.assertAllClose(f_jax(x1, x_dead, x3), + f_tf(x1, x_dead, x3)) @jtu.parameterized_filterable( kwargs=[dict(ordered=ordered, version=version) for ordered in [True, False] - for version in [8, 9] + for version in [9] ] ) def test_call_tf_graph_polymorphic(self, ordered: bool, version: int): - self.override_serialization_version(version) - @tf.function(jit_compile=True, autograph=False) - @partial(jax2tf.convert, - with_gradient=False, - native_serialization=True, - polymorphic_shapes=["(b)"]) - @jax.jit - def tf_f_2(x): - tf_f = lambda x: print(tf.strings.length(tf.constant("hello, world"))) - jax2tf.call_tf(tf_f, - call_tf_graph=True, - ordered=ordered)(x) - return x + with config.jax_export_calling_convention_version(version): + logging.info( + "Using JAX serialization version %s", + jax.config.jax_export_calling_convention_version) + + @tf.function(jit_compile=True, autograph=False) + @partial(jax2tf.convert, + with_gradient=False, + native_serialization=True, + polymorphic_shapes=["(b)"]) + @jax.jit + def tf_f_2(x): + tf_f = lambda x: print(tf.strings.length(tf.constant("hello, world"))) + jax2tf.call_tf(tf_f, + call_tf_graph=True, + ordered=ordered)(x) + return x - x = np.arange(3, dtype=np.int32) - _ = tf.function(tf_f_2, autograph=False).get_concrete_function(x) + x = np.arange(3, dtype=np.int32) + _ = tf.function(tf_f_2, autograph=False).get_concrete_function(x) # TODO(b/293927250): call_tf_graph=True only accept concrete_function. The # workaround here is to set `module.call=concrete_fn.`. diff --git a/jax/experimental/jax2tf/tests/converters.py b/jax/experimental/jax2tf/tests/converters.py index f0a293ca52d5..1ed017fc0819 100644 --- a/jax/experimental/jax2tf/tests/converters.py +++ b/jax/experimental/jax2tf/tests/converters.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Converters for jax2tf.""" + +from collections.abc import Callable import dataclasses import functools import tempfile -from typing import Any, Callable +from typing import Any + from jax.experimental import jax2tf import tensorflow as tf import tensorflowjs as tfjs diff --git a/jax/experimental/jax2tf/tests/cross_compilation_check.py b/jax/experimental/jax2tf/tests/cross_compilation_check.py index 0a4bf61f8847..cc34d78e88d4 100644 --- a/jax/experimental/jax2tf/tests/cross_compilation_check.py +++ b/jax/experimental/jax2tf/tests/cross_compilation_check.py @@ -26,12 +26,11 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import dataclasses import os import re -from typing import Callable, Optional import zlib from absl import app diff --git a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py index 9bb7466125c5..5b1169224ed9 100644 --- a/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py +++ b/jax/experimental/jax2tf/tests/flax_models/bilstm_classifier.py @@ -18,8 +18,9 @@ from __future__ import annotations +from collections.abc import Callable import functools -from typing import Any, Callable, Optional +from typing import Any from flax import linen as nn import jax diff --git a/jax/experimental/jax2tf/tests/flax_models/gnn.py b/jax/experimental/jax2tf/tests/flax_models/gnn.py index 6746da7a2700..4a74be446ba1 100644 --- a/jax/experimental/jax2tf/tests/flax_models/gnn.py +++ b/jax/experimental/jax2tf/tests/flax_models/gnn.py @@ -16,8 +16,7 @@ https://github.com/google/flax/tree/main/examples/ogbg_molpcba """ -from collections.abc import Sequence -from typing import Callable +from collections.abc import Callable, Sequence from flax import linen as nn diff --git a/jax/experimental/jax2tf/tests/flax_models/resnet.py b/jax/experimental/jax2tf/tests/flax_models/resnet.py index bb6e519deceb..48829127b304 100644 --- a/jax/experimental/jax2tf/tests/flax_models/resnet.py +++ b/jax/experimental/jax2tf/tests/flax_models/resnet.py @@ -19,9 +19,9 @@ # See issue #620. # pytype: disable=wrong-arg-count -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Any, Callable +from typing import Any from flax import linen as nn import jax.numpy as jnp diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py b/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py index 334248219962..27535c784e89 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_lm1b.py @@ -24,7 +24,8 @@ from __future__ import annotations -from typing import Callable, Any, Optional +from collections.abc import Callable +from typing import Any from flax import linen as nn from flax import struct diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py b/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py index 04111e6c4d5b..cc78b5a41496 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_nlp_seq.py @@ -18,7 +18,8 @@ from __future__ import annotations -from typing import Callable, Any, Optional +from collections.abc import Callable +from typing import Any from flax import linen as nn from flax import struct diff --git a/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py b/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py index 58e50dacd914..1cdeffeb6ea9 100644 --- a/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py +++ b/jax/experimental/jax2tf/tests/flax_models/transformer_wmt.py @@ -24,7 +24,8 @@ from __future__ import annotations -from typing import Callable, Any, Optional +from collections.abc import Callable +from typing import Any from flax import linen as nn from flax import struct diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index dde28315251f..03e6086a4924 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -15,9 +15,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import itertools -from typing import Any, Callable, Optional, Union +from typing import Any import jax from jax import lax @@ -756,7 +756,11 @@ def fft(cls, harness): enabled=(str(harness.params["fft_type"]) in ["FftType.IFFT", "FftType.IRFFT"])), # TODO: very high tolerance - custom_numeric(tol=1e-3, modes=("eager", "graph", "compiled")), + custom_numeric(tol=1e-3, modes=("eager", "graph", "compiled"), + native_serialization=Jax2TfLimitation.FOR_NON_NATIVE), + custom_numeric(tol=1e-5, modes=("eager", "graph", "compiled"), + native_serialization=Jax2TfLimitation.FOR_NATIVE, + devices=("cpu",)), ] @classmethod diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 9b0765b4b673..26266f67d4f2 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -15,9 +15,7 @@ Specific JAX primitive conversion tests are in primitives_test.""" import collections -from collections.abc import Sequence import contextlib -import functools import math import os import re @@ -29,6 +27,7 @@ import jax from jax import ad_checkpoint from jax import dtypes +from jax import export from jax import lax from jax import numpy as jnp from jax import sharding @@ -39,7 +38,6 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax.experimental import jax2tf -from jax.experimental import export from jax.experimental.jax2tf.tests import tf_test_util from jax.experimental.shard_map import shard_map from jax.experimental import pjit @@ -52,6 +50,15 @@ # pylint: enable=g-direct-tensorflow-import config.parse_flags_with_absl() +_exit_stack = contextlib.ExitStack() + +# TODO(necula): Remove once tensorflow is 2.10.0 everywhere. +def setUpModule(): + if not hasattr(tfxla, "optimization_barrier"): + _exit_stack.enter_context(jtu.global_config_context(jax_remat_opt_barrier=False)) + +def tearDownModule(): + _exit_stack.close() class Jax2TfTest(tf_test_util.JaxToTfTestCase): @@ -1307,7 +1314,7 @@ def body_fun(carry): shape = (3, 2) x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - jax_comp = jax.xla_computation(f_while)(x) + jax_comp = jax.jit(f_while).lower(x).compiler_ir('hlo') backend = xb.get_backend() modules = backend.compile(jax_comp).hlo_modules() jax_opt_hlo = modules[0].to_string() @@ -1552,7 +1559,7 @@ def apply_transform(func, transform: str): # Run the JAX native version, to check it works, and to fill caches. _ = func_to_convert(*args) exported = export.export( - func_to_convert, + (jax.jit(func_to_convert) if not hasattr(func_to_convert, "trace") else func_to_convert), lowering_platforms=("tpu",) )(*(core.ShapedArray(a.shape, a.dtype) for a in args)) @@ -1779,7 +1786,4 @@ def test_simple(self): if __name__ == "__main__": - # TODO: Remove once tensorflow is 2.10.0 everywhere. - if not hasattr(tfxla, "optimization_barrier"): - jax.config.update("jax_remat_opt_barrier", False) absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/jax2tf/tests/model_harness.py b/jax/experimental/jax2tf/tests/model_harness.py index 9af7229c0530..91aacf2f596f 100644 --- a/jax/experimental/jax2tf/tests/model_harness.py +++ b/jax/experimental/jax2tf/tests/model_harness.py @@ -15,10 +15,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools -from typing import Any, Callable, Optional, Union +from typing import Any import re import numpy as np diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 3c5f0bd660d0..22315f04c881 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -179,9 +179,12 @@ def test_primitive_coverage(self): # TODO: Remove once tensorflow is 2.10.0 everywhere. if p.name == "optimization_barrier": continue - if p.name == "debug_callback": + if p.name == "debug_callback" or p.name == "debug_print": # TODO(sharadmv,necula): enable debug callbacks in TF continue + if p.name in ("max_contiguous", "multiple_of"): + # Pallas-specific primitives are not supported. + continue if p.name == "pallas_call": continue if p.name in tf_not_yet_impl: diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index f5a32a5d771f..83aac43f2d9d 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -15,10 +15,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import math -from typing import Any, Callable +from typing import Any import unittest from absl import logging @@ -32,9 +32,8 @@ import jax from jax.experimental import jax2tf -from jax.experimental import export -from jax.experimental.export import _shape_poly as shape_poly from jax.experimental import pjit +from jax import export from jax import lax import jax.numpy as jnp from jax import random @@ -43,6 +42,7 @@ from jax._src import core from jax._src import test_util as jtu from jax._src import util +from jax._src.export import shape_poly from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow from jax._src.lib import xla_client @@ -615,7 +615,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer poly_spec="(a + 2*b, a, a + b + c)", @@ -624,7 +624,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Division had remainder 1 when computing the value of 'b'. " "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency poly_spec="(a + 2*b, a, a + b)", @@ -634,7 +634,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c poly_spec="(2 * a + b, a, c * c)", @@ -643,7 +643,7 @@ def conv_and_run(*, arg_shape: core.Shape, "We can only solve linear uni-variate constraints. " "Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). " "Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." )), ]) def test_shape_constraints_errors(self, *, @@ -2523,6 +2523,12 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] lambda x: jnp.tri(x.shape[0], M=x.shape[0] + 2) + x, arg_descriptors=[RandArg((3, 1), _f32)], polymorphic_shapes=["b, ..."]), + PolyHarness("tril", "", + lambda x: jnp.tril(jnp.ones((x.shape[0], x.shape[0] + x.shape[1]), + dtype=_f32), + k=x.shape[1]), + arg_descriptors=[RandArg((3, 4), _f32)], + polymorphic_shapes=["m, n"]), [ PolyHarness("triangular_solve", f"shape={jtu.format_shape_dtype_string(a_shape, dtype)}_{left_side=}_{a_poly=}_{b_poly=}", @@ -2809,17 +2815,8 @@ def test_harness(self, harness: PolyHarness): if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]): raise unittest.SkipTest("JAX implements eig only on CPU.") - prev_jax_config_flags = { - fname: getattr(jax.config, fname) - for fname, fvalue in harness.override_jax_config_flags.items() - } - try: - for fname, fvalue in harness.override_jax_config_flags.items(): - jax.config.update(fname, fvalue) + with jtu.global_config_context(**harness.override_jax_config_flags): harness.run_test(self) - finally: - for fname, _ in harness.override_jax_config_flags.items(): - jax.config.update(fname, prev_jax_config_flags[fname]) if __name__ == "__main__": diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 524308367fc5..b6750133090e 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -34,7 +34,6 @@ from jax._src import compiler from jax._src import config from jax._src import maps -from jax._src.maps import xmap from jax._src import test_util as jtu from jax._src import xla_bridge from jax import lax @@ -55,13 +54,13 @@ # Must come after initializing the flags from jax.experimental.jax2tf.tests import tf_test_util -prev_xla_flags = None -prev_spmd_lowering_flag = None - +_exit_stack = contextlib.ExitStack() topology = None def setUpModule(): - global prev_xla_flags, topology + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + + global topology if jtu.test_device_matches(["tpu"]): resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') tf.config.experimental_connect_to_cluster(resolver) @@ -70,26 +69,8 @@ def setUpModule(): else: topology = None - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - global prev_spmd_lowering_flag - prev_spmd_lowering_flag = maps.SPMD_LOWERING.value - config.update('experimental_xmap_spmd_lowering', True) - - def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() - config.update('experimental_xmap_spmd_lowering', prev_spmd_lowering_flag) + _exit_stack.close() class ShardingTest(tf_test_util.JaxToTfTestCase): @@ -455,22 +436,26 @@ def f_grad_tf(x_v, res_ct): ]) def test_grad_sharding_different_mesh(self): - self.skipTest("TODO: fix the plumbing of device_assignment for jax2tf: https://github.com/google/jax/pull/21319") # Convert with two similar meshes, the only difference being # the order of the devices. grad should not fail. # https://github.com/google/jax/issues/21314 + devices = jax.local_devices()[:2] + if len(devices) < 2: + raise unittest.SkipTest("Test requires 2 local devices") def f_jax(x): return jnp.sum(x * 2.) - mesh = Mesh(jax.local_devices(), "i") + mesh = Mesh(devices, "i") # The same mesh with reversed order of devices - mesh_rev = Mesh(list(reversed(jax.local_devices())), "i") + mesh_rev = Mesh(list(reversed(devices)), "i") shardings = NamedSharding(mesh, jax.sharding.PartitionSpec(("i",))) shardings_rev = NamedSharding(mesh_rev, jax.sharding.PartitionSpec(("i",))) - f_tf = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings))) - f_tf_rev = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings_rev))) - inp = np.ones((jax.local_device_count(), 4), dtype=np.float32) + f_tf = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings)), + autograph=False) + f_tf_rev = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings_rev)), + autograph=False) + inp = np.ones((2, 4), dtype=np.float32) input_v = tf.Variable(inp) with tf.GradientTape(persistent=True) as tape: @@ -532,115 +517,6 @@ def f_nested_pjit_replicated(a): "function with sharded arguments or results must be used under a `tf.function` context"): jax2tf.convert(f_jax)(a) - def test_xmap_basic(self): - devices = np.reshape(self.devices, (1, 2)) - ashape = (16, 8, 5) - a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape) - bshape = (2, 7) - b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape) - - # f_jax: f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28] - # lambda ...: f32[5], f32[7] -> f32[10], f32[28] - f_jax = xmap(lambda a, b: (jnp.concatenate([a, a], axis=0) * 2., - jnp.concatenate([b, b, b, b], axis=0) * 4.), - in_axes=({0: 'a', 1: 'b'}, ['c', ...]), - out_axes=({0: 'a', 1: 'b'}, ['c', ...]), - axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) - - @tf.function(autograph=False, jit_compile=True) - def f_tf(a, b): - # xmap works only with native serialization - f_converted = jax2tf.convert(f_jax, native_serialization=True) - if jtu.test_device_matches(["tpu"]): - res = tf.compat.v1.tpu.rewrite( - f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)], - device_assignment=self.device_assignment( - computation_shape=[1, 1, 1, 2]) - ) - return (res[0], res[1]) - else: - return f_converted(a, b) - - with Mesh(devices, ('x', 'y')): - res_jax = f_jax(a, b) - self.assertAllClose(res_jax, (jnp.concatenate([a, a], axis=2) * 2., - jnp.concatenate([b, b, b, b], axis=1) * 4.)) - res_tf = f_tf(a, b) - self.assertAllClose(res_tf, res_jax) - - self.check_sharding( - jax2tf.convert(f_jax, native_serialization=True), [a, b], - checks=[ - (r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", 1), - # The output sharding - (r"f32\[2,7\].*custom_call_target.*Sharding.*sharding.*replicated", 1), - (r"f32\[2,28\].*custom_call_target.*Sharding.*sharding.*replicated", 1), - ]) - - def test_xmap_collective_reduce(self): - devices = np.reshape(self.devices, (1, 2)) - ashape = (16, 8, 5) - a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape) - bshape = (2, 7) - b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape) - f_jax = xmap(lambda a, b: (lax.psum(a * 2., 'a'), b * 4.), - in_axes=(['a', 'b', ...], {0: 'c'}), - out_axes=(['b', ...], {0: 'c'}), - axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) - - @tf.function(autograph=False, jit_compile=True) - def f_tf(a, b): - f_converted = jax2tf.convert(f_jax, native_serialization=True) - if jtu.test_device_matches(["tpu"]): - res = tf.compat.v1.tpu.rewrite( - f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)], - device_assignment=self.device_assignment( - computation_shape=[1, 1, 1, 2]) - ) - return (res[0], res[1]) - else: - return f_converted(a, b) - - with Mesh(devices, ('x', 'y')): - res_jax = f_jax(a, b) - self.assertAllClose(res_jax, ((a * 2.).sum(0), b * 4.)) - res_tf = f_tf(a, b) - self.assertAllClose(res_tf, res_jax) - self.check_sharding( - jax2tf.convert(f_jax, native_serialization=True), [a, b], - checks=[ - (r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", 1), - (r"f32\[2,7\].*custom_call_target.*Sharding.*sharding.*replicated", 2), - (r"f32\[8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1), - ]) - - def test_grad_xmap(self): - devices = np.reshape(self.devices, (1, 2)) - ashape = (16, 8, 5) - a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape) - - # f_jax: f32[16,8,5]-> f32[16,8,10] - # lambda ...: f32[5]-> f32[10] - f_jax = xmap(lambda a: jnp.concatenate([a, a], axis=0) * 2., - in_axes=({0: 'a', 1: 'b'}), - out_axes={0: 'a', 1: 'b'}, - axis_resources={'a': 'x', 'b': 'y'}) - - def f_grad_tf(a, res_ct): - with tf.GradientTape(persistent=True) as tape: - tape.watch(a) - res_tf = jax2tf.convert(f_jax, native_serialization=True)(a) - return tape.gradient(res_tf, a, output_gradients=res_ct) - - with Mesh(devices, ('x', 'y')): - self.check_sharding(f_grad_tf, [a, np.concatenate([a, a], axis=2)], - checks=[ - # Primal input and grad output - (r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(2)), - # Input cotangent - (r"f32\[16,8,10\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(1)), - ]) - @jtu.ignore_warning(category=UserWarning, message="all_to_all .* are only implemented properly for TPUs and GPUs .*") def test_shmap_all_to_all(self): diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 7ad8a90da73f..32f89e533daf 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -14,12 +14,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import contextlib import dataclasses import re import os -from typing import Any, Callable, Optional +from typing import Any from absl.testing import absltest from absl import logging @@ -31,7 +31,7 @@ from jax import tree_util from jax.experimental import jax2tf -from jax.experimental import export +from jax import export from jax._src import config from jax._src import xla_bridge import numpy as np @@ -180,18 +180,18 @@ def setUp(self): # We run the tests using the maximum version supported, even though # the default serialization version may be held back for a while to # ensure compatibility - version = config.jax_serialization_version.value + version = config.jax_export_calling_convention_version.value if self.use_max_serialization_version: # Use the largest supported by both export and tfxla.call_module - version = min(export.maximum_supported_serialization_version, + version = min(export.maximum_supported_calling_convention_version, tfxla.call_module_maximum_supported_version()) self.assertGreaterEqual(version, - export.minimum_supported_serialization_version) - self.enter_context(config.jax_serialization_version(version)) + export.minimum_supported_calling_convention_version) + self.enter_context(config.jax_export_calling_convention_version(version)) logging.info( "Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)", version, - export.maximum_supported_serialization_version, + export.maximum_supported_calling_convention_version, tfxla.call_module_maximum_supported_version()) with contextlib.ExitStack() as stack: diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index ac23debd6d27..1ed6183b1229 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -52,7 +52,8 @@ `outstanding primitive rules `__. """ -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from functools import partial diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 4250ba9b2677..b4989e151a53 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -15,8 +15,9 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Callable, Iterator from functools import partial, reduce, total_ordering, wraps -from typing import Any, Callable, Iterator, NamedTuple +from typing import Any, NamedTuple import jax from jax import lax diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 0060a954f543..dd112db3b269 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -17,11 +17,11 @@ from __future__ import annotations import collections -from collections.abc import Sequence +from collections.abc import Callable, Generator, MutableMapping, Sequence import itertools import logging import math -from typing import Any, Callable, Generator, MutableMapping +from typing import Any from jax._src import xla_bridge as xb import numpy as np @@ -31,6 +31,7 @@ _TPU_V2 = 'TPU v2' _TPU_V3 = 'TPU v3' _TPU_V4 = 'TPU v4' +_TPU_V5_LITE = "TPU v5 lite" # Maps physical topology -> mesh shape -> transpose to use for jekbradbury's # famous contiguous mesh trick. @@ -64,7 +65,8 @@ # Physical ordering of core IDs in a tray that creates a ring _TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5) - +_TRAY_2x2_RING_ORDER = (0, 1, 3, 2) +_TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4) def _tpu_v2_v3_create_device_mesh( mesh_shape: Sequence[int], @@ -94,6 +96,45 @@ def _tpu_v2_v3_create_device_mesh( return np.asarray(devices).reshape(mesh_shape) +def _vlc_create_device_mesh( + mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs +) -> np.ndarray | None: + """Creates rotated pincer device assignment for selected topologies. + + Args: + mesh_shape: Logical mesh shape used by the model. + devices: TPU devices. + **unused_kwargs: ... + + Returns: + None or reordered devices reshaped as `mesh_shape`. + """ + max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices) + bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1 + # Our ring re-ordering makes sense only if the passed-in devices are + # sequential, which may not always be the case. reversed() changes z-minor to + # x-minor. + sequential_devices = sorted( + devices, + key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0))))) + + if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: # VLC2x2 + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # VLP4x4 + # Only uses ring order if the whole mesh is a replica group. + if max(mesh_shape) == len(devices): + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_TRAY_4x4_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + return None + + # Registers functions to create device mesh for specific device kinds. Takes # precedence over the more general logic in create_device_mesh(). Handler may # return None; in that case, it will fall back to using the default logic. @@ -103,6 +144,7 @@ def _tpu_v2_v3_create_device_mesh( ] = { _TPU_V2: _tpu_v2_v3_create_device_mesh, _TPU_V3: _tpu_v2_v3_create_device_mesh, + _TPU_V5_LITE: _vlc_create_device_mesh, } diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index f2c5b30a2e93..2cf9e9d4ff39 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -1,4 +1,3 @@ -from collections.abc import Callable # Copyright 2024 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,23 +13,24 @@ # limitations under the License. # ============================================================================== +from collections.abc import Callable, Sequence import contextlib import ctypes import dataclasses import functools +import itertools import os import pathlib import subprocess import tempfile import time -from typing import Any, Generic, Sequence, TypeVar +from typing import Any, Generic, TypeVar import jax from jax._src import config from jax._src import core as jax_core from jax._src.interpreters import mlir from jax._src.lib import xla_client -from jax._src.lib import mosaic_gpu as mosaic_gpu_lib from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import builtin @@ -40,7 +40,6 @@ from jaxlib.mlir.dialects import memref from jaxlib.mlir.dialects import nvgpu from jaxlib.mlir.dialects import nvvm -from jaxlib.mlir.execution_engine import ExecutionEngine from jaxlib.mlir.passmanager import PassManager import numpy as np @@ -68,34 +67,20 @@ c = mgpu.c # This is too common to fully qualify. -xla_client.register_custom_call_target( - "mosaic_gpu", - mosaic_gpu_lib._mosaic_gpu_ext._custom_call_capsule(), - platform="CUDA", -) -mosaic_gpu_lib._mosaic_gpu_ext.register_passes() - - -mosaic_gpu_dump_ptx = config.define_bool_state( - name="mosaic_gpu_dump_ptx", - default=config.bool_env("MOSAIC_GPU_DUMP_PTX", False), - help="If set, prints the kernel PTX", -) -mosaic_gpu_dump_ptxas = config.define_bool_state( - name="mosaic_gpu_dump_ptxas", - default=config.bool_env("MOSAIC_GPU_DUMP_PTXAS", False), - help="If set, prints the ptxas verbose output", -) -mosaic_gpu_dump_sass = config.define_bool_state( - name="mosaic_gpu_dump_sass", - default=config.bool_env("MOSAIC_GPU_DUMP_SASS", False), - help="If set, prints the kernel SASS", -) -mosaic_gpu_print_after_all = config.define_bool_state( - name='mosaic_gpu_print_after_all', - default=config.bool_env('MOSAIC_GPU_PRINT_AFTER_ALL', False), - help="If set, prints the kernel module after every pass", -) +RUNTIME_PATH = None +try: + from jax._src.lib import mosaic_gpu as mosaic_gpu_lib + + RUNTIME_PATH = ( + pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent + / "libmosaic_gpu_runtime.so" + ) +except ImportError: + pass + +if RUNTIME_PATH and RUNTIME_PATH.exists(): + # Set this so that the custom call can find it + os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p") @@ -107,30 +92,12 @@ def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes): del module, gmem_scratch_bytes # Unused. return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] +# TODO(apaszke): Implement a proper system for managing kernel lifetimes +kernel_idx = itertools.count() def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes): del out_types # Unused. - runtime_path = ( - pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent.parent.parent - / "mosaic" / "gpu" / "libmosaic_gpu_runtime.so" - ) - shared_libs = [str(runtime_path)] if runtime_path.exists() else [] - engine = ExecutionEngine( - module, opt_level=3, shared_libs=shared_libs, enable_object_dump=False - ) - ctx.module_context.add_keepalive(engine) - launch_func_ptr = ctypes.cast(engine.lookup("main"), ctypes.c_void_p) - init_func_ptr = ctypes.cast(engine.lookup("main_init"), ctypes.c_void_p) - # Make sure we won't get accidental hits due to address reuse. - mosaic_gpu_lib._mosaic_gpu_ext.invalidate_cache(init_func_ptr.value) - - trampoline_args = (ctypes.c_void_p * 2)() - trampoline_args[0] = launch_func_ptr - trampoline_args[1] = init_func_ptr - ctx.module_context.add_keepalive(trampoline_args) - ptr_bytes = ctypes.cast(trampoline_args, ctypes.c_void_p).value.to_bytes( - 8, byteorder="little" - ) # pytype: disable=attribute-error + idx_bytes = next(kernel_idx).to_bytes(8, byteorder="little") op = mlir.custom_call( "mosaic_gpu", result_types=[ @@ -140,7 +107,8 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes) ), ], operands=args, - backend_config=ptr_bytes, + backend_config=idx_bytes + + module.operation.get_asm(binary=True, enable_debug_info=True), ) return op.results[:-1] # Skip the scratch space. @@ -241,7 +209,6 @@ def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: @dataclasses.dataclass() class LaunchContext: launch_op: gpu.LaunchOp - gmem_scratch_ptr: ir.Value profiler: OnDeviceProfiler | None = None next_scratch_offset: int = 0 host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field( @@ -281,14 +248,21 @@ def _alloc_scratch( alloc_base = self.next_scratch_offset self.next_scratch_offset += size def host_init_wrapped(host_ptr): - with ir.InsertionPoint(self.launch_op): - host_init( - llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) - ) + host_init( + llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) + ) self.host_scratch_init.append(host_init_wrapped) + with ir.InsertionPoint.at_block_begin(self.launch_op.body.blocks[0]): + ptr_ty = ir.Type.parse("!llvm.ptr") + const_ptr_ty = ir.Type.parse("!llvm.ptr<4>") + gmem_scratch_ptr = llvm.call_intrinsic( + ptr_ty, + "llvm.nvvm.ptr.constant.to.gen.p0.p4", + [llvm.mlir_addressof(const_ptr_ty, "global_scratch")], + ) return device_init(llvm.getelementptr( - ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8 + ptr_ty, gmem_scratch_ptr, [], [alloc_base], i8 )) def _get_tma_desc( @@ -317,43 +291,43 @@ def _get_tma_desc( ref = t.apply(ref) ref_ty = ir.MemRefType(ref.type) - i64 = ir.IntegerType.get_signless(64) - ptr_ty = ir.Type.parse("!llvm.ptr") - def init_tma_desc(host_ptr): - _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) - aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) - as_i64 = lambda i: arith.index_cast(i64, i) - alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) - llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... - base_ptr = llvm.getelementptr( - ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, - ) - rank = ref_ty.rank - assert rank * 2 == len(sizes_and_strides) - args = [ - host_ptr, - base_ptr, - c(utils.bytewidth(ref_ty.element_type), i64), - c(rank, i64), - utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), - utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), - c(0 if swizzle is None else swizzle, i64), - utils.pack_array([c(v, i64) for v in transformed_slice_shape]), - ] - func.call([], "mosaic_gpu_init_tma_desc", args) - def cast_tma_desc(device_ptr): - # TODO(apaszke): Investigate why prefetching can cause launch failures - # nvvm.prefetch_tensormap(device_ptr) - return builtin.unrealized_conversion_cast( - [tensor_map_ty], [device_ptr] - ) - tma_desc = self._alloc_scratch( - TMA_DESCRIPTOR_BYTES, - alignment=TMA_DESCRIPTOR_ALIGNMENT, - host_init=init_tma_desc, - device_init=cast_tma_desc, + i64 = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr") + def init_tma_desc(host_ptr): + _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) + aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) + as_i64 = lambda i: arith.index_cast(i64, i) + alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) + llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... + base_ptr = llvm.getelementptr( + ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, + ) + rank = ref_ty.rank + assert rank * 2 == len(sizes_and_strides) + args = [ + host_ptr, + base_ptr, + c(utils.bytewidth(ref_ty.element_type), i64), + c(rank, i64), + utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), + utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), + c(0 if swizzle is None else swizzle, i64), + utils.pack_array([c(v, i64) for v in transformed_slice_shape]), + ] + func.call([], "mosaic_gpu_init_tma_desc", args) + def cast_tma_desc(device_ptr): + # TODO(apaszke): Investigate why prefetching can cause launch failures + # nvvm.prefetch_tensormap(device_ptr) + return builtin.unrealized_conversion_cast( + [tensor_map_ty], [device_ptr] ) - self.tma_descriptors[tma_desc_key] = tma_desc + tma_desc = self._alloc_scratch( + TMA_DESCRIPTOR_BYTES, + alignment=TMA_DESCRIPTOR_ALIGNMENT, + host_init=init_tma_desc, + device_init=cast_tma_desc, + ) + self.tma_descriptors[tma_desc_key] = tma_desc return tma_desc def async_copy( @@ -430,7 +404,11 @@ def async_copy( # nvgpu TMA instructions expect reversed indices... rev_dyn_based_indices = reversed(dyn_base_indices) - uniform_ctx = mgpu.once if uniform else contextlib.nullcontext + uniform_ctx = ( + functools.partial(mgpu.single_thread, per_block=False) + if uniform + else contextlib.nullcontext + ) if gmem_ref is src_ref: with uniform_ctx(): @@ -498,7 +476,6 @@ def _launch( token, grid, block, - gmem_scratch_ptr, smem_buffers: ShapeTree | Union[ShapeTree], profiler_spec: profiler.ProfilerSpec | None = None, maybe_prof_buffer: ir.Value | None = None, @@ -523,7 +500,7 @@ def _launch( smem_bytes = compute_smem_bytes if profiler_spec is not None: - smem_bytes += profiler_spec.smem_bytes(grid) + smem_bytes += profiler_spec.smem_bytes(block=block) # TODO(cperivol): Query the shared memory size programmatically. if smem_bytes > 228 * 1024: @@ -549,7 +526,7 @@ def _launch( if profiler_spec: prof_smem = memref.view( ir.MemRefType.get( - (profiler_spec.smem_i32_elements(grid=grid),), + (profiler_spec.smem_i32_elements(block=block),), i32, memory_space=smem, ), dynamic_smem, c(compute_smem_bytes, index), [], @@ -565,9 +542,9 @@ def _launch( else: smem_ref_tree: RefTree = smem_ref_trees[0] if smem_ref_trees else [] - yield LaunchContext(launch_op, gmem_scratch_ptr, prof), smem_ref_tree + yield LaunchContext(launch_op, prof), smem_ref_tree if prof is not None: - prof.finalize(grid=grid) + prof.finalize(grid=grid, block=block) gpu.terminator() @@ -583,6 +560,7 @@ def _lower_as_gpu_kernel( ptr_ty = ir.Type.parse("!llvm.ptr") token_ty = ir.Type.parse("!gpu.async.token") i8 = ir.IntegerType.get_signless(8) + i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: @@ -598,30 +576,32 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: unwrap_output_tuple = True out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape] if prof_spec is not None: - out_shape = (*out_shape, prof_spec.jax_buffer_type) - out_ref_tys.append(prof_spec.mlir_buffer_type) + out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block)) + out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) module = ir.Module.create() with ir.InsertionPoint(module.body): _declare_runtime_functions() gmem_scratch_bytes = 0 - @func.FuncOp.from_py_func(ptr_ty, ptr_ty) - def main(token_ptr, buffers): + global_scratch = llvm.GlobalOp( + ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet. + "global_scratch", + ir.Attribute.parse("#llvm.linkage"), + addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. + ) + @func.FuncOp.from_py_func(ptr_ty, ptr_ty, ptr_ty) + def main(token_ptr, buffers, gmem_scratch_ptr): nonlocal gmem_scratch_bytes token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) arg_refs = [] - i = -1 for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) - gmem_scratch_ptr = llvm.LoadOp( - ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i + 1], ptr_ty) - ) in_refs = arg_refs[:len(in_ref_tys)] out_refs = arg_refs[len(in_ref_tys):] prof_buffer = out_refs.pop() if prof_spec is not None else None with _launch( - token, grid, block, gmem_scratch_ptr, smem_scratch_shape, + token, grid, block, smem_scratch_shape, prof_spec, prof_buffer ) as (launch_ctx, smem_refs): body(launch_ctx, *in_refs, *out_refs, smem_refs) @@ -633,6 +613,9 @@ def main(token_ptr, buffers): host_scratch_ptr = llvm.alloca(ptr_ty, c(gmem_scratch_bytes, i64), i8) for init_callback in launch_ctx.host_scratch_init: init_callback(host_scratch_ptr) + global_scratch.global_type = ir.TypeAttr.get( + ir.Type.parse("!llvm.array<" + str(gmem_scratch_bytes) + " x i8>") + ) func.call( [], "mosaic_gpu_memcpy_async_h2d", @@ -644,15 +627,11 @@ def main(token_ptr, buffers): ], ) main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + sym_tab = ir.SymbolTable(module.operation) + sym_tab.insert(main.func_op) + sym_tab.insert(global_scratch) module.operation.verify() - dump_low_level(module) - - pass_manager = _get_mosaic_gpu_pipeline("fatbin") - if mosaic_gpu_print_after_all.value: - pass_manager.enable_ir_printing() - pass_manager.run(module.operation) - return module, out_shape, gmem_scratch_bytes, unwrap_output_tuple @@ -705,7 +684,7 @@ def dump_profile(prof_buffer): ) try: with open(out_file, "x") as f: - prof_spec.dump(prof_buffer, f) + prof_spec.dump(prof_buffer, f, grid=grid, block=block) except FileExistsError: pass # TODO: Retry jax.debug.callback(dump_profile, prof_buffer) @@ -733,82 +712,3 @@ def _declare_runtime_functions(): func.FuncOp( "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" ) - - -def dump_low_level(module): - dump_ptx = mosaic_gpu_dump_ptx.value - dump_ptxas = mosaic_gpu_dump_ptxas.value - dump_sass = mosaic_gpu_dump_sass.value - if not any([dump_ptx, dump_ptxas, dump_sass]): - return - module = ir.Module.parse( - module.operation.get_asm(binary=True, enable_debug_info=True) - ) - pm = _get_mosaic_gpu_pipeline("isa") - pm.run(module.operation) - - for op in module.body: - if op.OPERATION_NAME == "gpu.binary": - objects = ir.ArrayAttr(op.objects) - if len(objects) != 1: - raise NotImplementedError("Expected a single object") - obj = str(objects[0]) - start = obj.find('assembly = "') + len('assembly = "') - end = obj.find('"', start) - ptx = obj[start:end] - ptx = ptx.replace("\\09", "\t").replace("\\0A", "\n")[:-3] - if dump_ptx: - print(ptx) - if dump_ptxas or dump_sass: - with tempfile.TemporaryDirectory() as tmp: - ptx_path = os.path.join(tmp, "kernel.ptx") - with open(ptx_path, "w") as f: - f.write(ptx) - elf_path = os.path.join(tmp, 'kernel.o') - v_flag = "-v" if dump_ptxas else "" - ptxas_flags = f"{v_flag} --opt-level 3 --gpu-name sm_90a" - ptxas_out = subprocess.check_output( - f"{PTXAS_PATH} {ptxas_flags} --output-file {elf_path} {ptx_path}", - stderr=subprocess.STDOUT, - shell=True, - ) - if dump_ptxas: - print(ptxas_out.decode()) - if dump_sass: - sass = subprocess.check_output( - f"{NVDISASM_PATH} -ndf -c {elf_path}", - stderr=subprocess.STDOUT, - shell=True, - ) - print(sass.decode()) - - -def _get_mosaic_gpu_pipeline(kernel_format) -> PassManager: - passes = [ - "convert-nvgpu-to-nvvm", - "gpu-kernel-outlining{data-layout-str=}", - "convert-vector-to-scf{full-unroll=false lower-tensors=false target-rank=1}", - "convert-scf-to-cf", - "convert-nvvm-to-llvm", - "expand-strided-metadata", - "nvvm-attach-target{O=3 chip=sm_90a fast=false features=+ptx80 ftz=false module= triple=nvptx64-nvidia-cuda}", - "lower-affine", - "convert-arith-to-llvm{index-bitwidth=0}", - "convert-index-to-llvm{index-bitwidth=64}", - "canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true}", - "cse", - "gpu.module(strip-debuginfo)", - "gpu.module(convert-gpu-to-nvvm{has-redux=false index-bitwidth=64 use-bare-ptr-memref-call-conv=false})", - "gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true})", - "gpu.module(cse)", - "gpu.module(reconcile-unrealized-casts)", - "gpu-to-llvm{gpu-binary-annotation=gpu.binary use-bare-pointers-for-host=false use-bare-pointers-for-kernels=false}", - "gpu-module-to-binary{format=" + kernel_format + "}", - "convert-math-to-llvm{approximate-log1p=true}", - "canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true}", - "cse", - "reconcile-unrealized-casts", - *(["gpu-launch-lowering"] if kernel_format in {"bin", "fatbin"} else []), - "convert-func-to-llvm{index-bitwidth=0 use-bare-ptr-memref-call-conv=false}", - ] - return PassManager.parse(f"builtin.module({','.join(passes)})") diff --git a/jax/experimental/mosaic/gpu/dsl.py b/jax/experimental/mosaic/gpu/dsl.py index 8bb3eca4d367..bd8960c74c2b 100644 --- a/jax/experimental/mosaic/gpu/dsl.py +++ b/jax/experimental/mosaic/gpu/dsl.py @@ -37,9 +37,9 @@ memref_transpose, memref_unfold, memref_unsqueeze, - once, - tile_shape, + single_thread, thread_idx, + tile_shape, warp_idx, warpgroup_idx, ) diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD index 0aa12e78861f..3f9496b38376 100644 --- a/jax/experimental/mosaic/gpu/examples/BUILD +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_library") +load("//jaxlib:jax.bzl", "py_deps") +load("@rules_python//python:defs.bzl", "py_library", "py_test") licenses(["notice"]) @@ -46,3 +47,19 @@ py_library( "//third_party/py/jax:mosaic_gpu", ], ) + +py_test( + name = "run_matmul", + srcs = ["matmul.py"], + main = "matmul.py", + tags = [ + "manual", + "notap", + "requires-gpu-sm90-only", + ], + deps = [ + "//learning/brain/research/jax:gpu_support", + "//third_party/py/jax", + "//third_party/py/jax:mosaic_gpu", + ] + py_deps("numpy"), +) diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index 4c155aa68819..3e1f08868706 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -13,9 +13,11 @@ # limitations under the License. # ============================================================================== +import contextlib import dataclasses import enum import itertools +import os from absl import app import jax @@ -47,7 +49,7 @@ class BlockSizes: # TODO(apaszke): Implement a Q-scaled, base2 exp implementation. -class ExpImplementation(enum.StrEnum): +class ExpImplementation(enum.Enum): EXACT = enum.auto() APPROX = enum.auto() @@ -55,6 +57,7 @@ class ExpImplementation(enum.StrEnum): def build_kernel( batch_size: int, q_heads: int, + kv_heads: int, q_seq_len: int, kv_seq_len: int, head_dim: int, @@ -62,17 +65,15 @@ def build_kernel( prof_spec: profiler.ProfilerSpec | None = None, exp_impl: ExpImplementation = ExpImplementation.EXACT, ): - q_shape = jax.ShapeDtypeStruct( - (q_heads, q_seq_len, head_dim), jnp.float16 - ) - kv_shape = jax.ShapeDtypeStruct( - (1, kv_seq_len, head_dim), jnp.float16 - ) + wgs_per_block = 2 + if batch_size != 1: raise NotImplementedError if blocks.stages < 2: raise ValueError("Kernel requires at least 2 stages.") - if q_seq_len % blocks.q: + if q_heads % kv_heads: + raise ValueError("kv_heads must divide q_heads.") + if q_seq_len % (blocks.q * wgs_per_block): raise ValueError if kv_seq_len % blocks.kv: raise ValueError @@ -85,13 +86,21 @@ def build_kernel( if blocks.stages * blocks.kv > kv_seq_len: raise NotImplementedError + q_shape = jax.ShapeDtypeStruct( + (q_heads, q_seq_len, head_dim), jnp.float16 + ) + kv_shape = jax.ShapeDtypeStruct( + (kv_heads, kv_seq_len, head_dim), jnp.float16 + ) + q_heads_per_kv_head = q_heads // kv_heads + def exp(x: FragmentedArray) -> FragmentedArray: return x.exp(approx=exp_impl == ExpImplementation.APPROX) block_partition = Partition( elements=(batch_size, q_seq_len, q_heads), partition=(0, 1, 2), - chunk_size=(1, blocks.q, 1), + chunk_size=(1, blocks.q * wgs_per_block, 1), ) index = ir.IndexType.get() @@ -99,10 +108,10 @@ def exp(x: FragmentedArray) -> FragmentedArray: f32 = ir.F32Type.get() grid = block_partition.num_chunks - block = (128, 1, 1) + block = (wgs_per_block * 128, 1, 1) tiling = (64, 64) qo_scratch = jax.ShapeDtypeStruct( - tile_shape((blocks.q, head_dim), tiling), jnp.float16 + (wgs_per_block, *tile_shape((blocks.q, head_dim), tiling)), jnp.float16 ) k_scratch = jax.ShapeDtypeStruct( tile_shape((blocks.stages, head_dim, blocks.kv), tiling), jnp.float16 @@ -129,28 +138,48 @@ def kernel( out_gmem, smem_scratch, ): - barriers = BarrierArray(blocks.stages + 1) + barriers = BarrierArray(blocks.stages + wgs_per_block) + schedule_barrier = BarrierArray(1, arrival_count=256)[0] + def perform_schedule_barrier(): + schedule_barrier.arrive() + schedule_barrier.wait() + wg_idx = warpgroup_idx(sync=True) qo_smem, k_smem, v_smem = smem_scratch + qo_smem = memref_slice(qo_smem, arith.index_cast(index, wg_idx)) + + @contextlib.contextmanager + def only_wg(idx): + i32 = ir.IntegerType.get_signless(32) + is_wg = arith.cmpi(arith.CmpIPredicate.eq, wg_idx, c(idx, i32)) + with ir.InsertionPoint(scf.IfOp(is_wg).then_block): + yield + scf.yield_([]) - batch_idx, q_seq_base, head_idx = block_partition.get_base( + batch_idx, q_seq_base, q_head_idx = block_partition.get_base( gpu.block_id(gpu.Dimension.x), gpu.block_id(gpu.Dimension.y), gpu.block_id(gpu.Dimension.z), ) + q_seq_base = arith.addi( + q_seq_base, arith.muli(arith.index_cast(index, wg_idx), c(blocks.q)) + ) del batch_idx + q_barrier = arith.addi(c(blocks.stages), arith.index_cast(index, wg_idx)) with ctx.named_region("Q TMA start"): ctx.async_copy( src_ref=q_gmem, - gmem_slice=(head_idx, ds(q_seq_base, blocks.q)), + gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), gmem_transform=mosaic_gpu.TileTransform(tiling), dst_ref=qo_smem, - barrier=barriers[blocks.stages], + barrier=barriers[q_barrier], swizzle=128, ) + kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) + def kv_copy_init(slot, kv_seq_base): - with once(): + with single_thread(per_block=False): txcount = c(2 * blocks.kv * head_dim * bytewidth(f16)) nvgpu.mbarrier_arrive_expect_tx(barriers.value, txcount, slot) k_tr = ( @@ -162,7 +191,7 @@ def kv_copy_init(slot, kv_seq_base): ctx.async_copy( dst_ref=memref_slice(smem, slot), src_ref=gmem, - gmem_slice=(0, ds(kv_seq_base, blocks.kv)), + gmem_slice=(kv_head_idx, ds(kv_seq_base, blocks.kv)), gmem_transform=t, barrier=barriers[slot], arrive=False, @@ -171,12 +200,12 @@ def kv_copy_init(slot, kv_seq_base): ) loop_partition = Partition1D(kv_seq_len, chunk_size=blocks.kv) - with ctx.named_region("KV TMA warmup"): + with only_wg(1), ctx.named_region("KV TMA warmup"): for i in range(blocks.stages - 1): kv_copy_init(c(i), loop_partition.get_base(c(i))) with ctx.named_region("Q TMA wait"): - barriers[blocks.stages].wait() + barriers[q_barrier].wait() m_i = FragmentedArray.splat( c(-jnp.inf, f32), shape=(blocks.q,), layout=WGMMA_ROW_LAYOUT @@ -188,7 +217,10 @@ def kv_copy_init(slot, kv_seq_base): c(0, f32), shape=(blocks.q, head_dim), layout=WGMMA_LAYOUT ) - with ctx.named_region("KV TMA wait"): + with only_wg(1): + perform_schedule_barrier() + + with only_wg(0): barriers[c(0)].wait() @fori(c(loop_partition.num_chunks), (acc, m_i, l_i)) @@ -204,7 +236,7 @@ def kv_loop(kv_step, carry): nvvm.wgmma_commit_group_sync_aligned() # We hide the TMA overhead by overlapping it with the QK matmul. - with ctx.named_region("KV TMA start"): + with only_wg(1), ctx.named_region("KV TMA start"): tma_step = arith.addi(kv_step, c(blocks.stages - 1)) tma_slot = arith.remui(tma_step, c(blocks.stages)) tma_step_in_bounds = arith.cmpi( @@ -215,6 +247,8 @@ def kv_loop(kv_step, carry): kv_copy_init(tma_slot, loop_partition.get_base(tma_step)) scf.yield_([]) + perform_schedule_barrier() + with ctx.named_region("QK wait"): nvvm.wgmma_wait_group_sync_aligned(0) qk = qk_acc.value @@ -227,22 +261,18 @@ def kv_loop(kv_step, carry): acc *= alpha.broadcast_minor(head_dim) l_i *= alpha l_i += p.reduce(arith.addf, axis=1) + p = p.astype(f16) + + perform_schedule_barrier() - # For small head_dim we're not really constrained by the register budget. - # Even though unfusing the adds should have negative performance impact, - # it ends up emitting slightly better code for unclear reasons. - duplicate_acc = head_dim == 64 # TODO(apaszke): Investigate why. with ctx.named_region("PV issue"): - if duplicate_acc: - acc_update = WGMMAAccumulator.zero(*acc.shape) - else: - acc_update = WGMMAAccumulator.from_registers(acc) v = memref_slice(v_smem, slot) - acc_update = wgmma(acc_update, p.astype(f16), v) + acc_update = WGMMAAccumulator.from_registers(acc) + acc_update = wgmma(acc_update, p, v) nvvm.wgmma_commit_group_sync_aligned() # We hide the barrier overhead by overlapping it with the PV matmul. - with ctx.named_region("KV TMA wait"): + with only_wg(0), ctx.named_region("KV TMA wait"): wait_step = arith.addi(kv_step, c(1)) wait_slot = arith.remui(wait_step, c(blocks.stages)) wait_step_in_bounds = arith.cmpi( @@ -254,12 +284,13 @@ def kv_loop(kv_step, carry): with ctx.named_region("PV wait"): nvvm.wgmma_wait_group_sync_aligned(0) - if duplicate_acc: - acc += acc_update.value # We can now safely extract the update. - else: - acc = acc_update.value + acc = acc_update.value return acc, m_i, l_i + + with only_wg(0): + perform_schedule_barrier() + acc, m_i, l_i = kv_loop.results del m_i # TODO(apaszke): Invert and multiply to avoid expensive divisions. @@ -276,7 +307,7 @@ def kv_loop(kv_step, carry): ctx.async_copy( src_ref=qo_smem, dst_ref=out_gmem, - gmem_slice=(head_idx, ds(q_seq_base, blocks.q)), + gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), gmem_transform=mosaic_gpu.TileTransform(tiling), swizzle=128, ) @@ -292,28 +323,28 @@ def benchmark_and_verify( q_seq_len, kv_seq_len, num_q_heads, + num_kv_heads, head_dim, **kwargs, ) -> float: with mlir.make_ir_context(), ir.Location.unknown(): - kq, kk, kv = random.split(random.PRNGKey(1234), 3) + kq, kk, kv = random.split(random.key(1234), 3) q = random.normal( kq, (batch_size, num_q_heads, q_seq_len, head_dim), dtype=jnp.float16 ) k = random.normal( - kk, (batch_size, 1, kv_seq_len, head_dim), dtype=jnp.float16 + kk, (batch_size, num_kv_heads, kv_seq_len, head_dim), dtype=jnp.float16 ) v = random.normal( - kv, (batch_size, 1, kv_seq_len, head_dim), dtype=jnp.float16 + kv, (batch_size, num_kv_heads, kv_seq_len, head_dim), dtype=jnp.float16 ) - f = build_kernel( batch_size=batch_size, q_heads=num_q_heads, + kv_heads=num_kv_heads, q_seq_len=q_seq_len, kv_seq_len=kv_seq_len, head_dim=head_dim, - blocks=BlockSizes(q=64, kv=64, stages=2), **kwargs, ) out, runtime = profiler.measure(f, q[0], k[0], v[0]) @@ -322,12 +353,15 @@ def benchmark_and_verify( q = q.astype(jnp.float32) k = k.astype(jnp.float32) v = v.astype(jnp.float32) - logits = jnp.einsum("bhqc,bxkc->bhqk", q, k) + q_reshaped = q.reshape( + batch_size, num_kv_heads, num_q_heads // num_kv_heads, q_seq_len, + head_dim) + logits = jnp.einsum("bxhqc,bxkc->bxhqk", q_reshaped, k) m = logits.max(axis=-1) unnormalized = jnp.exp(logits - m[..., None]) l = unnormalized.sum(axis=-1) weights = unnormalized / l[..., None] - expected = jnp.einsum("bhqk,bxkc->bhqc", weights, v) + expected = jnp.einsum("bxhqk,bxkc->bxhqc", weights, v).reshape(*q.shape) np.testing.assert_allclose(out, expected, atol=2e-3, rtol=2e-3) return runtime @@ -335,29 +369,42 @@ def benchmark_and_verify( if __name__ == "__main__": batch_size = 1 num_q_heads = 4 + num_kv_heads = 1 prof_spec = None - # prof_spec = profiler.ProfilerSpec((4 * 32) * 4096) - param_it = itertools.product( - (4096,), (4096,), (64, 128, 256), ExpImplementation - ) - for kv_seq_len, q_seq_len, head_dim, exp_impl in param_it: - runtime_ms = benchmark_and_verify( - batch_size, - q_seq_len, - kv_seq_len, - num_q_heads, - head_dim, - prof_spec=prof_spec, - exp_impl=exp_impl, - ) - runtime_us = runtime_ms * 1e3 - matmul_flops = ( - 4 * q_seq_len * kv_seq_len * head_dim * num_q_heads * batch_size - ) - peak_flops = 1e15 # f16 TensorCore peak = 1000TFLOPS - optimal_time = matmul_flops / peak_flops * 1e6 # us - achieved_tc_util = optimal_time / runtime_us * 100 + problem_it = itertools.product((4096,), (4096,), (64, 128, 256)) + for kv_seq_len, q_seq_len, head_dim in problem_it: print( - f"{kv_seq_len=:<6} {q_seq_len=:<6} {num_q_heads=:<4} {head_dim=:<6} exp_impl={str(exp_impl):<6}:" - f" {runtime_us:<7.1f}us = {achieved_tc_util:4.1f}% TC utilization" + "====" + f" {kv_seq_len=:<6} {q_seq_len=:<6} {num_q_heads=:<4} {head_dim=:<6} ====" + ) + param_it = itertools.product( + ExpImplementation, (64,), (64, 128, 256), ) + for exp_impl, block_q, block_kv in param_it: + try: + runtime_ms = benchmark_and_verify( + batch_size, + q_seq_len, + kv_seq_len, + num_q_heads, + num_kv_heads, + head_dim, + prof_spec=prof_spec, + exp_impl=exp_impl, + blocks=BlockSizes(q=block_q, kv=block_kv, stages=2), + ) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: + continue + raise + runtime_us = runtime_ms * 1e3 + matmul_flops = ( + 4 * q_seq_len * kv_seq_len * head_dim * num_q_heads * batch_size + ) + peak_flops = 1e15 # f16 TensorCore peak = 1000TFLOPS + optimal_time = matmul_flops / peak_flops * 1e6 # us + achieved_tc_util = optimal_time / runtime_us * 100 + print( + f"exp_impl={exp_impl.name:<6} block_q={block_q:<4}block_kv={block_kv:<4}: {runtime_us:<7.1f}us" + f" = {achieved_tc_util:4.1f}% TC utilization" + ) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index 84c1995f5f34..53818f2ef39d 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -16,6 +16,7 @@ import dataclasses import enum +import functools import jax from jax import random @@ -216,9 +217,9 @@ def wgmma( nvvm.wgmma_wait_group_sync_aligned(0) arr.store_tiled(smem_scratch["cvt"], swizzle=128) commit_shared() - nvvm.wgmma_fence_aligned() - return wgmma(acc, a_slice, smem_scratch["cvt"], b_order=b_order) - + acc = wgmma(acc, a_slice, smem_scratch["cvt"], b_order=b_order) + nvvm.wgmma_commit_group_sync_aligned() + return acc def mlir_context(f): @@ -239,7 +240,8 @@ def build_kernel( wgmma_impl=WGMMADefaultImpl, profiler_spec: profiler.ProfilerSpec | None = None, ): - out_128b_elems = 128 // bytewidth(ir.F32Type.get()) + f32 = ir.F32Type.get() + out_128b_elems = 128 // bytewidth(f32) out_tiling = (64, out_128b_elems) out_tile = jax.ShapeDtypeStruct(tile_shape((tile_m, tile_n), out_tiling), jnp.float32) if tile_m % 64 != 0: @@ -251,9 +253,10 @@ def build_kernel( if stages < 2: raise ValueError(f"Need at least 2 stages, but got {stages=}") - smem = ir.Attribute.parse("#gpu.address_space") - lhs_128b_elems = 128 // bytewidth(mlir.dtype_to_ir_type(lhs_dtype)) - rhs_128b_elems = 128 // bytewidth(mlir.dtype_to_ir_type(rhs_dtype)) + lhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(lhs_dtype)) + rhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(rhs_dtype)) + lhs_128b_elems = 128 // lhs_elem_bytes + rhs_128b_elems = 128 // rhs_elem_bytes tile_k = max(lhs_128b_elems, rhs_128b_elems) if tile_n % rhs_128b_elems != 0: @@ -262,18 +265,13 @@ def build_kernel( f" {((lhs_128b_elems, lhs_dtype), (rhs_128b_elems, rhs_dtype))}" ) - if k % (stages * tile_k) != 0: - raise ValueError( - f"k must be divisible by {stages=} * {tile_k=} (={stages * tile_k})," - f" but got {k=}" - ) + if k % tile_k != 0: + raise ValueError(f"k must be divisible by {tile_k=}, but got {k=}") block_tiling = Tiling(m=tile_m, n=tile_n, k=tile_k) tma_tiling = Tiling(m=64, n=rhs_128b_elems, k=lhs_128b_elems) k_steps = k // block_tiling.k - - f32 = ir.F32Type.get() - index = ir.IndexType.get() + stages = min(stages, k_steps) def safe_div(x, y): assert x % y == 0, (x, y) @@ -282,8 +280,8 @@ def safe_div(x, y): grid = (safe_div(m, block_tiling.m), safe_div(n, block_tiling.n), 1) block = (128, 1, 1) - def c(value, ty=index): - return arith.ConstantOp(ty, ir.IntegerAttr.get(ty, value)) + c = arith.ConstantOp.create_index + divmod = lambda x, y: (arith.divui(x, c(y)), arith.remui(x, c(y))) compute_scratch_shapes = { "lhs": jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.mk, tma_tiling.mk)), lhs_dtype), @@ -311,13 +309,13 @@ def _main(ctx, a_device, b_device, c_device, def fetch(slot, ki): barrier = barrier_group[slot] k_start = arith.muli(c(block_tiling.k), ki) - lhs_tma_tile_bytes = np.prod(block_tiling.mk) * bytewidth(mlir.dtype_to_ir_type(lhs_dtype)) - rhs_tma_tile_bytes = np.prod(block_tiling.kn) * bytewidth(mlir.dtype_to_ir_type(rhs_dtype)) + lhs_tma_tile_bytes = int(np.prod(block_tiling.mk) * lhs_elem_bytes) + rhs_tma_tile_bytes = int(np.prod(block_tiling.kn) * rhs_elem_bytes) txcount = c(lhs_tma_tile_bytes + rhs_tma_tile_bytes) common_copy_args = dict( swizzle=128, barrier=barrier, arrive=False, uniform=False, ) - with once(): + with single_thread(): nvgpu.mbarrier_arrive_expect_tx(barrier_group.value, txcount, slot) ctx.async_copy( src_ref=a_device, @@ -389,26 +387,35 @@ def stage_loop_body(ki, accs): # TODO(apaszke): Make this into a proper copy function. warps_per_warpgroup = 4 lanes_per_warp = 32 - n_out_tiling = out_tiling[-1] - tidx = gpu.thread_id(gpu.Dimension.x) - warp_id = arith.divui(tidx, c(lanes_per_warp)) - lane_id = arith.remui(tidx, c(lanes_per_warp)) + m_out_tiling, n_out_tiling = out_tiling[-2:] + warp_id, lane_id = divmod(gpu.thread_id(gpu.Dimension.x), lanes_per_warp) # We store 4 f32 numbers for a block of 16B. vector_len = 4 - num_vectors = safe_div(tile_n, vector_len) - for_op = scf.ForOp(warp_id, c(tile_m), c(warps_per_warpgroup)) - with ir.InsertionPoint(for_op.body): - nested_for_op = scf.ForOp(lane_id, c(num_vectors), c(lanes_per_warp)) - with ir.InsertionPoint(nested_for_op.body): - vector_idx = nested_for_op.induction_variable + num_vectors_per_row = safe_div(tile_n, vector_len) + # Process several rows at once if it is necessary to fully exploit each + # warp. + if tile_n < lanes_per_warp * vector_len: + num_rows_per_warp = min( + safe_div(lanes_per_warp * vector_len, tile_n), + safe_div(tile_m, warps_per_warpgroup)) + else: + num_rows_per_warp = 1 + lanes_per_row = safe_div(lanes_per_warp, num_rows_per_warp) + lane_row_offset, lane_col_offset = divmod(lane_id, lanes_per_row) + warp_for_op = scf.ForOp(arith.muli(warp_id, c(num_rows_per_warp)), + c(tile_m), + c(warps_per_warpgroup * num_rows_per_warp)) + with ir.InsertionPoint(warp_for_op.body): + start_row = warp_for_op.induction_variable + m_row_idx = arith.addi(start_row, lane_row_offset) + vector_for_op = scf.ForOp(lane_col_offset, c(num_vectors_per_row), + c(lanes_per_row)) + with ir.InsertionPoint(vector_for_op.body): + vector_idx = vector_for_op.induction_variable n_store = arith.muli(vector_idx, c(vector_len)) - col_group = arith.divui(n_store, c(n_out_tiling)) - n_load = arith.remui(n_store, c(n_out_tiling)) - - m_smem = for_op.induction_variable - m_within_tile = arith.remui(m_smem, c(64)) - m_tile = arith.divui(m_smem, c(64)) - swizzle_source = arith.shli(arith.remui(m_smem, c(8)), c(2)) + col_group, n_load = divmod(n_store, n_out_tiling) + m_tile, m_within_tile = divmod(m_row_idx, m_out_tiling) + swizzle_source = arith.shli(arith.remui(m_row_idx, c(8)), c(2)) n_acc = arith.xori(n_load, swizzle_source) acc_part = vector.load( ir.VectorType.get((vector_len,), f32), @@ -418,7 +425,7 @@ def stage_loop_body(ki, accs): vector.store( acc_part, c_device, - [arith.addi(m_start, m_smem), arith.addi(n_start, n_store)], + [arith.addi(m_start, m_row_idx), arith.addi(n_start, n_store)], ) scf.yield_([]) scf.yield_([]) @@ -480,7 +487,7 @@ def verify( case F32Precision.TF32_X3: impl = WGMMATF32x3Impl - prof_spec = profiler.ProfilerSpec(132 * 4096) if profile else None + prof_spec = profiler.ProfilerSpec(4096) if profile else None f = build_kernel( m, n, k, jnp.dtype(lhs_dtype), jnp.dtype(rhs_dtype), @@ -503,16 +510,21 @@ def verify( jax.lax.reduce_precision(v, exponent_bits, mantissa_bits) for v in (x, y) ) - ref = jax.lax.dot_general( - x, y, dimension_numbers, + + ref_f = functools.partial( + jax.lax.dot_general, + dimension_numbers=dimension_numbers, preferred_element_type=jnp.float32, ) + + ref, ref_runtime = profiler.measure(ref_f, x, y) np.testing.assert_allclose(z, ref, atol=1e-3, rtol=1e-3) - return runtime + return runtime, ref_runtime if __name__ == "__main__": m, k, n = 33 * 128, 2048, 4 * 128 - runtime = verify(m=m, k=k, n=n) + runtime, ref_runtime = verify(m=m, k=k, n=n) tflops = float(2 * k * m * n) / (runtime / 1e3) / 1e12 - print(f"{runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") + print(f"Kernel: {runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") + print(f"Reference: {runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 6f3f4abc8b4c..ff7d1e0591b0 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -27,8 +27,8 @@ from jaxlib.mlir.dialects import vector import numpy as np -from . import utils from . import dsl as mgpu +from . import utils # mypy: ignore-errors @@ -46,7 +46,7 @@ class WGSplatFragLayout: wants. This means we can trivially broadcast, reshape and do elementwise operations with all other layouts. - Example: + Examples: To load a value in ``` @@ -424,7 +424,7 @@ def reduce_sum(self, scratch) -> ir.Value: memref.store(warp_result, scratch, [warp_id]) utils.commit_shared() zero_index = c(0, index) - with mgpu.once(): + with mgpu.single_thread(): scratch_vec = vector.load( ir.VectorType.get((4,), self.mlir_dtype), scratch, @@ -596,18 +596,14 @@ def load_tiled(cls, ref, swizzle: int | None): @staticmethod def transfer_tiled(shape, dtype, swizzle: int | None): bw = mgpu.bytewidth(dtype) - cols_per_tile = 128 // bw m, n = shape if n % 32 != 0: raise NotImplementedError cols_per_tile = 128 // bw if swizzle != 128: raise NotImplementedError("Only 128B swizzle supported") - index = ir.IndexType.get() - - def c(x): - return arith.ConstantOp(index, ir.IntegerAttr.get(index, x)) + c = arith.ConstantOp.create_index tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE)) lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31} warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3} @@ -619,7 +615,7 @@ def c(x): else: # We rely on canonicalization to clean up the selects. i1 = ir.IntegerType.get_signless(1) - is_even_row = arith.constant(i1, ir.IntegerAttr.get(i1, 1)) + is_even_row = arith.constant(i1, ir.BoolAttr.get(True)) row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16))) col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6} # The swizzle pattern is constant for a given thread. diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 597ea78a6f9a..9cec7361d756 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -17,6 +17,7 @@ import ctypes import functools import json +import math import jax from jax._src.interpreters import mlir @@ -69,16 +70,21 @@ def _record_event(args, event): treedef, record_event_p.bind(*flat_args, event=event) ) -def measure(f, *args): +def measure(f, *args, **kwargs): # TODO(apaszke): Raise if this is called under jit. start_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create() end_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create() try: + @jax.jit - def run(*args): - return _record_event(f(*_record_event(args, start_event)), end_event) - jax.block_until_ready(run(*args)) # Warmup. - results = jax.block_until_ready(run(*args)) + def run(*args, **kwargs): + flat_args, treedef = jax.tree.flatten((args, kwargs)) + flat_args = _record_event(flat_args, start_event) + args, kwargs = jax.tree.unflatten(treedef, flat_args) + return _record_event(f(*args, **kwargs), end_event) + + jax.block_until_ready(run(*args, **kwargs)) # Warmup. + results = jax.block_until_ready(run(*args, **kwargs)) elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed( start_event, end_event ) @@ -92,26 +98,40 @@ class ProfilerSpec: ENTER = 0 EXIT = 1 << 31 - def __init__(self, num_entries: int): - self.num_entries = num_entries + def __init__(self, entries_per_warpgroup: int): + self.entries_per_warpgroup = entries_per_warpgroup self.interned_names = {} - @property - def mlir_buffer_type(self) -> ir.Type: + def _num_warpgroups( + self, grid: tuple[int, ...], block: tuple[int, ...] + ) -> int: + if math.prod(block) % WARPGROUP_SIZE: + raise ValueError("Block size is not a multiple of warpgroup size") + return math.prod(grid) * math.prod(block) // WARPGROUP_SIZE + + def mlir_buffer_type( + self, grid: tuple[int, ...], block: tuple[int, ...] + ) -> ir.Type: return ir.MemRefType.get( - (1 + self.num_entries,), ir.IntegerType.get_signless(32) + (self._num_warpgroups(grid, block) * self.entries_per_warpgroup,), + ir.IntegerType.get_signless(32), ) - @property - def jax_buffer_type(self) -> ir.Type: - return jax.ShapeDtypeStruct((1 + self.num_entries,), jnp.uint32) + def jax_buffer_type( + self, grid: tuple[int, ...], block: tuple[int, ...] + ) -> ir.Type: + return jax.ShapeDtypeStruct( + (self._num_warpgroups(grid, block) * self.entries_per_warpgroup,), + jnp.uint32, + ) - def smem_i32_elements(self, grid: tuple[int, ...]): - return int(self.num_entries // np.prod(grid)) + def smem_i32_elements(self, block: tuple[int, ...]): + num_warpgroups = self._num_warpgroups((), block) + return int(num_warpgroups * self.entries_per_warpgroup) - def smem_bytes(self, grid: tuple[int, ...]): + def smem_bytes(self, block: tuple[int, ...]): bytes_per_entry = 4 - return self.smem_i32_elements(grid) * bytes_per_entry + return self.smem_i32_elements(block) * bytes_per_entry def intern_name(self, name: str) -> int: if name_id := self.interned_names.get(name, None): @@ -121,31 +141,31 @@ def intern_name(self, name: str) -> int: raise RuntimeError("Allocated too many names") return name_id - def dump(self, buffer, f): + def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): buffer = np.asarray(buffer) - num_blocks = buffer[0] - per_block = self.num_entries // num_blocks - block_entries = buffer[1 : 1 + num_blocks * per_block].reshape( - num_blocks, per_block + num_blocks = math.prod(grid) + warpgroups_per_block = self._num_warpgroups((), block) + entries = buffer.reshape( + num_blocks, warpgroups_per_block, self.entries_per_warpgroup ) - start_times = block_entries[:, :2].astype(np.int64) - start_times = (start_times[:, 0] << 32) + start_times[:, 1] + start_times = entries[..., :2].astype(np.int64) + start_times = (start_times[..., 0] << 32) + start_times[..., 1] start_times -= start_times.min() # Normalize - entries_used = block_entries[:, 2] - if np.any(entries_used > per_block - 2): + entries_used = entries[..., 2] + if np.any(entries_used > self.entries_per_warpgroup - 2): raise RuntimeError("Insufficient space to capture a full trace") - block_traces = block_entries[:, 3:] + traces = entries[..., 3:] unintern = {v: k for k, v in self.interned_names.items()} events = [] - for block_idx in range(num_blocks): - valid_entries = entries_used[block_idx] - 3 + for block_idx, wg_idx in np.ndindex(num_blocks, warpgroups_per_block): + valid_entries = entries_used[block_idx, wg_idx] - 3 local_clock_offset = None - assert valid_entries % 2 == 0 - start_time = start_times[block_idx] + assert valid_entries % 2 == 0, valid_entries + start_time = start_times[block_idx, wg_idx] block_events = [] for i in range(0, valid_entries, 2): - tag = block_traces[block_idx, i] - time = block_traces[block_idx, i + 1] + tag = traces[block_idx, wg_idx, i] + time = traces[block_idx, wg_idx, i + 1] if local_clock_offset is None: local_clock_offset = time time -= local_clock_offset @@ -162,8 +182,8 @@ def dump(self, buffer, f): "name": name, "ph": "B" if begin else "E", "ts": float(start_time + time) / 1e3, - "pid": 0, - "tid": block_idx, + "pid": 1 + block_idx, + "tid": 1 + wg_idx, }) else: # If we didn't break events.extend(block_events) @@ -177,12 +197,17 @@ def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Va # self.should_store = gpu.thread_id(gpu.Dimension.x) i32 = ir.IntegerType.get_signless(32) index = ir.IndexType.get() - num_blocks = c(1, index) - for dim in gpu.Dimension: - num_blocks = arith.muli(num_blocks, gpu.grid_dim(dim)) - memref.store(arith.index_cast(i32, num_blocks), gmem_buffer, [c(0, index)]) - self.entries_per_block = arith.divui(c(spec.num_entries, index), num_blocks) - self.smem_buffer = smem_buffer + self.entries_per_wg = spec.entries_per_warpgroup + wg_idx = warpgroup_idx(sync=False) + self.smem_buffer = memref_slice( + smem_buffer, + ds( + arith.index_cast( + index, arith.muli(wg_idx, c(self.entries_per_wg, i32)) + ), + self.entries_per_wg, + ), + ) self.gmem_buffer = gmem_buffer # Hopefully mem2reg will remove the allocation. self.offset = memref.alloca(ir.MemRefType.get((), i32), [], []) @@ -213,51 +238,56 @@ def store(modifier): yield store(ProfilerSpec.EXIT) - def finalize(self, grid): + def finalize(self, grid: tuple[int, ...], block: tuple[int, ...]): index = ir.IndexType.get() i32 = ir.IntegerType.get_signless(32) + gpu.barrier() # Make sure all warpgroups are done. + block_idx = c(0, index) - for dim in reversed(gpu.Dimension): # pytype: disable=wrong-arg-types + for dim in gpu.Dimension: # pytype: disable=wrong-arg-types block_idx = arith.addi( arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim) ) - start_offset = arith.addi( - arith.muli(block_idx, self.entries_per_block), c(1, index) + wg_idx = warpgroup_idx(sync=False) + wg_per_block = math.prod(block) // WARPGROUP_SIZE + global_wg_idx = arith.addi( + arith.muli(block_idx, c(wg_per_block, index)), + arith.index_cast(index, wg_idx), ) - block_gmem_buffer = memref.subview( - self.gmem_buffer, [start_offset], [self.spec.num_entries], [1], + start_offset = arith.muli(global_wg_idx, c(self.entries_per_wg, index)) + wg_gmem_buffer = memref.subview( + self.gmem_buffer, [start_offset], [self.entries_per_wg], [1], result_type=ir.Type.parse( - f"memref<{self.spec.num_entries}xi32, strided<[1], offset: ?>>" + f"memref<{self.entries_per_wg}xi32, strided<[1], offset: ?>>" ), ) - # TODO(apaszke): Either use globaltimer or delete - # memref.store(globaltimer("high"), block_gmem_buffer, [c(0, index)]) - # memref.store(globaltimer("low"), block_gmem_buffer, [c(1, index)]) - memref.store(c(0, i32), block_gmem_buffer, [c(0, index)]) - memref.store(c(0, i32), block_gmem_buffer, [c(1, index)]) - memref.store( - arith.addi(memref.load(self.offset, []), c(3, i32)), - block_gmem_buffer, - [c(2, index)], - ) - + thread_in_wg = arith.remui(thread_idx(), c(128, i32)) if_first = scf.IfOp( - arith.cmpi( - arith.CmpIPredicate.eq, gpu.thread_id(gpu.Dimension.x), c(0, index) - ) + arith.cmpi(arith.CmpIPredicate.eq, thread_in_wg, c(0, i32)) ) with ir.InsertionPoint(if_first.then_block): + # TODO(apaszke): Either use globaltimer or delete + # memref.store(globaltimer("high"), block_gmem_buffer, [c(0, index)]) + # memref.store(globaltimer("low"), block_gmem_buffer, [c(1, index)]) + memref.store(c(0, i32), wg_gmem_buffer, [c(0, index)]) + memref.store(c(0, i32), wg_gmem_buffer, [c(1, index)]) + memref.store( + arith.addi(memref.load(self.offset, []), c(3, i32)), + wg_gmem_buffer, + [c(2, index)], + ) + for_op = scf.ForOp( c(0, index), - c(self.spec.smem_i32_elements(grid) - 3, index), + c(self.entries_per_wg - 3, index), c(1, index), ) with ir.InsertionPoint(for_op.body): x = memref.load(self.smem_buffer, [for_op.induction_variable]) memref.store( x, - block_gmem_buffer, + wg_gmem_buffer, [arith.addi(for_op.induction_variable, c(3, index))], ) scf.yield_([]) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 14ffa7152ad7..c9de1eb29985 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -14,10 +14,12 @@ # ============================================================================== """Utilities for code generator.""" -from collections.abc import Iterator +from collections.abc import Iterator, Sequence import contextlib import dataclasses -from typing import Any, Literal, Sequence +import enum +import functools +from typing import Any, Literal import jax from jaxlib.mlir import ir @@ -116,10 +118,8 @@ def debug_print(fmt, *args, uniform=True): ty_format = "%llu" if ir.IntegerType.isinstance(arg.type): width = ir.IntegerType(arg.type).width - if width == 64: - ty_format = "%llu" - elif width == 1: - ty_format = "%llu" + ty_format = "%llu" + if width < 64: arg = arith.extui(ir.IntegerType.get_signless(64), arg) if ir.F32Type.isinstance(arg.type): ty_format = "%f" @@ -130,7 +130,11 @@ def debug_print(fmt, *args, uniform=True): raise NotImplementedError(arg.type) type_formats.append(ty_format) new_args.append(arg) - ctx = once if uniform else contextlib.nullcontext + ctx = ( + functools.partial(single_thread, per_block=False) + if uniform + else contextlib.nullcontext + ) with ctx(): gpu.printf(fmt.format(*type_formats) + "\n", new_args) @@ -190,55 +194,69 @@ def thread_idx(): return tidx -def warp_idx(sync=True): +def _warp_bcast(val, lane_idx=0): i32 = ir.IntegerType.get_signless(32) - warp_idx = arith.shrui(thread_idx(), c(5, i32)) - if not sync: - return warp_idx mask = c(0xFFFFFFFF, i32) return nvvm.shfl_sync( - warp_idx.type, mask, warp_idx, c(0, i32), c(0x1F, i32), nvvm.ShflKind.idx + val.type, mask, val, c(lane_idx, i32), c(0x1F, i32), nvvm.ShflKind.idx ) + +def warp_idx(sync=True): + i32 = ir.IntegerType.get_signless(32) + warp_idx = arith.shrui(thread_idx(), c(5, i32)) + # Performing a warp broadcast improves performance as compiler understands + # that the value is uniform across the warp. + return _warp_bcast(warp_idx) if sync else warp_idx + + def warpgroup_idx(sync=True): i32 = ir.IntegerType.get_signless(32) wg_idx = arith.shrui(thread_idx(), c(7, i32)) - if not sync: - return wg_idx - mask = c(0xFFFFFFFF, i32) - return nvvm.shfl_sync( - wg_idx.type, mask, wg_idx, c(0, i32), c(0x1F, i32), nvvm.ShflKind.idx - ) + # Performing a warp broadcast improves performance as compiler understands + # that the value is uniform across the warp. + return _warp_bcast(wg_idx) if sync else wg_idx + + +class ThreadSubset(enum.IntEnum): + WARPGROUP = enum.auto() + BLOCK = enum.auto() # True withon `once()` contexts. -_ONCE_REGION_ACTIVE = False +_ONCE_PER: ThreadSubset | None = None @contextlib.contextmanager -def once(): - """Runs the context only from a single thread from the first warp. +def single_thread(per_block=True): + """Runs the context only from a single thread. - The block is assumed to have a size of 1 in both y and z dimensions. + Args: + per_block: If True, only one thread per block will run the context. + Otherwise, only one thread per warp group will run the context. """ - global _ONCE_REGION_ACTIVE - - if _ONCE_REGION_ACTIVE: + global _ONCE_PER + scope = ThreadSubset.BLOCK if per_block else ThreadSubset.WARPGROUP + # If we're already in a single-thread context, we don't have to do anything. + if _ONCE_PER is not None and _ONCE_PER >= scope: yield return warp = warp_idx() + if not per_block: + warp = arith.remui(warp, c(4, warp.type)) first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) should_run = arith.andi(first_warp, elected) if_op = scf.IfOp(should_run) - _ONCE_REGION_ACTIVE = True + prev_scope = _ONCE_PER + _ONCE_PER = scope try: with ir.InsertionPoint(if_op.then_block): yield scf.YieldOp([]) finally: - _ONCE_REGION_ACTIVE = False + _ONCE_PER = prev_scope def clock(): @@ -499,9 +517,10 @@ def __init__(self, num_barriers: int, arrival_count: int = 1): i32 = ir.IntegerType.get_signless(32) self.phases = memref.alloca(ir.MemRefType.get((), i32), [], []) memref.store(c(0, i32), self.phases, []) - with once(): + with single_thread(per_block=True): for i in range(num_barriers): nvgpu.mbarrier_init(self.value, c(arrival_count, index), c(i, index)) + gpu.barrier() def __iter__(self) -> Iterator["Barrier"]: for offset in range(self.num_barriers): @@ -518,13 +537,32 @@ class Barrier: barrier_array: BarrierArray offset: ir.Value - def wait_parity(self, parity): + def wait_parity(self, parity, expect_wait=False): + i1 = ir.IntegerType.get_signless(1) index = ir.IndexType.get() - nvgpu.mbarrier_try_wait_parity( - self.barrier_array.value, parity, c(10000000, index), self.offset, + if expect_wait: + nvgpu.mbarrier_try_wait_parity( + self.barrier_array.value, parity, c(10000000, index), self.offset, + ) + return + barrier_ptr = self.get_ptr() + barrier_ready = llvm.inline_asm( + i1, + [barrier_ptr, parity], + "mbarrier.test_wait.parity.shared.b64 $0, [$1], $2;", + "=b,l,r", + asm_dialect=0, + has_side_effects=True, ) + should_wait = arith.xori(barrier_ready, c(1, i1)) + should_wait = llvm.intr_expect(should_wait, c(0, i1)) + with ir.InsertionPoint(scf.IfOp(should_wait).then_block): + nvgpu.mbarrier_try_wait_parity( + self.barrier_array.value, parity, c(10000000, index), self.offset, + ) + scf.yield_([]) - def wait(self): + def wait(self, expect_wait=False): i32 = ir.IntegerType.get_signless(32) parities = memref.load(self.barrier_array.phases, []) offset_i32 = arith.index_castui(i32, self.offset) @@ -534,12 +572,31 @@ def wait(self): ) new_parities = arith.xori(parities, bitmask) memref.store(new_parities, self.barrier_array.phases, []) - self.wait_parity(parity) + self.wait_parity(parity, expect_wait=expect_wait) def arrive(self): token_ty = ir.Type.parse("!nvgpu.mbarrier.token") nvgpu.mbarrier_arrive(token_ty, self.barrier_array.value, self.offset) + def get_ptr(self): + i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr<3>") + smem = ir.IntegerAttr.get(i64, 3) + num_barriers = self.barrier_array.num_barriers + mbarrier_ref_ty = ir.MemRefType.get((num_barriers,), i64, memory_space=smem) + mbarrier_ref = builtin.unrealized_conversion_cast( + [mbarrier_ref_ty], [self.barrier_array.value], + ) + mbarrier_ref_ptr = memref.extract_aligned_pointer_as_index(mbarrier_ref) + barrier_arr_ptr = llvm.inttoptr( + ptr_ty, arith.index_cast(i64, mbarrier_ref_ptr), + ) + offset_i32 = arith.index_cast(i32, self.offset) + return llvm.getelementptr( + ptr_ty, barrier_arr_ptr, [offset_i32], [-2147483648], i64, + ) + class Partition: source_bounds: tuple[int, ...] diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index 5b5739fb1a44..8fe3d8c79d61 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -15,6 +15,7 @@ import dataclasses import enum +import functools import itertools import jax @@ -50,14 +51,16 @@ def __init__(self, *, _value: mgpu.FragmentedArray, _sync: bool = True): raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator") self.value = _value if _sync: - nvvm.wgmma_fence_aligned() + self.value = wgmma_fence(_value) @classmethod - def zero(cls, m, n): + def zero(cls, m, n, dtype=None): if m % 64 or n % 8: raise ValueError f32 = ir.F32Type.get() - zero = arith.constant(f32, ir.FloatAttr.get(f32, 0.0)) + if dtype is None: + dtype = f32 + zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) return cls( _value=mgpu.FragmentedArray.splat(zero, (m, n), mgpu.WGMMA_LAYOUT) ) @@ -145,6 +148,23 @@ def create_descriptor( return desc.result +def _unpack_i32(vec_ty, r): + i32 = ir.IntegerType.get_signless(32) + return vector.bitcast( + vec_ty, vector.splat(ir.VectorType.get((1,), i32), r) + ) + + +def _supported_wgmma_types(dtype, abtype) -> bool: + input_types_are = lambda ty: ty.isinstance(abtype) + if ir.F32Type.isinstance(dtype): + return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, ir.F16Type)) + elif ir.F16Type.isinstance(dtype): + return input_types_are(ir.F16Type) + else: + return False + + def wgmma_m64k128B( acc: np.ndarray, # of register Values a, @@ -156,7 +176,11 @@ def wgmma_m64k128B( n: int, element_type: ir.Type, ): - f32 = ir.F32Type.get() + out_ty = ir.VectorType(acc.flat[0].type).element_type + if not _supported_wgmma_types(out_ty, element_type): + raise ValueError(f"Usupported wgmma types {(out_ty, element_type)=}") + + f16 = ir.F16Type.get() i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) index = ir.IndexType.get() @@ -181,7 +205,26 @@ def wgmma_m64k128B( if a_transpose is None: raise ValueError - num_acc_regs = n // 2 + if ir.F32Type.isinstance(out_ty): + num_acc_regs = n // 2 + out_ty_field = out_ty + acc_regs = [ # pylint: disable=g-complex-comprehension + vector.extractelement(reg, position=c(pos, index)) + for reg in acc.flat + for pos in range(2) + ] + to_acc_vec_regs = functools.partial(_as_fragmented_reg_ndarray, dtype=out_ty, shape=acc.shape) + acc_constraint = "f" + elif ir.F16Type.isinstance(out_ty): + num_acc_regs = n // 4 + out_ty_field = i32 + acc_regs = [_as_i32_reg(reg) for reg in acc.flat] + vec_ty = ir.VectorType(acc.flat[0].type) + to_acc_vec_regs = lambda regs : np.array([_unpack_i32(vec_ty, reg) for reg in regs]).reshape(acc.shape) + acc_constraint = "r" + else: + raise ValueError(f"WGMMA instruciton only supports f32 and f16 out (got {out_ty})") + num_imm_regs = 4 if supports_transpose else 2 if a_in_regs: @@ -192,7 +235,7 @@ def wgmma_m64k128B( # Reference for i/o aliasing: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html # Seems like it's not actually documented in LLVM IR docs. reg_constraints_list = ( - ["=f"] * num_acc_regs # accumulator registers + [f"={acc_constraint}"] * num_acc_regs # accumulator registers + [str(i) for i in range(num_acc_regs)] # we alias outputs as inputs, too. + a_reg_constraints # a descriptor / registers + ["l"] * 1 # b descriptor @@ -218,7 +261,7 @@ def take_regs(n): el_ty = element_type k_instr = 32 // bytewidth(element_type) wgmma_instr = ( - f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.f32.{el_ty}.{el_ty} " + f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty}.{el_ty}.{el_ty} " f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};" ) ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n" @@ -226,11 +269,6 @@ def take_regs(n): def lc(x): return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result - def as_i32_reg(v): - return llvm.extractelement( - vector.bitcast(ir.VectorType.get((1,), i32), v), lc(0) - ) - use_out = scale_a = scale_b = lc(1) imms = [use_out, scale_a, scale_b] if supports_transpose and a_transpose is not None: @@ -239,19 +277,14 @@ def as_i32_reg(v): imms += [lc(int(b_transpose))] if acc.ndim != 4 or acc.shape[0] != 1 or acc.shape[2:] != (2, 1): raise ValueError(acc.shape) - acc_regs = [ # pylint: disable=g-complex-comprehension - vector.extractelement(reg, position=c(pos, index)) - for reg in acc.flat - for pos in range(2) - ] acc_struct_type = ir.Type.parse( - f"!llvm.struct<({','.join('f32' for _ in acc_regs)})>" + f"!llvm.struct<({','.join(str(out_ty_field) for _ in acc_regs)})>" ) for i in range(4): # Slice out the relevant part of A or advance the A descriptor. if a_in_regs: a_slice = a[:, (i * 16) : ((i + 1) * 16)] - a_args = [as_i32_reg(v) for v in a_slice.registers.flat] + a_args = [_as_i32_reg(v) for v in a_slice.registers.flat] else: if i > 0: a = llvm_add( @@ -275,15 +308,9 @@ def as_i32_reg(v): has_side_effects=True, ) acc_regs = [ - llvm.extractvalue(f32, acc_struct, [i]) for i in range(len(acc_regs)) + llvm.extractvalue(out_ty_field, acc_struct, [i]) for i in range(len(acc_regs)) ] - acc_vec_regs = [] - for first, second in zip(acc_regs[::2], acc_regs[1::2]): - vec = llvm.mlir_undef(ir.VectorType.get((2,), f32)) - vec = llvm.insertelement(vec, first, position=lc(0)) - vec = llvm.insertelement(vec, second, position=lc(1)) - acc_vec_regs.append(vec) - return np.asarray(acc_vec_regs, dtype=object).reshape(acc.shape) + return to_acc_vec_regs(acc_regs) class WGMMALayout(enum.Enum): @@ -373,7 +400,7 @@ def wgmma( wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None if a_in_regs: - nvvm.wgmma_fence_aligned() # Make sure the registers are ready. + a = wgmma_fence(a) # Make sure the registers are ready. a_m_byte_stride = a_k_byte_stride = a_desc_base = None # Silence pytype. else: a_desc_base = create_descriptor(a, **a_desc_fields) @@ -410,3 +437,82 @@ def wgmma( ), _sync=False, ) + + +def wgmma_fence(array: mgpu.FragmentedArray): + """Fences the array construction from WGMMA instructions. + + This is a little workaround to force LLVM to initialize the PTX registers + before the wgmma.fence.sync.aligned instruction. Otherwise, LLVM treats + in-register computation as pure and can move it after the fence, which is + explicitly disallowed by the PTX programming model. + """ + i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() + dtype = array.mlir_dtype + src_vec_ty = ir.VectorType(array.registers.flat[0].type) + assert src_vec_ty.shape == [2] + + if dtype == ir.F32Type.get(): + regs = [ # pylint: disable=g-complex-comprehension + vector.extractelement(reg, position=c(pos, index)) + for reg in array.registers.flat + for pos in range(2) + ] + reg_dtype = dtype + reg_constraints_list = ["=f"] * len(regs) + ["f"] * len(regs) + ptx_lines = [f"mov.f32 ${i}, ${len(regs)+i}" for i in range(len(regs))] + elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get(): + regs = [_as_i32_reg(reg) for reg in array.registers.flat] + reg_dtype = i32 + reg_constraints_list = ["=r"] * len(regs) + ["r"] * len(regs) + ptx_lines = [f"mov.b32 ${i}, ${len(regs)+i}" for i in range(len(regs))] + else: + raise NotImplementedError(dtype) + reg_constraints = ",".join(reg_constraints_list) + # Copy over the registers. ptxas should be able to remove the moves. + ptx_lines.append("wgmma.fence.sync.aligned") + ptx = ";\n".join(ptx_lines) + ";\n" + dtype_str = str(reg_dtype) + struct_ty = ir.Type.parse( + f"!llvm.struct<({','.join(dtype_str for _ in regs)})>" + ) + acc_struct = llvm.inline_asm( + struct_ty, regs, ptx, reg_constraints, + asm_dialect=0, has_side_effects=True, + ) + regs = [ + llvm.extractvalue(reg_dtype, acc_struct, [i]) for i in range(len(regs)) + ] + if dtype == ir.F32Type.get(): + registers = _as_fragmented_reg_ndarray( + regs, array.mlir_dtype, array.registers.shape + ) + elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get(): + regs = [_unpack_i32(src_vec_ty, r) for r in regs] + registers = np.asarray(regs, dtype=object).reshape(array.registers.shape) + else: + raise NotImplementedError(dtype) + return mgpu.FragmentedArray(_registers=registers, _layout=array.layout) + + +def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]): + vec_regs = [] + for first, second in zip(flat_regs[::2], flat_regs[1::2]): + vec = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) + vec = llvm.insertelement(vec, first, position=_lc(0)) + vec = llvm.insertelement(vec, second, position=_lc(1)) + vec_regs.append(vec) + return np.asarray(vec_regs, dtype=object).reshape(shape) + + +def _as_i32_reg(v): + i32 = ir.IntegerType.get_signless(32) + return llvm.extractelement( + vector.bitcast(ir.VectorType.get((1,), i32), v), _lc(0) + ) + + +def _lc(x): + i32 = ir.IntegerType.get_signless(32) + return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 11fd442ecd51..b3665b5845b7 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -57,6 +57,14 @@ def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any: A pytree matching in_tree where the leaves now all contain the data from the first host. """ + if jax.process_count() == 1: + # Note: This may return results that are different from the multi-host case + # below since it does not force-convert inputs to numpy arrays. We don't do + # such conversion here (and the API contract does not promise such a + # requirement) because doing so could be expensive for single-controller + # runtimes with lots of addressable devices. + return in_tree + if is_source is None: is_source = jax.process_index() == 0 @@ -424,21 +432,27 @@ def global_array_to_host_local_array( You can use this function to convert the globally shaped `jax.Array` output from pjit to host local values again so that the transition to jax.Array can - be a mechanical change. Example usage + be a mechanical change. + + Example usage: - >> from jax.experimental import multihost_utils # doctest: +SKIP - >> - >> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP - >> - >> with mesh: # doctest: +SKIP - >> global_out = pjitted_fun(global_inputs) # doctest: +SKIP - >> - >> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP + >>> from jax.experimental import multihost_utils # doctest: +SKIP + >>> + >>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP + >>> + >>> with mesh: # doctest: +SKIP + ... global_out = pjitted_fun(global_inputs) # doctest: +SKIP + >>> + >>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP Args: global_inputs: A Pytree of global jax.Array's. - global_mesh: A jax.sharding.Mesh object. - pspecs: A Pytree of jax.sharding.PartitionSpec's. + global_mesh: A :class:`jax.sharding.Mesh` object. The mesh must be contiguous + meaning all local devices of the host must form a subcube. + pspecs: A Pytree of :class:`jax.sharding.PartitionSpec` objects. + + Returns: + A Pytree of host local arrays. """ flat_inps, out_tree = tree_flatten(global_inputs) out_pspecs = _flatten_pspecs('output pspecs', out_tree, diff --git a/jax/experimental/pallas/gpu.py b/jax/experimental/pallas/gpu.py index 8047aeed245b..adade4e8a72c 100644 --- a/jax/experimental/pallas/gpu.py +++ b/jax/experimental/pallas/gpu.py @@ -14,5 +14,5 @@ """Triton-specific Pallas APIs.""" -from jax._src.pallas.triton import approx_tanh -from jax._src.pallas.triton import elementwise_inline_asm +from jax._src.pallas.triton.primitives import approx_tanh +from jax._src.pallas.triton.primitives import elementwise_inline_asm diff --git a/jax/experimental/pallas/ops/__init__.py b/jax/experimental/pallas/ops/__init__.py index 017372046b07..132d3839e212 100644 --- a/jax/experimental/pallas/ops/__init__.py +++ b/jax/experimental/pallas/ops/__init__.py @@ -12,13 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from jax.experimental.pallas.ops import attention -from jax.experimental.pallas.ops import layer_norm -from jax.experimental.pallas.ops import rms_norm -from jax.experimental.pallas.ops import softmax - - # All files within ops should be treated as user code. import os import jax._src.source_info_util diff --git a/jaxlib/cpu/_ducc_fft.pyi b/jax/experimental/pallas/ops/gpu/__init__.py similarity index 74% rename from jaxlib/cpu/_ducc_fft.pyi rename to jax/experimental/pallas/ops/gpu/__init__.py index 7d5c3071adea..862a661e24b9 100644 --- a/jaxlib/cpu/_ducc_fft.pyi +++ b/jax/experimental/pallas/ops/gpu/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The JAX Authors. +# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -def dynamic_ducc_fft_descriptor(ndims: int, is_double: bool, fft_type: int, axes: list[int], forward: bool) -> bytes: ... -def registrations() -> dict: ... diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/gpu/attention.py similarity index 99% rename from jax/experimental/pallas/ops/attention.py rename to jax/experimental/pallas/ops/gpu/attention.py index a96d3e1cc8dd..6df8f94af2eb 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -119,7 +119,7 @@ def body(start_k, carry): # Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q) upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k) else: - upper_bound = pl.cdiv(seq_len, block_k) # type: ignore + upper_bound = pl.cdiv(seq_len, block_k) o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i)) if residual_refs: diff --git a/jax/experimental/pallas/ops/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py similarity index 100% rename from jax/experimental/pallas/ops/layer_norm.py rename to jax/experimental/pallas/ops/gpu/layer_norm.py diff --git a/jax/experimental/pallas/ops/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py similarity index 100% rename from jax/experimental/pallas/ops/rms_norm.py rename to jax/experimental/pallas/ops/gpu/rms_norm.py diff --git a/jax/experimental/pallas/ops/softmax.py b/jax/experimental/pallas/ops/gpu/softmax.py similarity index 100% rename from jax/experimental/pallas/ops/softmax.py rename to jax/experimental/pallas/ops/gpu/softmax.py diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 0905176e4eea..f3b09c96486b 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -1099,7 +1099,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, grid=grid, - in_specs=in_specs, # type: ignore + in_specs=in_specs, out_specs=out_specs, scratch_shapes=scratch_shapes, ), @@ -1444,8 +1444,8 @@ def kv_segment_ids_index_map( grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, grid=grid, - in_specs=in_specs, # type: ignore - out_specs=out_specs, # type: ignore + in_specs=in_specs, + out_specs=out_specs, scratch_shapes=scratch_shapes, ), out_shape=out_shapes, diff --git a/jax/experimental/pallas/ops/tpu/megablox/gmm.py b/jax/experimental/pallas/ops/tpu/megablox/gmm.py index e9072bbe1146..ba8ca6c1b617 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/gmm.py +++ b/jax/experimental/pallas/ops/tpu/megablox/gmm.py @@ -14,8 +14,9 @@ """Grouped matrix multiplication kernels for TPU written in Pallas.""" +from collections.abc import Callable import functools -from typing import Any, Callable, Optional, Union +from typing import Any, Optional import jax from jax import lax @@ -315,9 +316,9 @@ def gmm( rhs: jnp.ndarray, group_sizes: jnp.ndarray, preferred_element_type: jnp.dtype = jnp.float32, - tiling: Optional[Union[tuple[int, int, int], LutFn]] = (128, 128, 128), - group_offset: Optional[jnp.ndarray] = None, - existing_out: Optional[jnp.ndarray] = None, + tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128), + group_offset: jnp.ndarray | None = None, + existing_out: jnp.ndarray | None = None, transpose_rhs: bool = False, interpret: bool = False, ) -> jnp.ndarray: @@ -577,10 +578,10 @@ def tgmm( rhs: jnp.ndarray, group_sizes: jnp.ndarray, preferred_element_type: jnp.dtype = jnp.float32, - tiling: Optional[Union[tuple[int, int, int], LutFn]] = (128, 128, 128), - group_offset: Optional[jnp.ndarray] = None, - num_actual_groups: Optional[int] = None, - existing_out: Optional[jnp.ndarray] = None, + tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128), + group_offset: jnp.ndarray | None = None, + num_actual_groups: int | None = None, + existing_out: jnp.ndarray | None = None, interpret: bool = False, ) -> jnp.ndarray: """Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :]. diff --git a/jax/experimental/pallas/ops/tpu/megablox/ops.py b/jax/experimental/pallas/ops/tpu/megablox/ops.py index 874951db0452..015c6b3ade67 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/ops.py +++ b/jax/experimental/pallas/ops/tpu/megablox/ops.py @@ -14,8 +14,6 @@ """Grouped matrix multiplication operations with custom VJPs.""" -from typing import Optional - import jax from jax.experimental.pallas.ops.tpu.megablox import gmm as backend import jax.numpy as jnp @@ -33,8 +31,8 @@ def _gmm_fwd( group_sizes: jnp.ndarray, preferred_element_type: jnp.dtype = jnp.float32, tiling: tuple[int, int, int] = (128, 128, 128), - group_offset: Optional[jnp.ndarray] = None, - existing_out: Optional[jnp.ndarray] = None, + group_offset: jnp.ndarray | None = None, + existing_out: jnp.ndarray | None = None, transpose_rhs: bool = False, interpret: bool = False, ) -> tuple[ @@ -43,7 +41,7 @@ def _gmm_fwd( jnp.ndarray, jnp.ndarray, jnp.ndarray, - Optional[jnp.ndarray], + jnp.ndarray | None, int, ], ]: @@ -71,7 +69,7 @@ def _gmm_bwd( jnp.ndarray, jnp.ndarray, jnp.ndarray, - Optional[jnp.ndarray], + jnp.ndarray | None, int, ], grad: jnp.ndarray, diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index f4ea3f9a0a5e..7d47cc3d0efa 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -15,7 +15,6 @@ """PagedAttention TPU kernel.""" import functools -from typing import Optional, Union import jax from jax import lax @@ -364,14 +363,14 @@ def body(i, _): ) def paged_attention( q: jax.Array, - k_pages: Union[jax.Array, quantization_utils.QuantizedTensor], - v_pages: Union[jax.Array, quantization_utils.QuantizedTensor], + k_pages: jax.Array | quantization_utils.QuantizedTensor, + v_pages: jax.Array | quantization_utils.QuantizedTensor, lengths: jax.Array, page_indices: jax.Array, *, mask_value: float = DEFAULT_MASK_VALUE, pages_per_compute_block: int, - megacore_mode: Optional[str] = None, + megacore_mode: str | None = None, inline_seq_dim: bool = True, ) -> jax.Array: """Paged grouped query attention. diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index c2e5fe8cd714..a6c0715e6043 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -16,10 +16,11 @@ from __future__ import annotations +from collections.abc import Callable, Mapping import dataclasses import enum import functools -from typing import Any, Callable, Literal, NamedTuple, Union, Optional, overload +from typing import Any, Literal, NamedTuple, Optional, Union, overload import jax from jax import ad_checkpoint @@ -89,10 +90,13 @@ class SegmentIds(NamedTuple): def get_kernel_name( - is_mqa: bool, save_residuals: bool, is_segmented: bool, phase: str + block_metadata: Mapping[str, Any], + is_mqa: bool, + save_residuals: bool, + is_segmented: bool, + phase: str, ) -> str: """Returns a unique name for all SplashAttention kernel variants.""" - assert phase == "dq" or phase == "dkv" or phase == "fwd" # Saving residuals is supported only for the fwd phase. assert not save_residuals or phase == "fwd" @@ -103,7 +107,9 @@ def get_kernel_name( residuals = "_no_residuals" attention_type = "mqa" if is_mqa else "mha" segments = "_segmented" if is_segmented else "" - return f"splash_{attention_type}_{phase}{segments}{residuals}" + return f"splash_{attention_type}_{phase}{segments}{residuals}_" + "_".join( + f"{k}={v}" for k, v in sorted(block_metadata.items()) + ) # Reference attention implementations @@ -1054,28 +1060,17 @@ def logsumexp_index_map(h, i, *_): out_shapes += [None] out_specs += [None] - # Attach useful metadata to the custom-call HLO op. - # Having this information available in an HLO-dump or xprof is valuable for - # debugging and performance investigation. - metadata_dict = dict( - block_sizes=dataclasses.asdict(block_sizes), - is_mqa=is_mqa, - save_residuals=save_residuals, - mask_value=mask_value, - is_segmented=segment_ids is not None, - attn_logits_soft_cap=attn_logits_soft_cap, - residual_checkpoint_name=residual_checkpoint_name, - ) - - mosaic_params = pltpu.encode_kernel_regeneration_metadata(metadata_dict) - - mosaic_params.update( + mosaic_params = dict( dimension_semantics=("parallel", "arbitrary", "arbitrary"), flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, ) kernel_name = get_kernel_name( - is_mqa, save_residuals, segment_ids is not None, "fwd" + dataclasses.asdict(block_sizes), + is_mqa=is_mqa, + save_residuals=save_residuals, + is_segmented=segment_ids is not None, + phase="fwd", ) if fwd_mask_info.data_next is not None: @@ -1526,28 +1521,24 @@ def logsumexp_index_map(h, i, *_): ) num_scalar_prefetch = 3 - # Attach useful metadata to the custom-call HLO op. - # Having this information available in an HLO-dump or xprof is valuable for - # debugging and performance investigation. - metadata_dict = dict( - block_q_dq=bq, - block_kv_dq=bkv, - q_layout=q_layout, - k_layout=k_layout, - v_layout=v_layout, - is_mqa=is_mqa, - mask_value=mask_value, - is_segmented=segment_ids is not None, - attn_logits_soft_cap=attn_logits_soft_cap, - ) - - mosaic_params = pltpu.encode_kernel_regeneration_metadata(metadata_dict) - mosaic_params.update( + mosaic_params = dict( dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, ) - kernel_name = get_kernel_name(is_mqa, False, segment_ids is not None, "dq") + kernel_name = get_kernel_name( + dict( + block_q_dq=bq, + block_kv_dq=bkv, + q_layout=q_layout, + k_layout=k_layout, + v_layout=v_layout, + ), + is_mqa=is_mqa, + save_residuals=False, + is_segmented=segment_ids is not None, + phase="dq", + ) with jax.named_scope(kernel_name): _, dq = pl.pallas_call( kernel, @@ -2072,35 +2063,30 @@ def logsumexp_index_map( ) num_scalar_prefetch = 3 - # Attach useful metadata to the custom-call HLO op. - # Having this information available in an HLO-dump or xprof is valuable for - # debugging and performance investigation. - metadata_dict = dict( - block_q_dkv=bq, - block_kv_dkv=bkv, - block_kv_dkv_compute=bkv_compute, - q_layout=q_layout, - k_layout=k_layout, - v_layout=v_layout, - use_fused_bwd_kernel=use_fused_bwd_kernel, - is_mqa=is_mqa, - mask_value=mask_value, - is_segmented=segment_ids is not None, - attn_logits_soft_cap=attn_logits_soft_cap, - ) - - mosaic_params = pltpu.encode_kernel_regeneration_metadata(metadata_dict) # We set all dimensions to arbitrary because: # 1) for kv_seq_len, the splash attention prefetch schedule assumes no # megacore # 2) for heads, we are reducing over heads # 3) for q_seq_len, we are reducing over it to compute dkv - mosaic_params.update( + mosaic_params = dict( dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, ) - kernel_name = get_kernel_name(is_mqa, False, segment_ids is not None, "dkv") + kernel_name = get_kernel_name( + dict( + block_q_dkv=bq, + block_kv_dkv=bkv, + block_kv_dkv_compute=bkv_compute, + q_layout=q_layout, + k_layout=k_layout, + v_layout=v_layout, + ), + is_mqa=is_mqa, + save_residuals=False, + is_segmented=segment_ids is not None, + phase="dkv", + ) with jax.named_scope(kernel_name): _, _, _, dq_unreduced, dk, dv = pl.pallas_call( kernel, diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py index e65d9b073a18..eab2a695dc02 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py @@ -16,8 +16,9 @@ from __future__ import annotations +from collections.abc import Callable, Sequence import dataclasses -from typing import Any, Callable, Sequence, Tuple +from typing import Any import numpy as np # mypy: ignore-errors @@ -26,7 +27,7 @@ class Mask: """A base class for splash attention masks.""" @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: raise NotImplementedError def __getitem__(self, idx) -> np.ndarray: @@ -38,14 +39,14 @@ def __bool__(self) -> bool: ' instead of bitwise operations on masks.' ) - def __or__(self, other: 'Mask') -> 'Mask': + def __or__(self, other: Mask) -> Mask: if self.shape != other.shape: raise ValueError( f'Invalid shape for other: {other.shape}, expected: {self.shape}' ) return LogicalOr(self, other) - def __and__(self, other: 'Mask') -> 'Mask': + def __and__(self, other: Mask) -> Mask: if self.shape != other.shape: raise ValueError( f'Invalid shape for other: {other.shape}, expected: {self.shape}' @@ -53,7 +54,7 @@ def __and__(self, other: 'Mask') -> 'Mask': return LogicalAnd(self, other) -def make_causal_mask(shape: Tuple[int, int], offset: int = 0) -> np.ndarray: +def make_causal_mask(shape: tuple[int, int], offset: int = 0) -> np.ndarray: """Makes a causal attention mask. Args: @@ -73,8 +74,8 @@ def make_causal_mask(shape: Tuple[int, int], offset: int = 0) -> np.ndarray: def make_local_attention_mask( - shape: Tuple[int, int], - window_size: Tuple[int | None, int | None], + shape: tuple[int, int], + window_size: tuple[int | None, int | None], *, offset: int = 0, ) -> np.ndarray: @@ -92,7 +93,7 @@ def make_local_attention_mask( def make_random_mask( - shape: Tuple[int, int], sparsity: float, seed: int + shape: tuple[int, int], sparsity: float, seed: int ) -> np.ndarray: """Makes a random attention mask.""" np.random.seed(seed) @@ -111,7 +112,7 @@ def __init__(self, left: Mask, right: Mask): self.right = right @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self.left.shape def __getitem__(self, idx) -> np.ndarray: @@ -133,7 +134,7 @@ def __init__(self, left: Mask, right: Mask): self.right = right @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self.left.shape def __getitem__(self, idx) -> np.ndarray: @@ -167,7 +168,7 @@ def __post_init__(self): raise ValueError('Nesting MultiHeadMasks is not supported') @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return (len(self.masks),) + self.masks[0].shape def __getitem__(self, idx) -> np.ndarray: @@ -208,13 +209,13 @@ class _ComputableMask(Mask): mask rather than loading it. """ - _shape: Tuple[int, int] + _shape: tuple[int, int] q_sequence: np.ndarray mask_function: Callable[..., Any] def __init__( self, - shape: Tuple[int, int], + shape: tuple[int, int], mask_function: Callable[..., Any], shard_count: int = 1, ): @@ -231,7 +232,7 @@ def __init__( self.q_sequence = np.arange(q_seq_len, dtype=np.int32) @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self._shape def __getitem__(self, idx) -> np.ndarray: @@ -271,7 +272,7 @@ class CausalMask(_ComputableMask): def __init__( self, - shape: Tuple[int, int], + shape: tuple[int, int], offset: int = 0, shard_count: int = 1, ): @@ -329,15 +330,15 @@ class LocalMask(Mask): # TODO(amagni): Transform LocalMask into a _ComputableMask. - _shape: Tuple[int, int] - window_size: Tuple[int | None, int | None] + _shape: tuple[int, int] + window_size: tuple[int | None, int | None] offset: int _q_sequence: np.ndarray | None = None def __init__( self, - shape: Tuple[int, int], - window_size: Tuple[int | None, int | None], + shape: tuple[int, int], + window_size: tuple[int | None, int | None], offset: int, shard_count: int = 1, ): @@ -352,7 +353,7 @@ def __init__( ) @property - def shape(self) -> Tuple[int, int]: + def shape(self) -> tuple[int, int]: return self._shape def __getitem__(self, idx) -> np.ndarray: @@ -429,7 +430,7 @@ def __post_init__(self): raise ValueError('Mask must be a boolean array') @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self.array.shape def __getitem__(self, idx) -> np.ndarray: @@ -467,7 +468,7 @@ def __post_init__(self): raise ValueError(f'Unsupported shape type: {type(self.shape)}') @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: return self._shape def __getitem__(self, idx) -> np.ndarray: diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py index af046688067f..3c672b8dbe88 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py @@ -16,8 +16,9 @@ from __future__ import annotations import collections +from collections.abc import Callable import functools -from typing import Callable, Dict, List, NamedTuple, Set, Tuple +from typing import Dict, List, NamedTuple, Set, Tuple from jax import util as jax_util from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib import numpy as np @@ -161,11 +162,11 @@ def __eq__(self, other: object) -> bool: def _get_mask_info_for_shard( - output_shape: Tuple[int, int, int], + output_shape: tuple[int, int, int], has_mask_next: bool, mask: mask_lib.MultiHeadMask, - block_shape: Tuple[int, int], - coords_to_partial_mask_block_index: Dict[Tuple[int, int, int], int], + block_shape: tuple[int, int], + coords_to_partial_mask_block_index: dict[tuple[int, int, int], int], masks_per_head_shard: int, head_start: int, num_heads: int, @@ -173,7 +174,7 @@ def _get_mask_info_for_shard( q_seq_shard_size: int, blocked_q_seq_start: int, is_dkv: bool, -) -> Tuple[np.ndarray, np.ndarray | None]: +) -> tuple[np.ndarray, np.ndarray | None]: """Process a slice of the mask to compute data_next and mask_next. Args: @@ -310,7 +311,7 @@ def _get_mask_info_for_shard( @functools.lru_cache(maxsize=12) def _process_mask( mask: mask_lib.MultiHeadMask, # [num_heads, q_seq_len, kv_seq_len] - block_shape: Tuple[int, int], + block_shape: tuple[int, int], is_dkv: bool, *, downcast_smem_data: bool = True, @@ -394,18 +395,18 @@ def assign_unique_ids(objects): id_map = collections.defaultdict(lambda: len(id_map)) return {obj: id_map[obj] for obj in objects} - unique_masks_dict: Dict[mask_lib.Mask, int] = assign_unique_ids( + unique_masks_dict: dict[mask_lib.Mask, int] = assign_unique_ids( head_mask for head_mask in mask.masks ) # Build a mapping of heads to unique masks and masks to unique masks. - head_to_mask_id: List[int] = [0] * head_count - head_shard_to_mask_ids: List[Set[int]] = [set() for _ in range(head_shards)] - mask_id_to_heads: List[List[int]] = [ + head_to_mask_id: list[int] = [0] * head_count + head_shard_to_mask_ids: list[set[int]] = [set() for _ in range(head_shards)] + mask_id_to_heads: list[list[int]] = [ [] for _ in range(len(unique_masks_dict)) ] - mask_id_to_head_shards: List[Set[int]] = [ + mask_id_to_head_shards: list[set[int]] = [ set() for _ in range(len(unique_masks_dict)) ] @@ -436,10 +437,10 @@ def assign_unique_ids(objects): # TODO(amagni): checking the validity of the masks is slow for large masks. # Disable it for now, reevalute in the future. - partial_mask_block_ids: Dict[_HashableNDArray, int] = collections.defaultdict( + partial_mask_block_ids: dict[_HashableNDArray, int] = collections.defaultdict( lambda: len(partial_mask_block_ids) ) - block_id_to_block_coords: Dict[int, List[Tuple[int, ...]]] = ( + block_id_to_block_coords: dict[int, list[tuple[int, ...]]] = ( collections.defaultdict(list) ) diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index fcfe560958d1..ad5fb92719d0 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -12,38 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Contains Mosaic specific Pallas functions.""" -from jax._src.pallas.mosaic import ANY -from jax._src.pallas.mosaic import CMEM -from jax._src.pallas.mosaic import PrefetchScalarGridSpec -from jax._src.pallas.mosaic import SMEM -from jax._src.pallas.mosaic import SemaphoreType -from jax._src.pallas.mosaic import TPUMemorySpace -from jax._src.pallas.mosaic import VMEM -from jax._src.pallas.mosaic import DeviceIdType -from jax._src.pallas.mosaic import async_copy -from jax._src.pallas.mosaic import async_remote_copy -from jax._src.pallas.mosaic import bitcast -from jax._src.pallas.mosaic import dma_semaphore -from jax._src.pallas.mosaic import delay -from jax._src.pallas.mosaic import device_id -from jax._src.pallas.mosaic import emit_pipeline_with_allocations -from jax._src.pallas.mosaic import emit_pipeline -from jax._src.pallas.mosaic import get_pipeline_schedule -from jax._src.pallas.mosaic import make_pipeline_allocations -from jax._src.pallas.mosaic import BufferedRef -from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata -from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata -from jax._src.pallas.mosaic import get_barrier_semaphore -from jax._src.pallas.mosaic import make_async_copy -from jax._src.pallas.mosaic import make_async_remote_copy -from jax._src.pallas.mosaic import repeat -from jax._src.pallas.mosaic import roll -from jax._src.pallas.mosaic import run_scoped -from jax._src.pallas.mosaic import semaphore -from jax._src.pallas.mosaic import semaphore_read -from jax._src.pallas.mosaic import semaphore_signal -from jax._src.pallas.mosaic import semaphore_wait +"""Mosaic-specific Pallas APIs.""" + +from jax._src.pallas.mosaic import core +from jax._src.pallas.mosaic.core import dma_semaphore +from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec +from jax._src.pallas.mosaic.core import semaphore +from jax._src.pallas.mosaic.core import SemaphoreType +from jax._src.pallas.mosaic.core import TPUMemorySpace +from jax._src.pallas.mosaic.lowering import LoweringException +from jax._src.pallas.mosaic.pipeline import BufferedRef +from jax._src.pallas.mosaic.pipeline import emit_pipeline +from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations +from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule +from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations +from jax._src.pallas.mosaic.pipeline import ARBITRARY +from jax._src.pallas.mosaic.pipeline import PARALLEL +from jax._src.pallas.mosaic.primitives import async_copy +from jax._src.pallas.mosaic.primitives import async_remote_copy +from jax._src.pallas.mosaic.primitives import bitcast +from jax._src.pallas.mosaic.primitives import delay +from jax._src.pallas.mosaic.primitives import device_id +from jax._src.pallas.mosaic.primitives import DeviceIdType +from jax._src.pallas.mosaic.primitives import get_barrier_semaphore +from jax._src.pallas.mosaic.primitives import make_async_copy +from jax._src.pallas.mosaic.primitives import make_async_remote_copy +from jax._src.pallas.mosaic.primitives import repeat +from jax._src.pallas.mosaic.primitives import roll +from jax._src.pallas.mosaic.primitives import run_scoped +from jax._src.pallas.mosaic.primitives import semaphore_read +from jax._src.pallas.mosaic.primitives import semaphore_signal +from jax._src.pallas.mosaic.primitives import semaphore_wait +from jax._src.pallas.mosaic.primitives import prng_seed +from jax._src.pallas.mosaic.primitives import prng_random_bits from jax._src.tpu_custom_call import CostEstimate -from jax._src.pallas.mosaic import prng_seed -from jax._src.pallas.mosaic import prng_random_bits + +ANY = TPUMemorySpace.ANY +CMEM = TPUMemorySpace.CMEM +SMEM = TPUMemorySpace.SMEM +VMEM = TPUMemorySpace.VMEM diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index f63c9b8412b3..4957df4866f0 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -13,14 +13,14 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Hashable, Sequence +from collections.abc import Callable, Hashable, Sequence import enum from functools import partial import inspect import itertools as it from math import prod import operator as op -from typing import Any, Callable, TypeVar, Union +from typing import Any, TypeVar, Union import numpy as np @@ -52,7 +52,7 @@ from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3, as_hashable_function, memoize, partition_list, merge_lists, split_list, subs_list2) -from jax.api_util import flatten_fun_nokwargs, shaped_abstractify +from jax.api_util import flatten_fun_nokwargs, shaped_abstractify, argnums_partial from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe @@ -103,7 +103,8 @@ def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, be sharded along the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a ``mesh`` axis name at a position expresses sharding the corresponding argument array axis along that positional axis; not - mentioning an axis name expresses replication. + mentioning an axis name expresses replication. If an argument, or argument + subtree, has a corresponding spec of None, that argument is not sharded. out_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves, with a tree structure that is a tree prefix of the output of ``f``. Each ``PartitionSpec`` represents how the corresponding output shards should be @@ -153,13 +154,17 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs, def wrapped(*args): fun = lu.wrap_init(f) args_flat, in_tree = tree_flatten(args) - try: in_specs_flat = broadcast_prefix(in_specs, args) + fun, out_tree = flatten_fun_nokwargs(fun, in_tree) + try: in_specs_flat = broadcast_prefix(in_specs, args, + is_leaf=lambda x: x is None) except ValueError: e, *_ = prefix_errors(in_specs, args) raise e('shard_map in_specs') from None - _check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, args_flat) + dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) + if s is not None) + fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat) + _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat) in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) - fun, out_tree = flatten_fun_nokwargs(fun, in_tree) @memoize def out_names_thunk(): @@ -258,11 +263,13 @@ class NoFail: pass def _check_specs_vs_args( f: Callable, mesh: Mesh, in_tree: PyTreeDef, in_specs: Specs, - in_specs_flat: list[P], xs: list) -> None: + dyn_argnums: Sequence[int], in_specs_flat: Sequence[P], + xs: Sequence) -> None: in_avals = map(shaped_abstractify, xs) fail = [a if not len(p) <= a.ndim else no_fail for p, a in zip(in_specs_flat, in_avals)] if any(f is not no_fail for f in fail): + fail = _expand_fail(in_tree, dyn_argnums, fail) msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) raise ValueError(msg) in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) @@ -270,9 +277,18 @@ def _check_specs_vs_args( for d, ns in names.items()) else no_fail for a, names in zip(in_avals, in_names_flat)] if any(f is not no_fail for f in fail): + fail = _expand_fail(in_tree, dyn_argnums, fail) msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail) raise ValueError(msg) +def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int], + fail: Sequence[core.ShapedArray | NoFail] + ) -> list[core.ShapedArray | NoFail]: + fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves + for i, f in zip(dyn_argnums, fail): + fail_[i] = f + return fail_ + def _spec_rank_error( error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs, fails: list[core.ShapedArray | NoFail]) -> str: @@ -404,6 +420,7 @@ def _unmentioned(mesh: Mesh, names: AxisNames) -> list[AxisName]: name_set = {n for ns in names.values() for n in ns} return [n for n in mesh.axis_names if n not in name_set] + def _try_infer_args(f, tree): dummy_args = tree_unflatten(tree, [False] * tree.num_leaves) try: @@ -417,11 +434,11 @@ def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] failures = tree_unflatten(tree, fails) failures_aug = generate_key_paths(failures) specs_ = tree_unflatten(tree_structure(specs), generate_key_paths(specs)) - leaf = lambda x: type(x) is tuple and len(x) == 2 and type(x[1]) is P + leaf = lambda x: x is None or type(x) is tuple and len(x) == 2 and type(x[1]) is P specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf) - return [((spec_key, spec), (fail_key, fail_data)) - for (spec_key, spec), (fail_key, fail_data) - in zip(specs_aug, failures_aug) if fail_data is not no_fail] + return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data) + in zip(specs_aug, failures_aug) + if s is not None and fail_data is not no_fail] # Primitive @@ -501,9 +518,7 @@ def _shard_map_staging( in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) main = trace.main with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic( - f, main, in_avals_ - ) + jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_) out_avals_ = map(_check_shapedarray, genavals) _check_names(out_names_thunk(), out_avals_) in_rep = map(partial(_in_names_to_rep, mesh), in_names) @@ -653,7 +668,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - ns = aval_in.dtype._rules.physical_sharding(aval_in, ns) + ns = sharding_impls.physical_sharding(aval_in, ns) aval_in = core.physical_aval(aval_in) shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() unspecified = set(range(aval_in.ndim)) if auto else set() @@ -667,7 +682,7 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names, axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - ns = aval_out.dtype._rules.physical_sharding(aval_out, ns) + ns = sharding_impls.physical_sharding(aval_out, ns) aval_out = core.physical_aval(aval_out) unspecified = set(range(aval_out.ndim)) if auto else set() manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) @@ -896,13 +911,13 @@ def _debug_callback_eager_rule(mesh, *args, callback: Callable[..., Any], return [] eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule -def _device_put_eager_rule(mesh, x, *, src, device): - del mesh, src - if device is None: - return x - else: - raise ValueError("device_put with explicit device not allowed within " - f"shard_map-decorated functions, but got device {device}") +def _device_put_eager_rule(mesh, *xs, srcs, devices): + del mesh, srcs + for device in devices: + if device is not None: + raise ValueError("device_put with explicit device not allowed within " + f"shard_map-decorated functions, but got device {device}") + return xs eager_rules[dispatch.device_put_p] = _device_put_eager_rule # New primitives for efficient transposition @@ -1145,8 +1160,8 @@ def _io_callback_rule(mesh, *_, result_avals, **__): @register_check(dispatch.device_put_p) -def _device_put_rule(mesh, x, **_): - return x +def _device_put_rule(mesh, *xs, **_): + return list(xs) register_norewrite(dispatch.device_put_p) @@ -1274,6 +1289,9 @@ def _shard_map_batch( for ax in names} for names, d in zip(in_names, in_dims)] spmd_axis_name = trace.spmd_axis_name if spmd_axis_name is not None: + used = {n for names in in_names for ns in names.values() for n in ns} + if set(spmd_axis_name) & used: + raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped # type: ignore else ns for ns, d in zip(new_in_names, in_dims)] @as_hashable_function(closure=out_names_thunk) @@ -1306,6 +1324,9 @@ def _batch_out_names(spmd_axis_name, dims, out_names): out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] for ax in names} for names, d in zip(out_names, dims)] if spmd_axis_name is not None: + used = {n for names in out_names for ns in names.values() for n in ns} + if set(spmd_axis_name) & used: + raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs") out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped else ns for ns, d in zip(out_names_, dims)] return out_names_ @@ -1475,13 +1496,21 @@ def fun(*res_and_args): jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) return jaxpr + +def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]: + # We use a filtered-down version of unmentioned to avoid defensive-psum over + # more chips than required in the transpose-no-check-rep case. + name_set = {n for ns in names.values() for n in ns} + return [n for n in _all_mesh_names(mesh) if n not in name_set] + + def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, check_rep, rewrite, auto): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero - else x if rewrite - else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns)))) - for ns, x in zip(out_names, out_cts)] + else x if rewrite + else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns)))) + for ns, x in zip(out_names, out_cts)] args = [x if type(x) is not ad.UndefinedPrimal else ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval)) for ns, x in zip(in_names, args)] @@ -1497,8 +1526,9 @@ def fun_trans(out_cts, args): jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts ) out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero - else x if rewrite else jax.lax.psum(x, tuple(_unmentioned(mesh, ns))) - for ns, x in zip(in_names, out)] + else x if rewrite + else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns))) + for ns, x in zip(in_names, out)] return out fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) diff --git a/jax/experimental/slab/djax.py b/jax/experimental/slab/djax.py new file mode 100644 index 000000000000..c989ede8663e --- /dev/null +++ b/jax/experimental/slab/djax.py @@ -0,0 +1,187 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import collections +from collections.abc import Callable +from functools import partial +import sys + +import numpy as np +import jax +import jax.numpy as jnp +from jax import lax + +from jax._src import core +from jax._src import util + +import jax.experimental.slab.slab as sl + +map, zip = util.safe_map, util.safe_zip + +def make_djaxpr(f, abstracted_axes, **make_jaxpr_kwargs): + def djaxpr_maker(*args, **kwargs): + with jax._src.config.dynamic_shapes(True): + jaxpr_maker = jax.make_jaxpr( + f, abstracted_axes=abstracted_axes, **make_jaxpr_kwargs) + return jaxpr_maker(*args, **kwargs) + return djaxpr_maker + +@partial(jax.jit, static_argnums=(0,)) +def interp(djaxpr, slab, sizes, args): + views = [] + in_types = [x.aval for x in djaxpr.invars] + _, arg_types = util.split_list(in_types, [len(djaxpr.invars) - len(args)]) + for ty, x in zip(arg_types, args): + if isinstance(ty, core.DShapedArray): + resolved_shape = tuple(sizes.get(d, d) for d in ty.shape) + # TODO(frostig,mattjj): reconstructing slab views seems off? + views.append(sl.SlabView(x, resolved_shape, ty.dtype)) + else: + views.append(x) + slab, outs = eval_djaxpr(djaxpr, slab, *sizes.values(), *views) + return slab, outs + +def _check_axis_size_conflicts(all_axes, sizes): + if len(all_axes) != len(set(all_axes)): + d = collections.defaultdict(list) + for name, sz in zip(all_axes, sizes): + d[name].append(sz) + msg = '; '.join([f'{name}: {" != ".join(map(str, sizes))}' + for name, sizes in d.items() if len(sizes) > 1]) + raise ValueError(f'abstracted axes resolve to conflicting sizes. {msg}') + +def djit(f, abstracted_axes, **djit_kwargs): + # TODO(frostig,mattjj): un/flatten f + def f_wrapped(slab, *args): # TODO(frostig,mattjj): kw support + djaxpr = make_djaxpr(f, abstracted_axes, **djit_kwargs)(*args).jaxpr + in_types = [x.aval for x in djaxpr.invars] + _, arg_types = util.split_list(in_types, [len(djaxpr.invars) - len(args)]) + + def upload(slab, ty, x): + if isinstance(ty, core.DShapedArray): + return sl.slab_upload(slab, x) + elif isinstance(ty, core.ShapedArray): + return slab, x + else: + assert False + + slab, views = sl.chain(slab, upload, *zip(arg_types, args)) + + sizes: dict[core.Var, int] = {} + for ty, x in zip(arg_types, args): + for v, d in zip(ty.shape, x.shape): + if isinstance(v, core.Var): + d_ = sizes.setdefault(v, d) + if d_ != d: + raise ValueError( + f'abstract dimension bound to unequal sizes: {d_} != {d}') + + slab, out_views = interp( + djaxpr, slab, sizes, + [v.addr if isinstance(v, sl.SlabView) else v for v in views]) + return slab, tuple(sl.slab_download(slab, v) for v in out_views) + + return f_wrapped + +def eval_djaxpr(jaxpr: core.Jaxpr, slab: sl.Slab, *args: jax.Array | sl.SlabView): + if jaxpr.constvars: raise NotImplementedError + + env: dict[core.Var, jax.Array | sl.SlabView] = {} + + def read(a): + return env[a] if type(a) is core.Var else a.val + + def write(v, val): + env[v] = val + + map(write, jaxpr.invars, args) + for eqn in jaxpr.eqns: + invals = map(read, eqn.invars) + slab, outvals = rules[eqn.primitive](slab, *invals, **eqn.params) + map(write, eqn.outvars, outvals) + return slab, map(read, jaxpr.outvars) + +rules: dict[core.Primitive, Callable] = {} + +def matmul_rule(slab, lhs, rhs, *, dimension_numbers, **_): + slab, out = sl.matmul(slab, lhs, rhs) + return slab, [out] +rules[lax.dot_general_p] = matmul_rule + +def tanh_rule(slab, x, **_): + slab, out = sl.tanh(slab, x) + return slab, [out] +rules[lax.tanh_p] = tanh_rule + +# ------- + +def print_seg(msg): + print() + print(f'-- {msg}') + print() + +def check_djit(slab, f, abstracted_axes, *args): + refs, _ = jax.tree.flatten(f(*args)) + f_djit = djit(f, abstracted_axes=abstracted_axes) + slab, outs = f_djit(slab, *args) + for out, ref in zip(outs, refs): + abs_err = jnp.max(jnp.abs(out - ref)) + rel_err = jnp.max(jnp.abs(out - ref) / jnp.abs(ref)) + msg = f'abs={abs_err}, rel={rel_err}' + assert jnp.allclose(out, ref, atol=1e-4), msg + +def test(slab, xs): + a, b = xs + + def f(a, b): + c = jnp.dot(a, b) + return jnp.tanh(c) + + abstracted_axes = (('m', 'k'), ('k', 'n')) + + print_seg('djaxpr') + djaxpr = make_djaxpr(f, abstracted_axes)(a, b).jaxpr + print(djaxpr) + + print_seg('djax output') + f_djit = djit(f, abstracted_axes=abstracted_axes) + slab, [c] = f_djit(slab, a, b) + print(c) + + print_seg('djax -> jax lowering') + big_jaxpr = jax.make_jaxpr(f_djit)(slab, a, b) + print('\n'.join(str(big_jaxpr).split('\n')[:20])) + print('...') + print('\n'.join(str(big_jaxpr).split('\n')[-20:])) + print(len(str(big_jaxpr).split('\n'))) + + check_djit(slab, f, abstracted_axes, a, b) + +def parse_arr(i, s): + shape = eval(s) + return np.random.RandomState(i).normal(size=shape).astype(np.float32) + +def main(args): + slab_sz = eval(args[0]) + print('slab size', slab_sz) + xs = map(parse_arr, range(len(args[1:])), args[1:]) + assert all(len(x.shape) == 2 for x in xs) + slab = sl.slab_make(slab_sz) + test(slab, xs) + + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/jax/experimental/slab/slab.py b/jax/experimental/slab/slab.py new file mode 100644 index 000000000000..af7b079eeb7f --- /dev/null +++ b/jax/experimental/slab/slab.py @@ -0,0 +1,365 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from functools import partial, reduce +import sys +import typing +from typing import NamedTuple, Union + +import numpy as np + +import jax +import jax.numpy as jnp + +from jax._src import core +from jax._src import util + +map, zip = util.safe_map, util.safe_zip + +DInt = jax.Array +Address = DInt +XInt = Union[int, DInt] +DShape = tuple[XInt, ...] +SShape = tuple[int, ...] +DType = jnp.dtype + +class Slab(NamedTuple): + data: jax.Array + cursor: Address + +@jax.tree_util.register_pytree_node_class +class SlabView(NamedTuple): + addr: Address + shape: DShape + dtype: DType + + def size(self): + return jnp.prod(jnp.array(self.shape)) + + def ndim(self): + return len(self.shape) + + def tree_flatten(self): + return (self.addr, self.shape), self.dtype + + @classmethod + def tree_unflatten(cls, dtype, xs): + addr, shape = xs + return cls(addr, shape, dtype) + +word_b = 4 +phrase_b = 512 +phrase_w = 128 +tile_aspect = 8 + +def xceil_div(x: XInt, y: XInt) -> XInt: + """ceil(x / y)""" + return (x + y - 1) // y + +def _xadd(x: XInt, y: XInt) -> XInt: + return x + y + +def _xmul(x: XInt, y: XInt) -> XInt: + return x * y + +def xadd(*xs: XInt) -> XInt: + return reduce(_xadd, xs, typing.cast(XInt, 0)) + +def xmul(*xs: XInt) -> XInt: + return reduce(_xmul, xs, typing.cast(XInt, 1)) + +def xsum(xs: Iterable[XInt]) -> XInt: + return xadd(*list(xs)) + +def xprod(xs: Iterable[XInt]) -> XInt: + return xmul(*list(xs)) + +def static_int(x: XInt) -> bool: + return isinstance(core.get_aval(x), core.ConcreteArray) + +def static_shape(s: DShape) -> bool: + return all(map(static_int, s)) + +def assert_static_int(x: XInt) -> int: + if not static_int(x): + raise TypeError(f'{x} is not a static int') + return int(x) + +def assert_static_shape(s: DShape) -> SShape: + if not static_shape(s): + raise TypeError(f'{s} is not a static shape') + return tuple(map(int, s)) + +def tile_shape(shape: DShape, dtype) -> SShape: + # Units: (1, 1, ..., elements, 1) + if len(shape) < 2: + raise NotImplementedError('matrices or bust') + num_leading = len(shape) - 2 + return (1,) * num_leading + (tile_aspect * word_b // dtype.itemsize, + phrase_b // word_b) + +def tile_phrases(shape: DShape, dtype: DType): + # Units: phrases + return xprod(tile_shape(shape, dtype)) * dtype.itemsize // phrase_b + +def slab_make(num_phrases): + return Slab(jnp.zeros((num_phrases, phrase_w), dtype=jnp.uint32), + jnp.array(0, dtype=jnp.int32)) + +def slab_alloc(slab: Slab, shape: DShape, dtype): + if len(shape) < 2: + raise NotImplementedError('matrices or bust') + tiled_shape = map(xceil_div, shape, tile_shape(shape, dtype)) + num_p = xmul(*tiled_shape, tile_phrases(shape, dtype)) + new_slab = Slab(slab.data, slab.cursor + num_p) + slab_val = SlabView(slab.cursor, shape, dtype) + return new_slab, slab_val + +def strides(xs): + s = 1 + ss = [] + for x in reversed(xs): + ss.append(s) + s *= x + return tuple(reversed(ss)) + +def slab_slices(view, slice_base_e: DShape, slice_shape_e: SShape): + view_shape_e = tile_shape(view.shape, view.dtype) + # dassert all(s % t == 0 for s, t in zip(slice_base, view_shape_e)) + # dassert all(s % t == 0 for s, t in zip(slice_shape, view_shape_e)) + slice_base_t = [s // t for s, t in zip(slice_base_e, view_shape_e)] + slice_shape_t = [s // t for s, t in zip(slice_shape_e, view_shape_e)] + tiled_shape = map(xceil_div, view.shape, view_shape_e) + tiled_strides = strides(tiled_shape) + tp = tile_phrases(view.shape, view.dtype) + for idx in np.ndindex(*slice_shape_t[:-1]): + linear_idx_t = xsum( + map(xmul, map(xadd, slice_base_t, (*idx, 0)), tiled_strides)) + yield (view.addr + linear_idx_t * tp, slice_shape_t[-1] * tp) + +def reinterpret_cast(x: jax.Array, shape: SShape, dtype: DType): + x_bytes = x.size * x.dtype.itemsize + if -1 in shape: + assert x_bytes % xprod(s for s in shape if s != -1) * dtype.itemsize == 0 + else: + assert x_bytes == xprod(shape) * dtype.itemsize, (x.shape, x.dtype, shape, dtype) + if x.dtype.itemsize != dtype.itemsize: + # reshape(x, -1) in conversion below becomes reshape(-1, a, b) for some a,b + raise NotImplementedError('todo') + return jax.lax.bitcast_convert_type(x.reshape(-1), dtype).reshape(shape) + +def slab_read(slab, view, slice_base: DShape, slice_shape: SShape): + view_tile_shape = tile_shape(view.shape, view.dtype) + tiled_shape = assert_static_shape( + tuple(map(xceil_div, slice_shape, view_tile_shape))) + slices = [ + jax.lax.dynamic_slice_in_dim(slab.data, addr, phrases) + for addr, phrases in slab_slices(view, slice_base, slice_shape)] + slice_mem = jnp.stack(slices, axis=0) + return reinterpret_cast( + slice_mem, (*tiled_shape, *view_tile_shape), view.dtype + ).swapaxes(-2, -3).reshape(slice_shape) + +# TODO: just take vjp of slab_read +def slab_write(slab, view, slice_base: DShape, inval: jax.Array): + slice_shape = inval.shape + view_tile_shape = tile_shape(view.shape, view.dtype) + tiled_shape = map(xceil_div, inval.shape, view_tile_shape) + inval_linearized = inval.reshape( + *tiled_shape[:-1], view_tile_shape[-2], tiled_shape[-1], view_tile_shape[-1] + ).swapaxes(-2, -3) + slice_mem = reinterpret_cast(inval_linearized, (-1, phrase_w), + jnp.dtype('uint32')) + slice_addr = 0 + new_slab = slab.data + for slab_addr, slice_sz_p in slab_slices(view, slice_base, slice_shape): + s = jax.lax.dynamic_slice_in_dim(slice_mem, slice_addr, slice_sz_p) + slice_addr += slice_sz_p + new_slab = jax.lax.dynamic_update_slice_in_dim( + new_slab, s, slab_addr, axis=0) + return Slab(new_slab, slab.cursor) + +def elementwise(f, slab: Slab, xs: Sequence[SlabView], out: SlabView): + if len(xs) == 0: + raise TypeError('missing input arguments') + x = xs[0] + for y in xs[1:]: + if x.shape != y.shape: + raise ValueError(f'elementwise shapes mismatch: {x.shape} != {y.shape}') + if x.dtype != y.dtype: + raise ValueError(f'elementwise dtypes mismatch: {x.dtype} != {y.dtype}') + if x.shape != out.shape: + raise ValueError( + f'elementwise input/output shape mismatch: {x.shape} != {out.shape}') + + tiled_shape = map(xceil_div, x.shape, tile_shape(x.shape, x.dtype)) + x_sz_p = xprod(tiled_shape) * tile_phrases(x.shape, x.dtype) + compute_tile_p = 16 + num_whole_blocks = x_sz_p // compute_tile_p + + def f_u32(*zs): + a = zs[0] + return reinterpret_cast( + f(*[reinterpret_cast(z, a.shape, x.dtype) for z in zs]), + a.shape, jnp.dtype('uint32')) + + def body(i_b, mem): + i_p = i_b * compute_tile_p + slices = [ + jax.lax.dynamic_slice_in_dim(mem, z.addr + i_p, compute_tile_p) + for z in xs] + out_slice = f_u32(*slices) + return jax.lax.dynamic_update_slice_in_dim( + mem, out_slice, out.addr + i_p, axis=0) + mem = jax.lax.fori_loop(0, num_whole_blocks, body, slab.data) + + epi_start_p = num_whole_blocks * compute_tile_p + epi_size_p = x_sz_p - epi_start_p + slices = [ + jax.lax.dynamic_slice_in_dim(mem, z.addr + epi_start_p, compute_tile_p) + for z in xs] + out_slice = f_u32(*slices) + return Slab(masked_store(mem, out.addr + epi_start_p, out_slice, epi_size_p), + slab.cursor) + +def masked_store(mem, addr, update, num_p): + update_p = update.shape[0] + prev_val = jax.lax.dynamic_slice_in_dim(mem, addr, update_p) + new_val = jnp.where(jnp.arange(update_p)[:, None] < num_p, update, prev_val) + return jax.lax.dynamic_update_slice_in_dim(mem, new_val, addr, axis=0) + +def _matmul(slab: Slab, ins: Sequence[SlabView], out: SlabView): + lhs, rhs = ins + dtype = lhs.dtype + n, k, m = (*lhs.shape, rhs.shape[1]) + # todo: shape + dtype check + # dassert shapes are tile aligned + tile_n, tile_k, tile_m = 128, 128, 128 + n_tiles = n // tile_n + k_tiles = k // tile_k + m_tiles = m // tile_m + + mem = slab + def loop_n(ni, mem): + def loop_m(mi, mem): + acc = jnp.zeros((tile_n, tile_m), dtype=dtype) + def loop_k(ki, acc): + lhs_tile = slab_read(mem, lhs, (ni * tile_n, ki * tile_k), (tile_n, tile_k)) + rhs_tile = slab_read(mem, rhs, (ki * tile_k, mi * tile_m), (tile_k, tile_m)) + acc += lhs_tile @ rhs_tile + return acc + acc = jax.lax.fori_loop(0, k_tiles, loop_k, acc) + return slab_write(mem, out, (ni * tile_n, mi * tile_m), acc) + return jax.lax.fori_loop(0, m_tiles, loop_m, mem) + mem = jax.lax.fori_loop(0, n_tiles, loop_n, mem) + return mem + +def make_allocating_op(op, type_rule): + def made_op(slab, *xs: SlabView): + out_shape, out_dtype = type_rule(*xs) + slab, out = slab_alloc(slab, out_shape, out_dtype) + slab = op(slab, xs, out) + return slab, out + return made_op + +add = make_allocating_op(partial(elementwise, jax.lax.add), + lambda x, *_: (x.shape, x.dtype)) +mul = make_allocating_op(partial(elementwise, jax.lax.mul), + lambda x, *_: (x.shape, x.dtype)) +tanh = make_allocating_op(partial(elementwise, jax.lax.tanh), + lambda x, *_: (x.shape, x.dtype)) +matmul = make_allocating_op(_matmul, + lambda a, b: ((a.shape[0], b.shape[1]), a.dtype)) + +def parse_arr(i, s): + shape = eval(s) + return np.random.RandomState(i).normal(size=shape).astype(np.float32) + +def print_seg(msg): + print() + print(f'-- {msg}') + print() + +def make_jaxpr_slab_write(slab, view, inval): + return jax.make_jaxpr( + lambda slab, x: slab_write(slab, view, (0, 0), x))(slab, inval) + +def make_jaxpr_slab_read(slab, view, outval_shape): + return jax.make_jaxpr( + lambda slab: slab_read(slab, view, (0, 0), outval_shape))(slab) + +def slab_download(slab, v): + if not static_shape(v.shape): raise Exception + return slab_read(slab, v, (0,) * v.ndim(), v.shape) + +def slab_upload(slab, x): + slab, xv = slab_alloc(slab, x.shape, x.dtype) + slab = slab_write(slab, xv, (0,) * x.ndim, x) + return slab, xv + +def chain(slab, fs, *argss, unary=False): + if callable(fs): + fs = [fs] * len(argss) + outss = [] + for f, args in zip(fs, argss): + if unary: + slab, outs = f(slab, args) + else: + slab, outs = f(slab, *args) + outss.append(outs) + return slab, outss + +def test_binop(op, ref_op, slab, x, y): + z = ref_op(x, y) + slab, xv = slab_upload(slab, x) + slab, yv = slab_upload(slab, y) + slab, zv = op(slab, xv, yv) + assert jnp.allclose(slab_download(slab, xv), x, atol=1e-4) + assert jnp.allclose(slab_download(slab, yv), y, atol=1e-4) + assert jnp.allclose(slab_download(slab, zv), z, atol=1e-4) + +def main(args): + xs = map(parse_arr, range(len(args)), args) + assert all(len(x.shape) == 2 for x in xs) + + slab = slab_make(1024) + + x, y, *_ = xs + test_binop(add, jax.lax.add, slab, x, x) + test_binop(mul, jax.lax.mul, slab, x, x) + test_binop(matmul, lambda a, b: a @ b, slab, x, y) + + def put(slab, x): + slab, v = slab_upload(slab, x) + print_seg('slab_read result') + print(slab_download(slab, v)) + return slab, v + + slab, vals = chain(slab, put, *xs, unary=True) + + if len(vals) >= 2: + x, y, *_ = vals + slab, z = mul(slab, x, x) + print_seg('mul') + print(slab_download(slab, z)) + slab, w = add(slab, x, z) + print_seg('add') + print(slab_download(slab, w)) + + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/jax/experimental/sparse/ad.py b/jax/experimental/sparse/ad.py index c789baf9d15e..2c235c9320d5 100644 --- a/jax/experimental/sparse/ad.py +++ b/jax/experimental/sparse/ad.py @@ -14,9 +14,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import itertools -from typing import Any, Callable, Union +from typing import Any import jax from jax._src import core @@ -81,7 +81,7 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0, taking the gradient with respect to a :class:`jax.experimental.sparse` array, the gradient is computed in the subspace defined by the array's sparsity pattern. - Example: + Examples: >>> from jax.experimental import sparse >>> X = sparse.BCOO.fromdense(jnp.arange(6.)) @@ -109,7 +109,7 @@ def grad(fun: Callable, argnums: int | Sequence[int] = 0, the gradient with respect to a :class:`jax.experimental.sparse` array, the gradient is computed in the subspace defined by the array's sparsity pattern. - Example: + Examples: >>> from jax.experimental import sparse >>> X = sparse.BCOO.fromdense(jnp.arange(6.)) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index c942c4867f40..4cbe52383751 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -1224,15 +1224,24 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic out_n_batch = lhs.n_batch + rhs.n_batch - len(lhs_batch) out_nse = min(out_nse, math.prod(out_aval.shape[out_n_batch:])) + lhs_batch_shape = np.broadcast_shapes( + tuple(lhs_data.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch), + tuple(lhs_indices.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch), + ) + rhs_batch_shape = np.broadcast_shapes( + tuple(rhs_data.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch), + tuple(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch), + ) + data_shape = ( *(lhs_shape[dim] for dim in lhs_batch), - *(lhs_data.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch), - *(rhs_data.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch), + *lhs_batch_shape, + *rhs_batch_shape, out_nse) indices_shape = ( *(lhs_shape[dim] for dim in lhs_batch), - *(lhs_indices.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch), - *(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch), + *lhs_batch_shape, + *rhs_batch_shape, out_nse, lhs.n_sparse + rhs.n_sparse - 2 * len(lhs_contracting)) data_aval = core.ShapedArray(data_shape, out_aval.dtype) diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index c2990d3fed57..b0ac1fa5d380 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Union, Callable +from collections.abc import Callable import functools import jax diff --git a/jax/experimental/sparse/nm.py b/jax/experimental/sparse/nm.py index 251bf45f00d1..6c827325befc 100644 --- a/jax/experimental/sparse/nm.py +++ b/jax/experimental/sparse/nm.py @@ -181,6 +181,9 @@ def _nm_spmm_abstract_eval( if gpu_sparse.cuda_is_supported: mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="cuda") +if gpu_sparse.rocm_is_supported: + mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="rocm") + # -------------------------------------------------------------------- # nm_pack diff --git a/jax/experimental/sparse/test_util.py b/jax/experimental/sparse/test_util.py index 38e6785f2ded..365c436521b8 100644 --- a/jax/experimental/sparse/test_util.py +++ b/jax/experimental/sparse/test_util.py @@ -15,12 +15,11 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence import functools import itertools import math -from typing import Any, Callable, Union -from typing import NamedTuple +from typing import Any, NamedTuple import jax from jax import lax diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 453083c57f2e..86eb8a9aefe8 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -47,9 +47,9 @@ -0.15574613], dtype=float32) """ -from collections.abc import Sequence +from collections.abc import Callable, Sequence import functools -from typing import Any, Callable, NamedTuple, Optional +from typing import Any, NamedTuple import numpy as np diff --git a/jax/export.py b/jax/export.py new file mode 100644 index 000000000000..13186f886f43 --- /dev/null +++ b/jax/export.py @@ -0,0 +1,36 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize", + "maximum_supported_calling_convention_version", + "minimum_supported_calling_convention_version", + "default_export_platform", + "SymbolicScope", "is_symbolic_dim", + "symbolic_shape", "symbolic_args_specs"] + +from jax._src.export._export import ( + DisabledSafetyCheck, + Exported, + export, + deserialize, + maximum_supported_calling_convention_version, + minimum_supported_calling_convention_version, + default_export_platform) + +from jax._src.export import shape_poly_decision # Import only to set the decision procedure +del shape_poly_decision +from jax._src.export.shape_poly import ( + SymbolicScope, + is_symbolic_dim, + symbolic_shape, + symbolic_args_specs) diff --git a/jax/extend/BUILD b/jax/extend/BUILD index e33569d6d109..229b6cd6ec9d 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -29,6 +29,7 @@ pytype_strict_library( deps = [ ":backend", ":core", + ":ffi", ":linear_util", ":random", ":source_info_util", @@ -70,3 +71,9 @@ pytype_strict_library( srcs = ["source_info_util.py"], deps = ["//jax:source_info_util"], ) + +pytype_strict_library( + name = "ffi", + srcs = ["ffi.py"], + deps = ["//jax"], +) diff --git a/jax/extend/__init__.py b/jax/extend/__init__.py index 3f4327dde917..e8ef32935cbf 100644 --- a/jax/extend/__init__.py +++ b/jax/extend/__init__.py @@ -31,6 +31,7 @@ from jax.extend import ( backend as backend, core as core, + ffi as ffi, linear_util as linear_util, random as random, source_info_util as source_info_util, diff --git a/jax/_src/ffi.py b/jax/extend/ffi.py similarity index 66% rename from jax/_src/ffi.py rename to jax/extend/ffi.py index 1b394802a453..565b37cbb542 100644 --- a/jax/_src/ffi.py +++ b/jax/extend/ffi.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 -import os - -from jax._src.lib import jaxlib - - -def include_dir() -> str: - """Get the path to the directory containing header files bundled with jaxlib""" - jaxlib_dir = os.path.dirname(os.path.abspath(jaxlib.__file__)) - return os.path.join(jaxlib_dir, "include") +from jax._src.extend.ffi import ( + ffi_lowering as ffi_lowering, + include_dir as include_dir, + pycapsule as pycapsule, +) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index a515f2293214..c5aa31a536f6 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -23,7 +23,6 @@ global_avals_to_results_handler as global_avals_to_results_handler, global_result_handlers as global_result_handlers, parallel_callable as parallel_callable, - shard_arg as shard_arg, shard_args as shard_args, xla_pmap_p as xla_pmap_p, ) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 2cf251a36258..6e9c7af5eabb 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -2,7 +2,8 @@ from __future__ import annotations import builtins -from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, TypeVar, Union, overload +from collections.abc import Callable, Sequence +from typing import Any, Literal, NamedTuple, TypeVar, Union, overload from jax._src import core as _core from jax._src import dtypes as _dtypes @@ -33,21 +34,21 @@ def acos(x: ArrayLike, /) -> Array: ... def acosh(x: ArrayLike, /) -> Array: ... def add(x: ArrayLike, y: ArrayLike, /) -> Array: ... def amax(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def amin(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def all(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., *, where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., *, where: ArrayLike | None = ...) -> Array: ... def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ..., atol: ArrayLike = ..., equal_nan: builtins.bool = ...) -> Array: ... alltrue = all def angle(z: ArrayLike, deg: builtins.bool = ...) -> Array: ... def any(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., *, where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., *, where: ArrayLike | None = ...) -> Array: ... def append( - arr: ArrayLike, values: ArrayLike, axis: Optional[int] = ... + arr: ArrayLike, values: ArrayLike, axis: int | None = ... ) -> Array: ... def apply_along_axis(func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs) -> Array: ... @@ -56,9 +57,10 @@ def apply_over_axes( ) -> Array: ... def arange( start: DimSize, - stop: Optional[DimSize] = ..., - step: Optional[DimSize] = ..., - dtype: Optional[DTypeLike] = ..., + stop: DimSize | None = ..., + step: DimSize | None = ..., + dtype: DTypeLike | None = ..., *, + device: _Device | _Sharding | None = ..., ) -> Array: ... def arccos(x: ArrayLike, /) -> Array: ... def arccosh(x: ArrayLike, /) -> Array: ... @@ -69,20 +71,20 @@ def arctan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... def arctanh(x: ArrayLike, /) -> Array: ... def argmax( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - keepdims: Optional[builtins.bool] = ..., + keepdims: builtins.bool | None = ..., ) -> Array: ... def argmin( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - keepdims: Optional[builtins.bool] = ..., + keepdims: builtins.bool | None = ..., ) -> Array: ... def argpartition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ... def argsort( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., *, stable: builtins.bool = ..., descending: builtins.bool = ..., @@ -92,8 +94,8 @@ def argsort( def argwhere( a: ArrayLike, *, - size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ..., + size: int | None = ..., + fill_value: ArrayLike | None = ..., ) -> Array: ... around = round def array(object: Any, dtype: DTypeLike | None = ..., copy: builtins.bool = True, @@ -105,17 +107,17 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: ... array_repr = _np.array_repr def array_split( ary: ArrayLike, - indices_or_sections: Union[int, Sequence[int], ArrayLike], + indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = ..., ) -> list[Array]: ... array_str = _np.array_str def asarray( - a: Any, dtype: Optional[DTypeLike] = ..., order: Optional[str] = ..., - *, copy: Optional[builtins.bool] = ... + a: Any, dtype: DTypeLike | None = ..., order: str | None = ..., + *, copy: builtins.bool | None = ... ) -> Array: ... def asin(x: ArrayLike, /) -> Array: ... def asinh(x: ArrayLike, /) -> Array: ... -def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: _Device | _Sharding | None = ...) -> Array: ... +def astype(a: ArrayLike, dtype: DTypeLike | None, /, *, copy: builtins.bool = ..., device: _Device | _Sharding | None = ...) -> Array: ... def atan(x: ArrayLike, /) -> Array: ... def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... def atanh(x: ArrayLike, /) -> Array: ... @@ -141,19 +143,19 @@ def atleast_3d(x: ArrayLike, /) -> Array: ... def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... @overload -def average(a: ArrayLike, axis: _Axis = ..., weights: Optional[ArrayLike] = ..., +def average(a: ArrayLike, axis: _Axis = ..., weights: ArrayLike | None = ..., returned: Literal[False] = False, keepdims: builtins.bool = False) -> Array: ... @overload -def average(a: ArrayLike, axis: _Axis = ..., weights: Optional[ArrayLike] = ..., *, +def average(a: ArrayLike, axis: _Axis = ..., weights: ArrayLike | None = ..., *, returned: Literal[True], keepdims: builtins.bool = False) -> tuple[Array, Array]: ... @overload -def average(a: ArrayLike, axis: _Axis = ..., weights: Optional[ArrayLike] = ..., - returned: builtins.bool = False, keepdims: builtins.bool = False) -> Union[Array, tuple[Array, Array]]: ... +def average(a: ArrayLike, axis: _Axis = ..., weights: ArrayLike | None = ..., + returned: builtins.bool = False, keepdims: builtins.bool = False) -> Array | tuple[Array, Array]: ... def bartlett(M: int) -> Array: ... bfloat16: Any -def bincount(x: ArrayLike, weights: Optional[ArrayLike] = ..., - minlength: int = ..., *, length: Optional[int] = ...) -> Array: ... +def bincount(x: ArrayLike, weights: ArrayLike | None = ..., + minlength: int = ..., *, length: int | None = ...) -> Array: ... def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_count(x: ArrayLike, /) -> Array: ... def bitwise_invert(x: ArrayLike, /) -> Array: ... @@ -163,7 +165,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... def blackman(M: int) -> Array: ... -def block(arrays: Union[ArrayLike, Sequence[ArrayLike], Sequence[Sequence[ArrayLike]]]) -> Array: ... +def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ... bool: Any bool_: Any def broadcast_arrays(*args: ArrayLike) -> list[Array]: ... @@ -172,8 +174,8 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]: ... def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... @overload -def broadcast_shapes(*shapes: Sequence[Union[int, _core.Tracer]] - ) -> tuple[Union[int, _core.Tracer], ...]: ... +def broadcast_shapes(*shapes: Sequence[int | _core.Tracer] + ) -> tuple[int | _core.Tracer, ...]: ... def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: ... c_: _CClass @@ -187,53 +189,53 @@ def choose(a: ArrayLike, choices: Sequence[ArrayLike], def clip( x: ArrayLike | None = ..., /, - min: Optional[ArrayLike] = ..., - max: Optional[ArrayLike] = ..., + min: ArrayLike | None = ..., + max: ArrayLike | None = ..., a: ArrayLike | DeprecatedArg | None = ..., a_min: ArrayLike | DeprecatedArg | None = ..., a_max: ArrayLike | DeprecatedArg | None = ... ) -> Array: ... def column_stack( - tup: Union[_np.ndarray, Array, Sequence[ArrayLike]] + tup: _np.ndarray | Array | Sequence[ArrayLike] ) -> Array: ... complex128: Any complex64: Any complex_: Any complexfloating = _np.complexfloating -def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = ..., +def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = ..., out: None = ...) -> Array: ... def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: ... def concatenate( - arrays: Union[_np.ndarray, Array, Sequence[ArrayLike]], - axis: Optional[int] = ..., - dtype: Optional[DTypeLike] = ..., + arrays: _np.ndarray | Array | Sequence[ArrayLike], + axis: int | None = ..., + dtype: DTypeLike | None = ..., ) -> Array: ... def conjugate(x: ArrayLike, /) -> Array: ... conj = conjugate def convolve(a: ArrayLike, v: ArrayLike, mode: str = ..., *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... -def copy(a: ArrayLike, order: Optional[str] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... +def copy(a: ArrayLike, order: str | None = ...) -> Array: ... def copysign(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = ..., rowvar: builtins.bool = ...) -> Array: ... +def corrcoef(x: ArrayLike, y: ArrayLike | None = ..., rowvar: builtins.bool = ...) -> Array: ... def correlate(a: ArrayLike, v: ArrayLike, mode: str = ..., *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... def cos(x: ArrayLike, /) -> Array: ... def cosh(x: ArrayLike, /) -> Array: ... def count_nonzero(a: ArrayLike, axis: _Axis = ..., keepdims: builtins.bool = ...) -> Array: ... -def cov(m: ArrayLike, y: Optional[ArrayLike] = ..., rowvar: builtins.bool = ..., - bias: builtins.bool = ..., ddof: Optional[int] = ..., - fweights: Optional[ArrayLike] = ..., - aweights: Optional[ArrayLike] = ...) -> Array: ... +def cov(m: ArrayLike, y: ArrayLike | None = ..., rowvar: builtins.bool = ..., + bias: builtins.bool = ..., ddof: int | None = ..., + fweights: ArrayLike | None = ..., + aweights: ArrayLike | None = ...) -> Array: ... def cross( a: ArrayLike, b: ArrayLike, axisa: int = -1, axisb: int = -1, axisc: int = -1, - axis: Optional[int] = ..., + axis: int | None = ..., ) -> Array: ... csingle: Any def cumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., @@ -249,8 +251,8 @@ def deg2rad(x: ArrayLike, /) -> Array: ... degrees = rad2deg def delete( arr: ArrayLike, - obj: Union[ArrayLike, slice], - axis: Optional[int] = ..., + obj: ArrayLike | slice, + axis: int | None = ..., *, assume_unique_indices: builtins.bool = ..., ) -> Array: ... @@ -262,33 +264,33 @@ def diagonal( a: ArrayLike, offset: ArrayLike = ..., axis1: int = ..., axis2: int = ... ): ... def diff(a: ArrayLike, n: int = ..., axis: int = ..., - prepend: Optional[ArrayLike] = ..., - append: Optional[ArrayLike] = ...) -> Array: ... + prepend: ArrayLike | None = ..., + append: ArrayLike | None = ...) -> Array: ... def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ...) -> Array: ... divide = true_divide def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ... def dot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... double: Any def dsplit( - ary: ArrayLike, indices_or_sections: Union[int, ArrayLike] + ary: ArrayLike, indices_or_sections: int | ArrayLike ) -> list[Array]: ... -def dstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]], - dtype: Optional[DTypeLike] = ...) -> Array: ... +def dstack(tup: _np.ndarray | Array | Sequence[ArrayLike], + dtype: DTypeLike | None = ...) -> Array: ... dtype = _np.dtype e: float -def ediff1d(ary: ArrayLike, to_end: Optional[ArrayLike] = ..., - to_begin: Optional[ArrayLike] = ...) -> Array: ... +def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = ..., + to_begin: ArrayLike | None = ...) -> Array: ... @overload def einsum( subscript: str, /, *operands: ArrayLike, out: None = ..., - optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ..., + preferred_element_type: DTypeLike | None = ..., _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., ) -> Array: ... @@ -297,11 +299,11 @@ def einsum( def einsum( arr: ArrayLike, axes: Sequence[Any], /, - *operands: Union[ArrayLike, Sequence[Any]], + *operands: ArrayLike | Sequence[Any], out: None = ..., - optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ..., + preferred_element_type: DTypeLike | None = ..., _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., ) -> Array: ... @@ -310,9 +312,9 @@ def einsum( subscripts, /, *operands, out: None = ..., - optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ..., + preferred_element_type: DTypeLike | None = ..., _use_xeinsum: builtins.bool = ..., _dot_general: Callable[..., Array] = ..., ) -> Array: ... @@ -321,49 +323,51 @@ def einsum( def einsum_path( subscripts: str, /, *operands: ArrayLike, - optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... @overload def einsum_path( arr: ArrayLike, axes: Sequence[Any], /, - *operands: Union[ArrayLike, Sequence[Any]], - optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., + *operands: ArrayLike | Sequence[Any], + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... @overload def einsum_path( subscripts, /, *operands: ArrayLike, - optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... -def empty(shape: Any, dtype: Optional[DTypeLike] = ..., - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... -def empty_like(prototype: Union[ArrayLike, DuckTypedArray], - dtype: Optional[DTypeLike] = ..., +def empty(shape: Any, dtype: DTypeLike | None = ..., + device: _Device | _Sharding | None = ...) -> Array: ... +def empty_like(prototype: ArrayLike | DuckTypedArray, + dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... + device: _Device | _Sharding | None = ...) -> Array: ... def equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... euler_gamma: float def exp(x: ArrayLike, /) -> Array: ... def exp2(x: ArrayLike, /) -> Array: ... -def expand_dims(a: ArrayLike, axis: Union[int, Sequence[int]]) -> Array: ... +def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: ... def expm1(x: ArrayLike, /) -> Array: ... -def extract(condition: ArrayLike, arr: ArrayLike) -> Array: ... -def eye(N: DimSize, M: Optional[DimSize] = ..., k: int | ArrayLike = ..., - dtype: Optional[DTypeLike] = ...) -> Array: ... +def extract(condition: ArrayLike, arr: ArrayLike, *, + size: int | None = None, fill_value: ArrayLike = 0) -> Array: ... +def eye(N: DimSize, M: DimSize | None = ..., k: int | ArrayLike = ..., + dtype: DTypeLike | None = ..., *, + device: _Device | _Sharding | None = ...) -> Array: ... def fabs(x: ArrayLike, /) -> Array: ... finfo = _dtypes.finfo def fix(x: ArrayLike, out: None = ...) -> Array: ... def flatnonzero( a: ArrayLike, *, - size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike]] = ..., + size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike] = ..., ) -> Array: ... flexible = _np.flexible def flip( - m: ArrayLike, axis: Optional[Union[int, Sequence[int]]] = ... + m: ArrayLike, axis: int | Sequence[int] | None = ... ) -> Array: ... def fliplr(m: ArrayLike) -> Array: ... @@ -387,7 +391,7 @@ def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ... def from_dlpack(x: Any, /, *, device: _Device | _Sharding | None = None, copy: builtins.bool | None = None) -> Array: ... -def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = ..., +def frombuffer(buffer: bytes | Any, dtype: DTypeLike = ..., count: int = ..., offset: int = ...) -> Array: ... def fromfile(*args, **kwargs): ... def fromfunction(function: Callable[..., Array], shape: Any, @@ -397,12 +401,12 @@ def fromstring( string: str, dtype: DTypeLike = ..., count: int = ..., *, sep: str ) -> Array: ... def full(shape: Any, fill_value: ArrayLike, - dtype: Optional[DTypeLike] = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... -def full_like(a: Union[ArrayLike, DuckTypedArray], - fill_value: ArrayLike, dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., *, + device: _Device | _Sharding | None = ...) -> Array: ... +def full_like(a: ArrayLike | DuckTypedArray, + fill_value: ArrayLike, dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... + device: _Device | _Sharding | None = ...) -> Array: ... def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: ... generic = _np.generic def geomspace( @@ -410,48 +414,48 @@ def geomspace( stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., axis: int = ..., ) -> Array: ... get_printoptions = _np.get_printoptions def gradient(f: ArrayLike, *varargs: ArrayLike, - axis: Optional[Union[int, Sequence[int]]] = ..., - edge_order: Optional[int] = ...) -> Union[Array, list[Array]]: ... + axis: int | Sequence[int] | None = ..., + edge_order: int | None = ...) -> Array | list[Array]: ... def greater(x: ArrayLike, y: ArrayLike, /) -> Array: ... def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... def hamming(M: int) -> Array: ... def hanning(M: int) -> Array: ... def heaviside(x: ArrayLike, y: ArrayLike, /) -> Array: ... def histogram(a: ArrayLike, bins: ArrayLike = ..., - range: Optional[Sequence[ArrayLike]] = ..., - weights: Optional[ArrayLike] = ..., - density: Optional[builtins.bool] = ...) -> tuple[Array, Array]: ... + range: Sequence[ArrayLike] | None = ..., + weights: ArrayLike | None = ..., + density: builtins.bool | None = ...) -> tuple[Array, Array]: ... def histogram2d( x: ArrayLike, y: ArrayLike, - bins: Union[ArrayLike, Sequence[ArrayLike]] = ..., - range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]] = ..., - weights: Optional[ArrayLike] = ..., - density: Optional[builtins.bool] = ..., + bins: ArrayLike | Sequence[ArrayLike] = ..., + range: Sequence[None | Array | Sequence[ArrayLike]] | None = ..., + weights: ArrayLike | None = ..., + density: builtins.bool | None = ..., ) -> tuple[Array, Array, Array]: ... def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = ..., - range: Union[None, Array, Sequence[ArrayLike]] = ..., - weights: Optional[ArrayLike] = ...) -> Array: ... + range: None | Array | Sequence[ArrayLike] = ..., + weights: ArrayLike | None = ...) -> Array: ... def histogramdd( sample: ArrayLike, - bins: Union[ArrayLike, Sequence[ArrayLike]] = ..., - range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]] = ..., - weights: Optional[ArrayLike] = ..., - density: Optional[builtins.bool] = ..., + bins: ArrayLike | Sequence[ArrayLike] = ..., + range: Sequence[None | Array | Sequence[ArrayLike]] | None = ..., + weights: ArrayLike | None = ..., + density: builtins.bool | None = ..., ) -> tuple[Array, list[Array]]: ... def hsplit( - ary: ArrayLike, indices_or_sections: Union[int, ArrayLike] + ary: ArrayLike, indices_or_sections: int | ArrayLike ) -> list[Array]: ... -def hstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]], - dtype: Optional[DTypeLike] = ...) -> Array: ... +def hstack(tup: _np.ndarray | Array | Sequence[ArrayLike], + dtype: DTypeLike | None = ...) -> Array: ... def hypot(x: ArrayLike, y: ArrayLike, /) -> Array: ... def i0(x: ArrayLike) -> Array: ... -def identity(n: DimSize, dtype: Optional[DTypeLike] = ...) -> Array: ... +def identity(n: DimSize, dtype: DTypeLike | None = ...) -> Array: ... iinfo = _dtypes.iinfo def imag(x: ArrayLike, /) -> Array: ... index_exp = _np.index_exp @@ -464,15 +468,15 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, *, sparse: Literal[True]) -> tuple[Array, ...]: ... @overload def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, - sparse: builtins.bool = False) -> Union[Array, tuple[Array, ...]]: ... + sparse: builtins.bool = False) -> Array | tuple[Array, ...]: ... inexact = _np.inexact inf: float def inner( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... -def insert(arr: ArrayLike, obj: Union[ArrayLike, slice], values: ArrayLike, - axis: Optional[int] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... +def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, + axis: int | None = ...) -> Array: ... int16: Any int32: Any int4: Any @@ -481,17 +485,17 @@ int8: Any int_: Any integer = _np.integer def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, - left: Union[ArrayLike, str, None] = ..., - right: Union[ArrayLike, str, None] = ..., - period: Optional[ArrayLike] = ...) -> Array: ... + left: ArrayLike | str | None = ..., + right: ArrayLike | str | None = ..., + period: ArrayLike | None = ...) -> Array: ... def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ..., - return_indices: builtins.bool = ...) -> Union[Array, tuple[Array, Array, Array]]: ... + return_indices: builtins.bool = ...) -> Array | tuple[Array, Array, Array]: ... def invert(x: ArrayLike, /) -> Array: ... def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ..., atol: ArrayLike = ..., equal_nan: builtins.bool = ...) -> Array: ... def iscomplex(m: ArrayLike) -> Array: ... def iscomplexobj(x: Any) -> builtins.bool: ... -def isdtype(dtype: DTypeLike, kind: Union[DType, str, tuple[Union[DType, str], ...]]) -> builtins.bool: ... +def isdtype(dtype: DTypeLike, kind: DType | str | tuple[DType | str, ...]) -> builtins.bool: ... def isfinite(x: ArrayLike, /) -> Array: ... def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: builtins.bool = ..., invert: builtins.bool = ...) -> Array: ... @@ -517,23 +521,23 @@ def lexsort(keys: Sequence[ArrayLike], axis: int = ...) -> Array: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, retstep: Literal[False] = False, - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., axis: int = 0) -> Array: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int, endpoint: builtins.bool, retstep: Literal[True], - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., axis: int = 0) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, *, retstep: Literal[True], - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., axis: int = 0) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, retstep: builtins.bool = False, - dtype: Optional[DTypeLike] = ..., - axis: int = 0) -> Union[Array, tuple[Array, Array]]: ... + dtype: DTypeLike | None = ..., + axis: int = 0) -> Array | tuple[Array, Array]: ... def load(*args: Any, **kwargs: Any) -> Array: ... def log(x: ArrayLike, /) -> Array: ... @@ -548,20 +552,20 @@ def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... def logspace(start: ArrayLike, stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., base: ArrayLike = ..., - dtype: Optional[DTypeLike] = ..., axis: int = ...) -> Array: ... + dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ... def mask_indices( n: int, mask_func: Callable, k: int = ... ) -> tuple[Array, ...]: ... def matmul( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... def matrix_transpose(x: ArrayLike, /) -> Array: ... max = amax def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: ... def mean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ...) -> Array: ... -def median(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ..., + where: ArrayLike | None = ...) -> Array: ... +def median(a: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., keepdims: builtins.bool = ...) -> Array: ... def meshgrid(*xi: ArrayLike, copy: builtins.bool = ..., sparse: builtins.bool = ..., @@ -571,121 +575,121 @@ min = amin def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: ... def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ... -def moveaxis(a: ArrayLike, source: Union[int, Sequence[int]], - destination: Union[int, Sequence[int]]) -> Array: ... +def moveaxis(a: ArrayLike, source: int | Sequence[int], + destination: int | Sequence[int]) -> Array: ... def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: ... nan: float def nan_to_num(x: ArrayLike, copy: builtins.bool = ..., nan: ArrayLike = ..., - posinf: Optional[ArrayLike] = ..., - neginf: Optional[ArrayLike] = ...) -> Array: ... + posinf: ArrayLike | None = ..., + neginf: ArrayLike | None = ...) -> Array: ... def nanargmax( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - keepdims: Optional[builtins.bool] = ..., + keepdims: builtins.bool | None = ..., ) -> Array: ... def nanargmin( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - keepdims: Optional[builtins.bool] = ..., + keepdims: builtins.bool | None = ..., ) -> Array: ... def nancumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def nancumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def nanmax(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def nanmean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., - where: Optional[ArrayLike] = ...) -> Array: ... -def nanmedian(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ..., + where: ArrayLike | None = ...) -> Array: ... +def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., keepdims: builtins.bool = ...) -> Array: ... def nanmin(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def nanpercentile(a: ArrayLike, q: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = ..., + axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... def nanprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... -def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ..., + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... +def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... def nanstd(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + where: ArrayLike | None = ...) -> Array: ... def nansum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., - initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = 0, keepdims: builtins.bool = False, - where: Optional[ArrayLike] = ...) -> Array: ... + where: ArrayLike | None = ...) -> Array: ... ndarray = Array ndim = _np.ndim def negative(x: ArrayLike, /) -> Array: ... newaxis = ... def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def nonzero(a: ArrayLike, *, size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... +def nonzero(a: ArrayLike, *, size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... ) -> tuple[Array, ...]: ... def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... number = _np.number object_ = _np.object_ ogrid: _Ogrid -def ones(shape: Any, dtype: Optional[DTypeLike] = ..., - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... -def ones_like(a: Union[ArrayLike, DuckTypedArray], - dtype: Optional[DTypeLike] = ..., +def ones(shape: Any, dtype: DTypeLike | None = ..., + device: _Device | _Sharding | None = ...) -> Array: ... +def ones_like(a: ArrayLike | DuckTypedArray, + dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... + device: _Device | _Sharding | None = ...) -> Array: ... def outer(a: ArrayLike, b: Array, out: None = ...) -> Array: ... def packbits( - a: ArrayLike, axis: Optional[int] = ..., bitorder: str = ... + a: ArrayLike, axis: int | None = ..., bitorder: str = ... ) -> Array: ... PadValueLike = Union[_T, Sequence[_T], Sequence[Sequence[_T]]] def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | _np.ndarray], - mode: Union[str, Callable[..., Any]] = ..., **kwargs) -> Array: ... + mode: str | Callable[..., Any] = ..., **kwargs) -> Array: ... def partition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ... def percentile(a: ArrayLike, q: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = ..., + axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... def permute_dims(x: ArrayLike, /, axes: tuple[int, ...]) -> Array: ... pi: float -def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]], - funclist: Sequence[Union[ArrayLike, Callable[..., Array]]], +def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], + funclist: Sequence[ArrayLike | Callable[..., Array]], *args, **kw) -> Array: ... def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, inplace: builtins.bool = ...) -> Array: ... -def poly(seq_of_zeros: Array) -> Array: ... -def polyadd(a1: Array, a2: Array) -> Array: ... -def polyder(p: Array, m: int = ...) -> Array: ... +def poly(seq_of_zeros: ArrayLike) -> Array: ... +def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: ... +def polyder(p: ArrayLike, m: int = ...) -> Array: ... def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: builtins.bool = ...) -> tuple[Array, Array]: ... -def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = ..., - full: builtins.bool = ..., w: Optional[Array] = ..., cov: builtins.bool = ... - ) -> Union[Array, tuple[Array, ...]]: ... -def polyint(p: Array, m: int = ..., k: Optional[int] = ...) -> Array: ... +def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = ..., + full: builtins.bool = ..., w: ArrayLike | None = ..., cov: builtins.bool = ... + ) -> Array | tuple[Array, ...]: ... +def polyint(p: ArrayLike, m: int = ..., k: int | ArrayLike | None = ...) -> Array: ... def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: builtins.bool = ...) -> Array: ... -def polysub(a1: Array, a2: Array) -> Array: ... -def polyval(p: Array, x: Array, *, unroll: int = ...) -> Array: ... +def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: ... +def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = ...) -> Array: ... def positive(x: ArrayLike, /) -> Array: ... def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ... def power(x: ArrayLike, y: ArrayLike, /) -> Array: ... printoptions = _np.printoptions def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., - initial: Optional[ArrayLike] = ..., where: Optional[ArrayLike] = ..., + initial: ArrayLike | None = ..., where: ArrayLike | None = ..., promote_integers: builtins.bool = ...) -> Array: ... product = prod promote_types = _np.promote_types @@ -693,7 +697,7 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ...) -> Array: ... def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ... -def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ..., +def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... r_: _RClass @@ -706,19 +710,19 @@ def real(x: ArrayLike, /) -> Array: ... def reciprocal(x: ArrayLike, /) -> Array: ... register_jax_array_methods: Any def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = ..., *, - total_repeat_length: Optional[int] = ...) -> Array: ... +def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *, + total_repeat_length: int | None = ...) -> Array: ... def reshape( - a: ArrayLike, shape: Union[DimSize, Shape] = ..., - newshape: Union[DimSize, Shape] | None = ..., order: str = ... + a: ArrayLike, shape: DimSize | Shape = ..., + newshape: DimSize | Shape | None = ..., order: str = ... ) -> Array: ... def resize(a: ArrayLike, new_shape: Shape) -> Array: ... def result_type(*args: Any) -> DType: ... def right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def rint(x: ArrayLike, /) -> Array: ... -def roll(a: ArrayLike, shift: Union[ArrayLike, Sequence[int]], - axis: Optional[Union[int, Sequence[int]]] = ...) -> Array: ... +def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], + axis: int | Sequence[int] | None = ...) -> Array: ... def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: ... def roots(p: ArrayLike, *, strip_zeros: builtins.bool = ...) -> Array: ... def rot90(m: ArrayLike, k: int = ..., axes: tuple[int, int] = ...) -> Array: ... @@ -728,7 +732,7 @@ s_ = _np.s_ save = _np.save savez = _np.savez def searchsorted(a: ArrayLike, v: ArrayLike, side: str = ..., - sorter: None = ..., *, method: str = ...) -> Array: ... + sorter: ArrayLike | None = ..., *, method: str = ...) -> Array: ... def select( condlist: Sequence[ArrayLike], choicelist: Sequence[ArrayLike], @@ -740,8 +744,8 @@ def setdiff1d( ar2: ArrayLike, assume_unique: builtins.bool = ..., *, - size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ..., + size: int | None = ..., + fill_value: ArrayLike | None = ..., ) -> Array: ... def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ... shape = _np.shape @@ -756,7 +760,7 @@ size = _np.size sometrue = any def sort( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., *, stable: builtins.bool = ..., descending: builtins.bool = ..., @@ -766,24 +770,24 @@ def sort( def sort_complex(a: ArrayLike) -> Array: ... def split( ary: ArrayLike, - indices_or_sections: Union[int, Sequence[int], ArrayLike], + indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = ..., ) -> list[Array]: ... def sqrt(x: ArrayLike, /) -> Array: ... def square(x: ArrayLike, /) -> Array: ... def squeeze( - a: ArrayLike, axis: Optional[Union[int, Sequence[int]]] = ... + a: ArrayLike, axis: int | Sequence[int] | None = ... ) -> Array: ... def stack( - arrays: Union[_np.ndarray, Array, Sequence[ArrayLike]], + arrays: _np.ndarray | Array | Sequence[ArrayLike], axis: int = ..., out: None = ..., - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., ) -> Array: ... def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ..., correction: int | float | None = ...) -> Array: ... + where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ... def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ... def sum( a: ArrayLike, @@ -791,53 +795,53 @@ def sum( dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., - initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ..., + initial: ArrayLike | None = ..., + where: ArrayLike | None = ..., promote_integers: builtins.bool = ..., ) -> Array: ... def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: ... def take( a: ArrayLike, indices: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - mode: Optional[str] = ..., + mode: str | None = ..., unique_indices: builtins.bool = ..., indices_are_sorted: builtins.bool = ..., - fill_value: Optional[StaticScalar] = ..., + fill_value: StaticScalar | None = ..., ) -> Array: ... def take_along_axis( arr: ArrayLike, indices: ArrayLike, - axis: Optional[int], - mode: Optional[Union[str, GatherScatterMode]] = ..., - fill_value: Optional[StaticScalar] = None, + axis: int | None, + mode: str | GatherScatterMode | None = ..., + fill_value: StaticScalar | None = None, ) -> Array: ... def tan(x: ArrayLike, /) -> Array: ... def tanh(x: ArrayLike, /) -> Array: ... def tensordot(a: ArrayLike, b: ArrayLike, - axes: Union[int, Sequence[int], Sequence[Sequence[int]]] = ..., + axes: int | Sequence[int] | Sequence[Sequence[int]] = ..., *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... -def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... +def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: ... def trace(a: ArrayLike, offset: int | ArrayLike = ..., axis1: int = ..., axis2: int = ..., - dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ... -def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ... + dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... +def transpose(a: ArrayLike, axes: Sequence[int] | None = ...) -> Array: ... def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = ..., axis: int = ...) -> Array: ... def tri( - N: int, M: Optional[int] = ..., k: int = ..., dtype: DTypeLike = ... + N: int, M: int | None = ..., k: int = ..., dtype: DTypeLike = ... ) -> Array: ... def tril(m: ArrayLike, k: int = ...) -> Array: ... def tril_indices( - n: int, k: int = ..., m: Optional[int] = ... + n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: builtins.bool = ..., *, inplace: builtins.bool = ...) -> Array: ... def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ... def triu(m: ArrayLike, k: int = ...) -> Array: ... def triu_indices( - n: int, k: int = ..., m: Optional[int] = ... + n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... def triu_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... def true_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... @@ -852,8 +856,8 @@ def union1d( ar1: ArrayLike, ar2: ArrayLike, *, - size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ..., + size: int | None = ..., + fill_value: ArrayLike | None = ..., ) -> Array: ... class _UniqueAllResult(NamedTuple): values: Array @@ -867,71 +871,71 @@ class _UniqueInverseResult(NamedTuple): values: Array inverse_indices: Array def unique(ar: ArrayLike, return_index: builtins.bool = ..., return_inverse: builtins.bool = ..., - return_counts: builtins.bool = ..., axis: Optional[int] = ..., - *, equal_nan: builtins.bool = ..., size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ... + return_counts: builtins.bool = ..., axis: int | None = ..., + *, equal_nan: builtins.bool = ..., size: int | None = ..., + fill_value: ArrayLike | None = ... ): ... -def unique_all(x: ArrayLike, /, *, size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ...) -> _UniqueAllResult: ... -def unique_counts(x: ArrayLike, /, *, size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ...) -> _UniqueCountsResult: ... -def unique_inverse(x: ArrayLike, /, *, size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ...) -> _UniqueInverseResult: ... -def unique_values(x: ArrayLike, /, *, size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ...) -> Array: ... +def unique_all(x: ArrayLike, /, *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> _UniqueAllResult: ... +def unique_counts(x: ArrayLike, /, *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> _UniqueCountsResult: ... +def unique_inverse(x: ArrayLike, /, *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> _UniqueInverseResult: ... +def unique_values(x: ArrayLike, /, *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> Array: ... def unpackbits( a: ArrayLike, - axis: Optional[int] = ..., - count: Optional[ArrayLike] = ..., + axis: int | None = ..., + count: ArrayLike | None = ..., bitorder: str = ..., ) -> Array: ... def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: ... unsignedinteger = _np.unsignedinteger def unstack(x: ArrayLike , /, *, axis: int = ...) -> tuple[Array, ...]: ... -def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = ..., +def unwrap(p: ArrayLike, discont: ArrayLike | None = ..., axis: int = ..., period: ArrayLike = ...) -> Array: ... def vander( - x: ArrayLike, N: Optional[int] = ..., increasing: builtins.bool = ... + x: ArrayLike, N: int | None = ..., increasing: builtins.bool = ... ) -> Array: ... def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ..., correction: int | float | None = ...) -> Array: ... + where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ... def vdot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = ..., precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... def vsplit( - ary: ArrayLike, indices_or_sections: Union[int, ArrayLike] + ary: ArrayLike, indices_or_sections: int | ArrayLike ) -> list[Array]: ... -def vstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]], - dtype: Optional[DTypeLike] = ...) -> Array: ... +def vstack(tup: _np.ndarray | Array | Sequence[ArrayLike], + dtype: DTypeLike | None = ...) -> Array: ... @overload def where(condition: ArrayLike, x: Literal[None] = ..., y: Literal[None] = ..., - /, *, size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... + /, *, size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... ) -> tuple[Array, ...]: ... @overload def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /, *, - size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... + size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... ) -> Array: ... @overload -def where(condition: ArrayLike, x: Optional[ArrayLike] = ..., - y: Optional[ArrayLike] = ..., /, *, size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... - ) -> Union[Array, tuple[Array, ...]]: ... - -def zeros(shape: Any, dtype: Optional[DTypeLike] = ..., - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... -def zeros_like(a: Union[ArrayLike, DuckTypedArray], - dtype: Optional[DTypeLike] = ..., +def where(condition: ArrayLike, x: ArrayLike | None = ..., + y: ArrayLike | None = ..., /, *, size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... + ) -> Array | tuple[Array, ...]: ... + +def zeros(shape: Any, dtype: DTypeLike | None = ..., + device: _Device | _Sharding | None = ...) -> Array: ... +def zeros_like(a: ArrayLike | DuckTypedArray, + dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... + device: _Device | _Sharding | None = ...) -> Array: ... def vectorize(pyfunc, *, excluded = ..., signature = ...) -> Callable: ... diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index 98b9ca3e0694..c342fde0ae6e 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -44,6 +44,7 @@ tensordot as tensordot, tensorinv as tensorinv, tensorsolve as tensorsolve, + trace as trace, vector_norm as vector_norm, vecdot as vecdot, ) diff --git a/jax/scipy/special.py b/jax/scipy/special.py index e32d2169ad5f..e244c3705af3 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -17,10 +17,10 @@ from jax._src.scipy.special import ( bernoulli as bernoulli, + bessel_jn as bessel_jn, + beta as beta, betainc as betainc, betaln as betaln, - beta as beta, - bessel_jn as bessel_jn, digamma as digamma, entr as entr, erf as erf, @@ -31,31 +31,33 @@ expit as expit, expn as expn, factorial as factorial, + gamma as gamma, gammainc as gammainc, gammaincc as gammaincc, gammaln as gammaln, gammasgn as gammasgn, - gamma as gamma, + hyp1f1 as hyp1f1, i0 as i0, i0e as i0e, i1 as i1, i1e as i1e, + kl_div as kl_div, + log_ndtr as log_ndtr, + log_softmax as log_softmax, logit as logit, logsumexp as logsumexp, lpmn as lpmn, lpmn_values as lpmn_values, multigammaln as multigammaln, - log_ndtr as log_ndtr, ndtr as ndtr, ndtri as ndtri, + poch as poch, polygamma as polygamma, + rel_entr as rel_entr, + softmax as softmax, spence as spence, sph_harm as sph_harm, - xlogy as xlogy, xlog1py as xlog1py, + xlogy as xlogy, zeta as zeta, - kl_div as kl_div, - rel_entr as rel_entr, - poch as poch, - hyp1f1 as hyp1f1, ) diff --git a/jax/sharding.py b/jax/sharding.py index 18caa9eb0f57..fe221f90af67 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -17,7 +17,7 @@ from jax._src.sharding import Sharding as Sharding from jax._src.sharding_impls import ( - XLACompatibleSharding as XLACompatibleSharding, + XLACompatibleSharding as _deprecated_XLACompatibleSharding, NamedSharding as NamedSharding, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as PmapSharding, @@ -28,3 +28,23 @@ PartitionSpec as PartitionSpec, ) from jax._src.interpreters.pxla import Mesh as Mesh + +_deprecations = { + # Added Jun 4, 2024. + "XLACompatibleSharding": ( + ( + "jax.sharding.XLACompatibleSharding is deprecated. Use" + " jax.sharding.Sharding instead." + ), + _deprecated_XLACompatibleSharding, + ) +} + +import typing +if typing.TYPE_CHECKING: + XLACompatibleSharding = _deprecated_XLACompatibleSharding +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing diff --git a/jax/stages.py b/jax/stages.py index 0a6e6082f2ea..6ffc3144c3bc 100644 --- a/jax/stages.py +++ b/jax/stages.py @@ -30,4 +30,6 @@ Lowered as Lowered, Wrapped as Wrapped, ArgInfo as ArgInfo, + OutInfo as OutInfo, + Traced as Traced, ) diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py index d3a84c384932..84cc697d1894 100644 --- a/jax/tools/build_utils.py +++ b/jax/tools/build_utils.py @@ -107,3 +107,13 @@ def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str): ) with open(src_file, "w") as f: f.write(content) + +def update_setup_with_rocm_version(file_dir: pathlib.Path, rocm_version: str): + src_file = file_dir / "setup.py" + with open(src_file) as f: + content = f.read() + content = content.replace( + "rocm_version = 0 # placeholder", f"rocm_version = {rocm_version}" + ) + with open(src_file, "w") as f: + f.write(content) diff --git a/jax/tools/jax_to_ir.py b/jax/tools/jax_to_ir.py index 220ac0ea8e8b..f3e6fc571594 100644 --- a/jax/tools/jax_to_ir.py +++ b/jax/tools/jax_to_ir.py @@ -151,7 +151,7 @@ def ordered_wrapper(*args): return fn_curried(**dict(zip(arg_names, args))) if format == 'HLO': - comp = jax.xla_computation(ordered_wrapper)(*args) + comp = jax.jit(ordered_wrapper).lower(*args).compiler_ir('hlo') serialized_proto = comp.as_serialized_hlo_module_proto() debug_txt = comp.as_hlo_text() else: diff --git a/jax/version.py b/jax/version.py index de7fed082a13..f3d007eec9b1 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.4.29" +_version = "0.4.31" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None @@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.4.27" +_minimum_jaxlib_version = "0.4.30" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/jax_plugins/BUILD.bazel b/jax_plugins/BUILD.bazel index 2102c6404c5a..6e2cf6aadbaf 100644 --- a/jax_plugins/BUILD.bazel +++ b/jax_plugins/BUILD.bazel @@ -17,6 +17,7 @@ licenses(["notice"]) load( "//jaxlib:jax.bzl", "if_cuda_is_configured", + "if_rocm_is_configured", "py_library_providing_imports_info", ) @@ -30,5 +31,7 @@ py_library( ":jax_plugins", ] + if_cuda_is_configured([ "//jax_plugins/cuda:cuda_plugin", + ]) + if_rocm_is_configured([ + "//jax_plugins/rocm:rocm_plugin", ]), -) \ No newline at end of file +) diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index 77454fb488a3..cd26731aa629 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -49,13 +49,33 @@ def has_ext_modules(self): author="JAX team", author_email="jax-dev@google.com", packages=[package_name], - python_requires=">=3.9", + python_requires=">=3.10", install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"], + extras_require={ + 'with_cuda': [ + "nvidia-cublas-cu12>=12.1.3.1", + "nvidia-cuda-cupti-cu12>=12.1.105", + "nvidia-cuda-nvcc-cu12>=12.1.105", + "nvidia-cuda-runtime-cu12>=12.1.105", + "nvidia-cudnn-cu12>=9.0,<10.0", + "nvidia-cufft-cu12>=11.0.2.54", + "nvidia-cusolver-cu12>=11.4.5.107", + "nvidia-cusparse-cu12>=12.1.0.106", + "nvidia-nccl-cu12>=2.18.1", + # nvjitlink is not a direct dependency of JAX, but it is a transitive + # dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages + # do not have a version constraint on their dependencies, so the + # package doesn't get upgraded even though not doing that can cause + # problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196) + # Until NVIDIA add version constraints, add a version constraint + # here. + "nvidia-nvjitlink-cu12>=12.1.105", + ], + }, url="https://github.com/google/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", diff --git a/jax_plugins/rocm/BUILD.bazel b/jax_plugins/rocm/BUILD.bazel new file mode 100644 index 000000000000..08a61c786262 --- /dev/null +++ b/jax_plugins/rocm/BUILD.bazel @@ -0,0 +1,55 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +load("//jaxlib:symlink_files.bzl", "symlink_files") +load( + "//jaxlib:jax.bzl", + "if_windows", + "py_library_providing_imports_info", + "pytype_library", +) + +package( + default_applicable_licenses = [], + default_visibility = ["//:__subpackages__"], +) + +exports_files([ + "__init__.py", + "plugin_pyproject.toml", + "plugin_setup.py", + "pyproject.toml", + "setup.py", +]) + +symlink_files( + name = "pjrt_c_api_gpu_plugin", + srcs = if_windows( + ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], + ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], + ), + dst = ".", + flatten = True, +) + +py_library_providing_imports_info( + name = "rocm_plugin", + srcs = [ + "__init__.py", + ], + data = [":pjrt_c_api_gpu_plugin"], + lib_rule = pytype_library, +) diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py new file mode 100644 index 000000000000..4535f1b3bbc8 --- /dev/null +++ b/jax_plugins/rocm/__init__.py @@ -0,0 +1,91 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import importlib +import logging +import pathlib +import platform + +from jax._src.lib import xla_client +import jax._src.xla_bridge as xb + +# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without +# preinstalled jax rocm plugin packages. +for pkg_name in ['jax_rocm60_plugin', 'jaxlib']: + try: + rocm_plugin_extension = importlib.import_module( + f'{pkg_name}.rocm_plugin_extension' + ) + except ImportError: + rocm_plugin_extension = None + else: + break + +logger = logging.getLogger(__name__) + + +def _get_library_path(): + base_path = pathlib.Path(__file__).resolve().parent + installed_path = ( + base_path / 'xla_rocm_plugin.so' + ) + if installed_path.exists(): + return installed_path + + local_path = ( + base_path / 'pjrt_c_api_gpu_plugin.so' + ) + if local_path.exists(): + logger.debug( + 'Native library %s does not exist. This most likely indicates an issue' + ' with how %s was built or installed. Fallback to local test' + ' library %s', + installed_path, + __package__, + local_path, + ) + return local_path + + logger.debug( + 'WARNING: Native library %s and local test library path %s do not' + ' exist. This most likely indicates an issue with how %s was built or' + ' installed or missing src files.', + installed_path, + local_path, + __package__, + ) + return None + + +def initialize(): + path = _get_library_path() + if path is None: + return + options = xla_client.generate_pjrt_gpu_plugin_options() + options["platform_name"] = "ROCM" + c_api = xb.register_plugin( + 'rocm', priority=500, library_path=str(path), options=options + ) + if rocm_plugin_extension: + xla_client.register_custom_call_handler( + "ROCM", + functools.partial( + rocm_plugin_extension.register_custom_call_target, c_api + ), + ) + for _name, _value in rocm_plugin_extension.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM") + else: + logger.warning('rocm_plugin_extension is not found.') diff --git a/jax_plugins/rocm/plugin_pyproject.toml b/jax_plugins/rocm/plugin_pyproject.toml new file mode 100644 index 000000000000..8fe2f47af9a1 --- /dev/null +++ b/jax_plugins/rocm/plugin_pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py new file mode 100644 index 000000000000..9ccf3bf44339 --- /dev/null +++ b/jax_plugins/rocm/plugin_setup.py @@ -0,0 +1,70 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from setuptools import setup +from setuptools.dist import Distribution + +__version__ = None +rocm_version = 0 # placeholder +project_name = f"jax-rocm{rocm_version}-plugin" +package_name = f"jax_rocm{rocm_version}_plugin" + +def load_version_module(pkg_path): + spec = importlib.util.spec_from_file_location( + 'version', os.path.join(pkg_path, 'version.py')) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +_version_module = load_version_module(package_name) +__version__ = _version_module._get_version_for_build() +_cmdclass = _version_module._get_cmdclass(package_name) + +class BinaryDistribution(Distribution): + """This class makes 'bdist_wheel' include an ABI tag on the wheel.""" + + def has_ext_modules(self): + return True + +setup( + name=project_name, + version=__version__, + cmdclass=_cmdclass, + description="JAX Plugin for AMD GPUs", + long_description="", + long_description_content_type="text/markdown", + author="Ruturaj4", + author_email="Ruturaj.Vaidya@amd.com", + packages=[package_name], + python_requires=">=3.9", + install_requires=[f"jax-rocm{rocm_version}-pjrt=={__version__}"], + url="https://github.com/google/jax", + license="Apache-2.0", + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + ], + package_data={ + package_name: [ + "*", + ], + }, + zip_safe=False, + distclass=BinaryDistribution, +) diff --git a/jax_plugins/rocm/pyproject.toml b/jax_plugins/rocm/pyproject.toml new file mode 100644 index 000000000000..8fe2f47af9a1 --- /dev/null +++ b/jax_plugins/rocm/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/jax_plugins/rocm/setup.py b/jax_plugins/rocm/setup.py new file mode 100644 index 000000000000..8782676ce9a2 --- /dev/null +++ b/jax_plugins/rocm/setup.py @@ -0,0 +1,66 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from setuptools import setup, find_namespace_packages + +__version__ = None +rocm_version = 0 # placeholder +project_name = f"jax-rocm{rocm_version}-pjrt" +package_name = f"jax_plugins.xla_rocm{rocm_version}" + +def load_version_module(pkg_path): + spec = importlib.util.spec_from_file_location( + 'version', os.path.join(pkg_path, 'version.py')) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +_version_module = load_version_module(f"jax_plugins/xla_rocm{rocm_version}") +__version__ = _version_module._get_version_for_build() + +packages = find_namespace_packages( + include=[ + package_name, + f"{package_name}.*", + ] +) + +setup( + name=project_name, + version=__version__, + description="JAX XLA PJRT Plugin for AMD GPUs", + long_description="", + long_description_content_type="text/markdown", + author="Ruturaj4", + author_email="Ruturaj.Vaidya@amd.com", + packages=packages, + install_requires=[], + url="https://github.com/google/jax", + license="Apache-2.0", + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + ], + package_data={ + package_name: ["xla_rocm_plugin.so"], + }, + zip_safe=False, + entry_points={ + "jax_plugins": [ + f"xla_rocm{rocm_version} = {package_name}", + ], + }, +) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index a83fd625ced3..dc8b5148ca93 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -17,7 +17,6 @@ load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", - "if_building_mosaic_gpu", "py_library_providing_imports_info", "pybind_extension", "pytype_library", @@ -40,7 +39,6 @@ genrule( py_library_providing_imports_info( name = "jaxlib", srcs = [ - "ducc_fft.py", "gpu_common_utils.py", "gpu_linalg.py", "gpu_prng.py", @@ -55,21 +53,25 @@ py_library_providing_imports_info( ":xla_client", ":xla_extension_py", ], + data = [":ffi_headers"], lib_rule = pytype_library, deps = [ ":cpu_feature_guard", ":utils", - "//jaxlib/cpu:_ducc_fft", "//jaxlib/cpu:_lapack", "//jaxlib/mlir", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", "//jaxlib/mlir:chlo_dialect", "//jaxlib/mlir:func_dialect", + "//jaxlib/mlir:gpu_dialect", "//jaxlib/mlir:ir", + "//jaxlib/mlir:llvm_dialect", "//jaxlib/mlir:math_dialect", "//jaxlib/mlir:memref_dialect", "//jaxlib/mlir:mhlo_dialect", + "//jaxlib/mlir:nvgpu_dialect", + "//jaxlib/mlir:nvvm_dialect", "//jaxlib/mlir:pass_manager", "//jaxlib/mlir:scf_dialect", "//jaxlib/mlir:sparse_tensor_dialect", @@ -78,7 +80,7 @@ py_library_providing_imports_info( "//jaxlib/mosaic", "//jaxlib/triton", "@xla//xla/python:xla_extension", - ] + if_building_mosaic_gpu(["//jaxlib/mosaic/gpu:mosaic_gpu"]), + ], ) symlink_files( @@ -95,6 +97,13 @@ symlink_files( flatten = True, ) +symlink_files( + name = "ffi_headers", + srcs = ["@xla//xla/ffi/api:all_headers"], + dst = "include/xla/ffi/api", + flatten = True, +) + exports_files([ "README.md", "setup.py", @@ -191,7 +200,6 @@ pybind_extension( "@nanobind", "//jaxlib:kernel_nanobind_helpers", "@xla//third_party/python_runtime:headers", - "@xla//xla:status", "@local_config_cuda//cuda:cuda_headers", "@xla//xla:util", "@xla//xla/ffi/api:c_api", @@ -207,6 +215,28 @@ pybind_extension( ], ) +pybind_extension( + name = "rocm_plugin_extension", + srcs = ["rocm_plugin_extension.cc"], + module_name = "rocm_plugin_extension", + deps = [ + "//jaxlib:kernel_nanobind_helpers", + "@xla//third_party/python_runtime:headers", + "@xla//xla:status", + "@xla//xla:util", + "@xla//xla/ffi/api:c_api", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/python:py_client_gpu", + "@xla//xla/tsl/python/lib/core:numpy", + "@com_google_absl//absl/status", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + ], +) + # CPU kernels # TODO(phawkins): Remove this forwarding target. diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index 279cf5c2aab5..bb406ffd3adc 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -16,7 +16,6 @@ load( "//jaxlib:jax.bzl", - "flatbuffer_cc_library", "pybind_extension", ) @@ -36,8 +35,14 @@ cc_library( copts = ["-fexceptions"], features = ["-use_header_modules"], deps = [ + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", ], ) @@ -51,6 +56,7 @@ cc_library( pybind_extension( name = "_lapack", srcs = ["lapack.cc"], + hdrs = ["lapack.h"], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -64,49 +70,7 @@ pybind_extension( deps = [ ":lapack_kernels", "//jaxlib:kernel_nanobind_helpers", - "@nanobind", - ], -) - -# DUCC (CPU FFTs) - -flatbuffer_cc_library( - name = "ducc_fft_flatbuffers_cc", - srcs = ["ducc_fft.fbs"], -) - -cc_library( - name = "ducc_fft_kernels", - srcs = ["ducc_fft_kernels.cc"], - hdrs = ["ducc_fft_kernels.h"], - copts = ["-fexceptions"], # DUCC may throw. - features = ["-use_header_modules"], - deps = [ - ":ducc_fft_flatbuffers_cc", - "@xla//xla/service:custom_call_status", - "@com_github_google_flatbuffers//:flatbuffers", - "@ducc//:fft", - ], -) - -pybind_extension( - name = "_ducc_fft", - srcs = ["ducc_fft.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - enable_stub_generation = True, - features = ["-use_header_modules"], - module_name = "_ducc_fft", - pytype_srcs = [ - "_ducc_fft.pyi", - ], - deps = [ - ":ducc_fft_flatbuffers_cc", - ":ducc_fft_kernels", - "//jaxlib:kernel_nanobind_helpers", - "@com_github_google_flatbuffers//:flatbuffers", + "@xla//xla/ffi/api:ffi", "@nanobind", ], ) @@ -114,11 +78,13 @@ pybind_extension( cc_library( name = "cpu_kernels", srcs = ["cpu_kernels.cc"], + hdrs = ["lapack.h"], visibility = ["//visibility:public"], deps = [ - ":ducc_fft_kernels", ":lapack_kernels", ":lapack_kernels_using_lapack", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_target_registry", ], alwayslink = 1, diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index b7551cc61f91..0cb9e7cb3328 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -15,14 +15,24 @@ limitations under the License. // This file is not used by JAX itself, but exists to assist with running // JAX-generated HLO code from outside of JAX. + #include -#include "jaxlib/cpu/ducc_fft_kernels.h" + +#include "jaxlib/cpu/lapack.h" #include "jaxlib/cpu/lapack_kernels.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_target_registry.h" +#define JAX_CPU_REGISTER_HANDLER(name) \ + XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), #name, "Host", name); + namespace jax { namespace { +// Old-style kernels +// TODO(b/344892332): To be removed after the 6M compatibility period is over. + XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_strsm", Trsm::Kernel, "Host"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_dtrsm", Trsm::Kernel, @@ -105,8 +115,19 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( "lapack_cgees", ComplexGees>::Kernel, "Host"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( "lapack_zgees", ComplexGees>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "dynamic_ducc_fft", DynamicDuccFft, "Host"); + +// FFI Kernels + +JAX_CPU_REGISTER_HANDLER(lapack_sgetrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dgetrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cgetrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zgetrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_spotrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dpotrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cpotrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zpotrf_ffi); + +#undef JAX_CPU_REGISTER_HANDLER } // namespace } // namespace jax diff --git a/jaxlib/cpu/ducc_fft.cc b/jaxlib/cpu/ducc_fft.cc deleted file mode 100644 index a8f0490ac31d..000000000000 --- a/jaxlib/cpu/ducc_fft.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2020 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "nanobind/nanobind.h" -#include "nanobind/stl/vector.h" -#include "jaxlib/cpu/ducc_fft_generated.h" -#include "jaxlib/cpu/ducc_fft_kernels.h" -#include "jaxlib/kernel_nanobind_helpers.h" - -namespace nb = nanobind; - -namespace jax { -namespace { - - -nb::bytes BuildDynamicDuccFftDescriptor( - const uint32_t ndims, - bool is_double, int fft_type, - const std::vector &axes, - bool forward) { - DynamicDuccFftDescriptorT descriptor; - descriptor.ndims = ndims; - descriptor.fft_type = static_cast(fft_type); - descriptor.dtype = - is_double ? DuccFftDtype_COMPLEX128 : DuccFftDtype_COMPLEX64; - descriptor.axes = axes; - descriptor.forward = forward; - flatbuffers::FlatBufferBuilder fbb; - fbb.Finish(DynamicDuccFftDescriptor::Pack(fbb, &descriptor)); - return nb::bytes(reinterpret_cast(fbb.GetBufferPointer()), - fbb.GetSize()); -} - -nb::dict Registrations() { - nb::dict dict; - // TODO(b/311175955): this must be kept until May 2024 for backwards - // of serialized functions using fft. - dict["dynamic_ducc_fft"] = EncapsulateFunction(DynamicDuccFft); - return dict; -} - -NB_MODULE(_ducc_fft, m) { - m.def("registrations", &Registrations); - m.def("dynamic_ducc_fft_descriptor", &BuildDynamicDuccFftDescriptor, - nb::arg("ndims"), nb::arg("is_double"), nb::arg("fft_type"), - nb::arg("axes"), nb::arg("forward")); -} - -} // namespace -} // namespace jax diff --git a/jaxlib/cpu/ducc_fft_kernels.cc b/jaxlib/cpu/ducc_fft_kernels.cc deleted file mode 100644 index 12f8327b6d03..000000000000 --- a/jaxlib/cpu/ducc_fft_kernels.cc +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2020 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "ducc/src/ducc0/fft/fft.h" -#include "ducc/src/ducc0/fft/fft1d_impl.h" // NOLINT: required for fft definitions. -#include "ducc/src/ducc0/fft/fftnd_impl.h" // NOLINT: required for fft definitions. -#include "flatbuffers/flatbuffers.h" -#include "jaxlib/cpu/ducc_fft_generated.h" -#include "xla/service/custom_call_status.h" - -namespace jax { - -using shape_t = ducc0::fmav_info::shape_t; -using stride_t = ducc0::fmav_info::stride_t; - -namespace { - -void DuccFftImpl(void *out, void *operand, jax::DuccFftDtype dtype, - jax::DuccFftType fft_type, - shape_t shape, stride_t strides_in, stride_t strides_out, shape_t axes, - bool forward, double scale) { - - switch (fft_type) { - case DuccFftType_C2C: - if (dtype == DuccFftDtype_COMPLEX64) { - ducc0::cfmav> m_in( - reinterpret_cast *>(operand), shape, strides_in); - ducc0::vfmav> m_out( - reinterpret_cast *>(out), shape, strides_out); - ducc0::c2c(m_in, m_out, axes, forward, static_cast(scale)); - } else { - ducc0::cfmav> m_in( - reinterpret_cast *>(operand), shape, strides_in); - ducc0::vfmav> m_out( - reinterpret_cast *>(out), shape, strides_out); - ducc0::c2c(m_in, m_out, axes, forward, scale); - } - break; - case DuccFftType_C2R: - if (dtype == DuccFftDtype_COMPLEX64) { - auto shape_in = shape; - shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1; - ducc0::cfmav> m_in( - reinterpret_cast *>(operand), - shape_in, strides_in); - ducc0::vfmav m_out(reinterpret_cast(out), shape, - strides_out); - ducc0::c2r(m_in, m_out, axes, forward, static_cast(scale)); - } else { - auto shape_in = shape; - shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1; - ducc0::cfmav> m_in( - reinterpret_cast *>(operand), - shape_in, strides_in); - ducc0::vfmav m_out(reinterpret_cast(out), shape, - strides_out); - ducc0::c2r(m_in, m_out, axes, forward, scale); - } - break; - case DuccFftType_R2C: - if (dtype == DuccFftDtype_COMPLEX64) { - auto shape_out = shape; - shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1; - ducc0::cfmav m_in(reinterpret_cast(operand), shape, - strides_in); - ducc0::vfmav> m_out( - reinterpret_cast *>(out), - shape_out, strides_out); - ducc0::r2c(m_in, m_out, axes, forward, static_cast(scale)); - } else { - auto shape_out = shape; - shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1; - ducc0::cfmav m_in(reinterpret_cast(operand), shape, - strides_in); - ducc0::vfmav> m_out( - reinterpret_cast *>(out), - shape_out, strides_out); - ducc0::r2c(m_in, m_out, axes, forward, scale); - } - break; - } -} - -} // namespace - - -// TODO(b/311175955): this must be kept until May 2024 for backwards -// of serialized functions using fft. -void DynamicDuccFft(void *out, void **in, XlaCustomCallStatus *) { - // in[0]=descriptor, in[1]=operand, - // in[2]=shape, in[3]=strides_in, in[4]=strides_out, in[5]=scale. - const DynamicDuccFftDescriptor *descriptor = - flatbuffers::GetRoot(in[0]); - const std::uint32_t *dynamic_shape = - reinterpret_cast(in[2]); - shape_t shape(dynamic_shape, dynamic_shape + descriptor->ndims()); - const std::uint32_t *dynamic_strides_in = - reinterpret_cast(in[3]); - stride_t strides_in(dynamic_strides_in, - dynamic_strides_in + descriptor->ndims()); - const std::uint32_t *dynamic_strides_out = - reinterpret_cast(in[4]); - stride_t strides_out(dynamic_strides_out, - dynamic_strides_out + descriptor->ndims()); - shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end()); - const double *dynamic_scale = reinterpret_cast(in[5]); - - DuccFftImpl(out, in[1], descriptor->dtype(), descriptor->fft_type(), - shape, strides_in, strides_out, axes, - descriptor->forward(), *dynamic_scale); -} - -} // namespace jax diff --git a/jaxlib/cpu/ducc_fft_kernels.h b/jaxlib/cpu/ducc_fft_kernels.h deleted file mode 100644 index 13d0b1d4022e..000000000000 --- a/jaxlib/cpu/ducc_fft_kernels.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2020 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_CPU_DUCC_FFT_KERNELS_H_ -#define JAXLIB_CPU_DUCC_FFT_KERNELS_H_ - -#include "xla/service/custom_call_status.h" - -namespace jax { - - -// TODO(b/311175955): this must be kept until May 2024 for backwards -// of serialized functions using fft. -void DynamicDuccFft(void* out, void** in, XlaCustomCallStatus*); - -} // namespace jax - -#endif // JAXLIB_CPU_DUCC_FFT_KERNELS_H_ diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index ddf605fdd0bd..d01efa7f7864 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "jaxlib/cpu/lapack.h" + #include #include "nanobind/nanobind.h" @@ -24,6 +26,8 @@ namespace { namespace nb = nanobind; +using ::xla::ffi::DataType; + void GetLapackKernelsFromScipy() { static bool initialized = false; // Protected by GIL if (initialized) return; @@ -35,12 +39,11 @@ void GetLapackKernelsFromScipy() { auto blas_ptr = [&](const char* name) { return nb::cast(blas_capi[name]).data(); }; - Trsm::fn = reinterpret_cast::FnType*>(blas_ptr("strsm")); - Trsm::fn = reinterpret_cast::FnType*>(blas_ptr("dtrsm")); - Trsm>::fn = - reinterpret_cast>::FnType*>(blas_ptr("ctrsm")); - Trsm>::fn = - reinterpret_cast>::FnType*>(blas_ptr("ztrsm")); + + AssignKernelFn>(blas_ptr("strsm")); + AssignKernelFn>(blas_ptr("dtrsm")); + AssignKernelFn>>(blas_ptr("ctrsm")); + AssignKernelFn>>(blas_ptr("ztrsm")); nb::module_ cython_lapack = nb::module_::import_("scipy.linalg.cython_lapack"); @@ -48,106 +51,63 @@ void GetLapackKernelsFromScipy() { auto lapack_ptr = [&](const char* name) { return nb::cast(lapack_capi[name]).data(); }; - Getrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgetrf")); - Getrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgetrf")); - Getrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgetrf")); - Getrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgetrf")); - Geqrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgeqrf")); - Geqrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgeqrf")); - Geqrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgeqrf")); - Geqrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgeqrf")); - Orgqr::fn = - reinterpret_cast::FnType*>(lapack_ptr("sorgqr")); - Orgqr::fn = - reinterpret_cast::FnType*>(lapack_ptr("dorgqr")); - Orgqr>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cungqr")); - Orgqr>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zungqr")); - Potrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("spotrf")); - Potrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("dpotrf")); - Potrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cpotrf")); - Potrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zpotrf")); - RealGesdd::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgesdd")); - RealGesdd::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgesdd")); - ComplexGesdd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgesdd")); - ComplexGesdd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgesdd")); - RealSyevd::fn = - reinterpret_cast::FnType*>(lapack_ptr("ssyevd")); - RealSyevd::fn = - reinterpret_cast::FnType*>(lapack_ptr("dsyevd")); - ComplexHeevd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cheevd")); - ComplexHeevd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zheevd")); - RealGeev::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgeev")); - RealGeev::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgeev")); - ComplexGeev>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgeev")); - ComplexGeev>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgeev")); - RealGees::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgees")); - RealGees::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgees")); - ComplexGees>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgees")); - ComplexGees>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgees")); - Gehrd::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgehrd")); - Gehrd::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgehrd")); - Gehrd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgehrd")); - Gehrd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgehrd")); - Sytrd::fn = - reinterpret_cast::FnType*>(lapack_ptr("ssytrd")); - Sytrd::fn = - reinterpret_cast::FnType*>(lapack_ptr("dsytrd")); - Sytrd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("chetrd")); - Sytrd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zhetrd")); + AssignKernelFn>(lapack_ptr("sgetrf")); + AssignKernelFn>(lapack_ptr("dgetrf")); + AssignKernelFn>>(lapack_ptr("cgetrf")); + AssignKernelFn>>(lapack_ptr("zgetrf")); + AssignKernelFn>(lapack_ptr("sgetrf")); + AssignKernelFn>(lapack_ptr("dgetrf")); + AssignKernelFn>(lapack_ptr("cgetrf")); + AssignKernelFn>(lapack_ptr("zgetrf")); + + AssignKernelFn>(lapack_ptr("sgeqrf")); + AssignKernelFn>(lapack_ptr("dgeqrf")); + AssignKernelFn>>(lapack_ptr("cgeqrf")); + AssignKernelFn>>(lapack_ptr("zgeqrf")); + + AssignKernelFn>(lapack_ptr("sorgqr")); + AssignKernelFn>(lapack_ptr("dorgqr")); + AssignKernelFn>>(lapack_ptr("cungqr")); + AssignKernelFn>>(lapack_ptr("zungqr")); + + AssignKernelFn>(lapack_ptr("spotrf")); + AssignKernelFn>(lapack_ptr("dpotrf")); + AssignKernelFn>>(lapack_ptr("cpotrf")); + AssignKernelFn>>(lapack_ptr("zpotrf")); + AssignKernelFn>(lapack_ptr("spotrf")); + AssignKernelFn>(lapack_ptr("dpotrf")); + AssignKernelFn>(lapack_ptr("cpotrf")); + AssignKernelFn>(lapack_ptr("zpotrf")); + + AssignKernelFn>(lapack_ptr("sgesdd")); + AssignKernelFn>(lapack_ptr("dgesdd")); + AssignKernelFn>>(lapack_ptr("cgesdd")); + AssignKernelFn>>(lapack_ptr("zgesdd")); + + AssignKernelFn>(lapack_ptr("ssyevd")); + AssignKernelFn>(lapack_ptr("dsyevd")); + AssignKernelFn>>(lapack_ptr("cheevd")); + AssignKernelFn>>(lapack_ptr("zheevd")); + + AssignKernelFn>(lapack_ptr("sgeev")); + AssignKernelFn>(lapack_ptr("dgeev")); + AssignKernelFn>>(lapack_ptr("cgeev")); + AssignKernelFn>>(lapack_ptr("zgeev")); + + AssignKernelFn>(lapack_ptr("sgees")); + AssignKernelFn>(lapack_ptr("dgees")); + AssignKernelFn>>(lapack_ptr("cgees")); + AssignKernelFn>>(lapack_ptr("zgees")); + + AssignKernelFn>(lapack_ptr("sgehrd")); + AssignKernelFn>(lapack_ptr("dgehrd")); + AssignKernelFn>>(lapack_ptr("cgehrd")); + AssignKernelFn>>(lapack_ptr("zgehrd")); + + AssignKernelFn>(lapack_ptr("ssytrd")); + AssignKernelFn>(lapack_ptr("dsytrd")); + AssignKernelFn>>(lapack_ptr("chetrd")); + AssignKernelFn>>(lapack_ptr("zhetrd")); initialized = true; } @@ -222,14 +182,24 @@ nb::dict Registrations() { dict["lapack_zhetrd"] = EncapsulateFunction(Sytrd>::Kernel); + dict["lapack_sgetrf_ffi"] = EncapsulateFunction(lapack_sgetrf_ffi); + dict["lapack_dgetrf_ffi"] = EncapsulateFunction(lapack_dgetrf_ffi); + dict["lapack_cgetrf_ffi"] = EncapsulateFunction(lapack_cgetrf_ffi); + dict["lapack_zgetrf_ffi"] = EncapsulateFunction(lapack_zgetrf_ffi); + dict["lapack_spotrf_ffi"] = EncapsulateFunction(lapack_spotrf_ffi); + dict["lapack_dpotrf_ffi"] = EncapsulateFunction(lapack_dpotrf_ffi); + dict["lapack_cpotrf_ffi"] = EncapsulateFunction(lapack_cpotrf_ffi); + dict["lapack_zpotrf_ffi"] = EncapsulateFunction(lapack_zpotrf_ffi); + return dict; } NB_MODULE(_lapack, m) { // Populates the LAPACK kernels from scipy on first call. m.def("initialize", GetLapackKernelsFromScipy); - m.def("registrations", &Registrations); + + // Old-style LAPACK Workspace Size Queries m.def("lapack_sgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), nb::arg("n")); m.def("lapack_dgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), diff --git a/jaxlib/cpu/lapack.h b/jaxlib/cpu/lapack.h new file mode 100644 index 000000000000..b00440616f19 --- /dev/null +++ b/jaxlib/cpu/lapack.h @@ -0,0 +1,61 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CPU_LAPACK_H_ +#define JAXLIB_CPU_LAPACK_H_ + +#include "jaxlib/cpu/lapack_kernels.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { + +// FFI Definition Macros (by DataType) + +#define JAX_CPU_DEFINE_GETRF(name, data_type) \ + XLA_FFI_DEFINE_HANDLER( \ + name, LuDecomposition::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*ipiv*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) + +#define JAX_CPU_DEFINE_POTRF(name, data_type) \ + XLA_FFI_DEFINE_HANDLER( \ + name, CholeskyFactorization::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("uplo") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) + +// FFI Handlers + +JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_GETRF(lapack_cgetrf_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_GETRF(lapack_zgetrf_ffi, ::xla::ffi::DataType::C128); + +JAX_CPU_DEFINE_POTRF(lapack_spotrf_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_POTRF(lapack_dpotrf_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_POTRF(lapack_cpotrf_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_POTRF(lapack_zpotrf_ffi, ::xla::ffi::DataType::C128); + +#undef JAX_CPU_DEFINE_GETRF +#undef JAX_CPU_DEFINE_POTRF + +} // namespace jax + +#endif // JAXLIB_CPU_LAPACK_H_ diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 00b54bab0822..85c4cc44b065 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -15,32 +15,133 @@ limitations under the License. #include "jaxlib/cpu/lapack_kernels.h" +#include #include +#include +#include #include -#include +#include #include +#include +#include +#include +#include +#include +#include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +static_assert(sizeof(jax::lapack_int) == sizeof(int32_t), + "Expected LAPACK integers to be 32-bit"); + +namespace ffi = xla::ffi; + +// TODO(danfm): These macros and the casting functions should be moved to a +// separate header for use in other FFI kernels. +#define ASSIGN_OR_RETURN_FFI_ERROR(lhs, rhs) \ + if (!rhs.ok()) { \ + return ffi::Error(static_cast(rhs.status().code()), \ + std::string(rhs.status().message())); \ + } \ + lhs = rhs.value() + +#define RETURN_IF_FFI_ERROR(...) \ + do { \ + ffi::Error err = (__VA_ARGS__); \ + if (err.failure()) { \ + return err; \ + } \ + } while (0) namespace { -inline int64_t catch_lapack_int_overflow(const std::string& source, int64_t value) { - if constexpr (sizeof(jax::lapack_int) == sizeof(int64_t)) { +template +inline absl::StatusOr MaybeCastNoOverflow( + int64_t value, const std::string& source = __FILE__) { + if constexpr (sizeof(T) == sizeof(int64_t)) { return value; } else { - if (value > std::numeric_limits::max()) { - throw std::overflow_error(source + "(=" + std::to_string(value) + ") exceeds maximum value of jax::lapack_int"); + if (value > std::numeric_limits::max()) [[unlikely]] { + return absl::InvalidArgumentError( + absl::StrFormat("%s: Value (=%d) exceeds the maximum representable " + "value of the desired type", + source, value)); } - return value; + return static_cast(value); + } +} + +template +inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) { + auto result = MaybeCastNoOverflow(value, source); + if (!result.ok()) { + throw std::overflow_error{std::string(result.status().message())}; } + return result.value(); } +template +ffi::Error CheckMatrixDimensions(ffi::Span dims) { + if (dims.size() < 2) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "Matrix must have at least 2 dimensions"); + } + return ffi::Error::Success(); } +template +std::tuple SplitBatch2D(ffi::Span dims) { + auto matrix_dims = dims.last(2); + return std::make_tuple(absl::c_accumulate(dims.first(dims.size() - 2), 1, + std::multiplies()), + matrix_dims.front(), matrix_dims.back()); +} + +template +void CopyIfDiffBuffer(ffi::Buffer x, ffi::ResultBuffer x_out) { + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions); + if (x.data != x_out->data) { + const auto x_size = batch_count * x_rows * x_cols; + std::copy_n(x.data, x_size, x_out->data); + } +} + +} // namespace + +#define REGISTER_CHAR_ENUM_ATTR_DECODING(type) \ + std::optional xla::ffi::AttrDecoding::Decode( \ + XLA_FFI_AttrType attr_type, void* attr, DiagnosticEngine& diagnostic) { \ + if (attr_type != XLA_FFI_AttrType_SCALAR) [[unlikely]] { \ + return diagnostic.Emit("Wrong attribute type: expected ") \ + << XLA_FFI_AttrType_SCALAR << " but got" << attr_type; \ + } \ + auto* scalar = reinterpret_cast(attr); \ + if (scalar->dtype != XLA_FFI_DataType_U8) [[unlikely]] { \ + return diagnostic.Emit("Wrong scalar data type: expected ") \ + << XLA_FFI_DataType_U8 << " but got " << scalar->dtype; \ + } \ + auto underlying = \ + *reinterpret_cast*>(scalar->value); \ + return static_cast(underlying); \ + } + +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side); +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); + +#undef REGISTER_CHAR_ENUM_ATTR_DECODING + namespace jax { -static_assert(sizeof(lapack_int) == sizeof(int32_t), - "Expected LAPACK integers to be 32-bit"); +//== Triangular System Solver ==// + +// lapack trsm template typename Trsm::FnType* Trsm::fn = nullptr; @@ -92,7 +193,9 @@ template struct Trsm; template struct Trsm>; template struct Trsm>; -// Getrf +//== LU Decomposition ==// + +// lapack getrf template typename Getrf::FnType* Getrf::fn = nullptr; @@ -126,7 +229,47 @@ template struct Getrf; template struct Getrf>; template struct Getrf>; -// Geqrf +// FFI Kernel + +template +ffi::Error LuDecomposition::Kernel( + ffi::Buffer x, ffi::ResultBuffer x_out, + ffi::ResultBuffer ipiv, + ffi::ResultBuffer info) { + RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions)); + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions); + auto* x_out_data = x_out->data; + auto* ipiv_data = ipiv->data; + auto* info_data = info->data; + + CopyIfDiffBuffer(x, x_out); + + ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v, + MaybeCastNoOverflow(x_rows)); + ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v, + MaybeCastNoOverflow(x_cols)); + auto x_leading_dim_v = x_rows_v; + + const int64_t x_out_step{x_rows * x_cols}; + const int64_t ipiv_step{std::min(x_rows, x_cols)}; + for (int64_t i = 0; i < batch_count; ++i) { + fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, ipiv_data, + info_data); + x_out_data += x_out_step; + ipiv_data += ipiv_step; + ++info_data; + } + return ffi::Error::Success(); +} + +template struct LuDecomposition; +template struct LuDecomposition; +template struct LuDecomposition; +template struct LuDecomposition; + +//== QR Factorization ==// + +// lapack geqrf template typename Geqrf::FnType* Geqrf::fn = nullptr; @@ -173,7 +316,10 @@ template struct Geqrf; template struct Geqrf>; template struct Geqrf>; -// Orgqr +//== Orthogonal QR ==// +//== Computes orthogonal matrix Q from QR Decomposition ==// + +// lapack orgqr template typename Orgqr::FnType* Orgqr::fn = nullptr; @@ -221,7 +367,9 @@ template struct Orgqr; template struct Orgqr>; template struct Orgqr>; -// Potrf +//== Cholesky Factorization ==// + +// lapack potrf template typename Potrf::FnType* Potrf::fn = nullptr; @@ -255,7 +403,42 @@ template struct Potrf; template struct Potrf>; template struct Potrf>; -// Gesdd +// FFI Kernel + +template +ffi::Error CholeskyFactorization::Kernel( + ffi::Buffer x, MatrixParams::UpLo uplo, + ffi::ResultBuffer x_out, ffi::ResultBuffer info) { + RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions)); + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions); + auto* x_out_data = x_out->data; + auto* info_data = info->data; + + CopyIfDiffBuffer(x, x_out); + + auto uplo_v = static_cast(uplo); + ASSIGN_OR_RETURN_FFI_ERROR( + auto x_order_v, MaybeCastNoOverflow(x.dimensions.back())); + auto x_leading_dim_v = x_order_v; + + const int64_t x_out_step{x_rows * x_cols}; + for (int64_t i = 0; i < batch_count; ++i) { + fn(&uplo_v, &x_order_v, x_out_data, &x_leading_dim_v, info_data); + x_out_data += x_out_step; + ++info_data; + } + return ffi::Error::Success(); +} + +template struct CholeskyFactorization; +template struct CholeskyFactorization; +template struct CholeskyFactorization; +template struct CholeskyFactorization; + +//== Singular Value Decomposition (SVD) ==// +//== using a divide and conquer method ==// + +// lapack gesdd static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) { if (!job_opt_compute_uv) { @@ -267,7 +450,7 @@ static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) { } lapack_int GesddIworkSize(int64_t m, int64_t n) { - return catch_lapack_int_overflow("gesdd iwork", 8 * std::min(m, n)); + return CastNoOverflow(8 * std::min(m, n), "gesdd iwork"); } template @@ -333,11 +516,12 @@ int64_t RealGesdd::Workspace(lapack_int m, lapack_int n, lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) { int64_t mn = std::min(m, n); if (compute_uv == 0) { - return catch_lapack_int_overflow("complex gesdd rwork", 7 * mn); + return CastNoOverflow(7 * mn, "complex gesdd rwork"); } int64_t mx = std::max(m, n); - return catch_lapack_int_overflow("complex gesdd rwork", - std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn)); + return CastNoOverflow( + std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn), + "complex gesdd rwork"); } template @@ -408,13 +592,17 @@ template struct RealGesdd; template struct ComplexGesdd>; template struct ComplexGesdd>; +//== Eigenvalues and eigenvectors ==// + +// lapack syevd/heevd + // # Workspace sizes, taken from the LAPACK documentation. lapack_int SyevdWorkSize(int64_t n) { - return catch_lapack_int_overflow("syevd lwork", 1 + 6 * n + 2 * n * n); + return CastNoOverflow(1 + 6 * n + 2 * n * n, "syevd lwork"); } lapack_int SyevdIworkSize(int64_t n) { - return catch_lapack_int_overflow("syevd iwork", 3 + 5 * n); + return CastNoOverflow(3 + 5 * n, "syevd iwork"); } template @@ -454,11 +642,11 @@ void RealSyevd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { // Workspace sizes, taken from the LAPACK documentation. lapack_int HeevdWorkSize(int64_t n) { - return catch_lapack_int_overflow("heevd work", 1 + 2 * n + n * n); + return CastNoOverflow(1 + 2 * n + n * n, "heevd work"); } lapack_int HeevdRworkSize(int64_t n) { - return catch_lapack_int_overflow("heevd rwork", 1 + 5 * n + 2 * n * n); + return CastNoOverflow(1 + 5 * n + 2 * n * n, "heevd rwork"); } template @@ -534,6 +722,8 @@ static void UnpackEigenvectors(int n, const T* im_eigenvalues, const T* packed, } } +// lapack geev + template typename RealGeev::FnType* RealGeev::fn = nullptr; @@ -679,7 +869,9 @@ template struct RealGeev; template struct ComplexGeev>; template struct ComplexGeev>; -// Gees +//== Schur Decomposition ==// + +// lapack gees template typename RealGees::FnType* RealGees::fn = nullptr; @@ -809,6 +1001,10 @@ template struct RealGees; template struct ComplexGees>; template struct ComplexGees>; +//== Hessenberg Decomposition ==// + +// lapack gehrd + template typename Gehrd::FnType* Gehrd::fn = nullptr; @@ -859,6 +1055,10 @@ template struct Gehrd; template struct Gehrd>; template struct Gehrd>; +//== Tridiagonal Reduction ==// + +// lapack sytrd/hetrd + template typename Sytrd::FnType* Sytrd::fn = nullptr; @@ -917,3 +1117,6 @@ template struct Sytrd>; template struct Sytrd>; } // namespace jax + +#undef ASSIGN_OR_RETURN_FFI_ERROR +#undef RETURN_IF_FFI_ERROR diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 4641b772c2ab..4119f6ba08a2 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -16,19 +16,70 @@ limitations under the License. #ifndef JAXLIB_CPU_LAPACK_KERNELS_H_ #define JAXLIB_CPU_LAPACK_KERNELS_H_ -#include #include +#include +#include +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/api/c_api.h" #include "xla/service/custom_call_status.h" -// Underlying function pointers (e.g., Trsm::Fn) are initialized either +// Underlying function pointers (i.e., KERNEL_CLASS::Fn) are initialized either // by the pybind wrapper that links them to an existing SciPy lapack instance, // or using the lapack_kernels_strong.cc static initialization to link them // directly to lapack for use in a pure C++ context. namespace jax { -typedef int lapack_int; +struct MatrixParams { + enum class Side : char { kLeft = 'L', kRight = 'R' }; + enum class UpLo : char { kLower = 'L', kUpper = 'U' }; + enum class Diag : char { kNonUnit = 'N', kUnit = 'U' }; + enum class Transpose : char { + kNoTrans = 'N', + kTrans = 'T', + kConjTrans = 'C' + }; +}; + +template +void AssignKernelFn(void* func) { + KernelType::fn = reinterpret_cast(func); +} + +template +void AssignKernelFn(typename KernelType::FnType* func) { + KernelType::fn = func; +} + +} // namespace jax + +#define DEFINE_CHAR_ENUM_ATTR_DECODING(ATTR) \ + template <> \ + struct xla::ffi::AttrDecoding { \ + using Type = ATTR; \ + static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ + DiagnosticEngine& diagnostic); \ + } + +// XLA needs attributes to have deserialization method specified +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side); +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); + +#undef DEFINE_CHAR_ENUM_ATTR_DECODING + +namespace jax { + +using lapack_int = int; +inline constexpr auto LapackIntDtype = ::xla::ffi::DataType::S32; +static_assert( + std::is_same_v<::xla::ffi::NativeType, lapack_int>); + +//== Triangular System Solver ==// + +// lapack trsm template struct Trsm { @@ -40,6 +91,10 @@ struct Trsm { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +//== LU Decomposition ==// + +// lapack getrf + template struct Getrf { using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, @@ -49,6 +104,25 @@ struct Getrf { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +// FFI Kernel + +template <::xla::ffi::DataType dtype> +struct LuDecomposition { + using ValueType = ::xla::ffi::NativeType; + using FnType = void(lapack_int* m, lapack_int* n, ValueType* a, + lapack_int* lda, lapack_int* ipiv, lapack_int* info); + + inline static FnType* fn = nullptr; + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer ipiv, + ::xla::ffi::ResultBuffer info); +}; + +//== QR Factorization ==// + +// lapack geqrf + template struct Geqrf { using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, @@ -60,6 +134,10 @@ struct Geqrf { static int64_t Workspace(lapack_int m, lapack_int n); }; +//== Orthogonal QR ==// + +// lapack orgqr + template struct Orgqr { using FnType = void(lapack_int* m, lapack_int* n, lapack_int* k, T* a, @@ -70,6 +148,10 @@ struct Orgqr { static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k); }; +//== Cholesky Factorization ==// + +// lapack potrf + template struct Potrf { using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, @@ -78,6 +160,24 @@ struct Potrf { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +template <::xla::ffi::DataType dtype> +struct CholeskyFactorization { + using ValueType = ::xla::ffi::NativeType; + using FnType = void(char* uplo, lapack_int* n, ValueType* a, lapack_int* lda, + lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, MatrixParams::UpLo uplo, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer info); +}; + +//== Singular Value Decomposition (SVD) ==// + +// lapack gesdd + lapack_int GesddIworkSize(int64_t m, int64_t n); template @@ -109,6 +209,10 @@ struct ComplexGesdd { bool job_opt_full_matrices); }; +//== Eigenvalues and eigenvectors ==// + +// lapack syevd/heevd + lapack_int SyevdWorkSize(int64_t n); lapack_int SyevdIworkSize(int64_t n); @@ -135,6 +239,8 @@ struct ComplexHeevd { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +// lapack geev + template struct RealGeev { using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, @@ -155,6 +261,10 @@ struct ComplexGeev { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +//== Schur Decomposition ==// + +// lapack gees + template struct RealGees { using FnType = void(char* jobvs, char* sort, bool (*select)(T, T), @@ -176,7 +286,11 @@ struct ComplexGees { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; -// Gehrd: Reduces a non-symmetric square matrix to upper Hessenberg form. +//== Hessenberg Decomposition ==// +//== Reduces a non-symmetric square matrix to upper Hessenberg form ==// + +// lapack gehrd + template struct Gehrd { using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a, @@ -199,14 +313,16 @@ struct real_type> { typedef T type; }; -// Sytrd/Hetrd: Reduces a symmetric (Hermitian) square matrix to tridiagonal -// form. +//== Tridiagonal Reduction ==// +//== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==// + +// lapack sytrd/hetrd + template struct Sytrd { using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, typename real_type::type* d, - typename real_type::type* e, - T* tau, T* work, + typename real_type::type* e, T* tau, T* work, lapack_int* lwork, lapack_int* info); static FnType* fn; diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index bc67fc556a49..48b1d5bffc1b 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -13,12 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "jaxlib/cpu/lapack_kernels.h" // From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but // a C++ user should link against LAPACK directly. This is needed when using // JAX-generated HLO from C++. +namespace ffi = xla::ffi; + extern "C" { jax::Trsm::FnType strsm_; @@ -26,10 +31,10 @@ jax::Trsm::FnType dtrsm_; jax::Trsm>::FnType ctrsm_; jax::Trsm>::FnType ztrsm_; -jax::Getrf::FnType sgetrf_; -jax::Getrf::FnType dgetrf_; -jax::Getrf>::FnType cgetrf_; -jax::Getrf>::FnType zgetrf_; +jax::LuDecomposition::FnType sgetrf_; +jax::LuDecomposition::FnType dgetrf_; +jax::LuDecomposition::FnType cgetrf_; +jax::LuDecomposition::FnType zgetrf_; jax::Geqrf::FnType sgeqrf_; jax::Geqrf::FnType dgeqrf_; @@ -41,10 +46,10 @@ jax::Orgqr::FnType dorgqr_; jax::Orgqr>::FnType cungqr_; jax::Orgqr>::FnType zungqr_; -jax::Potrf::FnType spotrf_; -jax::Potrf::FnType dpotrf_; -jax::Potrf>::FnType cpotrf_; -jax::Potrf>::FnType zpotrf_; +jax::CholeskyFactorization::FnType spotrf_; +jax::CholeskyFactorization::FnType dpotrf_; +jax::CholeskyFactorization::FnType cpotrf_; +jax::CholeskyFactorization::FnType zpotrf_; jax::RealGesdd::FnType sgesdd_; jax::RealGesdd::FnType dgesdd_; @@ -80,51 +85,106 @@ jax::Sytrd>::FnType zhetrd_; namespace jax { +#define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch" + +static_assert(std::is_same_v::FnType, + jax::Getrf::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert(std::is_same_v::FnType, + jax::Getrf::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert(std::is_same_v::FnType, + jax::Getrf>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert(std::is_same_v::FnType, + jax::Getrf>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Potrf::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Potrf::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Potrf>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Potrf>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); + +#undef JAX_KERNEL_FNTYPE_MISMATCH_MSG + static auto init = []() -> int { - Trsm::fn = strsm_; - Trsm::fn = dtrsm_; - Trsm>::fn = ctrsm_; - Trsm>::fn = ztrsm_; - Getrf::fn = sgetrf_; - Getrf::fn = dgetrf_; - Getrf>::fn = cgetrf_; - Getrf>::fn = zgetrf_; - Geqrf::fn = sgeqrf_; - Geqrf::fn = dgeqrf_; - Geqrf>::fn = cgeqrf_; - Geqrf>::fn = zgeqrf_; - Orgqr::fn = sorgqr_; - Orgqr::fn = dorgqr_; - Orgqr>::fn = cungqr_; - Orgqr>::fn = zungqr_; - Potrf::fn = spotrf_; - Potrf::fn = dpotrf_; - Potrf>::fn = cpotrf_; - Potrf>::fn = zpotrf_; - RealGesdd::fn = sgesdd_; - RealGesdd::fn = dgesdd_; - ComplexGesdd>::fn = cgesdd_; - ComplexGesdd>::fn = zgesdd_; - RealSyevd::fn = ssyevd_; - RealSyevd::fn = dsyevd_; - ComplexHeevd>::fn = cheevd_; - ComplexHeevd>::fn = zheevd_; - RealGeev::fn = sgeev_; - RealGeev::fn = dgeev_; - ComplexGeev>::fn = cgeev_; - ComplexGeev>::fn = zgeev_; - RealGees::fn = sgees_; - RealGees::fn = dgees_; - ComplexGees>::fn = cgees_; - ComplexGees>::fn = zgees_; - Gehrd::fn = sgehrd_; - Gehrd::fn = dgehrd_; - Gehrd>::fn = cgehrd_; - Gehrd>::fn = zgehrd_; - Sytrd::fn = ssytrd_; - Sytrd::fn = dsytrd_; - Sytrd>::fn = chetrd_; - Sytrd>::fn = zhetrd_; + AssignKernelFn>(strsm_); + AssignKernelFn>(dtrsm_); + AssignKernelFn>>(ctrsm_); + AssignKernelFn>>(ztrsm_); + + AssignKernelFn>(sgetrf_); + AssignKernelFn>(dgetrf_); + AssignKernelFn>>(cgetrf_); + AssignKernelFn>>(zgetrf_); + + AssignKernelFn>(sgeqrf_); + AssignKernelFn>(dgeqrf_); + AssignKernelFn>>(cgeqrf_); + AssignKernelFn>>(zgeqrf_); + + AssignKernelFn>(sorgqr_); + AssignKernelFn>(dorgqr_); + AssignKernelFn>>(cungqr_); + AssignKernelFn>>(zungqr_); + + AssignKernelFn>(spotrf_); + AssignKernelFn>(dpotrf_); + AssignKernelFn>>(cpotrf_); + AssignKernelFn>>(zpotrf_); + + AssignKernelFn>(sgesdd_); + AssignKernelFn>(dgesdd_); + AssignKernelFn>>(cgesdd_); + AssignKernelFn>>(zgesdd_); + + AssignKernelFn>(ssyevd_); + AssignKernelFn>(dsyevd_); + AssignKernelFn>>(cheevd_); + AssignKernelFn>>(zheevd_); + + AssignKernelFn>(sgeev_); + AssignKernelFn>(dgeev_); + AssignKernelFn>>(cgeev_); + AssignKernelFn>>(zgeev_); + + AssignKernelFn>(sgees_); + AssignKernelFn>(dgees_); + AssignKernelFn>>(cgees_); + AssignKernelFn>>(zgees_); + + AssignKernelFn>(sgehrd_); + AssignKernelFn>(dgehrd_); + AssignKernelFn>>(cgehrd_); + AssignKernelFn>>(zgehrd_); + + AssignKernelFn>(ssytrd_); + AssignKernelFn>(dsytrd_); + AssignKernelFn>>(chetrd_); + AssignKernelFn>>(zhetrd_); + + // FFI Kernels + + AssignKernelFn>(sgetrf_); + AssignKernelFn>(dgetrf_); + AssignKernelFn>(cgetrf_); + AssignKernelFn>(zgetrf_); + + AssignKernelFn>(spotrf_); + AssignKernelFn>(dpotrf_); + AssignKernelFn>(cpotrf_); + AssignKernelFn>(zpotrf_); return 0; }(); diff --git a/jaxlib/cpu_feature_guard.c b/jaxlib/cpu_feature_guard.c index b7fe688eaf52..7c8ff2951a79 100644 --- a/jaxlib/cpu_feature_guard.c +++ b/jaxlib/cpu_feature_guard.c @@ -77,9 +77,19 @@ static int GetXCR0EAX() { static void ReportMissingCpuFeature(const char* name) { PyErr_Format( PyExc_RuntimeError, +#if defined(__APPLE__) + "This version of jaxlib was built using %s instructions, which your " + "CPU and/or operating system do not support. This error is frequently " + "encountered on macOS when running an x86 Python installation on ARM " + "hardware. In this case, try installing an ARM build of Python. " + "Otherwise, you may be able work around this issue by building jaxlib " + "from source.", +#else "This version of jaxlib was built using %s instructions, which your " "CPU and/or operating system do not support. You may be able work around " - "this issue by building jaxlib from source.", name); + "this issue by building jaxlib from source.", +#endif + name); } static PyObject *CheckCpuFeatures(PyObject *self, PyObject *args) { diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 151bbe828408..27f0a8e2a8ed 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -283,6 +283,10 @@ cc_library( ":cuda_vendor", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", ], ) @@ -296,7 +300,7 @@ cuda_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", "@local_config_cuda//cuda:cuda_headers", ], @@ -513,8 +517,6 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:absl_status_casters", - "//jaxlib:kernel_nanobind_helpers", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", @@ -522,6 +524,7 @@ cc_library( "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", "@xla//xla/tsl/cuda:cusparse", + "@com_google_absl//absl/base:dynamic_annotations", ], ) @@ -575,6 +578,7 @@ py_library( ":_sparse", ":_triton", ":_versions", + "//jaxlib/mosaic/gpu:mosaic_gpu", ], ) diff --git a/jaxlib/cuda/versions_helpers.cc b/jaxlib/cuda/versions_helpers.cc index e517b8c4f069..d42199d37467 100644 --- a/jaxlib/cuda/versions_helpers.cc +++ b/jaxlib/cuda/versions_helpers.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/base/dynamic_annotations.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" @@ -30,39 +31,45 @@ namespace jax::cuda { int CudaRuntimeGetVersion() { int version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaRuntimeGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CudaDriverGetVersion() { int version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaDriverGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } uint32_t CuptiGetVersion() { uint32_t version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuptiGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CufftGetVersion() { int version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cufftGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CusolverGetVersion() { int version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CublasGetVersion() { int version; - // NVIDIA promise that it's safe to parse nullptr as the handle to this + // NVIDIA promise that it's safe to pass a null pointer as the handle to this // function. JAX_THROW_IF_ERROR( JAX_AS_STATUS(cublasGetVersion(/*handle=*/nullptr, &version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } @@ -73,6 +80,9 @@ int CusparseGetVersion() { JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MAJOR_VERSION, &major))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MINOR_VERSION, &minor))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(PATCH_LEVEL, &patch))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&major, sizeof major); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&minor, sizeof minor); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&patch, sizeof patch); return major * 1000 + minor * 100 + patch; } size_t CudnnGetVersion() { @@ -82,6 +92,7 @@ size_t CudnnGetVersion() { if (version == 0) { throw std::runtime_error("cuDNN not found."); } + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CudaComputeCapability(int device) { @@ -91,6 +102,8 @@ int CudaComputeCapability(int device) { &major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute( &minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&major, sizeof major); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&minor, sizeof minor); return major * 10 + minor; } @@ -99,6 +112,7 @@ int CudaDeviceCount() { JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuInit(0))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuDeviceGetCount(&device_count))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&device_count, sizeof device_count); return device_count; } diff --git a/jaxlib/cuda_plugin_extension.cc b/jaxlib/cuda_plugin_extension.cc index d0c197b2b545..0bb8cbbace65 100644 --- a/jaxlib/cuda_plugin_extension.cc +++ b/jaxlib/cuda_plugin_extension.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include +#include #include +#include #include #include "nanobind/nanobind.h" @@ -34,9 +36,11 @@ namespace nb = nanobind; namespace xla { namespace { -Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, - nb::capsule fn, int api_version, - XLA_FFI_Handler_Traits traits) { +absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, + const char* fn_name_c_str, + size_t fn_name_size, nb::capsule fn, + int api_version, + XLA_FFI_Handler_Traits traits) { if (c_api->extension_start == nullptr) { return Unimplemented("The plugin does not have extension."); } @@ -57,8 +61,8 @@ Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, PJRT_Gpu_Register_Custom_Call_Args args; args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; - args.function_name = fn_name.c_str(); - args.function_name_size = nb::len(fn_name); + args.function_name = fn_name_c_str; + args.function_name_size = fn_name_size; #if PJRT_API_GPU_EXTENSION_VERSION >= 1 args.api_version = api_version; #endif @@ -66,7 +70,7 @@ Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, RETURN_STATUS_IF_PJRT_ERROR( reinterpret_cast(next)->custom_call(&args), c_api); - return OkStatus(); + return absl::OkStatus(); } nb::dict Registrations() { @@ -93,12 +97,23 @@ NB_MODULE(cuda_plugin_extension, m) { tsl::ImportNumpy(); m.def( "register_custom_call_target", - [](nb::capsule c_api, nb::str fn_name, nb::capsule fn, + [](nb::capsule c_api, nb::object fn_name_py, nb::capsule fn, nb::str xla_platform_name, int api_version, XLA_FFI_Handler_Traits traits) { + const char* fn_name_c_str; + size_t fn_name_size; + nb::str fn_name_bn_str; + if (nb::try_cast(fn_name_py, fn_name_bn_str)) { + fn_name_c_str = fn_name_bn_str.c_str(); + fn_name_size = nb::len(fn_name_bn_str); + } else{ + nb::bytes bytes = nb::cast(fn_name_py); + fn_name_c_str = bytes.c_str(); + fn_name_size = bytes.size(); + } xla::ThrowIfError(RegisterCustomCallTarget( - static_cast(c_api.data()), fn_name, std::move(fn), - api_version, traits)); + static_cast(c_api.data()), fn_name_c_str, + fn_name_size, std::move(fn), api_version, traits)); }, nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), nb::arg("xla_platform_name"), nb::arg("api_version") = 0, diff --git a/jaxlib/gpu/lu_pivot_kernels.cc b/jaxlib/gpu/lu_pivot_kernels.cc index dc5b71716d66..b2c6362273ab 100644 --- a/jaxlib/gpu/lu_pivot_kernels.cc +++ b/jaxlib/gpu/lu_pivot_kernels.cc @@ -16,8 +16,14 @@ limitations under the License. #include "jaxlib/gpu/lu_pivot_kernels.h" #include +#include +#include #include +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/c_api.h" @@ -28,29 +34,51 @@ namespace JAX_GPU_NAMESPACE { namespace ffi = xla::ffi; -XLA_FFI_Error* LuPivotsToPermutation(XLA_FFI_CallFrame* call_frame) { - static const auto* kImpl = - ffi::Ffi::Bind() - .Ctx>() - .Attr("batch_size") - .Attr("pivot_size") - .Attr("permutation_size") - .Arg>() - .Ret>() - .To([](gpuStream_t stream, std::int64_t batch_size, - std::int32_t pivot_size, std::int32_t permutation_size, - auto pivots, auto permutation) -> ffi::Error { - LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size, - permutation_size, pivots.data, - permutation->data); - if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) { - return ffi::Error(static_cast(status.code()), - std::string(status.message())); - } - return ffi::Error::Success(); - }) - .release(); - return kImpl->Call(call_frame); +template +inline absl::StatusOr MaybeCastNoOverflow( + std::int64_t value, const std::string& source = __FILE__) { + if constexpr (sizeof(T) == sizeof(std::int64_t)) { + return value; + } else { + if (value > std::numeric_limits::max()) [[unlikely]] { + return absl::InvalidArgumentError(absl::StrFormat( + "%s: Value (=%d) exceeds the maximum representable value of the " + "desired type", + source, value)); + } + return static_cast(value); + } +} + +ffi::Error LuPivotsToPermutationImpl( + gpuStream_t stream, std::int32_t permutation_size, + ffi::Buffer pivots, + ffi::Result> permutation) { + auto dims = pivots.dimensions; + if (dims.size() < 1) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "pivots must have at least one dimension"); + } + auto maybe_pivot_size = MaybeCastNoOverflow(dims.back()); + if (!maybe_pivot_size.ok()) { + return ffi::Error( + static_cast(maybe_pivot_size.status().code()), + std::string(maybe_pivot_size.status().message())); + } + std::int32_t pivot_size = maybe_pivot_size.value(); + std::int64_t batch_size = 1; + if (dims.size() >= 2) { + batch_size = + absl::c_accumulate(dims.first(dims.size() - 1), 1, std::multiplies<>()); + } + LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size, + permutation_size, pivots.data, + permutation->data); + if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) { + return ffi::Error(static_cast(status.code()), + std::string(status.message())); + } + return ffi::Error::Success(); } } // namespace JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/lu_pivot_kernels.h b/jaxlib/gpu/lu_pivot_kernels.h index a4af440d5b72..b2cceb883dc9 100644 --- a/jaxlib/gpu/lu_pivot_kernels.h +++ b/jaxlib/gpu/lu_pivot_kernels.h @@ -19,11 +19,13 @@ limitations under the License. #include #include "jaxlib/gpu/vendor.h" -#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" namespace jax { namespace JAX_GPU_NAMESPACE { +namespace ffi = xla::ffi; + void LaunchLuPivotsToPermutationKernel(gpuStream_t stream, std::int64_t batch_size, std::int32_t pivot_size, @@ -31,7 +33,17 @@ void LaunchLuPivotsToPermutationKernel(gpuStream_t stream, const std::int32_t* pivots, std::int32_t* permutation); -XLA_FFI_Error* LuPivotsToPermutation(XLA_FFI_CallFrame* call_frame); +ffi::Error LuPivotsToPermutationImpl( + gpuStream_t stream, std::int32_t permutation_size, + ffi::Buffer pivots, + ffi::Result> permutation); + +XLA_FFI_DEFINE_HANDLER(LuPivotsToPermutation, LuPivotsToPermutationImpl, + ffi::Ffi::Bind() + .Ctx>() + .Attr("permutation_size") + .Arg>() + .Ret>()); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 87fd6954d12e..af79b3ae756f 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -42,14 +42,21 @@ _name, _value, platform="CUDA", api_version=api_version ) -try: - from .rocm import _linalg as _hip_linalg # pytype: disable=import-error +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hip_linalg = importlib.import_module( + f"{rocm_module_name}._linalg", package="jaxlib" + ) + except ImportError: + _hip_linalg = None + else: + break + +if _hip_linalg: for _name, _value in _hip_linalg.registrations().items(): xla_client.register_custom_call_target( _name, _value, platform="ROCM", api_version=1 ) -except ImportError: - _hip_linalg = None _prod = lambda xs: functools.reduce(operator.mul, xs, 1) @@ -59,13 +66,9 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s typ = ir.RankedTensorType(pivots.type) dims = typ.shape i32_type = ir.IntegerType.get_signless(32) - i64_type = ir.IntegerType.get_signless(64) assert typ.element_type == i32_type, typ - batch_size = _prod(dims[:-1]) - pivot_size = dims[-1] - if not gpu_linalg: raise GpuLibNotLinkedError() @@ -80,8 +83,6 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s result_types=[permutations_type], operands=[pivots], backend_config=dict( - batch_size=ir.IntegerAttr.get(i64_type, batch_size), - pivot_size=ir.IntegerAttr.get(i32_type, pivot_size), permutation_size=ir.IntegerAttr.get(i32_type, permutation_size), ), operand_layouts=[pivots_layout], diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index a03dc8a8a4ca..12dcacebfa46 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -41,12 +41,19 @@ for _name, _value in _cuda_prng.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") -try: - from .rocm import _prng as _hip_prng # pytype: disable=import-error +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hip_prng = importlib.import_module( + f"{rocm_module_name}._prng", package="jaxlib" + ) + except ImportError: + _hip_prng = None + else: + break + +if _hip_prng: for _name, _value in _hip_prng.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") -except ImportError: - _hip_prng = None _prod = lambda xs: functools.reduce(operator.mul, xs, 1) diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index e804068e5e6d..f9a704a79f3d 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -59,21 +59,34 @@ for _name, _value in _cusolver.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") - try: from .rocm import _blas as _hipblas # pytype: disable=import-error - for _name, _value in _hipblas.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") except ImportError: - _hipblas = None + for rocm_module_name in ["jax_rocm60_plugin"]: + try: + _hipblas = importlib.import_module(f"{rocm_module_name}._blas") + except: + _hipblas = None + else: + break -try: - from .rocm import _solver as _hipsolver # pytype: disable=import-error - for _name, _value in _hipsolver.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") -except ImportError: - _hipsolver = None +if _hipblas: + for _name, _value in _hipblas.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM") + +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hipsolver = importlib.import_module( + f"{rocm_module_name}._solver", package="jaxlib" + ) + except ImportError: + _hipsolver = None + else: + break +if _hipsolver: + for _name, _value in _hipsolver.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM") def _real_type(dtype): """Returns the real equivalent of 'dtype'.""" diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index 7b840e9595ef..84192d4d0286 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -41,11 +41,17 @@ for _name, _value in _cusparse.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") -try: - from .rocm import _sparse as _hipsparse # pytype: disable=import-error -except ImportError: - _hipsparse = None -else: +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hipsparse = importlib.import_module( + f"{rocm_module_name}._sparse", package="jaxlib" + ) + except ImportError: + _hipsparse = None + else: + break + +if _hipsparse: for _name, _value in _hipsparse.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") diff --git a/jaxlib/gpu_triton.py b/jaxlib/gpu_triton.py index e38a7feddefc..f2d37bfec03d 100644 --- a/jaxlib/gpu_triton.py +++ b/jaxlib/gpu_triton.py @@ -38,8 +38,17 @@ get_custom_call = _cuda_triton.get_custom_call get_serialized_metadata = _cuda_triton.get_serialized_metadata -try: - from .rocm import _triton as _hip_triton # pytype: disable=import-error +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hip_triton = importlib.import_module( + f"{rocm_module_name}._triton", package="jaxlib" + ) + except ImportError: + _hip_triton = None + else: + break + +if _hip_triton: xla_client.register_custom_call_target( "triton_kernel_call", _hip_triton.get_custom_call(), platform='ROCM') @@ -51,5 +60,3 @@ get_compute_capability = _hip_triton.get_compute_capability get_custom_call = _hip_triton.get_custom_call get_serialized_metadata = _hip_triton.get_serialized_metadata -except ImportError: - _hip_triton = None diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index 4ec995172585..0d57a04f1aa7 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -16,9 +16,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Callable, Union +from typing import Union import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.stablehlo as hlo diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index abc99abf1567..a4da463a3447 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -48,11 +48,6 @@ jax_internal_test_harnesses_visibility = [] jax_test_util_visibility = [] loops_visibility = [] -def get_importlib_metadata(): - if HERMETIC_PYTHON_VERSION == "3.9": - return ["@pypi_importlib_metadata//:pkg"] - return [] - # TODO(vam): remove this once zstandard builds against Python 3.13 def get_zstandard(): if HERMETIC_PYTHON_VERSION == "3.13": @@ -66,9 +61,9 @@ _py_deps = { "cloudpickle": ["@pypi_cloudpickle//:pkg"], "colorama": ["@pypi_colorama//:pkg"], "epath": ["@pypi_etils//:pkg"], # etils.epath + "filelock": ["@pypi_filelock//:pkg"], "flatbuffers": ["@pypi_flatbuffers//:pkg"], "hypothesis": ["@pypi_hypothesis//:pkg"], - "importlib_metadata": get_importlib_metadata(), "matplotlib": ["@pypi_matplotlib//:pkg"], "opt_einsum": ["@pypi_opt_einsum//:pkg"], "pil": ["@pypi_pillow//:pkg"], @@ -207,12 +202,6 @@ def if_building_jaxlib(if_building, if_not_building = ["@pypi_jaxlib//:pkg"]): "//conditions:default": if_not_building, }) -def if_building_mosaic_gpu(if_building, if_not_building = []): - return select({ - "//jax:enable_mosaic_gpu": if_building, - "//conditions:default": if_not_building, - }) - # buildifier: disable=function-docstring def jax_test( name, @@ -221,6 +210,7 @@ def jax_test( env = {}, shard_count = None, deps = [], + data = [], disable_backends = None, # buildifier: disable=unused-variable backend_variant_args = {}, # buildifier: disable=unused-variable backend_tags = {}, # buildifier: disable=unused-variable @@ -262,6 +252,7 @@ def jax_test( "//jax:enable_build_cuda_plugin_from_source": ["//jax_plugins:gpu_plugin_only_test_deps"], "//conditions:default": [], }), + data = data, shard_count = test_shards, tags = test_tags, main = main, diff --git a/jaxlib/mlir/BUILD.bazel b/jaxlib/mlir/BUILD.bazel index 01cdaf019cd6..6626d0d162ab 100644 --- a/jaxlib/mlir/BUILD.bazel +++ b/jaxlib/mlir/BUILD.bazel @@ -224,20 +224,6 @@ symlink_inputs( ], ) -symlink_inputs( - name = "execution_engine", - rule = py_library, - symlinked_inputs = {"srcs": { - ".": [ - "@llvm-project//mlir/python:ExecutionEnginePyFiles", - ], - }}, - deps = [ - ":mlir", - "//jaxlib/mlir/_mlir_libs:_mlirExecutionEngine", - ], -) - symlink_inputs( name = "nvgpu_dialect", rule = py_library, @@ -248,7 +234,7 @@ symlink_inputs( ":core", ":ir", ":mlir", - "//jaxlib/mlir/_mlir_libs:_mlirDialectsNvgpu", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsNVGPU", ], ) diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index b52bb78b9235..22735eeaf1ad 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -14,7 +14,6 @@ load( "//jaxlib:jax.bzl", - "if_building_mosaic_gpu", "if_windows", "py_extension", "pybind_extension", @@ -59,34 +58,6 @@ py_extension( ], ) -pybind_extension( - name = "_mlirExecutionEngine", - srcs = [ - "@llvm-project//mlir:lib/Bindings/Python/ExecutionEngineModule.cpp", - ], - copts = COPTS, - linkopts = LINKOPTS, - pytype_srcs = [ - ":_mlirExecutionEnginePyi", - ], - deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIExecutionEngineHeaders", - "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", - "@pybind11", - ], -) - -symlink_inputs( - name = "_mlirExecutionEnginePyi", - rule = filegroup, - symlinked_inputs = {"srcs": { - ".": [ - "@llvm-project//mlir/python:ExecutionEnginePyIFiles", - ], - }}, -) - py_extension( name = "_mlirDialectsGPU", srcs = [ @@ -118,7 +89,7 @@ py_extension( ) py_extension( - name = "_mlirDialectsNvgpu", + name = "_mlirDialectsNVGPU", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectNVGPU.cpp", ], @@ -225,33 +196,6 @@ pybind_extension( ], ) -pybind_extension( - name = "_mosaic_gpu_ext", - srcs = ["mosaic_gpu_ext.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - linkopts = select({ - "@xla//xla/python:use_jax_cuda_pip_rpaths": [ - "-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib", - ], - "//conditions:default": [], - }), - visibility = ["//third_party/py/jax:__subpackages__"], - deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib:kernel_nanobind_helpers", - "//jaxlib/cuda:cuda_vendor", - "//jaxlib/mosaic/gpu:mlir_capi", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/synchronization", - "@nanobind", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", - ], -) - symlink_inputs( name = "_mlir_libs", rule = py_library, @@ -295,28 +239,25 @@ cc_library( py_extension( name = "register_jax_dialects", srcs = ["register_jax_dialects.cc"], - copts = COPTS + if_building_mosaic_gpu(["-DJAXLIB_MOSAIC_GPU"]), + copts = COPTS, linkopts = LINKOPTS, deps = [ ":jaxlib_mlir_capi_shared_library", "@llvm-project//mlir:CAPIArithHeaders", + "@llvm-project//mlir:CAPIGPUHeaders", "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPILLVMHeaders", "@llvm-project//mlir:CAPIMathHeaders", "@llvm-project//mlir:CAPIMemRefHeaders", + "@llvm-project//mlir:CAPINVGPUHeaders", + "@llvm-project//mlir:CAPINVVMHeaders", "@llvm-project//mlir:CAPISCFHeaders", "@llvm-project//mlir:CAPITransformsHeaders", "@llvm-project//mlir:CAPIVectorHeaders", "@llvm-project//mlir:MLIRBindingsPythonHeaders", "@local_config_python//:headers", "@pybind11", - ] + if_building_mosaic_gpu([ - ":jaxlib_mlir_capi_shims_hdrs", - "@llvm-project//mlir:CAPIGPUHeaders", - "@llvm-project//mlir:CAPINVGPUHeaders", - "@llvm-project//mlir:CAPINVVMHeaders", - "@llvm-project//mlir:CAPILLVMHeaders", - "@llvm-project//mlir:CAPIConversionHeaders", - ]), + ], ) ##---------------------------------------------------------------------------## @@ -404,9 +345,13 @@ cc_library( deps = [ "//jaxlib/mosaic:tpu_dialect_capi_objects", "@llvm-project//mlir:CAPIArithObjects", + "@llvm-project//mlir:CAPIGPUObjects", "@llvm-project//mlir:CAPIIRObjects", + "@llvm-project//mlir:CAPILLVMObjects", "@llvm-project//mlir:CAPIMathObjects", "@llvm-project//mlir:CAPIMemRefObjects", + "@llvm-project//mlir:CAPINVGPUObjects", + "@llvm-project//mlir:CAPINVVMObjects", "@llvm-project//mlir:CAPISCFObjects", "@llvm-project//mlir:CAPISparseTensorObjects", "@llvm-project//mlir:CAPITransformsObjects", @@ -420,16 +365,7 @@ cc_library( [ "//jaxlib/triton:triton_dialect_capi_objects", ], - ) + if_building_mosaic_gpu([ - ":jaxlib_mlir_capi_shims", - "//jaxlib/mosaic/gpu:mlir_capi_objects", - "@llvm-project//mlir:CAPIConversionObjects", - "@llvm-project//mlir:CAPIExecutionEngineObjects", - "@llvm-project//mlir:CAPIGPUObjects", - "@llvm-project//mlir:CAPILLVMObjects", - "@llvm-project//mlir:CAPINVGPUObjects", - "@llvm-project//mlir:CAPINVVMObjects", - ]), + ), ) cc_binary( diff --git a/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.cc b/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.cc deleted file mode 100644 index 1f13f48d77b6..000000000000 --- a/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2024 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h" - -#include "mlir-c/IR.h" -#include "mlir/CAPI/IR.h" -#include "mlir/Dialect/GPU/Pipelines/Passes.h" -#include "mlir/Target/LLVM/NVVM/Target.h" -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" - -extern "C" { - -void jaxMlirRegisterMemRefPasses() { - mlir::memref::registerMemRefPasses(); -} - -void jaxMlirRegisterInterfaceExternalModels(MlirDialectRegistry registry) { - mlir::NVVM::registerNVVMTargetInterfaceExternalModels(*unwrap(registry)); - mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels( - *unwrap(registry)); - mlir::registerBuiltinDialectTranslation(*unwrap(registry)); - mlir::registerGPUDialectTranslation(*unwrap(registry)); - mlir::registerLLVMDialectTranslation(*unwrap(registry)); - mlir::registerNVVMDialectTranslation(*unwrap(registry)); -} -void jaxMlirRegisterGPUToNVVMPipeline() { - mlir::gpu::registerGPUToNVVMPipeline(); -} - -} diff --git a/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h b/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h deleted file mode 100644 index bebf40a7a350..000000000000 --- a/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2024 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_MLIR_CAPI_SHIMS -#define JAXLIB_MLIR_CAPI_SHIMS - -#include "mlir-c/IR.h" -#include "mlir-c/Support.h" - -#ifdef __cplusplus -extern "C" { -#endif - -MLIR_CAPI_EXPORTED void jaxMlirRegisterMemRefPasses(); -MLIR_CAPI_EXPORTED void jaxMlirRegisterInterfaceExternalModels(MlirDialectRegistry registry); -MLIR_CAPI_EXPORTED void jaxMlirRegisterGPUToNVVMPipeline(); - -#ifdef __cplusplus -} -#endif - -#endif // JAXLIB_MLIR_CAPI_SHIMS diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc deleted file mode 100644 index 8afd1b21a60b..000000000000 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ /dev/null @@ -1,108 +0,0 @@ -#include - -#include "nanobind/nanobind.h" -#include "absl/container/flat_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_nanobind_helpers.h" -#include "jaxlib/mosaic/gpu/integrations/c/passes.h" -#include "xla/service/custom_call_status.h" - -namespace jax::cuda { -namespace { - -namespace nb = nanobind; -using MosaicInitFunc = void(void***); -using MosaicHostFunc = void(void**); - -std::pair*, absl::Mutex*> -GetContextCache() { - static absl::Mutex mutex; - static auto& context_cache = *new absl::flat_hash_map; - return std::make_pair(&context_cache, &mutex); -} - -void InvalidateCache(MosaicInitFunc* init) { - auto cache = GetContextCache(); - absl::MutexLock lock(cache.second); - // TODO(apaszke): Free all the resources instead of leaking. - cache.first->erase(reinterpret_cast(init)); -} - -// Each compiled kernel has a unique init func, and each kernel is used from -// a single HLO module. So it should be safe to not include the CUDA context -// in the key. -void* InitOnce(MosaicInitFunc* init) { - auto cache_and_mutex = GetContextCache(); - auto* cache = cache_and_mutex.first; - auto* mutex = cache_and_mutex.second; - - uintptr_t key = reinterpret_cast(init); - - { - // Fast path uses reader lock (as hash map look-up is relatively slow). - absl::ReaderMutexLock lock(mutex); - auto it = cache->find(key); - if (ABSL_PREDICT_TRUE(it != cache->end())) return it->second; - } - - absl::MutexLock lock(mutex); - void*& ctx = (*cache)[key]; - // We released the reader lock, another thread might have initialized it. - if (ctx == nullptr) { - void** ptr_to_ctx = &ctx; - init(&ptr_to_ctx); - } - return ctx; -} - -void MosaicKernelCall(void* stream, void** buffers, char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - void** static_args = *reinterpret_cast(opaque); - MosaicHostFunc* func = reinterpret_cast(static_args[0]); - MosaicInitFunc* init = reinterpret_cast(static_args[1]); - void* ctx = InitOnce(init); - void* args[3] = {&ctx, &stream, &buffers}; - func(args); -} - -void EventRecordCall(void* stream, void** buffers, char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto* event = reinterpret_cast(opaque); - if (gpuEventRecord(**event, reinterpret_cast(stream)) != - gpuSuccess) { - const char message[] = "Failed to record event"; - XlaCustomCallStatusSetFailure(status, message, sizeof(message)); - } -} - -NB_MODULE(_mosaic_gpu_ext, m) { - m.def("_custom_call_capsule", - []() { return EncapsulateFunction(MosaicKernelCall); }); - m.def("register_passes", []() { return mlirMosaicGpuRegisterPasses(); }); - m.def("_gpu_event_create", []() { - gpuEvent_t* event = new gpuEvent_t(); - gpuEventCreate(event, GPU_EVENT_DEFAULT); - return reinterpret_cast(event); - }); - m.def("_gpu_event_destroy", [](uintptr_t event) { - gpuEventDestroy(*reinterpret_cast(event)); - }); - m.def("_gpu_event_elapsed", [](uintptr_t start_event, uintptr_t end_event) { - float elapsed_ms = -1; - if (gpuEventElapsedTime( - &elapsed_ms, *reinterpret_cast(start_event), - *reinterpret_cast(end_event)) != gpuSuccess) { - throw std::runtime_error("Failed to get elapsed time between events"); - } - return elapsed_ms; - }); - m.def("_record_event_capsule", - []() { return EncapsulateFunction(EventRecordCall); }); - m.def("invalidate_cache", [](uintptr_t init_func_ptr) { - return InvalidateCache(reinterpret_cast(init_func_ptr)); - }); -} - -} // namespace -} // namespace jax::cuda diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index cb407a72b5ac..e1958c211b33 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -3,20 +3,16 @@ #include "mlir-c/Dialect/Arith.h" #include "mlir-c/Dialect/Func.h" +#include "mlir-c/Dialect/GPU.h" +#include "mlir-c/Dialect/LLVM.h" #include "mlir-c/Dialect/Math.h" #include "mlir-c/Dialect/MemRef.h" +#include "mlir-c/Dialect/NVGPU.h" +#include "mlir-c/Dialect/NVVM.h" #include "mlir-c/Dialect/SCF.h" #include "mlir-c/Dialect/Vector.h" #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/PybindAdaptors.h" -#ifdef JAXLIB_MOSAIC_GPU -#include "mlir-c/Dialect/GPU.h" -#include "mlir-c/Dialect/NVGPU.h" -#include "mlir-c/Dialect/NVVM.h" -#include "mlir-c/Dialect/LLVM.h" -#include "mlir-c/Conversion.h" -#include "jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi_shims.h" -#endif #define REGISTER_DIALECT(name) \ MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ @@ -32,21 +28,13 @@ PYBIND11_MODULE(register_jax_dialects, m) { REGISTER_DIALECT(memref); REGISTER_DIALECT(scf); REGISTER_DIALECT(vector); - mlirRegisterTransformsPasses(); - // Transforms used by JAX. - mlirRegisterTransformsStripDebugInfo(); - // TODO(apaszke): Move those to Mosaic GPU C bindings. -#ifdef JAXLIB_MOSAIC_GPU + // For Mosaic GPU REGISTER_DIALECT(gpu); REGISTER_DIALECT(nvgpu); REGISTER_DIALECT(nvvm); REGISTER_DIALECT(llvm); - mlirRegisterGPUPasses(); - mlirRegisterConversionPasses(); - // TODO(apaszke): Upstream and remove those. - jaxMlirRegisterMemRefPasses(); - jaxMlirRegisterInterfaceExternalModels(registry); - jaxMlirRegisterGPUToNVVMPipeline(); -#endif + mlirRegisterTransformsPasses(); + // Transforms used by JAX. + mlirRegisterTransformsStripDebugInfo(); }); } diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 004bd531ced6..a1531ff9f6de 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -35,6 +36,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" @@ -70,7 +72,9 @@ constexpr MlirTpuI64TargetTuple TARGET_SHAPE{8, 128}; // TODO(tlongeri): For our use-case, we don't really need C++ exceptions - just // setting the exception object and returning NULL to Python should suffice, but // not sure if this is possible with pybind. -class NotImplementedException : public std::exception {}; +class NotImplementedException : public std::runtime_error { + using runtime_error::runtime_error; +}; } // namespace template <> @@ -92,7 +96,7 @@ struct py::detail::type_caster { } else if (src.is(implicit_dim_cls.attr("SECOND_MINOR"))) { value = MlirTpuImplicitDimSecondMinor; } else { - throw NotImplementedException(); + throw py::value_error(); } return true; } @@ -156,37 +160,59 @@ struct py::detail::type_caster { }; namespace { -class NotImplementedDetector { +// Handler for use with MLIR C API print functions. The 2nd parameter is an +// opaque pointer to "user data" that should always be a string. +void printToString(MlirStringRef c_mlir_str, void* opaque_string) { + std::string* str = static_cast(opaque_string); + CHECK(str != nullptr); + str->append(c_mlir_str.data, c_mlir_str.length); +} + +class DiagnosticCapture { public: - NotImplementedDetector(MlirContext ctx) + DiagnosticCapture(MlirContext ctx) : ctx_(ctx), id_(mlirContextAttachDiagnosticHandler(ctx, handleDiagnostic, this, nullptr)) {} - ~NotImplementedDetector() { mlirContextDetachDiagnosticHandler(ctx_, id_); } - bool detected() const { return detected_; } - - private: - static void handleDiagnosticMessage(MlirStringRef str, - void* opaque_detector) { - // Note that we receive each argument to the stream separately. - // "Not implemented" must be entirely in a single argument. - NotImplementedDetector* detector = - static_cast(opaque_detector); - if (llvm::StringRef(str.data, str.length).contains("Not implemented")) { - detector->detected_ = true; + ~DiagnosticCapture() { mlirContextDetachDiagnosticHandler(ctx_, id_); } + + void throwIfError() { + if (error_messages_.size() == 1) { + // Throw NotImplementedException if we got a single diagnostic that + // contains "Not implemented". + llvm::StringRef ref = error_messages_.front(); + constexpr llvm::StringRef not_implemented = "Not implemented"; + if (const size_t pos = ref.find(not_implemented); + pos != llvm::StringRef::npos) { + // We strip "Not implemented" only if it is a prefix. Sometimes it may + // come after another prefix (e.g. op prefix), in which case we leave it + if (pos == 0) { + ref = ref.drop_front(not_implemented.size()); + ref.consume_front(": "); + } + throw NotImplementedException(ref.str()); + } + } + if (!error_messages_.empty()) { + // Note that it is unusual/unexpected to get multiple diagnostics, so we + // just forward all the error messages. + throw std::runtime_error(llvm::join(error_messages_, "\n")); } } + + private: static MlirLogicalResult handleDiagnostic(MlirDiagnostic diag, void* opaque_detector) { - NotImplementedDetector* detector = - static_cast(opaque_detector); + DiagnosticCapture* detector = + static_cast(opaque_detector); if (mlirDiagnosticGetSeverity(diag) == MlirDiagnosticError) { - mlirDiagnosticPrint(diag, handleDiagnosticMessage, detector); + std::string& message = detector->error_messages_.emplace_back(); + mlirDiagnosticPrint(diag, printToString, &message); } return mlirLogicalResultFailure(); // Propagate to other handlers } - bool detected_ = false; + llvm::SmallVector error_messages_; const MlirContext ctx_; const MlirDiagnosticHandlerID id_; }; @@ -562,7 +588,13 @@ PYBIND11_MODULE(_tpu_ext, m) { " shape: An optional shape of the vector to which both layouts " "apply. More layouts are considered equivalent when the shape is " "specified. Also see the docstring of the generalizes method.") - .def("__eq__", mlirTpuVectorLayoutEquals); + .def("__eq__", mlirTpuVectorLayoutEquals) + .def("__repr__", + [](MlirTpuVectorLayout self) { + std::string str; + mlirTpuVectorLayoutPrint(self, printToString, &str); + return str; + }); // TODO(tlongeri): Can we make the first parameter a VectorType? m.def("assemble", @@ -589,13 +621,11 @@ PYBIND11_MODULE(_tpu_ext, m) { TARGET_SHAPE); }); m.def("disassemble", [](MlirTpuVectorLayout layout, MlirValue val) { - NotImplementedDetector detector(getDefaultContext()); + DiagnosticCapture diag_capture(getDefaultContext()); MlirTpuValueArray val_arr = mlirTpuDisassemble(getDefaultInsertionPoint(), layout, val, TARGET_SHAPE); if (val_arr.vals == nullptr) { - if (detector.detected()) { - throw NotImplementedException(); - } + diag_capture.throwIfError(); throw py::value_error("Failed to disassemble"); } py::array_t np_vals( @@ -609,25 +639,21 @@ PYBIND11_MODULE(_tpu_ext, m) { }); m.def("apply_layout_op", [](int hardware_generation, const MlirOperation c_op) { - NotImplementedDetector detector(getDefaultContext()); + DiagnosticCapture diag_capture(getDefaultContext()); MlirLogicalResult res = mlirTpuApplyLayoutOp(hardware_generation, c_op, TARGET_SHAPE); if (mlirLogicalResultIsFailure(res)) { - if (detector.detected()) { - throw NotImplementedException(); - } + diag_capture.throwIfError(); throw std::runtime_error("applyLayoutOp failed"); } }); m.def("relayout", [](MlirValue v, MlirTpuVectorLayout src, MlirTpuVectorLayout dst) { - NotImplementedDetector detector(getDefaultContext()); + DiagnosticCapture diag_capture(getDefaultContext()); MlirValue new_v = mlirTpuRelayout(getDefaultInsertionPoint(), v, src, dst, TARGET_SHAPE); if (new_v.ptr == nullptr) { - if (detector.detected()) { - throw NotImplementedException(); - } + diag_capture.throwIfError(); throw py::value_error("Failed to relayout"); } return new_v; @@ -636,7 +662,7 @@ PYBIND11_MODULE(_tpu_ext, m) { try { if (p) std::rethrow_exception(p); } catch (const NotImplementedException& e) { - PyErr_SetNone(PyExc_NotImplementedError); + PyErr_SetString(PyExc_NotImplementedError, e.what()); } }); diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index eb8410e1f7df..a14d69881d05 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -73,6 +73,7 @@ cc_library( "@xla//xla:array", "@xla//xla:shape_util", "@xla//xla:util", + "@tsl//tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index 4ae2d738d93b..73b1b1e56ef2 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -24,10 +24,12 @@ limitations under the License. #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MemAlloc.h" +#include "llvm/Support/raw_ostream.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Registration.h" +#include "mlir/CAPI/Utils.h" #include "mlir/CAPI/Wrap.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" @@ -295,6 +297,12 @@ bool mlirTpuVectorLayoutEquivalentTo(MlirTpuVectorLayout layout, unwrap(target_shape)); } +void mlirTpuVectorLayoutPrint( + MlirTpuVectorLayout layout, MlirStringCallback callback, void *userData) { + mlir::detail::CallbackOstream stream(callback, userData); + unwrap(layout)->print(stream); +} + void mlirTpuVregDataBoundsDestroy(MlirTpuVregDataBounds data_bounds) { delete unwrap(data_bounds); } diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h index 60774147abe9..d1f126db4566 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h @@ -176,6 +176,9 @@ MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutEquivalentTo( MlirTpuVectorLayout layout, MlirTpuVectorLayout other, MlirTpuI64ArrayRef shape, MlirTpuI64TargetTuple target_shape); +MLIR_CAPI_EXPORTED void mlirTpuVectorLayoutPrint( + MlirTpuVectorLayout layout, MlirStringCallback callback, void* user_data); + MLIR_CAPI_EXPORTED void mlirTpuVregDataBoundsDestroy( MlirTpuVregDataBounds data_bounds); diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index 5923fe0766b9..7d11e6f49985 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -40,7 +40,6 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Support/MathExtras.h" #include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" @@ -577,10 +576,12 @@ bool VectorLayout::generalizes( } // Since we do not reorder axes, if the shapes resulting from inserting // implicit dimensions resulting are the same in the 2 minormost dimensions - // for both layouts, then the elements must be laid out the same way (i.e. - // layouts are equivalent). - return getImplicitTiledDims(shape, 1) == - other.getImplicitTiledDims(shape, 1); + // for both layouts, then the elements must be laid out the same way (before + // tiling). + if (getImplicitTiledDims(shape, 1) != + other.getImplicitTiledDims(shape, 1)) { + return false; + } } if (tiling_ != other.tiling_) { // Don't fail yet! diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 35d1cabe9fa6..93e4abe3e422 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -232,6 +232,7 @@ def TPU_StridedStoreOp : TPU_Op<"strided_store"> { let hasVerifier = 1; } +// TODO(jevinjiang): deprecate to use dynamic_rotate. def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { let arguments = (ins AnyVector:$value, @@ -249,6 +250,23 @@ def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { let hasVerifier = 1; } +def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> { + let arguments = (ins + AnyVector:$value, + I32:$amount, + SI32Attr:$dimension, + // When the stride is specified, the rotation amount for each index on the + // stride dimension will be (amount + stride * index). + OptionalAttr:$stride, + OptionalAttr:$stride_dimension + ); + let results = (outs AnyVector:$result); + let assemblyFormat = [{ + $value `by` $amount `dim` $dimension attr-dict `:` type($value) `,` type($amount) `->` type($result) + }]; + let hasVerifier = 1; +} + def TPU_IotaOp : TPU_Op<"iota", [Pure]> { let arguments = (ins OptionalAttr:$dimension); let results = (outs AnyVector:$output); @@ -392,6 +410,35 @@ def TPU_CreateMaskOp : TPU_Op<"create_mask", [Pure, SameVariadicOperandSize]> { } def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> { + let summary = "Create a mask masking contiguous rows of subelements."; + // TODO(tlongeri): Why don't we just get `num_subelems` from the result type? + // Taking a parameter and allowing a mismatch is confusing. + let description = [{ + The "half-sublanes", "quarter-sublanes", etc. (unit is determined by + `num_subelems`) of the mask are masked in the range specified by `from` and + `to`. + + - If `from <= to`, the range `[from, to)` is set and the rest is unset. + - If `to <= from`, the range `[to, from)` is unset and the rest is set. + + All lanes are set identically. + + Example: + + ```mlir + %msk = tpu.create_subelement_mask 3, 9, 2 : vector<8x128x2xi1> + ``` + + This creates a mask `%msk` where, for all `lane`s, `%msk[*][lane][*]` is: + + ``` + [[0, 0], [0, 1], [1, 1], [1, 1], [1, 0], [0, 0], [0, 0], [0, 0]] + ``` + + It is currently only supported: + - In TPU v4, for `num_subelems` of 1 and 2. + - In TPU v5, for `num_subelems` of 1, 2, and 4. + }]; let arguments = (ins I32Attr:$from, // inclusive I32Attr:$to, // exclusive @@ -576,6 +623,12 @@ def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> { let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; } +def TPU_GetInternalScratchOp : TPU_Op<"internal_scratch"> { + let arguments = (ins); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ attr-dict `:` type($result) }]; +} + def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> { let arguments = (ins Variadic:$seeds); let results = (outs); @@ -666,6 +719,7 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, Option<"mxu_contracting_size", "mxu-contracting-size", "int", /*default=*/"128", "">, Option<"mxu_noncontracting_size", "mxu-noncontracting-size", "int", /*default=*/"128", "">, + Option<"max_sublanes_in_scratch", "max-sublanes-in-scratch", "int", /*default=*/"0", "">, ]; } @@ -679,7 +733,10 @@ def LinalgVectorizationPass : Pass<"linalg-vectorization", "::mlir::func::FuncOp "::mlir::tpu::TPUDialect", ]; let constructor = "::mlir::tpu::createLinalgVectorizationPass(false)"; - let options = [Option<"supports_bf16_alu_instructions", "supports-bf16-alu-instructions", "bool", "", "">]; + let options = [ + Option<"supports_bf16_alu_instructions", "supports-bf16-alu-instructions", "bool", "", "">, + Option<"supports_bf16_matmul", "supports-bf16-matmul", "bool", "", "">, + ]; } #endif // TPU_ATTRS diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index abc7cc595cd6..df00093fabe6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -189,6 +189,9 @@ bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) { if (fuel <= 0) { return false; } + if (divisor == 1) { + return true; + } if (auto assume_op = value.getDefiningOp()) { return assume_op.getMultiple() % divisor == 0; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index f8a4905d31d7..dc5b68246e3f 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -58,13 +58,15 @@ std::unique_ptr> createInferVectorLayoutPass( std::unique_ptr> createApplyVectorLayoutPass( int hardware_generation = -1, int lane_count = 128, int sublane_count = 8, - int mxu_contracting_size = 128, int mxu_noncontracting_size = 128); + int mxu_contracting_size = 128, int mxu_noncontracting_size = 128, + int max_sublanes_in_scratch = 0); std::unique_ptr> createLogicalToPhysicalDeviceIdPass(int64_t total_devices); std::unique_ptr> createLinalgVectorizationPass( - bool supports_bf16_alu_instructions = false); + bool supports_bf16_alu_instructions = false, + bool supports_bf16_matmul = false); std::unique_ptr> createDebugAssertInsertionPass(); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 9ed7c43f1767..8cea43739719 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -230,28 +230,26 @@ LogicalResult ReinterpretCastOp::verify() { source_type.getMemorySpace() == target_type.getMemorySpace()); } -LogicalResult RotateOp::verify() { - auto vty = getResult().getType(); - if (vty.getRank() <= getDimension() || getDimension() < 0) { - emitOpError("Invalid dimension: ") << getDimension(); - return failure(); - } - if (getAmount() < 0) { - emitOpError("Rotate amount must be >= 0"); +template +LogicalResult verifyRotateOp(Op op) { + auto vty = op.getResult().getType(); + if (vty.getRank() <= op.getDimension() || op.getDimension() < 0) { + op.emitOpError("Invalid dimension: ") << op.getDimension(); return failure(); } - if (getStride().has_value() && getStride().value() < 0) { - emitOpError("Rotate stride must be >= 0 if it is specified"); + if (op.getStride().has_value() && op.getStride().value() < 0) { + op.emitOpError("Rotate stride must be >= 0 if it is specified"); return failure(); } - if (getStrideDimension().has_value() && - (vty.getRank() <= getStrideDimension().value() || - getStrideDimension().value() < 0)) { - emitOpError("Invalid stride dimension: ") << getStrideDimension().value(); + if (op.getStrideDimension().has_value() && + (vty.getRank() <= op.getStrideDimension().value() || + op.getStrideDimension().value() < 0)) { + op.emitOpError("Invalid stride dimension: ") + << op.getStrideDimension().value(); return failure(); } - if (getStride().has_value() != getStrideDimension().has_value()) { - emitOpError( + if (op.getStride().has_value() != op.getStrideDimension().has_value()) { + op.emitOpError( "Expected either none or both stride and stride dimension are " "present"); return failure(); @@ -259,6 +257,13 @@ LogicalResult RotateOp::verify() { return success(); } +// TODO(b/347016737): deprecate static rotate +LogicalResult RotateOp::verify() { return verifyRotateOp(*this); } + +LogicalResult DynamicRotateOp::verify() { + return verifyRotateOp(*this); +} + // a + matmul(l, r, 0) == matmul(l, r, a) template class CanonicalizeAddOfMatmul : public OpRewritePattern { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index d90733d7791c..f0c568aeea48 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -7,10 +7,10 @@ #include #include #include +#include #include #include #include -#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" @@ -110,6 +110,14 @@ namespace mlir::tpu { #define TPU_ASSERT_LE_LOC(loc, lhs, rhs) \ TPU_ASSERT_CMP_LOC_IMPL(mlir::emitError(loc), lhs, rhs, <=) +// The minimum bound required to rotate with scratch space. The bound refers to +// the number of VREGs on rotation dim. This number was concluded from some cost +// analysis for comparing different dynamic rotation implementations. If +// actual bound is greater than this, dynamic rotation with internal scratch +// space is more efficient. +// TODO(jevinjiang): need to update it based on the generation. +static constexpr int kMinBoundToRotateWithScratch = 27; + LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block); namespace { @@ -157,6 +165,36 @@ FailureOr maskOOB(RewriteContext &ctx, OpBuilder &builder, .getResult(); } +// Get the address of pre-allocated internal scratch space with requested shape. +// +// Arguments: +// shape: The shape of the requested scratch space. +// elem_ty: The type of the elements in the requested scratch space. +// +// Returns: +// A memref of the requested shape and type. +FailureOr getInternalScratch(RewriteContext &ctx, OpBuilder &builder, + Location loc, ArrayRef shape, + Type elem_ty) { + if (shape.empty()) { + return failure(); + } + if (shape.back() % ctx.target_shape[1] != 0) { + return failure(); + } + int sublane_count = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) / + ctx.target_shape[1]; + if (sublane_count > ctx.max_sublanes_in_scratch) { + return failure(); + } + FAILUREOR_ASSIGN_OR_RETURN( + MemRefType scratch_ref_ty, + inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation)); + return builder.create(loc, scratch_ref_ty) + .getResult(); +} + // Models Numpy's np.repeat, repeating each element `repeats` times along the // specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is // 3, this will return [1, 1, 1, 2, 2, 2]. @@ -204,6 +242,21 @@ xla::Array concatenate(const ArrayRef> arrays, return res; } +SmallVector> split(const xla::Array &vregs, int axis) { + CHECK(axis >= 0 && axis < vregs.num_dimensions()); + SmallVector> chunks; + chunks.reserve(vregs.dim(axis)); + SmallVector starts(vregs.num_dimensions(), 0); + SmallVector limits(vregs.dimensions().begin(), + vregs.dimensions().end()); + for (int64_t i = 0; i < vregs.dim(axis); ++i) { + starts[axis] = i; + limits[axis] = i + 1; + chunks.push_back(vregs.Slice(starts, limits)); + } + return chunks; +}; + template ArrayRef XlaArrayToFlatArrayRef(xla::Array xla_array) { return ArrayRef(xla_array.data(), xla_array.num_elements()); @@ -719,27 +772,39 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, const VectorType res_vreg_ty, getNativeVregType(result_ty.getElementType(), ctx.target_shape)); if (layout_in.implicit_dim() != layout_out.implicit_dim()) { - return op.emitOpError("Not implemented: Change of layout during the cast"); + return op.emitOpError( + "Not implemented: Change of implicit dim during the cast"); } if (layout_in.offsets() != layout_out.offsets()) { return op.emitOpError("Not implemented: Change of offsets during the cast"); } - if (layout_in.tiling() != layout_out.tiling()) { - return op.emitOpError("Not implemented: Changing tiling during the cast"); - } - auto tiling = layout_in.tiling(); - if (ctx.target_shape[0] % tiling[0] != 0 || - ctx.target_shape[1] != tiling[1]) { - return op.emitOpError("Not implemented: tiling not supported"); - } const int packing = layout_in.packing(); - output_vregs.Each([&](absl::Span idxs, Value *v) { - SmallVector input_vreg_idxs(toArrayRef(idxs)); - input_vreg_idxs.back() /= packing; - const int64_t vreg_part = idxs.back() % packing; - *v = builder.create( - res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); - }); + if (layout_in.hasNativeTiling(ctx.target_shape) && + layout_out.hasNativeTiling(ctx.target_shape)) { + output_vregs.Each([&](absl::Span idxs, Value *v) { + SmallVector input_vreg_idxs(toArrayRef(idxs)); + int64_t vreg_part = *(input_vreg_idxs.end() - 2) % packing; + *(input_vreg_idxs.end() - 2) /= packing; + *v = builder.create( + res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); + }); + } else { + if (layout_in.tiling() != layout_out.tiling()) { + return op.emitOpError("Not implemented: Changing tiling during the cast"); + } + auto tiling = layout_in.tiling(); + if (ctx.target_shape[0] % tiling[0] != 0 || + ctx.target_shape[1] != tiling[1]) { + return op.emitOpError("Not implemented: tiling not supported"); + } + output_vregs.Each([&](absl::Span idxs, Value *v) { + SmallVector input_vreg_idxs(toArrayRef(idxs)); + input_vreg_idxs.back() /= packing; + const int64_t vreg_part = idxs.back() % packing; + *v = builder.create( + res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); + }); + } if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { output_vregs.Reshape(output_vregs_shape); } @@ -875,9 +940,10 @@ LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_OP(layouts_out.front().has_value()); auto truncf_op = cast(op); if (layouts_in.front()->bitwidth() != 32 || - layouts_out.front()->bitwidth() != 16) { + (layouts_out.front()->bitwidth() != 16 && + layouts_out.front()->bitwidth() != 8)) { return op.emitOpError( - "Not implemented: Only 32-bit to 16-bit conversion supported"); + "Not implemented: Only 32-bit to 16-or-8-bit conversion supported"); } return trunc_op_rule_impl(ctx, truncf_op, *layouts_in.front(), *layouts_out.front()); @@ -913,16 +979,35 @@ LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op, scf::ForOp for_op = cast(op); TPU_ASSERT_EQ_OP(layouts_in.size(), for_op->getNumOperands()); TPU_ASSERT_EQ_OP(layouts_out.size(), for_op->getNumResults()); - if (!llvm::equal(layouts_in.drop_front(3), layouts_out)) { - return op.emitOpError( - "Expected matched layouts in scf.for's inputs and outputs"); - } FAILUREOR_ASSIGN_OR_RETURN( const SmallVector yield_in_layouts, getInLayouts(*for_op.getBody()->getTerminator(), ctx.target_shape)); - if (!llvm::equal(ArrayRef(yield_in_layouts), layouts_out)) { - return op.emitOpError( - "Expected matched layouts in scf.yield operands and scf.for's results"); + int out_idx = 0; + for (auto [in_layout, yield_layout, out_layout, result] : + llvm::zip_equal(layouts_in.drop_front(3), yield_in_layouts, layouts_out, + op.getResults())) { + if (auto vty = dyn_cast(result.getType())) { + TPU_ASSERT_OP(in_layout.has_value()); + TPU_ASSERT_OP(yield_layout.has_value()); + TPU_ASSERT_OP(out_layout.has_value()); + if (in_layout.value() != yield_layout.value()) { + return op.emitOpError( + "Not implemented: for loop input layout does not match with " + "yield layout ") + << out_idx; + } + if (in_layout.value() != out_layout.value()) { + return op.emitOpError( + "Not implemented: for loop input layout does not match with " + "out layout ") + << out_idx; + } + } else { + TPU_ASSERT_EQ_OP(in_layout, kNoLayout); + TPU_ASSERT_EQ_OP(yield_layout, kNoLayout); + TPU_ASSERT_EQ_OP(out_layout, kNoLayout); + } + ++out_idx; } if (failed(applyLayoutBlock(ctx, *for_op.getBody()))) { @@ -1047,30 +1132,53 @@ LogicalResult scf_while_rule(RewriteContext &ctx, Operation &op, // It takes multiple arguments -- the first being the decision to execute the // after region or branch to the exit. FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector condition_in_layouts, + const SmallVector cond_in_layouts, getInLayouts(*while_op.getBeforeBody()->getTerminator(), ctx.target_shape)); - if (!llvm::equal(ArrayRef(condition_in_layouts).drop_front(1), - layouts_out)) { - return op.emitOpError( - "Mismatched layouts between scf.while result and its before region " - "condition."); + + FAILUREOR_ASSIGN_OR_RETURN( + const SmallVector yield_in_layouts, + getInLayouts(*while_op.getYieldOp(), ctx.target_shape)); + int out_idx = 0; + for (auto [in_layout, cond_layout, yield_layout, out_layout, result] : + llvm::zip_equal(layouts_in, + ArrayRef(cond_in_layouts).drop_front(1), + yield_in_layouts, layouts_out, op.getResults())) { + if (auto vty = dyn_cast(result.getType())) { + TPU_ASSERT_OP(in_layout.has_value()); + TPU_ASSERT_OP(yield_layout.has_value()); + TPU_ASSERT_OP(out_layout.has_value()); + if (in_layout.value() != cond_layout.value()) { + return op.emitOpError( + "Not implemented: while loop input layout does not match " + "with condition layout ") + << out_idx; + } + if (in_layout.value() != yield_layout.value()) { + return op.emitOpError( + "Not implemented: while loop input layout does not match " + "with yield layout ") + << out_idx; + } + if (in_layout.value() != out_layout.value()) { + return op.emitOpError( + "Not implemented: while loop input layout does not match " + "with output layout ") + << out_idx; + } + } else { + TPU_ASSERT_EQ_OP(in_layout, kNoLayout); + TPU_ASSERT_EQ_OP(cond_layout, kNoLayout); + TPU_ASSERT_EQ_OP(yield_layout, kNoLayout); + TPU_ASSERT_EQ_OP(out_layout, kNoLayout); + } + ++out_idx; } if (failed(applyLayoutBlock(ctx, *while_op.getBeforeBody()))) { return failure(); } - FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector after_yield_in_layouts, - getInLayouts(*while_op.getYieldOp(), ctx.target_shape)); - if (!layouts_out.empty() && - ArrayRef(after_yield_in_layouts) != layouts_out) { - return op.emitOpError( - "Not implemented: different layouts while's yield's operands and " - "results"); - } - if (failed(applyLayoutBlock(ctx, *while_op.getAfterBody()))) { return failure(); } @@ -1221,17 +1329,42 @@ LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_OP(!layouts_in.front().has_value()); ImplicitLocOpBuilder builder(op.getLoc(), &op); scf::IfOp if_op = cast(op); + SmallVector then_yield_in_layouts; + SmallVector else_yield_in_layouts; FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector then_yield_in_layouts, + then_yield_in_layouts, getInLayouts(*if_op.thenYield(), ctx.target_shape)); - // TODO(tlongeri): ArrayRef conversion should not be necessary, fix - // after LLVM adds const qualifiers to ==/!= operators. Also - // applies to else_yield_in_layouts comparison below. - if (!layouts_out.empty() && - ArrayRef(then_yield_in_layouts) != layouts_out) { - return op.emitOpError( - "Not implemented: different layouts in then yield's operands and if's " - "results"); + if (!if_op.getElseRegion().empty()) { + FAILUREOR_ASSIGN_OR_RETURN( + else_yield_in_layouts, + getInLayouts(*if_op.elseYield(), ctx.target_shape)); + } + int out_idx = 0; + for (auto [then_layout, else_layout, result_layout, result] : + llvm::zip_equal(then_yield_in_layouts, else_yield_in_layouts, + layouts_out, op.getResults())) { + if (auto vty = dyn_cast(result.getType())) { + TPU_ASSERT_OP(then_layout.has_value()); + TPU_ASSERT_OP(else_layout.has_value()); + TPU_ASSERT_OP(result_layout.has_value()); + if (result_layout.value() != then_layout.value()) { + return op.emitOpError( + "Not implemented: yield layout from then branch does not " + "match with output layout ") + << out_idx; + } + if (result_layout.value() != else_layout.value()) { + return op.emitOpError( + "Not implemented: yield layout from else branch does not " + "match with output layout ") + << out_idx; + } + } else { + TPU_ASSERT_EQ_OP(then_layout, kNoLayout); + TPU_ASSERT_EQ_OP(else_layout, kNoLayout); + TPU_ASSERT_EQ_OP(result_layout, kNoLayout); + } + ++out_idx; } if (failed(applyLayoutBlock(ctx, *if_op.thenBlock()))) { return failure(); @@ -1241,15 +1374,6 @@ LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(layouts_out.size(), 0); return success(); } - FAILUREOR_ASSIGN_OR_RETURN( - const SmallVector else_yield_in_layouts, - getInLayouts(*if_op.elseYield(), ctx.target_shape)); - if (!layouts_out.empty() && - ArrayRef(else_yield_in_layouts) != layouts_out) { - return op.emitOpError( - "Not implemented: different layouts in else yield's operands and if's " - "results"); - } if (failed(applyLayoutBlock(ctx, *if_op.elseBlock()))) { return failure(); } @@ -1930,19 +2054,13 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op, return success(); } -LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - CHECK_EQ(layouts_in.size(), 1); - CHECK_EQ(layouts_out.size(), 1); - if (!layouts_in.front().has_value()) { - return op.emitOpError("Expected non-null input layout"); - } - if (!layouts_out.front().has_value()) { - return op.emitOpError("Expected non-null output layout"); - } - const VectorLayout &layout_in = *layouts_in.front(); - const VectorLayout &layout_out = *layouts_out.front(); +// TODO(b/347016737): Deprecate tpu.rotate and only use tpu.dynamic_rotate. So +// we do not need template for the op type and to explicitly force amount +// argument to dynamic. +template +LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, + const VectorLayout &layout_in, + const VectorLayout &layout_out) { auto layout = VectorLayout(32, {0, 0}, ctx.target_shape, VectorLayout::ImplicitDim::kNone); if (layout_in != layout) { @@ -1951,8 +2069,7 @@ LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, if (layout_out != layout) { return op.emitOpError("Not implemented: unsupported layout for output"); } - tpu::RotateOp rotate_op = cast(op); - auto vty = rotate_op.getResult().getType(); + auto vty = op.getResult().getType(); if (vty.getRank() < 2) { return op.emitOpError("Not implemented: unsupported 1D shape"); } @@ -1961,23 +2078,77 @@ LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Not implemented: unsupported unaliged shape"); } - ImplicitLocOpBuilder builder(op.getLoc(), &op); + ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); FAILUREOR_ASSIGN_OR_RETURN( VectorType res_vreg_ty, getNativeVregType(vty.getElementType(), ctx.target_shape)); FAILUREOR_ASSIGN_OR_RETURN( const xla::Array in_tiles, - disassemble(builder, layout_in, rotate_op.getValue(), ctx.target_shape)); + disassemble(builder, layout_in, op.getValue(), ctx.target_shape)); FAILUREOR_ASSIGN_OR_RETURN( const VectorType i32_vreg, getNativeVregType(builder.getI32Type(), ctx.target_shape)); - auto getVmaskByPaddingEnd = [&](int dim, int padding, int stride = 0) { + + // Some helper functions for math ops. + auto mlirI32Const = [&](int d) { + return builder.create( + builder.getIntegerAttr(builder.getI32Type(), d)); + }; + auto mlirIndexConst = [&](int d) { + return builder.create( + builder.getIntegerAttr(builder.getIndexType(), d)); + }; + auto modI = [&](const Value &v, unsigned d) -> Value { + if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + return mlirI32Const(cst.value() % d); + } + return builder.create(v, mlirI32Const(d)); + }; + auto divI = [&](const Value &v, unsigned d) -> Value { + if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + return mlirI32Const(cst.value() / d); + } + return builder.create(v, mlirI32Const(d)); + }; + auto addI = [&](const Value &v, unsigned d) -> Value { + if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + return mlirI32Const(cst.value() + d); + } + return builder.create(v, mlirI32Const(d)); + }; + + // A helper function that creates a VMASK with false flags to bottom (dim = 0) + // or right (dim = 1) where the flag count corresponds to the (dim_size - + // padding). If stride is provided, the padding value is sequentially + // increased by the stride value along the dim. + // + // For example, assume VMASK shape is (4, 8) + // + // getVmaskByPaddingEnd(padding=3, dim=1) creates: + // [T, T, T, T, T, F, F, F] + // [T, T, T, T, T, F, F, F] + // [T, T, T, T, T, F, F, F] + // [T, T, T, T, T, F, F, F] + // + // getVmaskByPaddingEnd(padding=3, dim=1, stride=1) creates: + // [T, T, T, T, T, F, F, F] + // [T, T, T, T, T, T, F, F] + // [T, T, T, T, T, T, T, F] + // [T, T, T, T, T, T, T, T] + auto getVmaskByPaddingEnd = [&](Value padding, int dim, int stride = 0) { CHECK(dim == 0 || dim == 1); - CHECK(padding >= 0 && padding <= ctx.target_shape[dim]); - Value padding_vreg = builder.create( - DenseElementsAttr::get(i32_vreg, builder.getI32IntegerAttr( - ctx.target_shape[dim] - padding))); + Value padding_vreg; + if (auto padding_cst = getIntConst(padding, /*silent=*/true); + succeeded(padding_cst)) { + CHECK_GE(padding_cst.value(), 0); + CHECK_LE(padding_cst.value(), ctx.target_shape[dim]); + padding_vreg = builder.create(DenseElementsAttr::get( + i32_vreg, builder.getI32IntegerAttr(padding_cst.value()))); + } else { + padding_vreg = builder.create(i32_vreg, padding); + } + if (stride > 0) { auto offset = builder.create( i32_vreg, @@ -1994,77 +2165,155 @@ LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, padding_vreg); }; - auto splitVregs = [](const xla::Array &vregs, int axis) { - CHECK(axis >= 0 && axis < vregs.num_dimensions()); - SmallVector> chunks; - chunks.reserve(vregs.dim(axis)); - for (int64_t i = 0; i < vregs.dim(axis); ++i) { - SmallVector starts(vregs.num_dimensions(), 0); - starts[axis] = i; - SmallVector limits(vregs.dimensions().begin(), - vregs.dimensions().end()); - limits[axis] = i + 1; - chunks.push_back(vregs.Slice(starts, limits)); + // Apply rotation on each vreg with the assumption that shift <= VREG dim size + // and blend the data from contiguous vregs to emulate circular rotation. + auto rotateOnTilingDim = [&](const xla::Array &vregs, + const Value &shift, int axis, int stride = 0) { + if (auto shift_cst = getIntConst(shift, /*silent=*/true); + succeeded(shift_cst)) { + if (shift_cst.value() == 0 && stride == 0) { + return vregs; + } + } + int tiling_dim = axis - (vregs.num_dimensions() - 2); + CHECK((tiling_dim == 0 && stride == 0) || (tiling_dim == 1 && stride >= 0)); + xla::Array result(vregs.dimensions()); + auto chunks = split(vregs, axis); + for (int64_t i = 0; i < chunks.size(); ++i) { + chunks[i].Each([&](absl::Span idxs, Value *v) { + auto stride_attr = + stride > 0 ? builder.getSI32IntegerAttr(stride) : nullptr; + auto stride_dimension_attr = + stride > 0 ? builder.getSI32IntegerAttr(0) : nullptr; + *v = builder.create(res_vreg_ty, *v, shift, + tiling_dim, stride_attr, + stride_dimension_attr); + }); } - return chunks; + auto mask = getVmaskByPaddingEnd(shift, tiling_dim, stride); + xla::Array last_chunk_copy(chunks[chunks.size() - 1]); + for (int64_t i = chunks.size() - 1; i > 0; --i) { + chunks[i].Each([&](absl::Span idxs, Value *v) { + *v = builder.create(mask, chunks[i - 1](idxs), *v); + }); + } + chunks[0].Each([&](absl::Span idxs, Value *v) { + *v = builder.create(mask, last_chunk_copy(idxs), *v); + }); + return concatenate(chunks, axis); }; - auto roll = [&](const xla::Array &vregs, int64_t shift, int axis, - int stride = 0) { + + std::function(const xla::Array &, Value, int, int)> + rotate; + rotate = [&](const xla::Array &vregs, Value shift, int axis, + int stride) { xla::Array result(vregs.dimensions()); CHECK(axis >= 0 && axis < vregs.num_dimensions()); - auto chunks = splitVregs(vregs, axis); - if (axis >= vregs.num_dimensions() - 2) { - int tiling_dim = axis - (vregs.num_dimensions() - 2); - int64_t shift_in_vreg = shift % ctx.target_shape[tiling_dim]; - shift /= ctx.target_shape[tiling_dim]; - CHECK((tiling_dim == 0 && stride == 0) || - (tiling_dim == 1 && stride >= 0)); - if (shift_in_vreg != 0 || stride != 0) { - for (int64_t i = 0; i < chunks.size(); ++i) { - chunks[i].Each([&](absl::Span idxs, Value *v) { - auto stride_attr = - stride > 0 ? builder.getSI32IntegerAttr(stride) : nullptr; - auto stride_dimension_attr = - stride > 0 ? builder.getSI32IntegerAttr(0) : nullptr; - *v = builder.create(res_vreg_ty, *v, shift_in_vreg, - tiling_dim, stride_attr, - stride_dimension_attr); - }); - } - // After rotation on each vreg, we need to select the wrapped data - // from the previous vreg and overwrite them to the current vreg. - auto mask = getVmaskByPaddingEnd( - tiling_dim, ctx.target_shape[tiling_dim] - shift_in_vreg, stride); - xla::Array last_chunk_copy(chunks[chunks.size() - 1]); - for (int64_t i = chunks.size() - 1; i > 0; --i) { - chunks[i].Each([&](absl::Span idxs, Value *v) { - *v = builder.create(mask, chunks[i - 1](idxs), *v); - }); - } + int tiling_dim = axis - (vregs.num_dimensions() - 2); + CHECK((tiling_dim != 1 && stride == 0) || (tiling_dim == 1 && stride >= 0)); + SmallVector, 4> chunks; + // Handle rotation with static shift. + if (auto shift_cst = getIntConst(shift, /*silent=*/true); + succeeded(shift_cst)) { + int64_t static_shift = shift_cst.value(); + if (tiling_dim >= 0) { + shift = mlirI32Const(static_shift % ctx.target_shape[tiling_dim]); + static_shift /= ctx.target_shape[tiling_dim]; + chunks = split(rotateOnTilingDim(vregs, shift, axis, stride), axis); + } else { + chunks = split(vregs, axis); + } + // Now we only need to shuffle vregs. + for (int64_t i = 0; i < chunks.size(); ++i) { + SmallVector starts(result.num_dimensions(), 0); + starts[axis] = (i + static_shift) % result.dim(axis); + result.UpdateSlice(chunks[i], starts); + } + return result; + } + // Handle rotation with dynamic shift. + // TODO(jevinjiang): consider optimize with assume_multiple op. + Value in_vreg_shift = tiling_dim >= 0 + ? modI(shift, ctx.target_shape[tiling_dim]) + : mlirI32Const(0); + Value vreg_shift = + tiling_dim >= 0 ? divI(shift, ctx.target_shape[tiling_dim]) : shift; + result = tiling_dim >= 0 + ? rotateOnTilingDim(vregs, in_vreg_shift, axis, stride) + : vregs; + int bound = vregs.dim(axis); + if (bound <= ctx.max_sublanes_in_scratch / ctx.target_shape[0] && + bound >= kMinBoundToRotateWithScratch) { + // Use static store + dynamic load to implement dynamic shift. + if (auto scratch_ref = getInternalScratch( + ctx, builder, op.getLoc(), + {ctx.max_sublanes_in_scratch / ctx.target_shape[0], + ctx.target_shape[0], ctx.target_shape[1]}, + vty.getElementType()); + succeeded(scratch_ref)) { + auto cst_0 = mlirIndexConst(0); + SmallVector scratch_indices(3, cst_0); + SmallVector sublane_mask(ctx.target_shape[0], true); + const auto sublane_mask_attr = + DenseBoolArrayAttr::get(op.getContext(), sublane_mask); + chunks = split(result, axis); chunks[0].Each([&](absl::Span idxs, Value *v) { - *v = builder.create(mask, last_chunk_copy(idxs), *v); + // Static store vregs. + for (int i = 0; i < bound; ++i) { + scratch_indices[0] = mlirIndexConst(i); + builder.create(chunks[i](idxs), scratch_ref.value(), + scratch_indices, sublane_mask_attr, + /*mask=*/nullptr, + /*sublane_stride=*/nullptr); + } + // Dynamic load vregs back from a circular buffer. + for (int i = 0; i < bound; ++i) { + scratch_indices[0] = builder.create( + builder.getIndexType(), + modI(builder.create(mlirI32Const(bound + i), + vreg_shift), + bound)); + chunks[i](idxs) = + builder.create(v->getType(), scratch_ref.value(), + scratch_indices, sublane_mask_attr, + /*sublane_stride=*/nullptr); + } }); + return concatenate(chunks, axis); } - } else { - CHECK_EQ(stride, 0); } - // Now we only need to shuffle vregs. - for (int64_t i = 0; i < chunks.size(); ++i) { - SmallVector starts(result.num_dimensions(), 0); - starts[axis] = (i + shift) % result.dim(axis); - result.UpdateSlice(chunks[i], starts); + // Convert dynamic shift to log(bound) static ops. + int roll_by = 1; + Value cst_1 = mlirI32Const(1); + while (bound > 0) { + auto new_result = rotate( + result, + mlirI32Const(tiling_dim >= 0 ? roll_by * ctx.target_shape[tiling_dim] + : roll_by), + axis, /*stride=*/0); + auto mask = builder.create( + arith::CmpIPredicate::ne, + builder.create( + i32_vreg, builder.create(vreg_shift, cst_1)), + builder.create( + DenseElementsAttr::get(i32_vreg, builder.getI32IntegerAttr(0)))); + result.Each([&](absl::Span idxs, Value *v) { + *v = builder.create(mask, new_result(idxs), *v); + }); + roll_by *= 2; + bound /= 2; + vreg_shift = divI(vreg_shift, 2); } return result; }; xla::Array out_tiles(in_tiles.dimensions()); - const auto dim = rotate_op.getDimension(); - const auto amount = rotate_op.getAmount() % vty.getDimSize(dim); + const auto dim = op.getDimension(); + amount = modI(amount, vty.getDimSize(dim)); - if (rotate_op.getStride().has_value() && - rotate_op.getStrideDimension().has_value()) { - auto stride_dim = rotate_op.getStrideDimension().value(); - auto stride = rotate_op.getStride().value() % vty.getDimSize(stride_dim); + if (op.getStride().has_value() && op.getStrideDimension().has_value()) { + auto stride_dim = op.getStrideDimension().value(); + auto stride = op.getStride().value() % vty.getDimSize(stride_dim); if (stride_dim == dim) { return op.emitOpError( "Expected rotation dimension and stride dimension are not equal"); @@ -2079,46 +2328,96 @@ LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, "is the minor most when stride dimension is the second minor most"); } CHECK_GE(stride, 0); - auto chunks = splitVregs(in_tiles, stride_dim); + auto chunks = split(in_tiles, stride_dim); for (int64_t i = 0; i < chunks.size(); ++i) { - int64_t base_amount = - (ctx.target_shape[0] * i * stride + amount) % vty.getDimSize(dim); + Value base_amount = modI(addI(amount, ctx.target_shape[0] * i * stride), + vty.getDimSize(dim)); // After applying stride, we expect all shifts in a vreg are less or // equal to the vreg's lane count for now. - auto max_shift_in_vreg = base_amount % ctx.target_shape[1] + - (ctx.target_shape[0] - 1) * stride; - if (max_shift_in_vreg > ctx.target_shape[1]) { - return op.emitOpError("Not implemented: the max shift in a vreg ") - << max_shift_in_vreg << " is larger than the vreg's width " - << ctx.target_shape[1]; + if (auto base_amount_cst = getIntConst(base_amount, /*silent=*/true); + succeeded(base_amount_cst)) { + int64_t static_base_amount = base_amount_cst.value(); + auto max_shift_in_vreg = static_base_amount % ctx.target_shape[1] + + (ctx.target_shape[0] - 1) * stride; + if (max_shift_in_vreg > ctx.target_shape[1]) { + return op.emitOpError("Not implemented: the max shift in a vreg ") + << max_shift_in_vreg << " is larger than the vreg's width " + << ctx.target_shape[1]; + } } SmallVector starts(out_tiles.num_dimensions(), 0); starts[stride_dim] = i; - out_tiles.UpdateSlice(roll(chunks[i], base_amount, dim, stride), + out_tiles.UpdateSlice(rotate(chunks[i], base_amount, dim, stride), starts); } } else { // Split vregs along the stride dimension. - auto chunks = splitVregs(in_tiles, stride_dim); + auto chunks = split(in_tiles, stride_dim); for (int64_t i = 0; i < chunks.size(); ++i) { SmallVector starts(out_tiles.num_dimensions(), 0); starts[stride_dim] = i; - out_tiles.UpdateSlice(roll(chunks[i], amount + i * stride, dim), - starts); + out_tiles.UpdateSlice( + rotate(chunks[i], addI(amount, i * stride), dim, /*stride=*/0), + starts); } } } else { // No stride. - out_tiles = roll(in_tiles, amount, dim); + out_tiles = rotate(in_tiles, amount, dim, /*stride=*/0); } const RollVectorsOp rolled_op = - assemble(builder, rotate_op.getResult().getType(), layout_out, out_tiles, + assemble(builder, op.getResult().getType(), layout_out, out_tiles, ctx.target_shape); op.replaceAllUsesWith(rolled_op); op.erase(); return success(); } +// TODO(b/347016737): deprecate the static rotate. +LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + CHECK_EQ(layouts_in.size(), 1); + CHECK_EQ(layouts_out.size(), 1); + if (!layouts_in.front().has_value()) { + return op.emitOpError("Expected non-null input layout"); + } + if (!layouts_out.front().has_value()) { + return op.emitOpError("Expected non-null output layout"); + } + auto rotate_op = cast(op); + if (rotate_op.getAmount() < 0) { + return op.emitOpError("Not implemented: shifting by negative amount"); + } + ImplicitLocOpBuilder builder(op.getLoc(), &op); + Value shift = builder.create( + builder.getIntegerAttr(builder.getI32Type(), rotate_op.getAmount())); + const VectorLayout &layout_in = *layouts_in.front(); + const VectorLayout &layout_out = *layouts_out.front(); + return rotate_rule_impl(ctx, rotate_op, shift, layout_in, layout_out); +} + +LogicalResult tpu_dynamic_rotate_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + CHECK_EQ(layouts_in.size(), 2); + CHECK_EQ(layouts_out.size(), 1); + if (!layouts_in.front().has_value()) { + return op.emitOpError("Expected non-null layout for the value to rotate"); + } + if (layouts_in[1].has_value()) { + return op.emitOpError("Expected null layout for the shift"); + } + if (!layouts_out.front().has_value()) { + return op.emitOpError("Expected non-null output layout"); + } + auto rotate_op = cast(op); + const VectorLayout &layout_in = *layouts_in.front(); + const VectorLayout &layout_out = *layouts_out.front(); + return rotate_rule_impl(ctx, rotate_op, rotate_op.getAmount(), layout_in, + layout_out); +} + LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -2461,7 +2760,9 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( Tiling memref_tiling, getMemRefTiling(load_op.getBase(), ctx.target_shape)); - if (layout_out.tiling() != memref_tiling) { + if (memref_tiling != layout_out.tiling() && + !(memref_tiling[0] == 1 && layout_out.tiling()[0] == 1 && + memref_tiling[1] % layout_out.tiling()[1] == 0)) { // Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes). // TODO(b/295393167): need to support strided load for bitwidth < 32. if (layout_out.bitwidth() != 32 || @@ -3259,12 +3560,10 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, case vector::CombiningKind::ADD: neutral = builder.getF32FloatAttr(0); break; - case vector::CombiningKind::MAXNUMF: case vector::CombiningKind::MAXIMUMF: { // TODO(b/322836633): The semantics of maximumf don't match the lowering // for older TPU versions because older TPU versions don't respect the - // -0.0 vs +0.0 ordering. Keeping MAXNUMF for backward compatibility of - // serialized artifacts. + // -0.0 vs +0.0 ordering. neutral = builder.getFloatAttr( builder.getF32Type(), APFloat::getInf(APFloat::IEEEsingle(), /*Negative=*/true)); @@ -3358,7 +3657,6 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, case vector::CombiningKind::ADD: tpu_kind = tpu::ReductionKind::SUM; break; - case vector::CombiningKind::MAXNUMF: case vector::CombiningKind::MAXIMUMF: tpu_kind = tpu::ReductionKind::MAX; break; @@ -3466,53 +3764,29 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, const ArrayRef src_shape = src_ty.getShape(); const VectorType dst_ty = shape_cast_op.getResultVectorType(); const ArrayRef dst_shape = dst_ty.getShape(); - const int layout_rank = layout_in.layout_rank(); bool no_op = false; + const std::array src_tiled_dims = + layout_in.getImplicitTiledDims(src_shape, 1); + const std::array dst_tiled_dims = + layout_out.getImplicitTiledDims(dst_shape, 1); const std::array src_vreg_slice = layout_in.vregSlice(ctx.target_shape); const std::array dst_vreg_slice = layout_out.vregSlice(ctx.target_shape); - // TODO(tlongeri): It looks like this could probably be simplified by using - // VectorLayout::implicitShape() - if (layout_in == layout_out && src_ty.getShape().take_back(layout_rank) == - dst_ty.getShape().take_back(layout_rank)) { - no_op = true; - } else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && - layout_out.implicit_dim() == - VectorLayout::ImplicitDim::kSecondMinor && - layout_in.hasNativeTiling(ctx.target_shape) && - layout_in.tiling() == layout_out.tiling() && - layout_in.offsets() == layout_out.offsets() && - *(src_shape.end() - 1) == *(dst_shape.end() - 1) && - *(src_shape.end() - 2) == 1) { - no_op = true; - } else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && - layout_out.implicit_dim() == VectorLayout::ImplicitDim::kMinor && - layout_in.hasNaturalTopology(ctx.target_shape) && - layout_in.tiling() == layout_out.tiling() && - layout_in.offsets() == layout_out.offsets() && - src_shape == - ArrayRef(layout_out.implicitShape(dst_shape))) { - no_op = true; - } else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kMinor && - layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone && - layout_out.hasNaturalTopology(ctx.target_shape) && - layout_in.tiling() == layout_out.tiling() && - layout_in.offsets() == layout_out.offsets() && - dst_shape == - ArrayRef(layout_in.implicitShape(src_shape))) { + if (layout_in.tiling() == layout_out.tiling() && + layout_in.offsets() == layout_out.offsets() && + src_tiled_dims == dst_tiled_dims) { no_op = true; } else if ( // Fold or unfold sublane dim, but keeping a whole number of // vregs. layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone && - layout_in.offsets() == LayoutOffsets{0, 0} && - layout_out.offsets() == LayoutOffsets{0, 0} && + layout_in.offsets()[0] == 0 && + layout_in.offsets() == layout_out.offsets() && layout_in.tiling() == layout_out.tiling() && - layout_in.tiling()[1] == ctx.target_shape[1] && *(dst_shape.end() - 1) == *(src_shape.end() - 1) && - *(dst_shape.end() - 2) % layout_in.tiling()[0] == 0 && - *(src_shape.end() - 2) % layout_in.tiling()[0] == 0) { + *(dst_shape.end() - 2) % dst_vreg_slice[0] == 0 && + *(src_shape.end() - 2) % src_vreg_slice[0] == 0) { no_op = true; } else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone && @@ -3659,7 +3933,9 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( const Tiling memref_tiling, getMemRefTiling(store_op.getBase(), ctx.target_shape)); - if (to_store_layout.tiling() != memref_tiling) { + if (memref_tiling != to_store_layout.tiling() && + !(memref_tiling[0] == 1 && to_store_layout.tiling()[0] == 1 && + memref_tiling[1] % to_store_layout.tiling()[1] == 0)) { // Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes). // TODO(b/295393167): need to support strided store for bitwidth < 32. if (to_store_layout.bitwidth() != 32 || @@ -4005,6 +4281,7 @@ const llvm::StringMap &rules() { {scf::IfOp::getOperationName(), scf_if_rule}, {scf::YieldOp::getOperationName(), scf_yield_rule}, {tpu::RotateOp::getOperationName(), tpu_rotate_rule}, + {tpu::DynamicRotateOp::getOperationName(), tpu_dynamic_rotate_rule}, {tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule}, {tpu::IotaOp::getOperationName(), tpu_iota_rule}, {tpu::GatherOp::getOperationName(), tpu_gather_rule}, @@ -4988,12 +5265,14 @@ struct ApplyVectorLayoutPass : public impl::ApplyVectorLayoutPassBase { ApplyVectorLayoutPass(int hardware_generation_, int lane_count_, int sublane_count_, int mxu_contracting_size_, - int mxu_noncontracting_size_) { + int mxu_noncontracting_size_, + int max_sublanes_in_scratch_) { hardware_generation = hardware_generation_; sublane_count = sublane_count_; lane_count = lane_count_; mxu_contracting_size = mxu_contracting_size_; mxu_noncontracting_size = mxu_noncontracting_size_; + max_sublanes_in_scratch = max_sublanes_in_scratch_; } void runOnOperation() override { // Fail if hardware_generation has not been set from the default value. @@ -5005,7 +5284,8 @@ struct ApplyVectorLayoutPass RewriteContext ctx{func, hardware_generation, {sublane_count, lane_count}, - {mxu_contracting_size, mxu_noncontracting_size}}; + {mxu_contracting_size, mxu_noncontracting_size}, + max_sublanes_in_scratch}; if (failed(applyLayoutFunc(ctx, func))) { signalPassFailure(); return; @@ -5015,10 +5295,11 @@ struct ApplyVectorLayoutPass std::unique_ptr> createApplyVectorLayoutPass( int hardware_generation, int lane_count, int sublane_count, - int mxu_contracting_size, int mxu_noncontracting_size) { + int mxu_contracting_size, int mxu_noncontracting_size, + int max_sublanes_in_scratch) { return std::make_unique( hardware_generation, lane_count, sublane_count, mxu_contracting_size, - mxu_noncontracting_size); + mxu_noncontracting_size, max_sublanes_in_scratch); } } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h index 75fb5e7904a1..547a8a00c10c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h @@ -21,6 +21,7 @@ struct RewriteContext { const int hardware_generation; const std::array target_shape = {8, 128}; const std::array mxu_shape = {128, 128}; + const int max_sublanes_in_scratch = 0; MLIRContext *getMLIRContext() { return func.getContext(); } }; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index 5efd4496d1a9..54add5fe469e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -23,6 +23,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/layout.h" @@ -33,6 +34,23 @@ namespace mlir::tpu { #define GEN_PASS_DEF_INFERMEMREFLAYOUTPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" +SmallVector ComputeTileStrides(MemRefType memref_ty, + int64_t leading_tile_rows) { + SmallVector tile_strides(memref_ty.getRank()); + int64_t stride = 1; + for (int i = memref_ty.getRank() - 1; i >= 0; --i) { + tile_strides[i] = stride; + if (i == memref_ty.getRank() - 1) { + stride *= llvm::divideCeil(memref_ty.getShape()[i], 128); + } else if (i == memref_ty.getRank() - 2) { + stride *= llvm::divideCeil(memref_ty.getShape()[i], leading_tile_rows); + } else { + stride *= memref_ty.getShape()[i]; + } + } + return tile_strides; +} + // Returns the number of 128-element groups in a tile. // // Arguments: @@ -55,7 +73,8 @@ int getTilingFactor(const int num_128s, const int hardware_generation, } FailureOr inferLayout(MemRefType memref_ty, - const int hardware_generation) { + const int hardware_generation, + int64_t leading_tile_rows = 0) { if (auto tiled_layout_attr = dyn_cast(memref_ty.getLayout())) { return tiled_layout_attr; @@ -91,11 +110,14 @@ FailureOr inferLayout(MemRefType memref_ty, } return TiledLayoutAttr::get(memref_ty.getContext(), tiles, {1}); } + // memref.getRank() > 1 const ArrayRef shape = memref_ty.getShape(); const int64_t second_minor = shape[shape.size() - 2]; - const int64_t leading_tile_rows = - getTilingFactor(second_minor, hardware_generation, bitwidth); + if (leading_tile_rows == 0) { + leading_tile_rows = + getTilingFactor(second_minor, hardware_generation, bitwidth); + } SmallVector tiles{xla::Tile({leading_tile_rows, 128})}; if (bitwidth != 32) { if (!llvm::has_single_bit(bitwidth) || bitwidth > 32) { @@ -105,19 +127,7 @@ FailureOr inferLayout(MemRefType memref_ty, } tiles.push_back(xla::Tile({32 / bitwidth, 1})); } - SmallVector tile_strides(memref_ty.getRank()); - int64_t stride = 1; - for (int i = memref_ty.getRank() - 1; i >= 0; --i) { - tile_strides[i] = stride; - if (i == memref_ty.getRank() - 1) { - stride *= (memref_ty.getShape()[i] + 127) / 128; - } else if (i == memref_ty.getRank() - 2) { - stride *= (memref_ty.getShape()[i] + leading_tile_rows - 1) / - leading_tile_rows; - } else { - stride *= memref_ty.getShape()[i]; - } - } + auto tile_strides = ComputeTileStrides(memref_ty, leading_tile_rows); return TiledLayoutAttr::get(memref_ty.getContext(), tiles, tile_strides); } return emitError(UnknownLoc::get(memref_ty.getContext()), @@ -149,7 +159,8 @@ LogicalResult checkTiles(MLIRContext *mlir_ctx, } FailureOr inferMemref(MemRefType memref, - const int hardware_generation) { + const int hardware_generation, + int64_t leading_tile_rows) { if (isa(memref.getElementType())) { const Attribute semaphore_mem = tpu::MemorySpaceAttr::get( memref.getContext(), MemorySpace::kSemaphoreMem); @@ -169,8 +180,9 @@ FailureOr inferMemref(MemRefType memref, tpu::MemorySpaceAttr::get(memref.getContext(), MemorySpace::vmem); const Attribute memory_space = memref.getMemorySpace() == nullptr ? vmem : memref.getMemorySpace(); - FAILUREOR_ASSIGN_OR_RETURN(const TiledLayoutAttr layout, - inferLayout(memref, hardware_generation)); + FAILUREOR_ASSIGN_OR_RETURN( + const TiledLayoutAttr layout, + inferLayout(memref, hardware_generation, leading_tile_rows)); const ArrayRef tiles = layout.getTiles(); if (failed(checkTiles(memref.getContext(), tiles))) { @@ -244,14 +256,24 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation) { Block &entry = f.getBody().front(); SmallVector new_arg_types; auto builder = OpBuilder::atBlockBegin(&entry); - for (BlockArgument arg : entry.getArguments()) { + for (int i = 0; i < entry.getNumArguments(); ++i) { + BlockArgument arg = entry.getArgument(i); const auto memref_ty = dyn_cast(arg.getType()); if (memref_ty == nullptr) { new_arg_types.push_back(arg.getType()); continue; } - FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty, - inferMemref(memref_ty, hardware_generation)); + int64_t leading_tile_rows = 0; + auto leading_tile_rows_attr = + f.getArgAttrOfType(i, kLeadingTileRows); + if (leading_tile_rows_attr != nullptr) { + leading_tile_rows = leading_tile_rows_attr.getInt(); + f.removeArgAttr(i, kLeadingTileRows); + } + + FAILUREOR_ASSIGN_OR_RETURN( + const MemRefType new_memref_ty, + inferMemref(memref_ty, hardware_generation, leading_tile_rows)); arg.setType(new_memref_ty); new_arg_types.push_back(arg.getType()); if (memref_ty != new_memref_ty) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h index 724a09fffa19..2ad0afbb690d 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h @@ -1,12 +1,17 @@ #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ +#include + #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" namespace mlir::tpu { -FailureOr inferMemref(MemRefType memref, int hardware_generation); +FailureOr inferMemref(MemRefType memref, int hardware_generation, + int64_t leading_tile_rows = 0); + +const std::string_view kLeadingTileRows = "leading_tile_rows"; } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 114371db8df1..fe02b4270a40 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -43,10 +43,12 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/include/mlir/IR/Attributes.h" #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/include/mlir/IR/OpDefinition.h" +#include "mlir/include/mlir/IR/Visitors.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "xla/layout.h" @@ -220,11 +222,11 @@ class VectorLayoutInferer { if (infer(op).failed()) { return failure(); } - } else if (auto op = dyn_cast(any_op)) { + } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); } - } else if (auto op = dyn_cast(any_op)) { + } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); } @@ -427,19 +429,7 @@ class VectorLayoutInferer { auto then_yield = op.thenBlock()->getTerminator(); TPU_CHECK_OP(then_yield->getOperandTypes() == op->getResultTypes(), "scf if results and then branch yield operands do not match"); - SmallVector result_layout; - result_layout.reserve(then_yield->getNumOperands()); - for (const auto &operand : then_yield->getOperands()) { - if (operand.getType().isSignlessIntOrIndexOrFloat()) { - result_layout.push_back(kNoLayout); - } else if (isa(operand.getType())) { - result_layout.push_back(getLayout(operand)); - } else { - op.emitOpError("unsupported scf.yield type"); - return failure(); - } - } - + auto then_yield_in_layouts = getLayoutFromOperands(then_yield); if (auto else_block = op.elseBlock()) { if (inferBlock(*else_block, match_yield).failed()) { op.emitOpError("failed to infer layout for else branch"); @@ -454,32 +444,53 @@ class VectorLayoutInferer { auto else_yield = op.elseBlock()->getTerminator(); TPU_CHECK_OP(else_yield->getOperandTypes() == op->getResultTypes(), "scf if results and else branch yield operands do not match"); - - // Check each layout of the yield in else branch and override the - // result_layout if else branch's yield layout is less general. For example, - // if we yield offset (*, *) in then branch and offset (*, 0) in else - // branch, the result offset should be (*, 0). - for (int i = 0; i < else_yield->getNumOperands(); ++i) { - const auto &operand = else_yield->getOperand(i); - if (!isa(operand.getType())) { - continue; - } - auto shape = dyn_cast(operand.getType()).getShape(); - auto layout = getLayout(operand); - CHECK(result_layout[i].has_value() && layout.has_value()); - result_layout[i] = - VectorLayout::join(result_layout[i].value(), layout.value(), shape); - if (!result_layout[i].has_value()) { - op.emitOpError( - "failed to find a compatible layout in then and else branch for " - "output ") - << i; - return failure(); + auto else_yield_in_layouts = getLayoutFromOperands(else_yield); + // Find a compatible layout from then and else branches for each reuslt. For + // example, if we yield offset (*, *) in then branch and offset (*, 0) in + // else branch, the result offset should be (*, 0). + SmallVector out_layouts; + out_layouts.reserve(op->getNumResults()); + int out_idx = 0; + for (auto [then_layout, else_layout, result] : llvm::zip_equal( + then_yield_in_layouts, else_yield_in_layouts, op.getResults())) { + if (auto vty = dyn_cast(result.getType())) { + if (!then_layout.has_value()) { + return op.emitOpError( + "expected a vector layout for then yield input ") + << out_idx; + } + if (!else_layout.has_value()) { + return op.emitOpError( + "expected a vector layout for else yield input ") + << out_idx; + } + auto compatible_layout = VectorLayout::join( + then_layout.value(), else_layout.value(), vty.getShape()); + // If no compatible layout is found in layouts for then and else + // branches, the output layout falls back to a normalized layout which + // has offsets 0 and the native tiling. + if (!compatible_layout.has_value()) { + compatible_layout = VectorLayout( + then_layout->bitwidth(), {0, 0}, + nativeTiling(then_layout->bitwidth()), ImplicitDim::kNone); + } + out_layouts.push_back(compatible_layout); + } else { + if (then_layout.has_value()) { + return op.emitOpError("expected no layout for then yield input ") + << out_idx; + } + if (else_layout.has_value()) { + return op.emitOpError("expected no layout for else yield input ") + << out_idx; + } + out_layouts.push_back(kNoLayout); } + ++out_idx; } - setInLayout(then_yield, result_layout); - setInLayout(else_yield, result_layout); - setOutLayout(op, result_layout); + setInLayout(then_yield, out_layouts); + setInLayout(else_yield, out_layouts); + setOutLayout(op, out_layouts); return success(); } @@ -497,48 +508,85 @@ class VectorLayoutInferer { op->getNumOperands() == 3 + op.getNumResults(), "expected num_operands is equal to 3 + num_results in scf.for"); - SmallVector in_layouts; - in_layouts.reserve(op->getNumOperands()); - in_layouts.push_back(kNoLayout); // Lower bound. - in_layouts.push_back(kNoLayout); // Upper bound. - in_layouts.push_back(kNoLayout); // Step. - for (const auto &arg : op.getInitArgs()) { - if (arg.getType().isSignlessIntOrIndexOrFloat()) { - in_layouts.push_back(kNoLayout); - } else if (isa(arg.getType())) { - auto layout = getLayout(arg); - in_layouts.push_back(layout); - } else { - op.emitOpError() << "unsupported arg type " << arg.getType() - << " in scf::for"; - return failure(); - } + auto in_layouts = getLayoutFromOperands(op); + // Drop the input layouts for lower bound, upper bound. But keep the layout + // for step because it matches with induction variable in arguments. + auto arg_layouts = ArrayRef(in_layouts).drop_front(2); + if (assumeLayoutsForBlockArgs(*op.getBody(), arg_layouts).failed() || + inferBlock(*op.getBody(), match_yield).failed()) { + return op.emitOpError( + "failed to infer layout with initial layouts for body in " + "scf.for op"); } - ArrayRef out_layouts = ArrayRef(in_layouts).drop_front(3); - // Use tpu.assume_layout to annotate every block argument with the layout of - // the corresponding operand in forOp and replace all uses of the block - // argument with the result of tpu.assume_layout. - ImplicitLocOpBuilder builder = - ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getBody()); + auto yield_op = op.getBody()->getTerminator(); + auto yield_in_layouts = getLayoutFromOperands(yield_op); - // Drop the induction_variable and layouts of bounds+step (respectively). - for (auto [iter_arg, layout] : llvm::zip_equal( - op.getBody()->getArguments().drop_front(1), out_layouts)) { - if (!dyn_cast(iter_arg.getType())) { - continue; + SmallVector out_layouts; + out_layouts.reserve(op->getNumResults()); + int out_idx = 0; + bool require_reinfer = false; + for (auto [in_layout, yield_layout, result] : + llvm::zip_equal(arg_layouts.drop_front( + 1), // Drop the layout for induction variable. + yield_in_layouts, op.getResults())) { + if (auto vty = dyn_cast(result.getType())) { + if (!in_layout.has_value()) { + return op.emitOpError("expected a vector layout for input ") + << out_idx; + } + if (!yield_layout.has_value()) { + return op.emitOpError("expected a vector layout for yield input ") + << out_idx; + } + auto compatible_layout = VectorLayout::join( + in_layout.value(), yield_layout.value(), vty.getShape()); + // If no compatible layout is found in layouts for input and + // yield, the output layout falls back to a normalized layout which + // has offsets 0 and the native tiling. + if (!compatible_layout.has_value()) { + compatible_layout = VectorLayout(in_layout->bitwidth(), {0, 0}, + nativeTiling(in_layout->bitwidth()), + ImplicitDim::kNone); + } + if (!require_reinfer && + (compatible_layout.value() != in_layout.value() || + compatible_layout.value() != yield_layout.value())) { + require_reinfer = true; + } + out_layouts.push_back(compatible_layout); + } else { + if (in_layout.has_value()) { + return op.emitOpError("expected no layout for input ") << out_idx; + } + if (yield_layout.has_value()) { + return op.emitOpError("expected no layout for yield input ") + << out_idx; + } + out_layouts.push_back(kNoLayout); + } + ++out_idx; + } + if (require_reinfer) { + // Force same layouts in input layout but skip the first 3 layouts for + // lower bound, upper bound and step. + std::copy(out_layouts.begin(), out_layouts.end(), in_layouts.begin() + 3); + + // Terminator in the loop will carry layouts to the next loop but + // the loop's block args' layouts are determined by the initial inputs. We + // need to force the same layouts for all in order to make layouts be + // consistent across all branches. To ensure that, we need to reprocess + // layout inference for the entire body with the final consolidated + // layout. + clearBlockLayouts(*op.getBody()); + if (assumeLayoutsForBlockArgs(*op.getBody(), + ArrayRef(in_layouts).drop_front(2)) + .failed() || + inferBlock(*op.getBody(), match_yield).failed()) { + return op.emitOpError( + "failed to infer layout with compatible layouts for body in " + "scf.for op"); } - auto assume_layout_op = - builder.create(iter_arg.getType(), iter_arg); - setLayout(assume_layout_op, layout, layout); - iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) { - return operand.getOwner() != assume_layout_op; - }); - } - - if (inferBlock(*op.getBody(), match_yield).failed()) { - return failure(); } - auto yield_op = op.getBody()->getTerminator(); setInLayout(yield_op, out_layouts); setLayout(op, in_layouts, out_layouts); return success(); @@ -555,107 +603,120 @@ class VectorLayoutInferer { }; TPU_CHECK_OP(op.getNumRegions() == 2, "expected two blocks for scf.while"); - const auto layout_for_type = [&op, this](const ::mlir::Value &arg, - SmallVector *layouts) { - if (arg.getType().isSignlessIntOrIndexOrFloat()) { - layouts->push_back(kNoLayout); - } else if (isa(arg.getType())) { - auto layout = getLayout(arg); - layouts->push_back(layout); - } else { - op.emitOpError() << "unsupported arg type " << arg.getType() - << " in scf.while"; - return failure(); - } - return success(); - }; + SmallVector in_layouts = getLayoutFromOperands(op); - SmallVector in_layouts; - in_layouts.reserve(op->getNumOperands()); - for (const auto &arg : op.getInits()) { - const auto status = layout_for_type(arg, &in_layouts); - if (status.failed()) return status; - } - - // Formally, the types and layouts of the results should follow the layout - // of the condition op in the Before region, rather than mimicking the input - // layouts. In practice these are constrained to be the same for our current - // pipelines, but doesn't represent the full expressiveness of scf.while. - // TODO(hmckenzie): Base output layout on ConditionOp, not inputs. - SmallVector out_layouts = in_layouts; - - // Use tpu.assume_layout to annotate every block argument with the layout of - // the corresponding operand in WhileOp and replace all uses of the block - // argument with the result of tpu.assume_layout. - ImplicitLocOpBuilder builder = - ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getBeforeBody()); - for (auto [iter_arg, layout] : - llvm::zip_equal(op.getBeforeBody()->getArguments(), in_layouts)) { - if (!dyn_cast(iter_arg.getType())) { - continue; - } - auto assume_layout_op = - builder.create(iter_arg.getType(), iter_arg); - setLayout(assume_layout_op, layout, layout); - iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) { - return operand.getOwner() != assume_layout_op; - }); - } - if (inferBlock(*op.getBeforeBody(), match_condition).failed()) { - return failure(); + if (assumeLayoutsForBlockArgs(*op.getBeforeBody(), in_layouts).failed() || + inferBlock(*op.getBeforeBody(), match_condition).failed()) { + return op.emitOpError( + "failed to infer layout with initial layouts for before body in " + "scf.while op"); } - builder = - ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getAfterBody()); - for (auto [iter_arg, layout] : - llvm::zip_equal(op.getAfterBody()->getArguments(), out_layouts)) { - if (!dyn_cast(iter_arg.getType())) { - continue; - } - auto assume_layout_op = - builder.create(iter_arg.getType(), iter_arg); - setLayout(assume_layout_op, layout, layout); - iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) { - return operand.getOwner() != assume_layout_op; - }); + if (assumeLayoutsForBlockArgs(*op.getAfterBody(), in_layouts).failed() || + inferBlock(*op.getAfterBody(), match_yield).failed()) { + return op.emitOpError( + "failed to infer layout with initial layouts for after body in " + "scf.while op"); } - if (inferBlock(*op.getAfterBody(), match_yield).failed()) { - return failure(); - } - - auto *condition_op = op.getBeforeBody()->getTerminator(); - SmallVector cond_layout; - cond_layout.reserve(out_layouts.size() + 1); - cond_layout.push_back(kNoLayout); - cond_layout.append(out_layouts); - setInLayout(condition_op, cond_layout); - + auto *cond_op = op.getBeforeBody()->getTerminator(); + auto cond_in_layouts = getLayoutFromOperands(cond_op); auto *yield_op = op.getAfterBody()->getTerminator(); - setInLayout(yield_op, in_layouts); + auto yield_in_layouts = getLayoutFromOperands(yield_op); - setLayout(op, in_layouts, out_layouts); - return success(); - } - LogicalResult infer(scf::ConditionOp op) { - SmallVector in_layouts; - in_layouts.reserve(op->getNumOperands()); - for (const auto &arg : op.getOperands()) { - if (arg.getType().isSignlessIntOrIndexOrFloat()) { - in_layouts.push_back(kNoLayout); - } else if (isa(arg.getType())) { - auto layout = getLayout(arg); - in_layouts.push_back(layout); + // Find a compatible layout from condition body and loop body for each + // reuslt. For example, if we yield offset (*, *) in condition body and + // offset (*, 0) in loop body, the result offset should be (*, 0). + SmallVector out_layouts; + out_layouts.reserve(op->getNumResults()); + int out_idx = 0; + bool require_reinfer = false; + for (auto [in_layout, cond_layout, yield_layout, result] : llvm::zip_equal( + in_layouts, ArrayRef(cond_in_layouts).drop_front(1), + yield_in_layouts, op.getResults())) { + if (auto vty = dyn_cast(result.getType())) { + if (!in_layout.has_value()) { + return op.emitOpError("expected a vector layout for whileOp input ") + << out_idx; + } + if (!cond_layout.has_value()) { + return op.emitOpError("expected a vector layout for condition input ") + << out_idx + 1; // ConditionOp's first input is 1 bit bool. + } + if (!yield_layout.has_value()) { + return op.emitOpError("expected a vector layout for yield input ") + << out_idx; + } + auto compatible_layout = VectorLayout::join( + cond_layout.value(), yield_layout.value(), vty.getShape()); + if (compatible_layout.has_value()) { + compatible_layout = VectorLayout::join( + in_layout.value(), compatible_layout.value(), vty.getShape()); + } + // If no compatible layout is found in layouts for input, condition and + // yield, the output layout falls back to a normalized layout which + // has offsets 0 and the native tiling. + if (!compatible_layout.has_value()) { + compatible_layout = VectorLayout(in_layout->bitwidth(), {0, 0}, + nativeTiling(in_layout->bitwidth()), + ImplicitDim::kNone); + } + if (!require_reinfer && + (compatible_layout.value() != in_layout.value() || + compatible_layout.value() != cond_layout.value() || + compatible_layout.value() != yield_layout.value())) { + require_reinfer = true; + } + out_layouts.push_back(compatible_layout); } else { - op.emitOpError() << "unsupported arg type " << arg.getType() - << " in scf::condition"; - return failure(); + if (in_layout.has_value()) { + return op.emitOpError("expected no layout for whileOp input ") + << out_idx; + } + if (cond_layout.has_value()) { + return op.emitOpError("expected no layout for condition input ") + << out_idx + 1; // ConditionOp's first input is 1 bit bool. + } + if (yield_layout.has_value()) { + return op.emitOpError("expected no layout for yield input ") + << out_idx; + } + out_layouts.push_back(kNoLayout); + } + ++out_idx; + } + if (require_reinfer) { + clearBlockLayouts(*op.getBeforeBody()); + clearBlockLayouts(*op.getAfterBody()); + // Terminator in the loop will carry layouts to the next loop but + // the loop's block args' layouts are determined by the initial inputs. We + // need to force the same layouts for all in order to make layouts be + // consistent across all branches. To ensure that, we need to reprocess + // layout inference for the entire body with the final consolidated + // layout. + if (assumeLayoutsForBlockArgs(*op.getBeforeBody(), out_layouts) + .failed() || + inferBlock(*op.getBeforeBody(), match_condition).failed()) { + return op.emitOpError( + "failed to infer layout with compatible layouts for before body in " + "scf.while op"); + } + if (assumeLayoutsForBlockArgs(*op.getAfterBody(), out_layouts).failed() || + inferBlock(*op.getAfterBody(), match_yield).failed()) { + return op.emitOpError( + "failed to infer layout with compatible layouts for after body in " + "scf.while op"); } } - setLayout(op, in_layouts, ArrayRef(in_layouts).drop_front(1)); + std::copy(out_layouts.begin(), out_layouts.end(), + cond_in_layouts.begin() + 1); // Skip the first 1 bit bool. + setInLayout(cond_op, cond_in_layouts); + setInLayout(yield_op, out_layouts); + setLayout(op, out_layouts, out_layouts); return success(); } + // TODO(b/347016737): deprecate the static rotate. LogicalResult infer(tpu::RotateOp op) { auto bitwidth = op.getType().getElementTypeBitWidth(); if (bitwidth != 32) { @@ -670,6 +731,21 @@ class VectorLayoutInferer { return success(); } + LogicalResult infer(tpu::DynamicRotateOp op) { + auto bitwidth = op.getType().getElementTypeBitWidth(); + // TODO(b/347067057): Support dynamic rotate with packed dtype. + if (bitwidth != 32) { + NYI("Rotate with non-32-bit data"); + } + if (op.getType().getRank() < 2) { + NYI("Unsupported 1D shape"); + } + auto layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), + ImplicitDim::kNone); + setLayout(op, {layout, kNoLayout}, layout); + return success(); + } + LogicalResult infer(tpu::ConcatenateOp op) { TPU_CHECK_OP(!op.getSources().empty(), "Need at least one vector to concatenate"); @@ -1019,6 +1095,10 @@ class VectorLayoutInferer { "memref and vector rank mismatch"); int64_t rank = res_ty.getRank(); int8_t bitwidth = res_ty.getElementTypeBitWidth(); + if (kNativeBitwidth % bitwidth != 0) { + return op.emitOpError("Unsupported bitwidth"); + } + const int packing = kNativeBitwidth / bitwidth; auto maybe_tiling = verifyMemoryTiling(op, getMemRefLayout(op.getBase()).getTiles(), src_ty.getRank(), src_ty.getElementTypeBitWidth()); @@ -1050,12 +1130,10 @@ class VectorLayoutInferer { } if (rank == 1) { TPU_CHECK_OP(tiling.size() == 1, "Expected 1D tiling in 1D loads"); + const int64_t lane_tiling = packing * target_shape_[1]; auto tile = tiling.front(); - TPU_CHECK_OP(tile % target_shape_[1] == 0, - "Unsupported tiling for 1D load"); + TPU_CHECK_OP(tile % lane_tiling == 0, "Unsupported tiling for 1D load"); CHECK_EQ(tile_offsets.size(), 1); - // TODO(tlongeri): Also pick a unique (canonical) tiling for packed types - const int64_t lane_tiling = bitwidth == 32 ? target_shape_[1] : tile; // TODO(apaszke): We could generate replicated loads for short values. setLayout(op, in_layout, VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling}, @@ -1109,8 +1187,6 @@ class VectorLayoutInferer { LogicalResult infer(vector::ExtractStridedSliceOp op) { auto input_layout = getLayout(op.getVector()); TPU_CHECK_OP(input_layout, "missing vector layout"); - TPU_CHECK_OP(op.getType().getElementTypeBitWidth() == 32, - "Only 32-bit types supported"); auto offsets_attr = op.getOffsets().getValue(); auto strides_attr = op.getStrides().getValue(); auto offsets = llvm::map_to_vector(offsets_attr, [](auto attr) { @@ -1219,6 +1295,8 @@ class VectorLayoutInferer { auto some_src_layout = getLayout(op.getSource()); TPU_CHECK_OP(some_src_layout, "missing vector layout"); auto layout = *some_src_layout; + const unsigned bitwidth = src_ty.getElementTypeBitWidth(); + const std::array vreg_slice = layout.vregSlice(target_shape_); if (layout.implicit_dim() == ImplicitDim::kNone) { // Nothing changes in the last two dims. if (res_rank >= 2 && src_shape.take_back(2) == res_shape.take_back(2)) { @@ -1226,17 +1304,17 @@ class VectorLayoutInferer { return success(); } // Sublane (un)tiling. - if (res_rank >= 2 && layout.tiling()[1] == target_shape_[1] && - src_ty.getDimSize(src_ty.getRank() - 1) == - res_shape[res_shape.size() - 1] && - src_ty.getDimSize(src_ty.getRank() - 2) % layout.tiling()[0] == 0 && - res_shape[res_shape.size() - 2] % layout.tiling()[0] == 0) { - layout = VectorLayout(layout.bitwidth(), {0, 0}, layout.tiling(), - layout.implicit_dim()); + if (res_rank >= 2 && *(src_shape.end() - 1) == *(res_shape.end() - 1) && + *(src_shape.end() - 2) % vreg_slice[0] == 0 && + *(res_shape.end() - 2) % vreg_slice[0] == 0) { + // TODO(b/343808585): We shouldn't force second minor offset to 0 when + // unfolding, it's still a no-op, but we need to add + // support in apply-vector-layout. + layout = VectorLayout(layout.bitwidth(), {0, layout.offsets()[1]}, + layout.tiling(), layout.implicit_dim()); setLayout(op, layout, layout); return success(); } - const unsigned bitwidth = src_ty.getElementTypeBitWidth(); const auto native_tiling = nativeTiling(bitwidth); // Lane (un)tiling. if (src_ty.getDimSize(src_ty.getRank() - 1) != @@ -1294,7 +1372,6 @@ class VectorLayoutInferer { if (res_ty.getRank() >= 2) { // Squeeze out the sublane dim. if (layout_shape[0] == 1 && - res_shape.drop_back(1) == src_shape.drop_back(2) && res_shape.back() == src_shape.back()) { setLayout(op, layout, VectorLayout(bitwidth, layout.offsets(), layout.tiling(), @@ -1312,28 +1389,28 @@ class VectorLayoutInferer { return success(); } } else if (res_ty.getRank() == 1) { - bool all_one = true; - for (int64_t s : src_ty.getShape().drop_back(2)) { - all_one &= s == 1; - } - // Squeeze out everything, but lanes - if (layout_shape[0] == 1 && all_one && - res_ty.getShape().back() == layout_shape[1]) { + // All dimensions have been folded into a single one + + // Squeeze all but minor dimension + if (res_ty.getShape().back() == layout_shape[1]) { + // The condition implies that everything apart from the minor + // dimension is 1 in the source. setLayout(op, layout, VectorLayout(bitwidth, layout.offsets(), layout.tiling(), ImplicitDim::kSecondMinor)); return success(); } - // Squeeze out everything, but sublanes - if (layout_shape[1] == 1 && all_one && - res_ty.getShape().back() == layout_shape[0]) { - TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == kNativeBitwidth, - "only 32-bit shape casts supported"); + // Squeeze all but second minor dimension + if (res_ty.getShape().back() == layout_shape[0]) { + // The condition implies that everything apart from the second minor + // dimension is 1 in the source setLayout(op, layout, VectorLayout(kNativeBitwidth, layout.offsets(), layout.tiling(), ImplicitDim::kMinor)); return success(); } + // TODO(b/340625465): Add case where layout_shape is (1, 1) and we fold + // batch dimensions once we support 0-D layouts. } } else { // Nothing changes in the last dim. @@ -1341,22 +1418,23 @@ class VectorLayoutInferer { setLayout(op, layout, layout); return success(); } - TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == kNativeBitwidth, - "only 32-bit shape casts supported"); // Insert a singleton innermost dim. if (res_ty.getRank() == src_ty.getRank() + 1 && src_ty.getDimSize(src_rank - 1) == res_ty.getDimSize(res_rank - 2) && res_ty.getDimSize(res_rank - 1) == 1) { if (layout.implicit_dim() == ImplicitDim::kMinor) { setLayout(op, layout, - VectorLayout(kNativeBitwidth, layout.offsets(), - default_tiling_, ImplicitDim::kNone)); + VectorLayout(bitwidth, layout.offsets(), layout.tiling(), + ImplicitDim::kNone)); } else { + TPU_CHECK_OP(bitwidth == kNativeBitwidth, + "Insertion of minor dim that is not a no-op only " + "supported for 32-bit types"); TPU_CHECK_OP(layout.implicit_dim() == ImplicitDim::kSecondMinor, "unexpected implicit dim value"); setLayout(op, layout, - VectorLayout(kNativeBitwidth, {0, std::nullopt}, - default_tiling_, ImplicitDim::kNone)); + VectorLayout(bitwidth, {0, std::nullopt}, default_tiling_, + ImplicitDim::kNone)); } return success(); } @@ -1372,6 +1450,10 @@ class VectorLayoutInferer { "memref and vector rank mismatch"); int64_t rank = ref_ty.getRank(); int8_t bitwidth = store_ty.getElementTypeBitWidth(); + if (kNativeBitwidth % bitwidth != 0) { + return op.emitOpError("Unsupported bitwidth"); + } + const int packing = kNativeBitwidth / bitwidth; auto maybe_tiling = verifyMemoryTiling(op, getMemRefLayout(op.getBase()).getTiles(), ref_ty.getRank(), ref_ty.getElementTypeBitWidth()); @@ -1402,11 +1484,10 @@ class VectorLayoutInferer { } if (rank == 1) { TPU_CHECK_OP(tiling.size() == 1, "Expected 1D tiling in 1D store"); + const int64_t lane_tiling = packing * target_shape_[1]; auto tile = tiling.front(); - TPU_CHECK_OP(tile % target_shape_[1] == 0, + TPU_CHECK_OP(tile % lane_tiling == 0, "Unsupported 1D tiling for 1D store"); - // TODO(tlongeri): Also pick a unique (canonical) tiling for packed types - const int64_t lane_tiling = bitwidth == 32 ? target_shape_[1] : tile; CHECK_EQ(tile_offsets.size(), 1); store_layout = VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling}, @@ -1510,13 +1591,21 @@ class VectorLayoutInferer { if (default_tiling_[0] % layout.tiling()[0] == 0 && default_tiling_[1] == layout.tiling()[1]) { src_layout = layout; + dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(), + layout.implicit_dim()); + } else if (layout.tiling() == + nativeTiling(src_ty.getElementTypeBitWidth())) { + // If the source is already in native tiling, we can unpack it directly. + src_layout = layout; + dst_layout = VectorLayout(32, layout.offsets(), default_tiling_, + layout.implicit_dim()); } else { // TODO(b/335863273): we should also reduce offsets. src_layout = VectorLayout(layout.bitwidth(), layout.offsets(), default_tiling_, layout.implicit_dim()); + dst_layout = VectorLayout(32, layout.offsets(), default_tiling_, + layout.implicit_dim()); } - dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(), - layout.implicit_dim()); setLayout(op, src_layout, dst_layout); return success(); } @@ -1534,8 +1623,9 @@ class VectorLayoutInferer { TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); if (dyn_cast(op)) { TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 && - dst_ty.getElementTypeBitWidth() == 16, - "Only 32-bit to 16-bit truncation supported"); + (dst_ty.getElementTypeBitWidth() == 16 || + dst_ty.getElementTypeBitWidth() == 8), + "Only 32-bit to 8-bit or 16-bit truncation supported"); } else { TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32, "Only 32-bit truncation supported"); @@ -1638,25 +1728,31 @@ class VectorLayoutInferer { } LogicalResult inferMatmul(Operation *op) { - auto get_unpadded_layout = - [&](Value v, std::optional major_multiple = std::nullopt, + auto get_operand_layout = + [&](Value v, llvm::StringRef operand_name, + std::optional major_multiple = std::nullopt, std::optional minor_multiple = std::nullopt) -> std::optional { - auto pad = getLayout(v); - if (!pad.has_value() || pad->implicit_dim() != ImplicitDim::kNone) { + auto layout = getLayout(v); + if (!layout.has_value()) { + op->emitOpError("Internal error: assert failed: Operand ") + << operand_name << " has no vector layout"; return std::nullopt; } auto vty = cast(v.getType()); auto tiling = nativeTiling(vty.getElementTypeBitWidth()); auto shape = vty.getShape().take_back(2); - if (pad->offsets()[0].value_or(0) != 0 || - pad->offsets()[1].value_or(0) != 0 || - shape[0] % major_multiple.value_or(tiling[0]) != 0 || + if (shape[0] % major_multiple.value_or(tiling[0]) != 0 || shape[1] % minor_multiple.value_or(tiling[1]) != 0) { + op->emitOpError("Matmul operand") + << operand_name << " must have a shape divisible by (" + << major_multiple.value_or(tiling[0]) << ", " + << minor_multiple.value_or(tiling[1]) << "), but got: (" << shape[0] + << ", " << shape[1] << ")"; return std::nullopt; } // Override tiling to match the native one. - return VectorLayout(pad->bitwidth(), pad->offsets(), tiling, + return VectorLayout(layout->bitwidth(), {0, 0}, tiling, ImplicitDim::kNone); }; auto res_ty = dyn_cast(op->getResult(0).getType()); @@ -1678,15 +1774,18 @@ class VectorLayoutInferer { rhs_major_multiple = 1; } in_layout[0] = - get_unpadded_layout(op->getOperand(0), lhs_major_multiple, 1); + get_operand_layout(op->getOperand(0), "lhs", lhs_major_multiple, 1); + if (!in_layout[0].has_value()) { + return failure(); + } in_layout[1] = - get_unpadded_layout(op->getOperand(1), rhs_major_multiple, 1); - in_layout[2] = get_unpadded_layout(op->getOperand(2), 1, 1); - for (Layout &l : in_layout) { - if (!l.has_value()) { - op->emitOpError("unsupported operand shapes or layouts"); - return failure(); - } + get_operand_layout(op->getOperand(1), "rhs", rhs_major_multiple, 1); + if (!in_layout[1].has_value()) { + return failure(); + } + in_layout[2] = get_operand_layout(op->getOperand(2), "result", 1, 1); + if (!in_layout[2].has_value()) { + return failure(); } setLayout(op, in_layout, VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_, @@ -1725,6 +1824,53 @@ class VectorLayoutInferer { return true; } + LogicalResult assumeLayoutsForBlockArgs(Block &block, + ArrayRef layouts) { + auto op = block.getParentOp(); + if (layouts.size() != block.getNumArguments()) { + return op->emitOpError( + "Block arguments must have the same number of layouts"); + } + // Use tpu.assume_layout to annotate every block argument with the layout of + // the corresponding operand and replace all uses of the block argument with + // the result of tpu.assume_layout. + ImplicitLocOpBuilder builder = + ImplicitLocOpBuilder::atBlockBegin(op->getLoc(), &block); + for (auto [iter_arg, layout] : + llvm::zip_equal(block.getArguments(), layouts)) { + if (!dyn_cast(iter_arg.getType())) { + continue; + } + if (llvm::any_of(iter_arg.getUsers(), [](Operation *user) { + return isa(user); + })) { + return op->emitOpError("Expected no assume layout for block arguments"); + } + auto assume_layout_op = + builder.create(iter_arg.getType(), iter_arg); + setLayout(assume_layout_op, layout, layout); + iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) { + return operand.getOwner() != assume_layout_op; + }); + } + return success(); + } + + void clearBlockLayouts(Block &block) { + block.walk([&](Operation *op) { + // We need to remove assume_layout ops in each block. Otherwise, we will + // create extra assume_layout ops for nested blocks. + if (auto assume_op = dyn_cast(op)) { + assume_op.getResult().replaceAllUsesWith(assume_op.getInput()); + assume_op->erase(); + return WalkResult::advance(); + } + op->removeAttr("in_layout"); + op->removeAttr("out_layout"); + return WalkResult::advance(); + }); + } + void setInLayout(Operation *op, ArrayRef in) { CHECK_EQ(in.size(), op->getNumOperands()) << Print(op); SmallVector in_attrs; @@ -1802,10 +1948,24 @@ class VectorLayoutInferer { return cast(out_attrs[result_index]).getLayout(); } + SmallVector getLayoutFromOperands(Operation *op) { + SmallVector layouts; + layouts.reserve(op->getNumOperands()); + for (const auto &operand : op->getOperands()) { + if (isa(operand.getType())) { + layouts.push_back(getLayout(operand)); + } else { + layouts.push_back(kNoLayout); + } + } + return layouts; + } + private: std::optional> verifyMemoryTiling( Operation *op, ArrayRef mem_tiling, int64_t rank, int8_t bitwidth) { + const int packing = kNativeBitwidth / bitwidth; if (bitwidth == 32) { if (mem_tiling.size() != 1) { op->emitOpError("Only one-level tiling supported for 32-bit loads"); @@ -1822,7 +1982,7 @@ class VectorLayoutInferer { } auto first = mem_tiling[0].dimensions(); auto second = mem_tiling[1].dimensions(); - if (first.size() != 1 || first[0] % target_shape_[1] != 0) { + if (first.size() != 1 || first[0] % (packing * target_shape_[1]) != 0) { op->emitOpError("Invalid first-level tile in 1D memory op"); return std::nullopt; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc index ca9dd260c6e4..4d5e62049098 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc @@ -181,17 +181,15 @@ struct TransferReadOfConstant // Rewrite `vector.transfer_read(arith.select)` as `arith.select` with // `transfer_read` applied to its operands. -struct TransferReadOfSelect - : public ::mlir::OpRewritePattern<::mlir::vector::TransferReadOp> { - using OpRewritePattern<::mlir::vector::TransferReadOp>::OpRewritePattern; +struct TransferReadOfSelect : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - ::mlir::LogicalResult matchAndRewrite( - ::mlir::vector::TransferReadOp op, - ::mlir::PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { if (failed(checkPreconditions(op, rewriter))) { return failure(); } - auto select = op.getSource().getDefiningOp<::mlir::arith::SelectOp>(); + auto select = op.getSource().getDefiningOp(); if (!select) { return rewriter.notifyMatchFailure(op, "source not an arith.select"); } @@ -214,27 +212,25 @@ struct TransferReadOfSelect auto transfer_read = [&](Value value, RankedTensorType type) { return createTransferReadOp(op, value, type, rewriter); }; - rewriter.replaceOpWithNewOp<::mlir::arith::SelectOp>( + rewriter.replaceOpWithNewOp( op, transfer_read(select.getCondition(), condition_type), transfer_read(select.getTrueValue(), true_value_ty), transfer_read(select.getFalseValue(), false_value_ty)); - return ::mlir::success(); + return success(); } }; // Rewrite `vector.transfer_read(arith.cmpi)` as `arith.cmpi` with // `transfer_read` applied to its operands. -struct TransferReadOfCmpI - : public ::mlir::OpRewritePattern<::mlir::vector::TransferReadOp> { - using OpRewritePattern<::mlir::vector::TransferReadOp>::OpRewritePattern; +struct TransferReadOfCmpI : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - ::mlir::LogicalResult matchAndRewrite( - ::mlir::vector::TransferReadOp op, - ::mlir::PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { if (failed(checkPreconditions(op, rewriter))) { return failure(); } - auto cmp = op.getSource().getDefiningOp<::mlir::arith::CmpIOp>(); + auto cmp = op.getSource().getDefiningOp(); if (!cmp) { return rewriter.notifyMatchFailure(op, "source not an arith.cmpi"); } @@ -249,25 +245,23 @@ struct TransferReadOfCmpI auto transfer_read = [&](Value value, RankedTensorType type) { return createTransferReadOp(op, value, type, rewriter); }; - rewriter.replaceOpWithNewOp<::mlir::arith::CmpIOp>( + rewriter.replaceOpWithNewOp( op, cmp.getPredicate(), transfer_read(cmp.getLhs(), lhs_type), transfer_read(cmp.getRhs(), rhs_type)); - return ::mlir::success(); + return success(); } }; // Rewrite `vector.transfer_read(tensor.splat)` as `vector.broadcast`. -struct TransferReadOfSplat - : public ::mlir::OpRewritePattern<::mlir::vector::TransferReadOp> { - using OpRewritePattern<::mlir::vector::TransferReadOp>::OpRewritePattern; +struct TransferReadOfSplat : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - ::mlir::LogicalResult matchAndRewrite( - ::mlir::vector::TransferReadOp op, - ::mlir::PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { if (failed(checkPreconditions(op, rewriter))) { return failure(); } - auto splat = op.getSource().getDefiningOp<::mlir::tensor::SplatOp>(); + auto splat = op.getSource().getDefiningOp(); if (!splat) { return rewriter.notifyMatchFailure(op, "source not a tensor.splat"); } @@ -276,7 +270,7 @@ struct TransferReadOfSplat } rewriter.replaceOpWithNewOp(op, op.getVectorType(), splat.getInput()); - return ::mlir::success(); + return success(); } }; @@ -354,6 +348,116 @@ class GenericBitwidthConvert : public RewritePattern { const bool supports_bf16_alu_instructions_; }; +// Rewrite `vector.contraction` with bf16 accumulator and output into a +// contraction with f32 accumulator and output, where the accumulator is +// extended and the output truncated. For targets that do not support bf16 +// matmul, the lhs and rhs are extended to f32. +struct ContractionBitwidthConvert + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ContractionBitwidthConvert(bool supports_bf16_matmul, MLIRContext *ctx) + : OpRewritePattern(ctx), supports_bf16_matmul_(supports_bf16_matmul) {} + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + // The ContractionOp contract is that (1) lhs and rhs have same element + // type, and (2) the accumulator and result have the same element type. + + // If the target does not support bf16 matmul and we have bf16 operands, we + // need to extend the lhs and rhs to f32. + const bool extend_operands = + op.getLhsType().getElementType().isBF16() && !supports_bf16_matmul_; + // Determine if the accumulator is bf16 and hence needs to be extended to + // f32. + ShapedType acc_ty = dyn_cast(op.getAccType()); + if (acc_ty == nullptr) { + return rewriter.notifyMatchFailure(op, + "accumulator is not a shaped type"); + } + const bool extend_acc = acc_ty.getElementType().isBF16(); + + if (!extend_operands && !extend_acc) { + return rewriter.notifyMatchFailure(op, "no bf16 operands or accumulator"); + } + + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + if (extend_operands) { + lhs = rewriter.create( + op.getLoc(), + VectorType::get(op.getLhsType().getShape(), rewriter.getF32Type()), + lhs); + rhs = rewriter.create( + op.getLoc(), + VectorType::get(op.getRhsType().getShape(), rewriter.getF32Type()), + rhs); + } + + Value acc = op.getAcc(); + if (extend_acc) { + acc = rewriter.create( + op.getLoc(), + VectorType::get(acc_ty.getShape(), rewriter.getF32Type()), + op.getAcc()); + } + + vector::ContractionOp contraction = rewriter.create( + op.getLoc(), lhs, rhs, acc, op.getIndexingMaps(), op.getIteratorTypes(), + op.getKind()); + + if (extend_acc) { + rewriter.replaceOpWithNewOp( + op, dyn_cast(op.getResultType()), contraction); + } else { + rewriter.replaceOp(op, contraction); + } + return success(); + } + + private: + const bool supports_bf16_matmul_; +}; + +// Rewrite `vector.multi_dim_reduction` with bf16 source/accumulator/output into +// a multi_dim_reduction with f32 source/accumulator/output, where the source +// and accumulator are extended and the result is truncated. +// TODO(b/324596736): Make the rewrite conditional on the target supporting +// bf16 reductions. +struct MultiDimReductionBitwidthConvert + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp op, + PatternRewriter &rewriter) const override { + // Below we rely on the contract that the source operand, accumulator, and + // result have the same element type. + auto src_ty = op.getSourceVectorType(); + if (!src_ty.getElementType().isBF16()) { + return rewriter.notifyMatchFailure(op, "not bf16 reduction"); + } + + auto res_ty = dyn_cast(op.getResult().getType()); + if (!res_ty) { + return rewriter.notifyMatchFailure(op, "not vector reduction"); + } + + auto reduction = rewriter.create( + op.getLoc(), + rewriter.create( + op.getLoc(), + VectorType::get(src_ty.getShape(), rewriter.getF32Type()), + op.getSource()), + rewriter.create( + op.getLoc(), + VectorType::get(res_ty.getShape(), rewriter.getF32Type()), + op.getAcc()), + op.getReductionMask(), op.getKind()); + rewriter.replaceOpWithNewOp(op, res_ty, reduction); + return success(); + } +}; + struct LinalgVectorizationPass : public impl::LinalgVectorizationPassBase { explicit LinalgVectorizationPass( @@ -406,6 +510,8 @@ struct LinalgVectorizationPass patterns.add(ternary_op_name, ctx, supports_bf16_alu_instructions); } + patterns.add(supports_bf16_matmul, ctx); + patterns.add(ctx); // We do not want to apply the vector patterns above to the ops that are // unrelated to the original linalg op. @@ -413,7 +519,9 @@ struct LinalgVectorizationPass func.walk([&](Operation *op) { if (dyn_cast(op) || dyn_cast(op) || - dyn_cast(op)) { + dyn_cast(op) || + dyn_cast(op) || + dyn_cast(op)) { linalgOps.push_back(op); } }); @@ -426,9 +534,10 @@ struct LinalgVectorizationPass } // namespace std::unique_ptr> createLinalgVectorizationPass( - bool supports_bf16_alu_instructions) { + bool supports_bf16_alu_instructions, bool supports_bf16_matmul) { LinalgVectorizationPassOptions options; options.supports_bf16_alu_instructions = supports_bf16_alu_instructions; + options.supports_bf16_matmul = supports_bf16_matmul; return std::make_unique(options); } diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 9ae5f9a59619..4ffbf160b1c9 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -14,6 +14,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "absl/types/span.h" +#include "tsl/platform/statusor.h" // TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with // MLIR diagnostics? @@ -57,6 +58,12 @@ FailureOr getTypeBitwidth(Type ty) { if (auto bf16_ty = dyn_cast(ty)) { return 16; } + if (auto f8e5m2_ty = dyn_cast(ty)) { + return 8; + } + if (auto f8e4m3fn_ty = dyn_cast(ty)) { + return 8; + } return emitError(UnknownLoc::get(ty.getContext()), "Unsupported type: ") << ty; } diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 4830622f1c2d..20fcf2b4ce74 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_library") +load("//jaxlib:jax.bzl", "pybind_extension") package( default_applicable_licenses = [], @@ -22,29 +23,32 @@ package( py_library( name = "mosaic_gpu", data = [":libmosaic_gpu_runtime.so"], - deps = [ - "//jaxlib/mlir:execution_engine", - "//jaxlib/mlir:gpu_dialect", - "//jaxlib/mlir:llvm_dialect", - "//jaxlib/mlir:nvgpu_dialect", - "//jaxlib/mlir:nvvm_dialect", - "//jaxlib/mlir/_mlir_libs:_mosaic_gpu_ext", - ], + deps = [":_mosaic_gpu_ext"], ) cc_library( name = "passes", - srcs = ["launch_lowering.cc"], - hdrs = ["launch_lowering.h"], + srcs = [ + "launch_lowering.cc", + "passes.cc", + ], + hdrs = [ + "launch_lowering.h", + "pass_boilerplate.h", + "passes.h", + ], deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:DataLayoutInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", ], ) @@ -95,6 +99,83 @@ cc_library( ], ) +cc_library( + name = "custom_call", + srcs = ["custom_call.cc"], + deps = [ + ":passes", + "//jaxlib/cuda:cuda_vendor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToLLVMIRTranslation", + "@llvm-project//mlir:GPUTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MemRefToLLVM", + "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:NVGPUDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:NVVMTarget", + "@llvm-project//mlir:NVVMToLLVM", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBToLLVM", + "@llvm-project//mlir:VectorDialect", + "@xla//xla/service:custom_call_status", + "@xla//xla/service:custom_call_target_registry", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], + alwayslink = True, +) + +pybind_extension( + name = "_mosaic_gpu_ext", + srcs = ["mosaic_gpu_ext.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + linkopts = select({ + "@xla//xla/python:use_jax_cuda_pip_rpaths": [ + "-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib", + ], + "//conditions:default": [], + }), + deps = [ + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/cuda:cuda_vendor", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", + "@nanobind", + ], +) + cc_binary( name = "libmosaic_gpu_runtime.so", srcs = ["runtime.cc"], diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc new file mode 100644 index 000000000000..56b3d2312c19 --- /dev/null +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -0,0 +1,446 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "jaxlib/gpu/vendor.h" +#include "absl/base/optimization.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "llvm/include/llvm/ADT/SmallVector.h" +#include "llvm/include/llvm/Support/CodeGen.h" +#include "llvm/include/llvm/Support/TargetSelect.h" +#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/include/mlir/Conversion/Passes.h" +#include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/include/mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/include/mlir/Dialect/Math/IR/Math.h" +#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/include/mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/include/mlir/ExecutionEngine/OptUtils.h" +#include "mlir/include/mlir/IR/AsmState.h" +#include "mlir/include/mlir/IR/DialectRegistry.h" +#include "mlir/include/mlir/IR/MLIRContext.h" +#include "mlir/include/mlir/Parser/Parser.h" +#include "mlir/include/mlir/Pass/PassManager.h" +#include "mlir/include/mlir/Pass/PassRegistry.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/include/mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/include/mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/include/mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/include/mlir/Transforms/Passes.h" +#include "jaxlib/mosaic/gpu/launch_lowering.h" +#include "jaxlib/mosaic/gpu/passes.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/custom_call_target_registry.h" + +namespace { + +using MosaicInitFunc = void(void****); +using MosaicHostFunc = void(void**); + +mlir::FailureOr GetPassPipeline( + mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target) { + static bool register_once = []() { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::registerCanonicalizer(); + mlir::registerCSE(); + mlir::registerStripDebugInfo(); + mlir::registerConvertNVGPUToNVVMPass(); + mlir::registerConvertVectorToSCF(); + mlir::registerSCFToControlFlow(); + mlir::registerConvertNVVMToLLVMPass(); + mlir::registerArithToLLVMConversionPass(); + mlir::registerConvertIndexToLLVMPass(); + mlir::registerConvertGpuOpsToNVVMOps(); + mlir::registerConvertMathToLLVMPass(); + mlir::registerConvertFuncToLLVMPass(); + mlir::registerConvertAffineToStandard(); + mlir::registerReconcileUnrealizedCasts(); + // TODO(apaszke): Only register the passes we actually use. + mlir::memref::registerMemRefPasses(); + mlir::registerConvertToLLVMPass(); + mlir::registerGPUPasses(); + mosaic::gpu::registerGpuLaunchLoweringPass(); + mosaic::gpu::registerConvertGpuToLLVMPass(); + return true; + }(); + (void)register_once; + return mlir::parsePassPipeline( + R"( + builtin.module( + convert-nvgpu-to-nvvm, + gpu-kernel-outlining{data-layout-str=}, + convert-vector-to-scf{full-unroll=false lower-tensors=false target-rank=1}, + convert-scf-to-cf, + convert-nvvm-to-llvm, + expand-strided-metadata, + nvvm-attach-target{O=3 chip=sm_90a fast=false features=+ptx80 ftz=false module= triple=nvptx64-nvidia-cuda}, + lower-affine, + convert-arith-to-llvm{index-bitwidth=0}, + convert-index-to-llvm{index-bitwidth=64}, + canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, + cse, + gpu.module(strip-debuginfo), + gpu.module(convert-gpu-to-nvvm{has-redux=false index-bitwidth=64 use-bare-ptr-memref-call-conv=false}), + gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}), + gpu.module(cse), + gpu.module(reconcile-unrealized-casts), + mosaic-convert-gpu-to-llvm, + gpu-module-to-binary{format=)" + + mlir::gpu::stringifyCompilationTarget(target).str() + R"(}, + convert-math-to-llvm{approximate-log1p=true}, + canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, + cse, + )" + + (target != mlir::gpu::CompilationTarget::Assembly ? "gpu-launch-lowering," + : "") + + R"( + convert-to-llvm, + reconcile-unrealized-casts + ) + )"); +} + +mlir::LogicalResult RunPasses(mlir::OpPassManager&& passes, + mlir::ModuleOp module) { + mlir::PassManager pm(module.getContext()); + *static_cast(&pm) = std::move(passes); + if (getenv("MOSAIC_GPU_DUMP_MLIR_PASSES") != nullptr) { + pm.enableIRPrinting(); + } + return pm.run(module); +} + +void InitContext(mlir::MLIRContext* context) { + mlir::DialectRegistry registry; + registry.insert(); + mlir::registerConvertNVVMToLLVMInterface(registry); + mlir::registerConvertComplexToLLVMInterface(registry); + mlir::registerConvertMemRefToLLVMInterface(registry); + mlir::registerConvertMathToLLVMInterface(registry); + mlir::registerConvertFuncToLLVMInterface(registry); + mlir::index::registerConvertIndexToLLVMInterface(registry); + mlir::cf::registerConvertControlFlowToLLVMInterface(registry); + mlir::ub::registerConvertUBToLLVMInterface(registry); + mlir::arith::registerConvertArithToLLVMInterface(registry); + mlir::registerConvertMemRefToLLVMInterface(registry); + mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); + mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry); + mlir::registerBuiltinDialectTranslation(registry); + mlir::registerGPUDialectTranslation(registry); + mlir::registerLLVMDialectTranslation(registry); + mlir::registerNVVMDialectTranslation(registry); + context->appendDialectRegistry(registry); + context->loadAllAvailableDialects(); +} + +absl::Status RunCUDATool(const char* tool, + const std::vector& args, + bool stderr_to_stdout = false) { + CHECK(!args.empty() && args.back() == nullptr); + const char * cuda_path_ptr = getenv("CUDA_ROOT"); + if (!cuda_path_ptr) return absl::InternalError("Failed to get CUDA_ROOT"); + std::string tool_path(cuda_path_ptr); + tool_path += "/bin/"; + tool_path += tool; + pid_t child_pid; + posix_spawn_file_actions_t file_actions; + if (posix_spawn_file_actions_init(&file_actions)) { + return absl::InternalError("Failed to initialize spawn file actions"); + } + if (posix_spawn_file_actions_adddup2(&file_actions, STDOUT_FILENO, + STDERR_FILENO)) { + return absl::InternalError("Failed to set up spawn file actions"); + } + // execv is guaranteed by POSIX to not modify the args (other than + // replacing the whole process image), so the const_cast is valid. + if (posix_spawn(&child_pid, tool_path.c_str(), &file_actions, nullptr, + const_cast(args.data()), environ)) { + return absl::InternalError("Process spawn failed"); + } + int status; + if (waitpid(child_pid, &status, 0) == -1) { + return absl::InternalError("Failed to wait for CUDA tool invocation"); + } + if (status != 0) return absl::InternalError("CUDA tool failed"); + if (posix_spawn_file_actions_destroy(&file_actions) != 0) { + return absl::InternalError("Failed to clean up after posix_spawn"); + } + return absl::OkStatus(); +} + +class TemporaryDirectory { + private: + TemporaryDirectory(std::string path) : path(std::move(path)) {} + // TODO(apaszke): Unlink in destructor. + + public: + static absl::StatusOr Create() { + std::string pattern = "/tmp/mosaic-gpu-XXXXXX"; + if (mkdtemp(pattern.data()) == NULL) { + return absl::InternalError("Failed to create temporary directory"); + } + return TemporaryDirectory(std::move(pattern)); + } + + std::string_view GetPath() { return path; } + + private: + std::string path; +}; + +void DumpCompilationOutput(mlir::ModuleOp module) { + bool dump_ptx = getenv("MOSAIC_GPU_DUMP_PTX") != nullptr; + bool dump_ptxas = getenv("MOSAIC_GPU_DUMP_PTXAS") != nullptr; + bool dump_sass = getenv("MOSAIC_GPU_DUMP_SASS") != nullptr; + if (!dump_ptx && !dump_ptxas && !dump_sass) { + return; + } + + module = module.clone(); // Prevent accidental modification. + auto passes = GetPassPipeline(module.getContext(), + mlir::gpu::CompilationTarget::Assembly); + if (mlir::failed(passes) || + mlir::failed(RunPasses(std::move(*passes), module))) { + return; + } + for (mlir::Operation& op : module.getBody()->getOperations()) { + auto binary = mlir::dyn_cast(&op); + if (!binary) { continue; } + auto objects = binary.getObjects(); + if (objects.size() != 1) { + std::cerr << "Multiple objects per gpu.binary unsupported" << std::endl; + continue; + } + auto object = mlir::cast(*objects.begin()); + std::string ptx = object.getObject().getValue().str(); + if (dump_ptx) { + std::cout << ptx << std::endl; + } + if (!dump_ptxas && !dump_sass) { continue; } // We're done. + auto tmpdir = TemporaryDirectory::Create(); + if (!tmpdir.ok()) { + std::cerr << "Failed to create a temporary directory" << std::endl; + continue; + } + std::string ptx_path = std::string(tmpdir->GetPath()) + "/kernel.ptx"; + std::string elf_path = std::string(tmpdir->GetPath()) + "/kernel.o"; + // Dump PTX into a file. + std::ofstream ptx_out(ptx_path.c_str()); + if (!ptx_out) { + std::cerr << "Failed to write PTX to a file" << std::endl; + continue; + } + ptx_out << ptx << std::endl; + // Run ptxas to generate SASS. + std::vector ptxas_args = { + "ptxas", "--opt-level", "3", + "--gpu-name", "sm_90a", "--output-file", + elf_path.c_str(), ptx_path.c_str()}; + if (dump_ptxas) { + ptxas_args.push_back("-v"); + } + ptxas_args.push_back(nullptr); + if (auto status = RunCUDATool("ptxas", ptxas_args); !status.ok()) { + std::cerr << "ptxas invocation failed: " << status.message() << std::endl; + continue; + } + if (!dump_sass) { continue; } // We're done. + // Call nvdisasm to pretty-print SASS. + if (auto status = RunCUDATool( + "nvdisasm", {"nvdisasm", "-ndf", "-c", elf_path.c_str(), nullptr}); + !status.ok()) { + std::cerr << "nvdisasm invocation failed: " << status.message() + << std::endl; + continue; + } + } +} + +absl::StatusOr> Compile( + mlir::ModuleOp module) { + DumpCompilationOutput(module); + auto passes = GetPassPipeline(module.getContext(), + mlir::gpu::CompilationTarget::Binary); + if (mlir::failed(passes)) { + return absl::InternalError("Failed to construct pass pipeline"); + } + if (mlir::failed(RunPasses(std::move(*passes), module))) { + return absl::InternalError("Pass pipeline failed"); + } + + llvm::SmallVector runtime_lib; + if (const char* lib_path = getenv("MOSAIC_GPU_RUNTIME_LIB_PATH")) { + runtime_lib.emplace_back(lib_path); + } + // Create a transformer to run all LLVM optimization passes at the + // specified optimization level. + mlir::ExecutionEngineOptions options; + options.transformer = mlir::makeOptimizingTransformer(3, 0, nullptr); + options.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive; + options.sharedLibPaths = runtime_lib; + auto maybe_execution_engine = mlir::ExecutionEngine::create(module, options); + if (!maybe_execution_engine) { + return absl::InternalError("Failed to compile kernel"); + } + return std::move(*maybe_execution_engine); +} + +class CompiledKernel { + public: + CompiledKernel(std::unique_ptr engine, void* ctx, + void* scratch_addr, MosaicHostFunc* host_launch) + : engine_(std::move(engine)), + ctx_(ctx), + scratch_addr_(scratch_addr), + host_launch_(host_launch) {} + + std::tuple GetHostLaunch() { + return std::make_tuple(ctx_, scratch_addr_, host_launch_); + } + + private: + std::unique_ptr engine_; + void* ctx_; // TODO(apaszke): Destroy this properly + void* scratch_addr_; + MosaicHostFunc* host_launch_; +}; + +std::pair*, absl::Mutex*> +GetKernelCache() { + static absl::Mutex mutex; + static auto& context_cache = + *new absl::flat_hash_map; + return std::make_pair(&context_cache, &mutex); +} + +// Each compiled kernel has a unique init func, and each kernel is used from +// a single HLO module. So it should be safe to not include the CUDA context +// in the key. +absl::StatusOr> CompileAndInit( + uint64_t kernel_id, const char* module) { + auto cache_and_mutex = GetKernelCache(); + auto* cache = cache_and_mutex.first; + auto* mutex = cache_and_mutex.second; + + { + // Fast path uses reader lock (as hash map look-up is relatively slow). + absl::ReaderMutexLock lock(mutex); + auto it = cache->find(kernel_id); + if (ABSL_PREDICT_TRUE(it != cache->end())) + return it->second.GetHostLaunch(); + } + + absl::MutexLock lock(mutex); + // We released the reader lock, another thread might have initialized it. + if (cache->find(kernel_id) == cache->end()) { + mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); + InitContext(&context); + mlir::ParserConfig parse_config(&context); + auto module_op = + mlir::parseSourceString(module, parse_config); + if (!module_op) { + return absl::InternalError("Failed to parse module"); + } + auto maybe_engine = Compile(*module_op); + if (!maybe_engine.ok()) { + return maybe_engine.status(); + } + mlir::ExecutionEngine* execution_engine = maybe_engine->get(); + auto main = execution_engine->lookupPacked("_mlir_ciface_main"); + auto init = execution_engine->lookupPacked("_mlir_ciface_main_init"); + if (!init || !main) { + return absl::InternalError("Failed to retrieve kernel function"); + } + void* module_ptr = nullptr; + void* kernel_ptr = nullptr; + void** module_ptr_ptr = &module_ptr; + void** kernel_ptr_ptr = &kernel_ptr; + void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; + reinterpret_cast(*init)(init_args); + CUmodule module = static_cast(module_ptr); + CUdeviceptr scratch_addr; + cuModuleGetGlobal(&scratch_addr, nullptr, module, "global_scratch"); + cache->insert_or_assign( + kernel_id, + CompiledKernel(std::move(*maybe_engine), kernel_ptr, + reinterpret_cast(scratch_addr), + reinterpret_cast(*main))); + } + return cache->at(kernel_id).GetHostLaunch(); +} + +void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + uint64_t kernel_id = *reinterpret_cast(opaque); + auto ctx_and_kernel = CompileAndInit(kernel_id, opaque + sizeof(uint64_t)); + if (!ctx_and_kernel.ok()) { + XlaCustomCallStatusSetFailure(status, + ctx_and_kernel.status().message().data(), + ctx_and_kernel.status().message().size()); + return; + } + void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers, + &std::get<1>(*ctx_and_kernel)}; + std::get<2>(*ctx_and_kernel)(args); +} + +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, + "CUDA"); + +} // namespace diff --git a/jaxlib/mosaic/gpu/launch_lowering.cc b/jaxlib/mosaic/gpu/launch_lowering.cc index 0c66a0325581..f3a1ce4f439f 100644 --- a/jaxlib/mosaic/gpu/launch_lowering.cc +++ b/jaxlib/mosaic/gpu/launch_lowering.cc @@ -128,8 +128,7 @@ void buildInitFunction(mlir::OpBuilder &module_builder, auto i32 = mlir::IntegerType::get(init_func.getContext(), 32); auto ptr_ty = mlir::LLVM::LLVMPointerType::get(init_func.getContext()); mlir::Location loc = init_func.getLoc(); - auto builder = - mlir::OpBuilder::atBlockBegin(&init_func.getBody().emplaceBlock()); + auto builder = mlir::OpBuilder::atBlockBegin(init_func.addEntryBlock()); auto binary_global_decl = module_builder.create( loc, mlir::LLVM::LLVMArrayType::get(builder.getI8Type(), @@ -171,7 +170,7 @@ void buildInitFunction(mlir::OpBuilder &module_builder, used_smem = builder.create( loc, i32, builder.getI32IntegerAttr( - mlir::cast(const_smem.getValue()).getSInt())); + mlir::cast(const_smem.getValue()).getInt())); } } mlir::Value kernel_handle = @@ -180,7 +179,11 @@ void buildInitFunction(mlir::OpBuilder &module_builder, loc, "mosaic_gpu_get_function", ptr_ty, mlir::ValueRange{module_handle, kernel_name_ptr, used_smem}) .getResult(0); - builder.create(loc, kernel_handle); + builder.create(loc, module_handle, + init_func.getArgument(0)); + builder.create(loc, kernel_handle, + init_func.getArgument(1)); + builder.create(loc); } mlir::LogicalResult launchPreloadedKernel(mlir::func::FuncOp func, @@ -253,7 +256,7 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { auto module_builder = mlir::OpBuilder::atBlockBegin(module.getBody()); auto init_func = module_builder.create( op.getLoc(), func.getName().str() + "_init", - mlir::FunctionType::get(func->getContext(), {}, {ptr_ty})); + mlir::FunctionType::get(func->getContext(), {ptr_ty, ptr_ty}, {})); init_func->setAttr(mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), mlir::UnitAttr::get(func->getContext())); bool had_launch = false; diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc new file mode 100644 index 000000000000..ec574de4368f --- /dev/null +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -0,0 +1,63 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "nanobind/nanobind.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/service/custom_call_status.h" + +namespace jax::cuda { +namespace { + +namespace nb = nanobind; + +void EventRecordCall(void* stream, void** buffers, char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto* event = reinterpret_cast(opaque); + if (gpuEventRecord(**event, reinterpret_cast(stream)) != + gpuSuccess) { + const char message[] = "Failed to record event"; + XlaCustomCallStatusSetFailure(status, message, sizeof(message)); + } +} + +NB_MODULE(_mosaic_gpu_ext, m) { + m.def("_gpu_event_create", []() { + gpuEvent_t* event = new gpuEvent_t(); + gpuEventCreate(event, GPU_EVENT_DEFAULT); + return reinterpret_cast(event); + }); + m.def("_gpu_event_destroy", [](uintptr_t event) { + gpuEventDestroy(*reinterpret_cast(event)); + }); + m.def("_gpu_event_elapsed", [](uintptr_t start_event, uintptr_t end_event) { + float elapsed_ms = -1; + if (gpuEventElapsedTime( + &elapsed_ms, *reinterpret_cast(start_event), + *reinterpret_cast(end_event)) != gpuSuccess) { + throw std::runtime_error("Failed to get elapsed time between events"); + } + return elapsed_ms; + }); + m.def("_record_event_capsule", + []() { return EncapsulateFunction(EventRecordCall); }); +} + +} // namespace +} // namespace jax::cuda diff --git a/jaxlib/mosaic/gpu/pass_boilerplate.h b/jaxlib/mosaic/gpu/pass_boilerplate.h new file mode 100644 index 000000000000..b0241fca97ab --- /dev/null +++ b/jaxlib/mosaic/gpu/pass_boilerplate.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_ +#define JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_ + +#include "mlir/include/mlir/IR/DialectRegistry.h" +#include "mlir/include/mlir/Pass/Pass.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/Support/TypeID.h" +namespace mosaic { +namespace gpu { + +template +class Pass : public ::mlir::OperationPass { + public: + Pass() : ::mlir::OperationPass(::mlir::TypeID::get()) {} + Pass(const Pass &other) : ::mlir::OperationPass(other) {} + Pass &operator=(const Pass &) = delete; + Pass(Pass &&) = delete; + Pass &operator=(Pass &&) = delete; + ~Pass() = default; + + static constexpr ::llvm::StringLiteral getArgumentName() { + return ::llvm::StringLiteral(Derived::kArgumentName); + } + ::llvm::StringRef getArgument() const override { return getArgumentName(); } + ::llvm::StringRef getDescription() const override { return ""; } + static constexpr ::llvm::StringLiteral getPassName() { + return ::llvm::StringLiteral(Derived::kPassName); + } + ::llvm::StringRef getName() const override { return getPassName(); } + static bool classof(const ::mlir::Pass *pass) { + return pass->getTypeID() == ::mlir::TypeID::get(); + } + std::unique_ptr<::mlir::Pass> clonePass() const override { + return std::make_unique(*static_cast(this)); + } + void getDependentDialects(::mlir::DialectRegistry ®istry) const override {} + + private: + using This = + Pass; // Can't have a comma in the macro instantiation + + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(This) +}; + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_ diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc new file mode 100644 index 000000000000..9c9d82ac4f3f --- /dev/null +++ b/jaxlib/mosaic/gpu/passes.cc @@ -0,0 +1,77 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/gpu/passes.h" +#include +#include +#include + +#include "llvm/include/llvm/ADT/StringRef.h" +#include "mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/IR/SymbolTable.h" +#include "mlir/include/mlir/Pass/PassRegistry.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/Transforms/DialectConversion.h" +#include "jaxlib/mosaic/gpu/pass_boilerplate.h" + +namespace mosaic { +namespace gpu { + +namespace { + +class ConvertGpuToLLVMPass + : public mosaic::gpu::Pass { + public: + using mosaic::gpu::Pass::Pass; + static constexpr llvm::StringLiteral kArgumentName = + "mosaic-convert-gpu-to-llvm"; + static constexpr llvm::StringLiteral kPassName = "ConvertGpuToLLVMPass"; + + void runOnOperation() override { + mlir::MLIRContext *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + mlir::LLVMTypeConverter converter(ctx); + mlir::ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addLegalOp(); + target.addDynamicallyLegalOp( + [&](mlir::gpu::LaunchFuncOp op) -> bool { + return converter.isLegal(op->getOperandTypes()) && + converter.isLegal(op->getResultTypes()); + }); + auto symtab = mlir::SymbolTable(getOperation()); + mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false); + if (mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)) + .failed()) { + signalPassFailure(); + } + } +}; + +} // namespace + +void registerConvertGpuToLLVMPass() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return std::make_unique(); + }); +} + +} // namespace gpu +} // namespace mosaic diff --git a/jaxlib/cpu/ducc_fft.fbs b/jaxlib/mosaic/gpu/passes.h similarity index 64% rename from jaxlib/cpu/ducc_fft.fbs rename to jaxlib/mosaic/gpu/passes.h index bc8572ad5414..bf7a804ee217 100644 --- a/jaxlib/cpu/ducc_fft.fbs +++ b/jaxlib/mosaic/gpu/passes.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The JAX Authors. +/* Copyright 2024 The JAX Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,25 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace jax; +#ifndef JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_ +#define JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_ -enum DuccFftDtype : byte { - COMPLEX64 = 0, - COMPLEX128 = 1, -} +namespace mosaic { +namespace gpu { -enum DuccFftType : byte { - C2C = 0, - C2R = 1, - R2C = 2, -} +void registerConvertGpuToLLVMPass(); -table DynamicDuccFftDescriptor { - ndims:uint32; - dtype:DuccFftDtype; - fft_type:DuccFftType; - axes:[uint32]; - forward:bool; -} +} // namespace gpu +} // namespace mosaic -root_type DynamicDuccFftDescriptor; +#endif // JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_ diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index 6e952b546866..13e63208632c 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -201,6 +201,10 @@ cc_library( ":hip_gpu_kernel_helpers", ":hip_lu_pivot_kernels_impl", ":hip_vendor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", @@ -215,7 +219,7 @@ rocm_library( ":hip_gpu_kernel_helpers", ":hip_vendor", "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", ], ) @@ -398,3 +402,12 @@ py_library( ":_triton", ], ) + +py_library( + name = "gpu_only_test_deps", + # `if_rocm_is_configured` will default to `[]`. + deps = if_rocm_is_configured([ + ":rocm_gpu_support", + "//jaxlib:rocm_plugin_extension", + ]), +) diff --git a/jaxlib/rocm_plugin_extension.cc b/jaxlib/rocm_plugin_extension.cc new file mode 100644 index 000000000000..9b0743d27cd9 --- /dev/null +++ b/jaxlib/rocm_plugin_extension.cc @@ -0,0 +1,152 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include + +#include "nanobind/nanobind.h" +#include "absl/status/status.h" +#include "rocm/include/hip/hip_runtime.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/ffi/api/c_api.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/py_client_gpu.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" + +namespace nb = nanobind; + +namespace xla { +namespace { +Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, + nb::capsule fn, int api_version, + XLA_FFI_Handler_Traits traits) { + if (c_api->extension_start == nullptr) { + return Unimplemented("The plugin does not have extension."); + } + const PJRT_Extension_Base* next = + reinterpret_cast(c_api->extension_start); + while (next != nullptr && + next->type != + PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { + next = next->next; + } + if (next == nullptr) { + return Unimplemented("The plugin does not have a custom call extension."); + } + + if (traits != 0) { + return Unimplemented("The plugin does not support custom call traits."); + } + + PJRT_Gpu_Register_Custom_Call_Args args; + args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; + args.function_name = fn_name.c_str(); + args.function_name_size = nb::len(fn_name); +#if PJRT_API_GPU_EXTENSION_VERSION >= 1 + args.api_version = api_version; +#endif + args.custom_call_function = static_cast(fn.data()); + RETURN_STATUS_IF_PJRT_ERROR( + reinterpret_cast(next)->custom_call(&args), + c_api); + return OkStatus(); +} + +nb::dict Registrations() { + nb::dict dict; + dict["xla_python_gpu_callback"] = + jax::EncapsulateFunction(xla::XlaPythonGpuCallback); + return dict; +} + +std::string ToString(hipError_t result) { +#define OSTREAM_ROCM_ERROR(__name) \ + case hipError##__name: \ + return "HIP_ERROR_" #__name; + + switch (result) { + OSTREAM_ROCM_ERROR(InvalidValue) + OSTREAM_ROCM_ERROR(OutOfMemory) + OSTREAM_ROCM_ERROR(NotInitialized) + OSTREAM_ROCM_ERROR(Deinitialized) + OSTREAM_ROCM_ERROR(NoDevice) + OSTREAM_ROCM_ERROR(InvalidDevice) + OSTREAM_ROCM_ERROR(InvalidImage) + OSTREAM_ROCM_ERROR(InvalidContext) + OSTREAM_ROCM_ERROR(InvalidHandle) + OSTREAM_ROCM_ERROR(NotFound) + OSTREAM_ROCM_ERROR(NotReady) + OSTREAM_ROCM_ERROR(NoBinaryForGpu) + + // Encountered an uncorrectable ECC error during execution. + OSTREAM_ROCM_ERROR(ECCNotCorrectable) + + // Load/store on an invalid address. Must reboot all context. + case 700: + return "ROCM_ERROR_ILLEGAL_ADDRESS"; + // Passed too many / wrong arguments, too many threads for register count. + case 701: + return "ROCM_ERROR_LAUNCH_OUT_OF_RESOURCES"; + + OSTREAM_ROCM_ERROR(ContextAlreadyInUse) + OSTREAM_ROCM_ERROR(PeerAccessUnsupported) + OSTREAM_ROCM_ERROR(Unknown) // Unknown internal error to ROCM. + default: + return absl::StrCat("hipError_t(", static_cast(result), ")"); + } +} +} // namespace + +NB_MODULE(rocm_plugin_extension, m) { + tsl::ImportNumpy(); + m.def( + "register_custom_call_target", + [](nb::capsule c_api, nb::str fn_name, nb::capsule fn, + nb::str xla_platform_name, int api_version, + XLA_FFI_Handler_Traits traits) { + xla::ThrowIfError(RegisterCustomCallTarget( + static_cast(c_api.data()), fn_name, std::move(fn), + api_version, traits)); + }, + nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), + nb::arg("xla_platform_name"), nb::arg("api_version") = 0, + nb::arg("traits") = 0); + m.def("registrations", &Registrations); + m.def( + "get_device_ordinal", + [](std::intptr_t data_value) { + if (data_value == 0) { + return 0; + } + int device_ordinal; + void* data_ptr = reinterpret_cast(data_value); + hipError_t result = + hipPointerGetAttribute(static_cast(&device_ordinal), + HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + reinterpret_cast(data_ptr)); + if (result != hipSuccess) { + LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << data_ptr + << ". Error: " << ToString(result); + } + return device_ordinal; + }, + nb::arg("data_value")); +} +} // namespace xla diff --git a/jaxlib/setup.py b/jaxlib/setup.py index e3699674915e..adc3ba452111 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -59,32 +59,16 @@ def has_ext_modules(self): author='JAX team', author_email='jax-dev@google.com', packages=['jaxlib', 'jaxlib.xla_extension'], - python_requires='>=3.9', + python_requires='>=3.10', install_requires=[ 'scipy>=1.9', "scipy>=1.11.1; python_version>='3.12'", - 'numpy>=1.22', - 'ml_dtypes>=0.4.0', + 'numpy>=1.24', + 'ml_dtypes>=0.2.0', ], - extras_require={ - 'cuda12_pip': [ - "nvidia-cublas-cu12>=12.1.3.1", - "nvidia-cuda-cupti-cu12>=12.1.105", - "nvidia-cuda-nvcc-cu12>=12.1.105", - "nvidia-cuda-runtime-cu12>=12.1.105", - # https://docs.nvidia.com/deeplearning/cudnn/developer/misc.html#cudnn-api-compatibility - "nvidia-cudnn-cu12>=9.0,<10.0", - "nvidia-cufft-cu12>=11.0.2.54", - "nvidia-cusolver-cu12>=11.4.5.107", - "nvidia-cusparse-cu12>=12.1.0.106", - "nvidia-nccl-cu12>=2.18.1", - "nvidia-nvjitlink-cu12>=12.1.105", - ], - }, url='https://github.com/google/jax', license='Apache-2.0', classifiers=[ - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 832f53249b8a..089cba21dc7b 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -61,12 +61,31 @@ py_test( ], ) +cc_binary( + name = "pjrt_c_api_gpu_plugin.so", + linkopts = [ + "-Wl,--version-script,$(location @xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds)", + "-Wl,--no-undefined", + ], + linkshared = True, + deps = [ + "//jaxlib/mosaic/gpu:custom_call", + "@xla//xla/pjrt/c:pjrt_c_api_gpu", + "@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds", + "@xla//xla/service:gpu_plugin", + ] + if_cuda([ + "@xla//xla/stream_executor:cuda_platform", + ]) + if_rocm([ + "@xla//xla/stream_executor:rocm_platform", + ]), +) + py_binary( name = "build_gpu_plugin_wheel", srcs = ["build_gpu_plugin_wheel.py"], data = [ "LICENSE.txt", - "@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so", + ":pjrt_c_api_gpu_plugin.so", ] + if_cuda([ "//jaxlib:version", "//jaxlib/cuda:cuda_gpu_support", @@ -74,6 +93,12 @@ py_binary( "//jax_plugins/cuda:setup.py", "//jax_plugins/cuda:__init__.py", "@local_config_cuda//cuda:cuda-nvvm", + ]) + if_rocm([ + "//jaxlib:version", + "//jaxlib/rocm:rocm_gpu_support", + "//jax_plugins/rocm:pyproject.toml", + "//jax_plugins/rocm:setup.py", + "//jax_plugins/rocm:__init__.py", ]), deps = [ "//jax/tools:build_utils", @@ -85,17 +110,24 @@ py_binary( ) py_binary( - name = "build_cuda_kernels_wheel", - srcs = ["build_cuda_kernels_wheel.py"], + name = "build_gpu_kernels_wheel", + srcs = ["build_gpu_kernels_wheel.py"], data = [ "LICENSE.txt", ] + if_cuda([ + "//jaxlib/mosaic/gpu:mosaic_gpu", "//jaxlib:cuda_plugin_extension", "//jaxlib:version", "//jaxlib/cuda:cuda_gpu_support", "//jax_plugins/cuda:plugin_pyproject.toml", "//jax_plugins/cuda:plugin_setup.py", "@local_config_cuda//cuda:cuda-nvvm", + ]) + if_rocm([ + "//jaxlib:rocm_plugin_extension", + "//jaxlib:version", + "//jaxlib/rocm:rocm_gpu_support", + "//jax_plugins/rocm:plugin_pyproject.toml", + "//jax_plugins/rocm:plugin_setup.py", ]), deps = [ "//jax/tools:build_utils", diff --git a/jaxlib/tools/LICENSE.txt b/jaxlib/tools/LICENSE.txt index 71123c553fab..6c7416993e2d 100644 --- a/jaxlib/tools/LICENSE.txt +++ b/jaxlib/tools/LICENSE.txt @@ -4331,34 +4331,6 @@ Copyright 2019 The TensorFlow Authors. All rights reserved. See the License for the specific language governing permissions and limitations under the License. --------------------------------------------------------------------------------- -License for the FFT components of ducc0: -Copyright (C) 2010-2022 Max-Planck-Society -All rights reserved. - -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright notice, this - list of conditions and the following disclaimer in the documentation and/or - other materials provided with the distribution. -* Neither the name of the copyright holder nor the names of its contributors may - be used to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -------------------------------------------------------------------------------- License for pybind11: Copyright (c) 2016 Wenzel Jakob , All rights reserved. diff --git a/jaxlib/tools/build_cuda_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py similarity index 60% rename from jaxlib/tools/build_cuda_kernels_wheel.py rename to jaxlib/tools/build_gpu_kernels_wheel.py index 34280ff1ffbf..28d2806a7da9 100644 --- a/jaxlib/tools/build_cuda_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -43,16 +43,24 @@ "--cpu", default=None, required=True, help="Target CPU architecture. Required." ) parser.add_argument( - "--cuda_version", + "--platform_version", default=None, required=True, - help="Target CUDA version. Required.", + help="Target CUDA/ROCM version. Required.", ) parser.add_argument( "--editable", action="store_true", - help="Create an 'editable' jax cuda plugin build instead of a wheel.", + help="Create an 'editable' jax cuda/rocm plugin build instead of a wheel.", ) +parser.add_argument( + "--enable-cuda", + default=False, + help="Should we build with CUDA enabled? Requires CUDA and CuDNN.") +parser.add_argument( + "--enable-rocm", + default=False, + help="Should we build with ROCM enabled?") args = parser.parse_args() r = runfiles.Create() @@ -70,7 +78,7 @@ def write_setup_cfg(sources_path, cpu): """) -def prepare_wheel( +def prepare_wheel_cuda( sources_path: pathlib.Path, *, cpu, cuda_version ): """Assembles a source tree for the cuda kernel wheel in `sources_path`.""" @@ -90,10 +98,6 @@ def prepare_wheel( write_setup_cfg(sources_path, cpu) plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin" - copy_runfiles( - dst_dir=plugin_dir / "nvvm" / "libdevice", - src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"], - ) copy_runfiles( dst_dir=plugin_dir, src_files=[ @@ -106,19 +110,64 @@ def prepare_wheel( f"__main__/jaxlib/cuda/_triton.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", f"__main__/jaxlib/cuda_plugin_extension.{pyext}", + f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", + "__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", + "__main__/jaxlib/version.py", + ], + ) + +def prepare_wheel_rocm( + sources_path: pathlib.Path, *, cpu, rocm_version +): + """Assembles a source tree for the rocm kernel wheel in `sources_path`.""" + copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + + copy_runfiles( + "__main__/jax_plugins/rocm/plugin_pyproject.toml", + dst_dir=sources_path, + dst_filename="pyproject.toml", + ) + copy_runfiles( + "__main__/jax_plugins/rocm/plugin_setup.py", + dst_dir=sources_path, + dst_filename="setup.py", + ) + build_utils.update_setup_with_rocm_version(sources_path, rocm_version) + write_setup_cfg(sources_path, cpu) + + plugin_dir = sources_path / f"jax_rocm{rocm_version}_plugin" + copy_runfiles( + dst_dir=plugin_dir, + src_files=[ + f"__main__/jaxlib/rocm/_solver.{pyext}", + f"__main__/jaxlib/rocm/_blas.{pyext}", + f"__main__/jaxlib/rocm/_linalg.{pyext}", + f"__main__/jaxlib/rocm/_prng.{pyext}", + f"__main__/jaxlib/rocm/_sparse.{pyext}", + f"__main__/jaxlib/rocm/_triton.{pyext}", + f"__main__/jaxlib/rocm_plugin_extension.{pyext}", "__main__/jaxlib/version.py", ], ) # Build wheel for cuda kernels -tmpdir = tempfile.TemporaryDirectory(prefix="jax_cuda_plugin") +if args.enable_rocm: + tmpdir = tempfile.TemporaryDirectory(prefix="jax_rocm_plugin") +else: + tmpdir = tempfile.TemporaryDirectory(prefix="jax_cuda_plugin") sources_path = tmpdir.name try: os.makedirs(args.output_path, exist_ok=True) - prepare_wheel( - pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.cuda_version - ) - package_name = f"jax cuda{args.cuda_version} plugin" + if args.enable_cuda: + prepare_wheel_cuda( + pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.platform_version + ) + package_name = f"jax cuda{args.platform_version} plugin" + elif args.enable_rocm: + prepare_wheel_rocm( + pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version + ) + package_name = f"jax rocm{args.platform_version} plugin" if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 7e178e3ad2a4..73cb8a9e020d 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Script that builds a jax cuda plugin wheel, intended to be run via bazel run -# as part of the jax cuda plugin build process. +# Script that builds a jax cuda/rocm plugin wheel, intended to be run via bazel run +# as part of the jax cuda/rocm plugin build process. # Most users should not run this script directly; use build.py instead. @@ -49,16 +49,24 @@ "--cpu", default=None, required=True, help="Target CPU architecture. Required." ) parser.add_argument( - "--cuda_version", + "--platform_version", default=None, required=True, - help="Target CUDA version. Required.", + help="Target CUDA/ROCM version. Required.", ) parser.add_argument( "--editable", action="store_true", - help="Create an 'editable' jax cuda plugin build instead of a wheel.", + help="Create an 'editable' jax cuda/rocm plugin build instead of a wheel.", ) +parser.add_argument( + "--enable-cuda", + default=False, + help="Should we build with CUDA enabled? Requires CUDA and CuDNN.") +parser.add_argument( + "--enable-rocm", + default=False, + help="Should we build with ROCM enabled?") args = parser.parse_args() r = runfiles.Create() @@ -100,24 +108,62 @@ def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version): ], ) copy_runfiles( - "xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so", + "__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so", dst_dir=plugin_dir, dst_filename="xla_cuda_plugin.so", ) +def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): + """Assembles a source tree for the ROCm wheel in `sources_path`.""" + copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + + plugin_dir = sources_path / "jax_plugins" / f"xla_rocm{rocm_version}" + copy_runfiles( + dst_dir=sources_path, + src_files=[ + "__main__/jax_plugins/rocm/pyproject.toml", + "__main__/jax_plugins/rocm/setup.py", + ], + ) + build_utils.update_setup_with_rocm_version(sources_path, rocm_version) + write_setup_cfg(sources_path, cpu) + copy_runfiles( + dst_dir=plugin_dir, + src_files=[ + "__main__/jax_plugins/rocm/__init__.py", + "__main__/jaxlib/version.py", + ], + ) + copy_runfiles( + "__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so", + dst_dir=plugin_dir, + dst_filename="xla_rocm_plugin.so", + ) + + tmpdir = None sources_path = args.sources_path if sources_path is None: - tmpdir = tempfile.TemporaryDirectory(prefix="jaxcudapjrt") + tmpdir = tempfile.TemporaryDirectory(prefix="jaxgpupjrt") sources_path = tmpdir.name try: os.makedirs(args.output_path, exist_ok=True) - prepare_cuda_plugin_wheel( - pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.cuda_version - ) - package_name = "jax cuda plugin" + + if args.enable_cuda: + prepare_cuda_plugin_wheel( + pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.platform_version + ) + package_name = "jax cuda plugin" + elif args.enable_rocm: + prepare_rocm_plugin_wheel( + pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version + ) + package_name = "jax rocm plugin" + else: + raise ValueError("Unsupported backend. Choose either 'cuda' or 'rocm'.") + if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 7c9bfa12099d..62864f7ad30d 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -195,7 +195,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/utils.{pyext}", "__main__/jaxlib/lapack.py", "__main__/jaxlib/hlo_helpers.py", - "__main__/jaxlib/ducc_fft.py", "__main__/jaxlib/gpu_prng.py", "__main__/jaxlib/gpu_linalg.py", "__main__/jaxlib/gpu_rnn.py", @@ -218,15 +217,10 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): dst_dir=jaxlib_dir / "cpu", src_files=[ f"__main__/jaxlib/cpu/_lapack.{pyext}", - f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", ], ) if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not skip_gpu_kernels: - copy_runfiles( - dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice", - src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"], - ) copy_runfiles( dst_dir=jaxlib_dir / "cuda", src_files=[ @@ -240,7 +234,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/cuda/_versions.{pyext}", ], ) - if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"): + if exists(f"__main__/jaxlib/rocm/_solver.{pyext}") and not skip_gpu_kernels: copy_runfiles( dst_dir=jaxlib_dir / "rocm", src_files=[ @@ -266,18 +260,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): "__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir ) - has_mosaic_gpu = exists(f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}") - def if_has_mosaic_gpu(extras): - return extras if has_mosaic_gpu else [] - - if has_mosaic_gpu: - copy_runfiles( - dst_dir=jaxlib_dir / "mosaic" / "gpu", - src_files=[ - "__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", - ], - ) - copy_runfiles( dst_dir=jaxlib_dir / "mlir", src_files=[ @@ -285,9 +267,7 @@ def if_has_mosaic_gpu(extras): "__main__/jaxlib/mlir/ir.pyi", "__main__/jaxlib/mlir/passmanager.py", "__main__/jaxlib/mlir/passmanager.pyi", - ] + if_has_mosaic_gpu([ - "__main__/jaxlib/mlir/execution_engine.py", - ]), + ], ) copy_runfiles( dst_dir=jaxlib_dir / "mlir" / "dialects", @@ -307,6 +287,14 @@ def if_has_mosaic_gpu(extras): "__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", "__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", "__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", + "__main__/jaxlib/mlir/dialects/_gpu_enum_gen.py", + "__main__/jaxlib/mlir/dialects/_gpu_ops_gen.py", + "__main__/jaxlib/mlir/dialects/_nvgpu_enum_gen.py", + "__main__/jaxlib/mlir/dialects/_nvgpu_ops_gen.py", + "__main__/jaxlib/mlir/dialects/_nvvm_enum_gen.py", + "__main__/jaxlib/mlir/dialects/_nvvm_ops_gen.py", + "__main__/jaxlib/mlir/dialects/_llvm_enum_gen.py", + "__main__/jaxlib/mlir/dialects/_llvm_ops_gen.py", "__main__/jaxlib/mlir/dialects/arith.py", "__main__/jaxlib/mlir/dialects/builtin.py", "__main__/jaxlib/mlir/dialects/chlo.py", @@ -318,19 +306,10 @@ def if_has_mosaic_gpu(extras): "__main__/jaxlib/mlir/dialects/sparse_tensor.py", "__main__/jaxlib/mlir/dialects/stablehlo.py", "__main__/jaxlib/mlir/dialects/vector.py", - ] + if_has_mosaic_gpu([ - "__main__/jaxlib/mlir/dialects/_gpu_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_gpu_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_nvgpu_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_nvgpu_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_nvvm_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_nvvm_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_llvm_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_llvm_ops_gen.py", "__main__/jaxlib/mlir/dialects/nvgpu.py", "__main__/jaxlib/mlir/dialects/nvvm.py", "__main__/jaxlib/mlir/dialects/llvm.py", - ]), + ], ) copy_runfiles( dst_dir=jaxlib_dir / "mlir" / "extras", @@ -338,19 +317,18 @@ def if_has_mosaic_gpu(extras): "__main__/jaxlib/mlir/extras/meta.py", ], ) - if has_mosaic_gpu: - copy_runfiles( - dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu", - src_files=[ - "__main__/jaxlib/mlir/dialects/gpu/__init__.py", - ], - ) - copy_runfiles( - dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu" / "passes", - src_files=[ - "__main__/jaxlib/mlir/dialects/gpu/passes/__init__.py", - ], - ) + copy_runfiles( + dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu", + src_files=[ + "__main__/jaxlib/mlir/dialects/gpu/__init__.py", + ], + ) + copy_runfiles( + dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu" / "passes", + src_files=[ + "__main__/jaxlib/mlir/dialects/gpu/passes/__init__.py", + ], + ) if build_utils.is_windows(): @@ -373,6 +351,10 @@ def if_has_mosaic_gpu(extras): f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsGPU.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsLLVM.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsNVGPU.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_mlirGPUPasses.{pyext}", ] + ( [] @@ -381,15 +363,7 @@ def if_has_mosaic_gpu(extras): f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}", "__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi", ] - ) + if_has_mosaic_gpu([ - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsGPU.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsLLVM.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsNvgpu.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirExecutionEngine.{pyext}", - "__main__/jaxlib/mlir/_mlir_libs/_mlirExecutionEngine.pyi", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirGPUPasses.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}", - ]), + ), ) triton_dir = jaxlib_dir / "triton" diff --git a/pyproject.toml b/pyproject.toml index c06c16029560..e81320c13421 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ warn_unused_ignores = true module = [ "absl.*", "colorama.*", - "importlib_metadata.*", + "filelock.*", "IPython.*", "numpy.*", "opt_einsum.*", @@ -40,17 +40,10 @@ module = [ "jax.experimental.jax2tf.tests.flax_models", "jax.experimental.jax2tf.tests.back_compat_testdata", "setuptools.*", + "jax_cuda12_plugin.*", ] ignore_missing_imports = true -[[tool.mypy.overrides]] -module = [ - "jax.interpreters.autospmd", - "jax.lax.lax_parallel", - "jax._src.internal_test_util.test_harnesses", -] -ignore_errors = true - [tool.pytest.ini_options] markers = [ "multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators", @@ -60,8 +53,8 @@ filterwarnings = [ "error", "default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'", "default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'", - "default:backend and device argument on jit is deprecated.*:DeprecationWarning", "default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", + "default:jax.xla_computation is deprecated. Please use the AOT APIs.*:DeprecationWarning", # TODO(jakevdp): remove when array_api_tests stabilize # start array_api_tests-related warnings "default:.*not machine-readable.*:UserWarning", @@ -108,7 +101,7 @@ exclude = [ ] line-length = 88 indent-width = 2 -target-version = "py39" +target-version = "py310" [tool.ruff.lint] ignore = [ @@ -122,6 +115,8 @@ ignore = [ "F841", # Raise with from clause inside except block "B904", + # Zip without explicit strict parameter + "B905", ] select = [ "B9", diff --git a/setup.py b/setup.py index e1cb3e2e38b0..cc2c75ab7ff4 100644 --- a/setup.py +++ b/setup.py @@ -19,12 +19,10 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.28' +_current_jaxlib_version = '0.4.30' # The following should be updated with each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.4.28' -_default_cuda12_cudnn_version = '91' -_available_cuda12_cudnn_versions = [_default_cuda12_cudnn_version] -_libtpu_version = '0.1.dev20240508' +_latest_jaxlib_version_on_pypi = '0.4.30' +_libtpu_version = '0.1.dev20240617' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( @@ -35,6 +33,7 @@ def load_version_module(pkg_path): _version_module = load_version_module(project_name) __version__ = _version_module._get_version_for_build() +_jax_version = _version_module._version # JAX version, with no .dev suffix. _cmdclass = _version_module._get_cmdclass(project_name) _minimum_jaxlib_version = _version_module._minimum_jaxlib_version @@ -52,27 +51,23 @@ def load_version_module(pkg_path): author_email='jax-dev@google.com', packages=find_packages(exclude=["examples", "jax/src/internal_test_util"]), package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]}, - python_requires='>=3.9', + python_requires='>=3.10', install_requires=[ - 'ml_dtypes>=0.4.0', - 'numpy>=1.22', - "numpy>=1.23.2; python_version>='3.11'", + f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', + 'ml_dtypes>=0.2.0', + 'numpy>=1.24', "numpy>=1.26.0; python_version>='3.12'", 'opt_einsum', 'scipy>=1.9', "scipy>=1.11.1; python_version>='3.12'", - # Required by xla_bridge.discover_pjrt_plugins for forwards compat with - # Python versions < 3.10. Can be dropped when 3.10 is the minimum - # required Python version. - 'importlib_metadata>=4.6;python_version<"3.10"', ], extras_require={ # Minimum jaxlib version; used in testing. 'minimum-jaxlib': [f'jaxlib=={_minimum_jaxlib_version}'], - # CPU-only jaxlib can be installed via: - # $ pip install jax[cpu] - 'cpu': [f'jaxlib=={_current_jaxlib_version}'], + # A CPU-only jax doesn't require any extras, but we keep this extra + # around for compatibility. + 'cpu': [], # Used only for CI builds that install JAX from github HEAD. 'ci': [f'jaxlib=={_latest_jaxlib_version_on_pypi}'], @@ -80,76 +75,38 @@ def load_version_module(pkg_path): # Cloud TPU VM jaxlib can be installed via: # $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 'tpu': [ - f'jaxlib=={_current_jaxlib_version}', + f'jaxlib>={_current_jaxlib_version},<={_jax_version}', f'libtpu-nightly=={_libtpu_version}', 'requests', # necessary for jax.distributed.initialize ], - # CUDA installations require adding the JAX CUDA releases URL, e.g., - # Cuda installation defaulting to a CUDA and Cudnn version defined above. - # $ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - 'cuda': [f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}"], - - - 'cuda12_pip': [ - f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}", - "nvidia-cublas-cu12>=12.1.3.1", - "nvidia-cuda-cupti-cu12>=12.1.105", - "nvidia-cuda-nvcc-cu12>=12.1.105", - "nvidia-cuda-runtime-cu12>=12.1.105", - # https://docs.nvidia.com/deeplearning/cudnn/developer/misc.html#cudnn-api-compatibility - "nvidia-cudnn-cu12>=9.0,<10.0", - "nvidia-cufft-cu12>=11.0.2.54", - "nvidia-cusolver-cu12>=11.4.5.107", - "nvidia-cusparse-cu12>=12.1.0.106", - "nvidia-nccl-cu12>=2.18.1", - # nvjitlink is not a direct dependency of JAX, but it is a transitive - # dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages - # do not have a version constraint on their dependencies, so the - # package doesn't get upgraded even though not doing that can cause - # problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196) - # Until NVIDIA add version constraints, add a version constraint - # here. - "nvidia-nvjitlink-cu12>=12.1.105", - ], + 'cuda': [ + f"jaxlib=={_current_jaxlib_version}", + f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", + ], 'cuda12': [ f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin=={_current_jaxlib_version}", - "nvidia-cublas-cu12>=12.1.3.1", - "nvidia-cuda-cupti-cu12>=12.1.105", - "nvidia-cuda-nvcc-cu12>=12.1.105", - "nvidia-cuda-runtime-cu12>=12.1.105", - "nvidia-cudnn-cu12>=9.0,<10.0", - "nvidia-cufft-cu12>=11.0.2.54", - "nvidia-cusolver-cu12>=11.4.5.107", - "nvidia-cusparse-cu12>=12.1.0.106", - "nvidia-nccl-cu12>=2.18.1", - # nvjitlink is not a direct dependency of JAX, but it is a transitive - # dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages - # do not have a version constraint on their dependencies, so the - # package doesn't get upgraded even though not doing that can cause - # problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196) - # Until NVIDIA add version constraints, add a version constraint - # here. - "nvidia-nvjitlink-cu12>=12.1.105", + f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", + ], + + # Deprecated alias for cuda12, kept to avoid breaking users who wrote + # cuda12_pip in their CI. + 'cuda12_pip': [ + f"jaxlib=={_current_jaxlib_version}", + f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", ], # Target that does not depend on the CUDA pip wheels, for those who want # to use a preinstalled CUDA. 'cuda12_local': [ - f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}", + f"jaxlib=={_current_jaxlib_version}", + f"jax-cuda12-plugin=={_current_jaxlib_version}", ], - - # CUDA installations require adding jax releases URL; e.g. - # $ pip install jax[cuda12_cudnn89] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - **{f'cuda12_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{cudnn_version}" - for cudnn_version in _available_cuda12_cudnn_versions} }, url='https://github.com/google/jax', license='Apache-2.0', classifiers=[ - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", diff --git a/tests/BUILD b/tests/BUILD index 83cbc368f9f3..036946529de1 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -35,6 +35,7 @@ jax_test( name = "api_test", srcs = ["api_test.py"], shard_count = 10, + tags = ["test_cpu_thunks"], ) jax_test( @@ -74,6 +75,11 @@ jax_test( }, ) +jax_test( + name = "config_test", + srcs = ["config_test.py"], +) + jax_test( name = "core_test", srcs = ["core_test.py"], @@ -126,15 +132,6 @@ jax_test( deps = ["//jax:extend"], ) -py_test( - name = "ffi_test", - srcs = ["ffi_test.py"], - deps = [ - "//jax", - "//jax:test_util", - ], -) - jax_test( name = "fft_test", srcs = ["fft_test.py"], @@ -269,6 +266,9 @@ jax_test( jax_test( name = "layout_test", srcs = ["layout_test.py"], + backend_tags = { + "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. + }, tags = ["multiaccelerator"], ) @@ -318,6 +318,9 @@ jax_test( jax_test( name = "array_test", srcs = ["array_test.py"], + backend_tags = { + "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. + }, tags = ["multiaccelerator"], deps = [ "//jax:experimental", @@ -349,6 +352,7 @@ jax_test( jax_test( name = "infeed_test", srcs = ["infeed_test.py"], + tags = ["test_cpu_thunks"], deps = [ "//jax:experimental_host_callback", ], @@ -358,6 +362,7 @@ jax_test( name = "jax_jit_test", srcs = ["jax_jit_test.py"], main = "jax_jit_test.py", + tags = ["test_cpu_thunks"], ) py_test( @@ -785,6 +790,14 @@ jax_test( jax_test( name = "qdwh_test", srcs = ["qdwh_test.py"], + backend_tags = { + "tpu": [ + "noasan", # Times out + "nomsan", # Times out + "notsan", # Times out + ], + }, + shard_count = 10, ) jax_test( @@ -1079,6 +1092,11 @@ jax_test( main = "third_party/scipy/line_search_test.py", ) +jax_test( + name = "blocked_sampler_test", + srcs = ["blocked_sampler_test.py"], +) + py_test( name = "tree_util_test", srcs = ["tree_util_test.py"], @@ -1136,6 +1154,16 @@ py_test( ], ) +py_test( + name = "lru_cache_test", + srcs = ["lru_cache_test.py"], + deps = [ + "//jax", + "//jax:lru_cache", + "//jax:test_util", + ] + py_deps("filelock"), +) + jax_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.py"], diff --git a/tests/api_test.py b/tests/api_test.py index 71dba9535ef2..3929a29f9c30 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -16,6 +16,7 @@ import collections import collections.abc +from collections.abc import Callable import concurrent.futures from contextlib import contextmanager import copy @@ -33,7 +34,7 @@ import subprocess import sys import types -from typing import Callable, NamedTuple +from typing import NamedTuple import unittest import weakref @@ -53,6 +54,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src import debugging +from jax._src import pjit as pjit_lib from jax._src.ad_checkpoint import saved_residuals from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe @@ -229,7 +231,9 @@ def f(x, y, z): def test_jit_device(self): device = jax.devices()[-1] - x = jit(lambda x: x, device=device)(3.) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + x = jit(lambda x: x, device=device)(3.) _check_instance(self, x) self.assertEqual(x.devices(), {device}) @@ -260,10 +264,12 @@ def test_jit_default_device(self, module): with jax.default_device(test_device): # Explicit `device` or `backend` argument to jit overrides default_device - self.assertEqual( - module(f, device=system_default_device)(1).devices(), - system_default_devices) - out = module(f, backend="cpu")(1) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + self.assertEqual( + module(f, device=system_default_device)(1).devices(), + system_default_devices) + out = module(f, backend="cpu")(1) self.assertEqual(next(iter(out.devices())).platform, "cpu") # Sticky input device overrides default_device @@ -703,7 +709,6 @@ def test_trivial_computations(self): self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer()) self.assertEqual(z2, 1) - @unittest.skipIf(xla_extension_version < 264, "jaxlib version too old") def test_print_token_buffer_error(self): token = jax.lax.create_token() with self.assertRaisesRegex( @@ -856,8 +861,10 @@ def test_cpp_jitted_function_returns_PyBuffer(self): @jtu.skip_on_devices("cpu") def test_explicit_backend(self, module): f = lambda x: x + 1 - jitted_f = module(f, backend=jtu.device_under_test()) - jitted_f_cpu = module(f, backend="cpu") + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + jitted_f = module(f, backend=jtu.device_under_test()) + jitted_f_cpu = module(f, backend="cpu") result = jitted_f(1.) result_cpu = jitted_f_cpu(1.) @@ -872,8 +879,10 @@ def test_explicit_backend(self, module): def test_device_to_device_copy_between_backends(self, module): # b/186624243 f = lambda x: x + 1 - jitted_f = module(f, backend=jtu.device_under_test()) - jitted_f_cpu = module(f, backend="cpu") + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + jitted_f = module(f, backend=jtu.device_under_test()) + jitted_f_cpu = module(f, backend="cpu") x = np.arange(30).reshape(1, 10, 3) result = jitted_f(x) @@ -884,6 +893,8 @@ def test_device_to_device_copy_between_backends(self, module): self.assertAllClose(result_cpu_2, x + 4) @jtu.skip_on_devices("cpu") + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def test_mismatched_nested_backends(self): @partial(jax.jit, backend=jtu.device_under_test()) def f(x): @@ -2215,12 +2226,10 @@ def f(x, y): return x + y def test_vjp_mismatched_arguments(self): _, pullback = api.vjp(lambda x, y: x * y, np.float32(3), np.float32(4)) self.assertRaisesRegex( - TypeError, - "Tree structure of cotangent input.*does not match", + ValueError, "unexpected tree structure", lambda: pullback((np.float32(7), np.float32(100)))) self.assertRaisesRegex( - TypeError, - "Type of cotangent input to vjp pullback.*is not the expected tangent type", + ValueError, "unexpected JAX type", lambda: pullback(np.float16(42))) def test_vjp_bad_cotangent_shape(self): @@ -2229,9 +2238,7 @@ def test_vjp_bad_cotangent_shape(self): def f_jax(x, y): return jnp.matmul(x, y) res, pullback = jax.vjp(f_jax, x, y) - with self.assertRaisesRegex( - ValueError, - "Shape of cotangent input to vjp pullback function .* must be the same as the shape of corresponding primal input .*"): + with self.assertRaisesRegex(ValueError, "unexpected JAX type"): pullback(np.ones((2, 4), dtype=np.float32)) def test_jvp_jit_cached(self): @@ -2583,7 +2590,7 @@ def fun(x, y): def test_eval_shape_trace_cache_share(self): def f(x): - return x * 2 + return x inp = np.arange(8) @@ -2591,8 +2598,32 @@ def f(x): jax.eval_shape(f, inp) jax.jit(f)(inp) - # one for `f` and another for mul (`x * 2`) which is jitted. - self.assertEqual(count[0], 2) + self.assertEqual(count[0], 1) + + @unittest.skipIf(xla_extension_version <= 273, "requires jaxlib 0.4.31") + def test_jit_infer_params_cache(self): + def f(x): + return x + + f_jit = jax.jit(f) + + def g(x): + x = f_jit(x) # noqa: F821 + x = f_jit(x) # noqa: F821 + return x + + g_jit = jax.jit(g) + + inp = np.arange(8) + with jtu.count_jit_infer_params_cache_miss() as count: + g_jit(inp) + + self.assertDictEqual(count, {f: 1, g: 1}) + cache_size = pjit_lib._infer_params_cached.cache_info().currsize + del count, f, f_jit, g, g_jit + # Cache should only keep a weak reference to f and g. + self.assertLess(pjit_lib._infer_params_cached.cache_info().currsize, + cache_size, msg=pjit_lib._infer_params_cached.cache_keys()) def test_eval_shape_out_shardings(self): s = jax.sharding.SingleDeviceSharding(jax.devices()[0]) @@ -2709,14 +2740,14 @@ def test_vjp_of_int_index(self): self.assertEqual(tangent_i, np.zeros(shape=(), dtype=float0)) def test_vjp_of_int_shapes(self): - out, fn_vjp = api.vjp(lambda x: lax.reshape(x, (2, 2)), np.ones((4, 1), - dtype=int)) - tangent, = fn_vjp(out) + out, fn_vjp = api.vjp( + lambda x: lax.reshape(x, (2, 2)), np.ones((4, 1), dtype=int)) + tangent, = fn_vjp(np.zeros((2, 2), dtypes.float0)) self.assertArraysEqual(tangent, np.zeros(shape=(4, 1), dtype=float0)) def test_jit_vjp_of_int(self): primal, fn_vjp = api.vjp(lambda x, y: x+y, 2, 1) - tangent_x, tangent_i = jax.jit(fn_vjp)(1) + tangent_x, tangent_i = jax.jit(fn_vjp)(np.zeros((), dtypes.float0)) self.assertEqual(primal, 3) self.assertEqual(tangent_x, np.zeros(shape=(), dtype=float0)) self.assertEqual(tangent_i, np.zeros(shape=(), dtype=float0)) @@ -2949,6 +2980,7 @@ def fn(x): axis_env = [(axis_name, jax.local_device_count())] _ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x) + @jtu.unaccelerate_getattr_deprecation(jax, 'xla_computation') def test_xla_computation_axis_env(self): def fn(x): z = x * jax.lax.axis_index('i').astype(jnp.float32) @@ -4062,6 +4094,16 @@ def f(): return jnp.exp(dtype(0)) f() # doesn't error + def test_vmap_make_jaxpr_close_over_tracer(self): + def run(inp): + def f(x, y): + return x + y + g = lambda x: f(x, inp) + jaxpr = jax.make_jaxpr(g)(1) + return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1) + + jax.vmap(run)(jnp.arange(2)) # doesn't crash + def test_large_python_ints(self): with self.assertRaises(OverflowError): jnp.multiply(2 ** 100, 3.) @@ -4339,9 +4381,14 @@ def f(x, y): g = jax.grad(f, argnums=-1) g(x, y) # doesn't crash + @unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31") def test_jit_negative_static_argnums(self): - g = jax.jit(lambda x, y: x * y, static_argnums=-1) - g(1, 2) # doesn't crash + @partial(jax.jit, static_argnums=-1) + def g(x, y): + assert isinstance(y, int) + return x * y + for i in range(3): # Loop verifies we exercise both Python and C++ dispatch + self.assertEqual(2 * i, g(2, i), msg=i) def test_fastpath_cache_confusion(self): # https://github.com/google/jax/issues/12542 @@ -4463,6 +4510,10 @@ def f(i): jax.clear_caches() self.assertEqual(f._cache_size, 0) + def test_invalid_value_device_put(self): + with self.assertRaisesRegex(ValueError, r".*Received invalid value.*"): + jax.device_put(jnp.arange(8), 'cpu') + def test_clear_cache(self): @jax.jit def add(x): @@ -4708,6 +4759,34 @@ def test_inner_jit_forwarded_consts_stay_const(self): out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash self.assertEqual(out, 3) + def test_lowering_platform_aot(self): + @jax.jit + def f(x): + return x * 2 + + f.trace(jnp.arange(8)).lower(lowering_platforms=('tpu',)) # doesn't crash + + def test_no_double_dots_in_error_message(self): + @jax.jit + def f(x): + return 1 if x > 0 else 0 + + with self.assertRaisesRegex(TracerBoolConversionError, r"with shape bool\[\]\.[^\.]"): + f(0) + + def test_inlined_literals_with_error(self): + @jax.jit + def f(): + @partial(jax.jit, inline=True) + def g(): + return jnp.sin(1.) + if g() > 0: + return 1. + return 0. + + with self.assertRaisesRegex(TracerBoolConversionError, "Attempted boolean"): + f() + class RematTest(jtu.JaxTestCase): @@ -9462,6 +9541,22 @@ def foo_bwd(_, g): finally: jax.config.update('jax_custom_vjp_disable_shape_check', False) + def test_bwd_rule_can_produce_list_or_tuple(self): + @jax.custom_vjp + def f(x, y): + return x * y + + def f_fwd(x, y): + return f(x, y), (x, y) + + def f_bwd(xy, g): + x, y = xy + return [g * y, x * g] # list, not tuple + + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)(1., 2.) # don't crash + def transpose_unary(f, x_example): def transposed(y): @@ -10523,8 +10618,8 @@ def test_pmap_nested_donate_ignored(self): class NamedCallTest(jtu.JaxTestCase): + @jtu.unaccelerate_getattr_deprecation(jax, 'xla_computation') def test_default_name(self): - @api.named_call def my_test_function(x): return x**2 @@ -10695,9 +10790,11 @@ def wsc_as_noop(ctx, operand, *args, **kwargs): rules = ((jax.lax.sharding_constraint_p, wsc_as_noop),) lowered_ir = ( jax.jit(f) - .lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16), - _experimental_lowering_parameters=mlir.LoweringParameters( - override_lowering_rules=rules)).as_text()) + .trace(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16)) + .lower(_private_parameters=mlir.LoweringParameters( + override_lowering_rules=rules)) + .as_text() + ) self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index b555576b3261..c2cd4c0f968d 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -214,7 +214,6 @@ def testNumpyToJax(self, shape, dtype, copy): shape=all_shapes, dtype=numpy_dtypes, ) - @unittest.skipIf(numpy_version < (1, 23, 0), "Requires numpy 1.23 or newer") @jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks def testJaxToNumpy(self, shape, dtype): rng = jtu.rand_default(self.rng()) diff --git a/tests/array_test.py b/tests/array_test.py index 7bb79f212772..fe7c4208dca8 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -15,7 +15,6 @@ import contextlib import math -import os import unittest from absl.testing import absltest @@ -25,15 +24,18 @@ import jax import jax.numpy as jnp from jax._src import core +from jax._src import deprecations from jax._src import dispatch from jax._src import op_shardings from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc from jax._src.util import safe_zip +from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import (_op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map, - NamedSharding, GSPMDSharding) + NamedSharding, GSPMDSharding, + PositionalSharding) from jax.experimental.pjit import pjit from jax.experimental import multihost_utils from jax.sharding import PartitionSpec as P @@ -42,33 +44,18 @@ jax.config.parse_flags_with_absl() - -prev_xla_flags = None - with contextlib.suppress(ImportError): import pytest pytestmark = pytest.mark.multiaccelerator - # Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() + def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xb.get_backend.cache_clear() - -# Reset to previous configuration in case other test modules will be run. + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xb.get_backend.cache_clear() + _exit_stack.close() def create_array(shape, sharding, global_data=None): @@ -610,6 +597,22 @@ def test_array_addressable_shards(self): x = jnp.array([1, 2, 3]) self.assertIsInstance(x.addressable_data(0), array.ArrayImpl) + def test_array_not_hashable(self): + x = jnp.arange(4) + with self.assertRaisesRegex(TypeError, "unhashable type"): + hash(x) + + @jax.jit + def check_tracer_hash(x): + self.assertIsInstance(hash(x), int) + + if deprecations.is_accelerated('tracer-hash'): + with self.assertRaisesRegex(TypeError, "unhashable type"): + check_tracer_hash(x) + else: + with self.assertWarnsRegex(FutureWarning, "unhashable type"): + check_tracer_hash(x) + def test_shape_dtype_struct_sharding_jit(self): mesh = jtu.create_global_mesh((8,), ('x')) s = jax.sharding.NamedSharding(mesh, P('x')) @@ -845,6 +848,15 @@ def test_mesh_pspec_sharding_interface(self): self.assertListEqual(hlo_sharding.tile_assignment_devices(), [0, 2, 4, 6, 1, 3, 5, 7]) + def test_util_clear_cache(self): + mesh = jtu.create_global_mesh((1,), ('x',)) + s = NamedSharding(mesh, P()) + s.devices_indices_map((8,)) + jax.clear_caches() + s.devices_indices_map((8,)) + c = common_devices_indices_map.cache_info() + self.assertEqual(c.currsize, 1) + @parameterized.named_parameters( ("mesh_x_y", P("x", "y")), ("mesh_x", P("x")), @@ -922,7 +934,7 @@ def test_is_compatible_error(self): r"Sharding NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), " r"spec=PartitionSpec\(None, \('mdl',\), None, None\).*\) is only " "valid for values of rank at least 4, but was applied to a value of rank 2"): - new_mps.is_compatible_aval(shape) + new_mps.check_compatible_aval(shape) def test_is_subclass(self): # array version of api_test.py::APITest::test_is_subclass @@ -947,6 +959,10 @@ def test_gspmd_sharding_repr(self): # memory kind also appears in the repr but only for TPU. self.assertIn('GSPMDSharding({replicated}', repr(s2)) + def test_positional_sharding_fully_replicated(self): + sharding = PositionalSharding(jax.devices()) + jax.device_put(jnp.array(1), sharding.replicate()) # doesn't crash + @parameterized.named_parameters( ("mesh_x_y", P("x", "y"), (4, 2), (), False), ("mesh_x", P("x"), (4, 2), (1,), False), @@ -976,6 +992,17 @@ def test_positional_sharding_op_sharding_lowering( devices_sharding.shard_shape(value_shape)) self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) + def test_positional_sharding_aval_compatible(self): + if jax.device_count() < 2: + self.skipTest('Requires >=2 devices') + sharding = PositionalSharding(jax.devices()).reshape(1, jax.device_count()) + x = jax.random.uniform(jax.random.key(42), (256, 20, 1000)) + with self.assertRaisesRegex( + ValueError, + 'Sharding PositionalSharding.*is only valid for values of rank 2, but' + ' was applied to a value of rank 3'): + jax.lax.with_sharding_constraint(x, sharding) + @parameterized.named_parameters( ("2d_mesh_x_y", (4, 2), P("x", "y")), ("2d_mesh_x", (4, 2), P("x")), @@ -1178,7 +1205,7 @@ def test_scalar_input_wrong_pspec(self): with self.assertRaisesRegex( ValueError, r"For scalars the PartitionSpec should be P()"): - s.is_compatible_aval(shape) + s.check_compatible_aval(shape) def test_mesh_caching_during_construction(self): if jax.device_count() < 2: @@ -1222,6 +1249,18 @@ def f(x): with self.assertRaisesRegex(ValueError, msg): jax.jit(f)(x) + def test_make_array_from_single_device_arrays_bad_inputs(self): + x = jnp.arange(10) + mesh = jtu.create_global_mesh((2,), ('x',)) + s = jax.sharding.NamedSharding(mesh, P('x')) + x = jax.device_put(x, s) + + msg = ("When making an array from single-device arrays the input arrays " + "must have one shard each. An argument array had 2 shard\\(s\\).") + with self.assertRaisesRegex(ValueError, msg): + jax.make_array_from_single_device_arrays(x.shape, s, [x, x]) + + def test_gspmd_sharding_hash_eq(self): mesh = jtu.create_global_mesh((1, 1, 1), ('x', 'y', 'z')) ns = NamedSharding(mesh, P('x', 'y', 'z')) diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 1f372907596e..5c834f314270 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -45,7 +45,6 @@ class Thing: class AttrsTest(jtu.JaxTestCase): - @parameterized.parameters([True, False]) def test_jit_basic(self, jit: bool): thing = Thing(1.0) @@ -67,6 +66,100 @@ def double_it() -> None: double_it() self.assertEqual(thing.x, 16.0) + @parameterized.parameters([True, False]) + def test_jit_basic_tree(self, jit: bool): + thing = Thing((1.0, 2.0)) + + def double_it() -> None: + (cur_x, cur_y) = jax_getattr(thing, "x") + jax_setattr(thing, "x", (cur_x * 2, cur_y * 2)) + + if jit: + double_it = jax.jit(double_it) + + self.assertEqual(thing.x, (1.0, 2.0)) + double_it() + self.assertEqual(thing.x, (2.0, 4.0)) + double_it() + self.assertEqual(thing.x, (4.0, 8.0)) + double_it() + self.assertEqual(thing.x, (8.0, 16.0)) + double_it() + self.assertEqual(thing.x, (16.0, 32.0)) + + @parameterized.parameters([True, False]) + def test_jit_basic_tree_changes(self, jit: bool): + thing = Thing(None) + count = 0 + + def double_it() -> None: + nonlocal count + count += 1 + maybe_x = jax_getattr(thing, "x") + x = 1.0 if maybe_x is None else maybe_x + jax_setattr(thing, "x", 2 * x) + + if jit: + double_it = jax.jit(double_it) + + self.assertEqual(thing.x, None) + double_it() + self.assertEqual(thing.x, 2.0) + self.assertEqual(count, 1) + double_it() + self.assertEqual(thing.x, 4.0) + self.assertEqual(count, 2) + double_it() + self.assertEqual(thing.x, 8.0) + self.assertEqual(count, 2 + (not jit)) + + def test_jit_basic_tree_changes_multiple(self): + thing1 = Thing(None) + thing2 = Thing(0) + count = 0 + + @jax.jit + def double_it() -> None: + nonlocal count + count += 1 + + x1 = jax_getattr(thing1, "x") + if x1 is None: + jax_setattr(thing1, 'x', (None,)) + elif isinstance(x1, tuple): + # depend on a new value + jax_setattr(thing1, 'x', jax_getattr(thing2, 'x') + 1) + else: + jax_setattr(thing2, 'x', jax_getattr(thing1, 'x')) + jax_setattr(thing1, 'x', None) + + self.assertEqual(thing1.x, None) + self.assertEqual(thing2.x, 0) + double_it() + self.assertEqual(thing1.x, (None,)) + self.assertEqual(thing2.x, 0) + self.assertEqual(count, 1) + double_it() + self.assertEqual(thing1.x, 1) + self.assertEqual(thing2.x, 0) + self.assertEqual(count, 2) + double_it() + self.assertEqual(thing1.x, None) + self.assertEqual(thing2.x, 1) + self.assertEqual(count, 3) + double_it() + self.assertEqual(thing1.x, (None,)) + self.assertEqual(thing2.x, 1) + self.assertEqual(count, 3) + double_it() + self.assertEqual(thing1.x, 2) + self.assertEqual(thing2.x, 1) + self.assertEqual(count, 3) + double_it() + self.assertEqual(thing1.x, None) + self.assertEqual(thing2.x, 2) + self.assertEqual(count, 3) + def test_jit_nesting_basic(self): thing = Thing(1.0) diff --git a/tests/batching_test.py b/tests/batching_test.py index 36e686443ac7..4d912bfca206 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -14,10 +14,11 @@ from __future__ import annotations +from collections.abc import Callable from contextlib import contextmanager from functools import partial import itertools as it -from typing import Any, Callable, TypeVar, Union +from typing import Any, TypeVar, Union import numpy as np from absl.testing import absltest diff --git a/tests/blocked_sampler_test.py b/tests/blocked_sampler_test.py new file mode 100644 index 000000000000..1f8f2b645f06 --- /dev/null +++ b/tests/blocked_sampler_test.py @@ -0,0 +1,90 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import numpy as jnp +from jax._src import blocked_sampler +from jax._src import config +from jax._src import test_util as jtu +import numpy as np + + +config.parse_flags_with_absl() + + +def call_kernel( + kernel, + grid: tuple[int, int], + transpose_grid: bool, + *args + ): + """Calls a kernel over a grid and concatenates results to a single array.""" + if transpose_grid: + grid = (grid[1], grid[0]) + m, n = grid + return jnp.concatenate([ + jnp.concatenate([ + kernel(i, j, *args) for j in range(n)], axis=1) + for i in range(m)], axis=0) + + +def uniform_kernel(i: int, j: int, total_size, block_size, tile_size): + """Uniform random sampling kernel function.""" + global_key = jax.random.key(0) + keys = blocked_sampler.blocked_fold_in(global_key, + total_size=total_size, + block_size=block_size, + tile_size=tile_size, + block_index=(i, j)) + return blocked_sampler.sample_block(jax.random.uniform, + keys, + block_size=block_size, + tile_size=tile_size, + minval=0.0, maxval=1.0) + + +class BlockedSamplerTest(jtu.JaxTestCase): + + @parameterized.named_parameters( + dict(testcase_name='8x128_vs_16x256', total_size=(32, 256), + block_size_a=(8, 128), block_size_b=(16, 256), + tile_size=(8, 128), transpose_grid=False), + dict(testcase_name='transpose_8x128_vs_16x256', total_size=(32, 256), + block_size_a=(8, 128), block_size_b=(16, 256), + tile_size=(8, 128), transpose_grid=True), + dict(testcase_name='8x128_vs_32x128', total_size=(32, 128), + block_size_a=(8, 128), block_size_b=(32, 128), + tile_size=(8, 128), transpose_grid=False), + dict(testcase_name='16x256_vs_32x128', total_size=(32, 256), + block_size_a=(16, 256), block_size_b=(32, 128), + tile_size=(8, 128), transpose_grid=False), + ) + def test_block_shape_invariance(self, total_size, block_size_a, + block_size_b, tile_size, transpose_grid): + grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a)) + result_a = call_kernel( + uniform_kernel, grid_a, transpose_grid, + total_size, block_size_a, tile_size) + + grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b)) + result_b = call_kernel( + uniform_kernel, grid_b, transpose_grid, + total_size, block_size_b, tile_size) + np.testing.assert_array_equal(result_a, result_b) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/cholesky_update_test.py b/tests/cholesky_update_test.py index f37ea1191d41..63f732dcd55d 100644 --- a/tests/cholesky_update_test.py +++ b/tests/cholesky_update_test.py @@ -19,16 +19,11 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.lax import linalg as lax_linalg -from jax._src.lib import version as jaxlib_version # pylint: disable=g-importing-member import numpy as np config.parse_flags_with_absl() class CholeskyUpdateTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - if jaxlib_version < (0, 4, 29): - self.skipTest("Requires jaxlib 0.4.29 or newer") @jtu.sample_product( shape=[ diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 805a8590edea..d5bf671b1f83 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -60,6 +60,22 @@ def tearDownModule(): def increment_event_count(event): _counts[event] += 1 +class CompilationCacheTestCase(jtu.JaxTestCase): + tmpdir: str + + def setUp(self): + super().setUp() + cc.reset_cache() + tmpdir = tempfile.TemporaryDirectory() + self.enter_context(tmpdir) + self.enter_context(config.compilation_cache_dir(tmpdir.name)) + self.tmpdir = tmpdir.name + + def tearDown(self): + cc.reset_cache() + self.tmpdir = "" + super().tearDown() + @jtu.with_config( jax_enable_compilation_cache=True, @@ -67,8 +83,7 @@ def increment_event_count(event): jax_persistent_cache_min_compile_time_secs=0, jax_persistent_cache_min_entry_size_bytes=0, ) -class CompilationCacheTest(jtu.JaxTestCase): - +class CompilationCacheTest(CompilationCacheTestCase): def setUp(self): super().setUp() supported_platforms = ["tpu", "gpu", "cpu"] @@ -78,255 +93,220 @@ def setUp(self): "serialize executable only works on " + ",".join(supported_platforms) ) - cc.reset_cache() - - def tearDown(self): - cc.reset_cache() - super().tearDown() - def test_get_no_executable(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() - devices = np.array([[jax.local_devices()[0]]]) - compile_options = compiler.get_compile_options( - num_replicas=1, num_partitions=1 - ) - backend = xla_bridge.get_backend() - key = cc.get_cache_key(computation, devices, compile_options, backend) - executable, compile_time = cc.get_executable_and_time( - key, compile_options, backend) - self.assertIsNone(executable) - self.assertIsNone(compile_time) + computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() + devices = np.array([[jax.local_devices()[0]]]) + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + key = cc.get_cache_key(computation, devices, compile_options, backend) + executable, compile_time = cc.get_executable_and_time( + key, compile_options, backend) + self.assertIsNone(executable) + self.assertIsNone(compile_time) def test_diff_executables(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - computation1 = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()) - computation2 = str(jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()) - compile_options = compiler.get_compile_options( - num_replicas=1, num_partitions=1 - ) - backend = xla_bridge.get_backend() - executable1 = backend.compile(computation1, compile_options) - executable2 = backend.compile(computation2, compile_options) - cc.put_executable_and_time( - "key1", "computation1", executable1, backend, FAKE_COMPILE_TIME) - cc.put_executable_and_time( - "key2", "computation2", executable2, backend, FAKE_COMPILE_TIME) - self.assertNotEqual( - cc.get_executable_and_time("key1", compile_options, backend)[0], - cc.get_executable_and_time("key2", compile_options, backend)[0] - ) + computation1 = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()) + computation2 = str(jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()) + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + executable1 = backend.compile(computation1, compile_options) + executable2 = backend.compile(computation2, compile_options) + cc.put_executable_and_time( + "key1", "computation1", executable1, backend, FAKE_COMPILE_TIME) + cc.put_executable_and_time( + "key2", "computation2", executable2, backend, FAKE_COMPILE_TIME) + self.assertNotEqual( + cc.get_executable_and_time("key1", compile_options, backend)[0], + cc.get_executable_and_time("key2", compile_options, backend)[0] + ) def test_put_executable(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - computation = ( - jax.jit(lambda x, y: x + y) - .lower(np.int32(1), np.int32(1)) - .compiler_ir() - ) - devices = np.array([[jax.local_devices()[0]]]) - compile_options = compiler.get_compile_options( - num_replicas=1, num_partitions=1 - ) - backend = xla_bridge.get_backend() - executable = backend.compile(str(computation), compile_options) - key = cc.get_cache_key(computation, devices, compile_options, backend) - cc.put_executable_and_time( - key, "alambda", executable, backend, FAKE_COMPILE_TIME) - executable_retrieved, compile_time_retrieved = cc.get_executable_and_time( - key, compile_options, backend) - inputs_to_executable = ( - np.array(1, dtype=np.int32), - np.array(2, dtype=np.int32), - ) - expected = xla_client.execute_with_python_values( - executable, inputs_to_executable, backend - ) - actual = xla_client.execute_with_python_values( - executable_retrieved, inputs_to_executable, backend - ) - self.assertEqual(expected, actual) - self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved) + computation = ( + jax.jit(lambda x, y: x + y) + .lower(np.int32(1), np.int32(1)) + .compiler_ir() + ) + devices = np.array([[jax.local_devices()[0]]]) + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + executable = backend.compile(str(computation), compile_options) + key = cc.get_cache_key(computation, devices, compile_options, backend) + cc.put_executable_and_time( + key, "alambda", executable, backend, FAKE_COMPILE_TIME) + executable_retrieved, compile_time_retrieved = cc.get_executable_and_time( + key, compile_options, backend) + inputs_to_executable = ( + np.array(1, dtype=np.int32), + np.array(2, dtype=np.int32), + ) + expected = xla_client.execute_with_python_values( + executable, inputs_to_executable, backend + ) + actual = xla_client.execute_with_python_values( + executable_retrieved, inputs_to_executable, backend + ) + self.assertEqual(expected, actual) + self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved) def test_pmap(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - f = pmap(lambda x: x - lax.psum(x, "i"), axis_name="i") - x = np.arange(jax.device_count(), dtype=np.int64) - f(x) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 1) - x = np.arange(jax.device_count(), dtype=np.float32) - f(x) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 2) - # TODO: create a test for calling pmap with the same input more than once + f = pmap(lambda x: x - lax.psum(x, "i"), axis_name="i") + x = np.arange(jax.device_count(), dtype=np.int64) + f(x) + files_in_directory = len(os.listdir(self.tmpdir)) + self.assertEqual(files_in_directory, 1) + x = np.arange(jax.device_count(), dtype=np.float32) + f(x) + files_in_directory = len(os.listdir(self.tmpdir)) + self.assertEqual(files_in_directory, 2) + # TODO: create a test for calling pmap with the same input more than once def test_jit(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - f = jit(lambda x: x * x) - f(1) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 1) - f(1.0) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 2) + f = jit(lambda x: x * x) + f(1) + files_in_directory = len(os.listdir(self.tmpdir)) + self.assertEqual(files_in_directory, 1) + f(1.0) + files_in_directory = len(os.listdir(self.tmpdir)) + self.assertEqual(files_in_directory, 2) def test_xla_autofdo_profile_version(self): original_profile_version = config.jax_xla_profile_version.value - with (tempfile.TemporaryDirectory() as tmpdir, - config.jax_xla_profile_version(original_profile_version + 1)): - cc.set_cache_dir(tmpdir) + with config.jax_xla_profile_version(original_profile_version + 1): f = jit(lambda x: x * x) f(1) - files_in_cache_directory = os.listdir(tmpdir) + files_in_cache_directory = os.listdir(self.tmpdir) self.assertLen(files_in_cache_directory, 1) # Clear the cache directory, then update the profile version and execute # again. The in-memory caches should be invalidated and a new persistent # cache entry created. - os.unlink(os.path.join(tmpdir, files_in_cache_directory[0])) + os.unlink(os.path.join(self.tmpdir, files_in_cache_directory[0])) with config.jax_xla_profile_version(original_profile_version + 2): f(1) - files_in_directory = len(os.listdir(tmpdir)) + files_in_directory = len(os.listdir(self.tmpdir)) self.assertEqual(files_in_directory, 1) @jtu.with_mesh([("x", 2)]) def test_pjit(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - - @partial(pjit, in_shardings=(P("x"), P("x")), out_shardings=None) - def f(x, y): - return x + y - - shape = (8, 8) - x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape) - f(x, x + 1) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 1) - x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - f(x, x + 1) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 2) + @partial(pjit, in_shardings=(P("x"), P("x")), out_shardings=None) + def f(x, y): + return x + y + + shape = (8, 8) + x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape) + f(x, x + 1) + files_in_directory = len(os.listdir(self.tmpdir)) + self.assertEqual(files_in_directory, 1) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) + f(x, x + 1) + files_in_directory = len(os.listdir(self.tmpdir)) + self.assertEqual(files_in_directory, 2) @jtu.with_mesh([("x", 2)]) def test_xmap(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - - def f(x): - return x * 2 - - devices = np.array(jax.local_devices()[:2]) - if devices.size < 2: - raise SkipTest("Test requires 2 devices") - x = np.arange(8, dtype=np.int64).reshape((2, 2, 2)) - xmap( - f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"} - )(x) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 1) - x = np.arange(8, dtype=np.float32).reshape((2, 2, 2)) - xmap( - f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"} - )(x) - files_in_directory = len(os.listdir(tmpdir)) - self.assertEqual(files_in_directory, 2) + def f(x): + return x * 2 + + devices = np.array(jax.local_devices()[:2]) + if devices.size < 2: + raise SkipTest("Test requires 2 devices") + x = np.arange(8, dtype=np.int64).reshape((2, 2, 2)) + xmap( + f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"} + )(x) + files_in_directory = len(os.listdir(self.tmpdir)) + self.assertEqual(files_in_directory, 1) + x = np.arange(8, dtype=np.float32).reshape((2, 2, 2)) + xmap( + f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"} + )(x) + files_in_directory = len(os.listdir(self.tmpdir)) + self.assertEqual(files_in_directory, 2) def test_cache_write_warning(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - f = jit(lambda x: x * x) + f = jit(lambda x: x * x) - with ( - config.raise_persistent_cache_errors(False), - mock.patch.object(cc._get_cache().__class__, "put") as mock_put, - warnings.catch_warnings(record=True) as w, - ): - mock_put.side_effect = RuntimeError("test error") - self.assertEqual(f(2).item(), 4) - if len(w) != 1: - print("Warnings:", [str(w_) for w_ in w], flush=True) - self.assertLen(w, 1) - self.assertIn( - ( - "Error writing persistent compilation cache entry " - "for 'jit__lambda_': RuntimeError: test error" - ), - str(w[0].message), - ) + backend = xla_bridge.get_backend() + with ( + config.raise_persistent_cache_errors(False), + mock.patch.object(cc._get_cache(backend).__class__, "put") as mock_put, + warnings.catch_warnings(record=True) as w, + ): + mock_put.side_effect = RuntimeError("test error") + self.assertEqual(f(2).item(), 4) + if len(w) != 1: + print("Warnings:", [str(w_) for w_ in w], flush=True) + self.assertLen(w, 1) + self.assertIn( + ( + "Error writing persistent compilation cache entry " + "for 'jit__lambda_': RuntimeError: test error" + ), + str(w[0].message), + ) def test_cache_read_warning(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - f = jit(lambda x: x * x) + f = jit(lambda x: x * x) - with ( - config.raise_persistent_cache_errors(False), - mock.patch.object(cc._get_cache().__class__, "get") as mock_get, - warnings.catch_warnings(record=True) as w, - ): - mock_get.side_effect = RuntimeError("test error") - # Calling assertEqual with the jitted f will generate two PJIT - # executables: Equal and the lambda function itself. - self.assertEqual(f(2).item(), 4) - if len(w) != 1: - print("Warnings:", [str(w_) for w_ in w], flush=True) - self.assertLen(w, 1) - self.assertIn( - ( - "Error reading persistent compilation cache entry " - "for 'jit__lambda_': RuntimeError: test error" - ), - str(w[0].message), - ) + backend = xla_bridge.get_backend() + with ( + config.raise_persistent_cache_errors(False), + mock.patch.object(cc._get_cache(backend).__class__, "get") as mock_get, + warnings.catch_warnings(record=True) as w, + ): + mock_get.side_effect = RuntimeError("test error") + # Calling assertEqual with the jitted f will generate two PJIT + # executables: Equal and the lambda function itself. + self.assertEqual(f(2).item(), 4) + if len(w) != 1: + print("Warnings:", [str(w_) for w_ in w], flush=True) + self.assertLen(w, 1) + self.assertIn( + ( + "Error reading persistent compilation cache entry " + "for 'jit__lambda_': RuntimeError: test error" + ), + str(w[0].message), + ) def test_min_entry_size(self): with ( - tempfile.TemporaryDirectory() as tmpdir, config.persistent_cache_min_compile_time_secs(0), config.persistent_cache_min_entry_size_bytes(1048576), # 1MiB ): - cc.set_cache_dir(tmpdir) - jit(lambda x: x + 1)(1) - files_in_cache = len(os.listdir(tmpdir)) + files_in_cache = len(os.listdir(self.tmpdir)) self.assertEqual(files_in_cache, 0) def test_min_compile_time(self): with ( - tempfile.TemporaryDirectory() as tmpdir, config.persistent_cache_min_compile_time_secs(2), config.persistent_cache_min_entry_size_bytes(0), ): - cc.set_cache_dir(tmpdir) - # Mock time to progress in small intervals so compilation time is small. with mock.patch("time.monotonic", side_effect=np.arange(0, 10, 0.1)): jit(lambda x: x + 1)(1) - files_in_cache = len(os.listdir(tmpdir)) + files_in_cache = len(os.listdir(self.tmpdir)) self.assertEqual(files_in_cache, 0) # Mock time to progress in large intervals so compilation time is large. with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)): jit(lambda x: x + 2)(1) - files_in_cache = len(os.listdir(tmpdir)) + files_in_cache = len(os.listdir(self.tmpdir)) self.assertEqual(files_in_cache, 1) # This is perhaps related to mocking time.monotonic? @unittest.skipIf(platform.system() == "Windows", "Test fails on Windows") def test_cache_saving_metric(self): with ( - tempfile.TemporaryDirectory() as tmpdir, config.persistent_cache_min_compile_time_secs(2), config.persistent_cache_min_entry_size_bytes(0), ): - cc.set_cache_dir(tmpdir) - durations = Counter() # Map metric name to time duration. def append_metric_duration(metric, duration): durations[metric] += duration @@ -356,29 +336,24 @@ def append_metric_duration(metric, duration): durations["/jax/compilation_cache/compile_time_saved_sec"], 0) def test_task_using_cache_metric(self): - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - count_before_first_use = _counts[ - "/jax/compilation_cache/tasks_using_cache"] - jit(lambda x: x + 1)(1) - count_after_first_use = _counts[ - "/jax/compilation_cache/tasks_using_cache"] - self.assertEqual(count_after_first_use, count_before_first_use + 1) - - # Verify that the count is incremented only once per task. - jit(lambda x: x + 3)(3) - count_after_second_use = _counts[ - "/jax/compilation_cache/tasks_using_cache"] - self.assertEqual(count_after_second_use, count_after_first_use) + count_before_first_use = _counts[ + "/jax/compilation_cache/tasks_using_cache"] + jit(lambda x: x + 1)(1) + count_after_first_use = _counts[ + "/jax/compilation_cache/tasks_using_cache"] + self.assertEqual(count_after_first_use, count_before_first_use + 1) + + # Verify that the count is incremented only once per task. + jit(lambda x: x + 3)(3) + count_after_second_use = _counts[ + "/jax/compilation_cache/tasks_using_cache"] + self.assertEqual(count_after_second_use, count_after_first_use) def test_compile_requests_use_cache_metric(self): previous_counts = Counter(_counts) - with tempfile.TemporaryDirectory() as tmpdir: - cc.set_cache_dir(tmpdir) - - jit(lambda x: x + 1)(1) - jit(lambda x: x + 2)(1) - jit(lambda x: x + 1)(1) + jit(lambda x: x + 1)(1) + jit(lambda x: x + 2)(1) + jit(lambda x: x + 1)(1) self.assertEqual( _counts["/jax/compilation_cache/compile_requests_use_cache"] @@ -389,12 +364,9 @@ def test_compile_requests_use_cache_metric(self): def test_cache_misses_metric(self, min_entry_size): previous_counts = Counter(_counts) with ( - tempfile.TemporaryDirectory() as tmpdir, config.persistent_cache_min_compile_time_secs(2), config.persistent_cache_min_entry_size_bytes(min_entry_size), ): - cc.set_cache_dir(tmpdir) - # Mock time to create a long compilation time and make cache misses. with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)): jit(lambda x: x + 1)(1) @@ -414,12 +386,9 @@ def test_cache_misses_metric(self, min_entry_size): def test_cache_hits_metric(self): previous_counts = Counter(_counts) with ( - tempfile.TemporaryDirectory() as tmpdir, config.persistent_cache_min_compile_time_secs(2), config.persistent_cache_min_entry_size_bytes(0), ): - cc.set_cache_dir(tmpdir) - # Mock time to create a long compilation time, cache saved. with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)): jit(lambda x: x + 1)(1) @@ -433,16 +402,13 @@ def test_cache_hits_metric(self): @parameterized.parameters(0, 1) def test_cache_write_with_process_restriction(self, process_id): with ( - tempfile.TemporaryDirectory() as tmpdir, config.persistent_cache_min_compile_time_secs(0), config.persistent_cache_min_entry_size_bytes(0), mock.patch.object(distributed.global_state, "process_id", process_id), ): - cc.set_cache_dir(tmpdir) - jit(lambda x: x + 1)(1) - files_in_directory = len(os.listdir(tmpdir)) + files_in_directory = len(os.listdir(self.tmpdir)) if process_id == 0: self.assertEqual(files_in_directory, 1) elif process_id == 1: @@ -468,17 +434,7 @@ def test_backend_serialization_deserialization(self): jax_persistent_cache_min_compile_time_secs=0, jax_persistent_cache_min_entry_size_bytes=0, ) -class CompilationCacheDisabledTest(jtu.JaxTestCase): - - def setUp(self): - super().setUp() - - cc.reset_cache() - - def tearDown(self): - cc.reset_cache() - super().tearDown() - +class CompilationCacheDisabledTest(CompilationCacheTestCase): # If the cache is disabled, there should be no files in the cache directory. # A call to set_cache_dir() does not affect this. def test_jit(self): @@ -487,14 +443,10 @@ def test_jit(self): # 2. Flag is enabled by JaxTestCase for some test configs # (see test_util.py). # We need the flag disabled for this test, so disable it below. - with ( - tempfile.TemporaryDirectory() as tmpdir, - config.enable_compilation_cache(False), - ): - cc.set_cache_dir(tmpdir) + with config.enable_compilation_cache(False): f = jit(lambda x: x * x) f(1) - files_in_directory = len(os.listdir(tmpdir)) + files_in_directory = len(os.listdir(self.tmpdir)) self.assertEqual(files_in_directory, 0) diff --git a/tests/config_test.py b/tests/config_test.py new file mode 100644 index 000000000000..0f49d988a46c --- /dev/null +++ b/tests/config_test.py @@ -0,0 +1,73 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +from jax._src import test_util as jtu +from jax._src import config + +jax.config.parse_flags_with_absl() + + +jax_test_enum_config = config.enum_state( + name='jax_test_enum_config', + enum_values=['default', 'xxx', 'yyy'], + default='default', + help='Configuration only used for tests.') + + +class ConfigTest(jtu.JaxTestCase): + def test_config_setting_via_update(self): + self.assertEqual(jax_test_enum_config.value, 'default') + + jax.config.update('jax_test_enum_config', 'xxx') + self.assertEqual(jax_test_enum_config.value, 'xxx') + + jax.config.update('jax_test_enum_config', 'yyy') + self.assertEqual(jax_test_enum_config.value, 'yyy') + + jax.config.update('jax_test_enum_config', 'default') + self.assertEqual(jax_test_enum_config.value, 'default') + + def test_config_setting_via_context(self): + self.assertEqual(jax_test_enum_config.value, 'default') + + with jax_test_enum_config('xxx'): + self.assertEqual(jax_test_enum_config.value, 'xxx') + + with jax_test_enum_config('yyy'): + self.assertEqual(jax_test_enum_config.value, 'yyy') + + self.assertEqual(jax_test_enum_config.value, 'xxx') + + self.assertEqual(jax_test_enum_config.value, 'default') + + def test_config_update_validation(self): + self.assertEqual(jax_test_enum_config.value, 'default') + with self.assertRaisesRegex(ValueError, 'new enum value must be in.*'): + jax.config.update('jax_test_enum_config', 'invalid') + # Error should raise before changing the value + self.assertEqual(jax_test_enum_config.value, 'default') + + def test_config_context_validation(self): + self.assertEqual(jax_test_enum_config.value, 'default') + with self.assertRaisesRegex(ValueError, 'new enum value must be in.*'): + with jax_test_enum_config('invalid'): + pass + self.assertEqual(jax_test_enum_config.value, 'default') + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index e5743944401b..1cc342e26784 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -29,17 +29,9 @@ jax.config.parse_flags_with_absl() +@jtu.with_config(jax_debug_nans=True) class DebugNaNsTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self.cfg = jax.config._read("jax_debug_nans") - jax.config.update("jax_debug_nans", True) - - def tearDown(self): - jax.config.update("jax_debug_nans", self.cfg) - super().tearDown() - def testSinc(self): # Regression test for #6936 self.assertEqual(jnp.sinc(0.0), 1.0) @@ -65,8 +57,8 @@ def testJitComputationNaN(self): ans = jax.jit(lambda x: 0. / x)(A) ans.block_until_ready() + @jax.debug_nans(False) def testJitComputationNaNContextManager(self): - jax.config.update("jax_debug_nans", False) A = jnp.array(0.) f = jax.jit(lambda x: 0. / x) ans = f(A) @@ -205,17 +197,9 @@ def f(x, y): jax.jit(f)(inp, inp) +@jtu.with_config(jax_debug_infs=True) class DebugInfsTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - self.cfg = jax.config._read("jax_debug_infs") - jax.config.update("jax_debug_infs", True) - - def tearDown(self): - jax.config.update("jax_debug_infs", self.cfg) - super().tearDown() - def testSingleResultPrimitiveNoInf(self): A = jnp.array([[1., 2.], [2., 3.]]) ans = jnp.tanh(A) diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 91319e2557ad..18693a7bb2c3 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -13,6 +13,7 @@ # limitations under the License. from collections.abc import Sequence +import contextlib import io import re import textwrap @@ -40,16 +41,13 @@ def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringI def _format_multiline(text): return textwrap.dedent(text).lstrip() -prev_xla_flags = None +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(2) + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() foo = 2 diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index b9d12e4a26ba..c00253385792 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +import contextlib import functools import textwrap import unittest @@ -41,16 +42,13 @@ def _format_multiline(text): return textwrap.dedent(text).lstrip() -prev_xla_flags = None +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(2) + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() class DummyDevice: def __init__(self, platform, id): @@ -807,7 +805,8 @@ def foo(x): lines = ["0: 0", "1: 2", "2: 4", "3: 6", "4: 8", "5: 10", "6: 12", "7: 14", "Out: 7.0", ""] jax.effects_barrier() - self._assertLinesEqual(output(), "\n".join(lines)) + + self._assertLinesEqual(output(), "\n".join(lines)) def test_unordered_print_with_xmap(self): def f(x): diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 363bb39fe1df..4712e6aec652 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -558,6 +558,22 @@ def inner_bwd(prev_scale, grads): _, new_scale = jax.jit(jax.grad(outer, (0, 1)))(jnp.float32(3.14), scale) self.assertAllClose(new_scale, jnp.float32(1.0)) + def test_check_dtype_non_hashable(self): + # regression test for issue with checking non-hashable custom dtype + class MyDtype: + __hash__ = None + dtype = np.dtype('float32') + dtypes.check_user_dtype_supported(MyDtype()) + + def test_check_dtype_array(self): + x = jnp.arange(4) + msg = "Passing an array as a dtype argument is deprecated" + with self.assertWarnsRegex(DeprecationWarning, msg): + dtypes.check_user_dtype_supported(x) + with self.assertWarnsRegex(DeprecationWarning, msg): + jax.jit(dtypes.check_user_dtype_supported)(x) + + class EArrayTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) @@ -578,15 +594,6 @@ def convert_to(foo_dtype, target_dtype): def physical_element_aval(foo_dtype): return core.ShapedArray((), dtypes.dtype('float32')) - @staticmethod - def replicate_trailing_dims(ctx, val, aval): - del ctx, aval - return val - - @staticmethod - def logical_sharding(aval, phys_sharding): - return phys_sharding - @staticmethod def global_sharded_result_handler(aval, out_sharding, committed): phys_sharding = out_sharding # unlike KeyTyRules, assume same shape @@ -595,10 +602,6 @@ def global_sharded_result_handler(aval, out_sharding, committed): phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) return lambda bufs: earray.EArray(aval, phys_handler(bufs)) - @staticmethod - def physical_sharding(aval, sharding): - return sharding # unlike KeyTyRules, assume same shape - @dataclasses.dataclass(frozen=True) class FooTy(dtypes.ExtendedDType): name: str = 'foo' diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index acb7dec33aaa..4a87bc1fa23d 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -27,11 +27,10 @@ import jax from jax import lax -from jax.experimental.export import _export +from jax._src.export import _export from jax._src.internal_test_util import export_back_compat_test_util as bctu -from jax._src.internal_test_util.export_back_compat_test_data import cpu_ducc_fft from jax._src.internal_test_util.export_back_compat_test_data import cpu_cholesky_lapack_potrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_eig_lapack_geev from jax._src.internal_test_util.export_back_compat_test_data import cuda_eigh_cusolver_syev @@ -51,6 +50,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import tpu_stablehlo_dynamic_reduce_window from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_rng_bit_generator from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_top_k +from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_approx_top_k from jax.experimental import pjit from jax.experimental.shard_map import shard_map @@ -108,7 +108,6 @@ def test_custom_call_coverage(self): # Add here all the testdatas that should cover the targets guaranteed # stable covering_testdatas = [ - cpu_ducc_fft.data_2023_06_14, cpu_cholesky_lapack_potrf.data_2023_06_19, cpu_eig_lapack_geev.data_2023_06_19, cpu_eigh_lapack_syev.data_2023_03_17, @@ -126,6 +125,7 @@ def test_custom_call_coverage(self): stablehlo_dynamic_rng_bit_generator.data_2023_06_17, stablehlo_dynamic_top_k.data_2023_07_16, stablehlo_dynamic_top_k.data_2023_08_11, # with shape_assertion + stablehlo_dynamic_approx_top_k.data_2024_05_30, ] # Some of the above are nested structures. covering_testdatas = itertools.chain( @@ -146,16 +146,6 @@ def test_custom_call_coverage(self): "stable but are not covered by any tests: " f"{not_covered}")) - def test_ducc_fft(self): - def func(x): - return lax.fft(x, fft_type="fft", fft_lengths=(4,)) - - # TODO(b/311175955): Remove this test and the corresponding custom calls. - # A newer lowering, with dynamic_ducc_fft. - data = self.load_testdata(cpu_ducc_fft.data_2023_06_14) - # FFT no longer lowers to a custom call. - self.run_one_test(func, data, expect_current_custom_calls=[]) - def cholesky_input(self, shape, dtype): a = jtu.rand_default(self.rng())(shape, dtype) return np.matmul(a, np.conj(np.swapaxes(a, -1, -2))) @@ -596,7 +586,7 @@ def func(x): def test_sharding(self): # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2: - self.skipTest("Test runs only on TPU with at least 2 devices") + self.skipTest("Test runs only on TPU with at least 2 devices") # Must use exactly 2 devices for expected outputs from ppermute devices = jax.devices()[:2] @@ -721,6 +711,44 @@ def check_top_k_results(res_run, res_expected, *, rtol, atol): polymorphic_shapes=("_, b",), check_results=check_top_k_results) + def test_dynamic_approx_top_k(self): + # stablehlo.dynamic_approx_top_k is used temporarily for a approx_top_k + # with dynamism + # This is the input that was used to generate the test_data + _ = np.arange(24, dtype=np.float32) + + def func(a): # a: f32[b + 4] + return lax.approx_max_k(a, k=a.shape[0] - 4) + + data = self.load_testdata(stablehlo_dynamic_approx_top_k.data_2024_05_30) + + def check_top_k_results(res_run, res_expected, *, rtol, atol): + a = data.inputs[0] + # The order of the results may be different, but should be the same ones + values_expected, _ = res_expected + values_run, indices_run = res_run + # Check that indices are correct + self.assertAllClose( + values_run, + a[indices_run], + atol=atol, + rtol=rtol, + ) + self.assertAllClose( + np.sort(values_run), np.sort(values_expected), atol=atol, rtol=rtol + ) + + self.run_one_test( + func, + data, + polymorphic_shapes=("b + 4,",), + check_results=check_top_k_results, + expect_current_custom_calls=[ + "stablehlo.dynamic_approx_top_k", + "shape_assertion", + ], + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 44a3070f8b86..035905d3f9e5 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -21,9 +21,9 @@ from __future__ import annotations +from collections.abc import Callable import math import re -from typing import Callable from absl import logging from absl.testing import absltest @@ -31,9 +31,9 @@ import numpy as np import jax +from jax import export from jax import lax from jax._src import test_util as jtu -from jax.experimental import export from jax._src.internal_test_util import test_harnesses @@ -152,7 +152,8 @@ def export_and_compare_to_native( ) logging.info("Exporting harness for %s", lowering_platforms) - exp = export.export(func_jax, lowering_platforms=lowering_platforms)(*args) + exp = export.export(jax.jit(func_jax), + lowering_platforms=lowering_platforms)(*args) for device in devices: if device.platform in skip_run_on_platforms: @@ -164,7 +165,7 @@ def export_and_compare_to_native( logging.info("Running harness natively on %s", device) native_res = func_jax(*device_args) logging.info("Running exported harness on %s", device) - exported_res = export.call_exported(exp)(*device_args) + exported_res = exp.call(*device_args) if tol is not None: logging.info(f"Using non-standard tolerance {tol}") self.assertAllClose(native_res, exported_res, atol=tol, rtol=tol) diff --git a/tests/export_test.py b/tests/export_test.py index bb025d6b7028..8ccf3bb35849 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +from collections.abc import Callable, Sequence import contextlib import dataclasses import functools @@ -25,7 +26,7 @@ import jax from jax import lax from jax import numpy as jnp -from jax.experimental import export +from jax import export from jax.experimental import pjit from jax.experimental.shard_map import shard_map from jax.sharding import NamedSharding @@ -46,15 +47,13 @@ config.parse_flags_with_absl() -prev_xla_flags = None +_exit_stack = contextlib.ExitStack() + def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(2) + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() ### Setup for testing lowering with effects @dataclasses.dataclass(frozen=True) @@ -139,24 +138,19 @@ def _testing_multi_platform_fun_expected(x, ] -def get_exported(fun, vjp_order=0, +def get_exported(fun: Callable, vjp_order=0, **export_kwargs): """Like export.export but with serialization + deserialization.""" def serde_exported(*fun_args, **fun_kwargs): exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs) - serialized = export.serialize(exp, vjp_order=vjp_order) + serialized = exp.serialize(vjp_order=vjp_order) return export.deserialize(serialized) return serde_exported -class JaxExportTest(jtu.JaxTestCase): - def override_serialization_version(self, version_override: int): - version = config.jax_serialization_version.value - if version != version_override: - self.enter_context(config.jax_serialization_version(version_override)) - logging.info( - "Using JAX serialization version %s", - config.jax_serialization_version.value) +# Run tests with the maximum supported version by default +@jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version) +class JaxExportTest(jtu.JaxTestCase): @classmethod def setUpClass(cls): @@ -170,19 +164,15 @@ def setUpClass(cls): cls.platforms.append(backend) super().setUpClass() - def setUp(self): - super().setUp() - # Run tests with the maximum supported version by default - self.override_serialization_version( - export.maximum_supported_serialization_version) - def test_basic_export_only(self): + @jax.jit def my_fun(x): return jnp.sin(x) exp = get_exported(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32)) self.assertEqual("my_fun", exp.fun_name) - self.assertEqual((export.default_lowering_platform(),), - exp.lowering_platforms) + expected_lowering_platform = xb.canonicalize_platform(jax.default_backend()) + self.assertEqual((expected_lowering_platform,), + exp.platforms) self.assertEqual(jax.tree.flatten(((1,), {}))[1], exp.in_tree) self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals) self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals) @@ -193,10 +183,10 @@ def test_pytree_export_only(self): def f(a_b_pair, *, a, b): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) - exp = get_exported(f, lowering_platforms=("cpu",))((a, b), a=a, b=b) + exp = get_exported(jax.jit(f), lowering_platforms=("cpu",))((a, b), a=a, b=b) a_aval = core.ShapedArray(a.shape, a.dtype) b_aval = core.ShapedArray(b.shape, b.dtype) - self.assertEqual(exp.lowering_platforms, ("cpu",)) + self.assertEqual(exp.platforms, ("cpu",)) args = ((a, b),) kwargs = dict(a=a, b=b) self.assertEqual(exp.in_tree, jax.tree.flatten((args, kwargs))[1]) @@ -209,8 +199,7 @@ def test_basic(self): x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x) - f1 = export.call_exported(exp_f) - self.assertAllClose(f(x), f1(x)) + self.assertAllClose(f(x), exp_f.call(x)) def test_jit_static_arg(self): @@ -223,8 +212,7 @@ def f(x, *, c): x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x, c=0.1) - f1 = export.call_exported(exp_f) - self.assertAllClose(f(x, c=0.1), f1(x)) + self.assertAllClose(f(x, c=0.1), exp_f.call(x)) with self.subTest("static_argnums"): @@ -235,16 +223,34 @@ def g(x, c): x = np.arange(4, dtype=np.float32) exp_g = get_exported(g)(x, 0.1) - g1 = export.call_exported(exp_g) - self.assertAllClose(g(x, 0.1), g1(x)) + self.assertAllClose(g(x, 0.1), exp_g.call(x)) + + def test_export_error_no_jit(self): + # Can export a lambda, without jit + with self.assertRaisesRegex(ValueError, + "Function to be exported must be the result of `jit`"): + _ = export.export(lambda x: jnp.sin(x)) + + @jtu.ignore_warning(category=DeprecationWarning, + message="The jax.experimental.export module is deprecated") + def test_export_experimental_back_compat(self): + from jax.experimental import export + # Can export a lambda, without jit + exp = export.export(lambda x: jnp.sin(x))(.1) + self.assertAllClose(exp.call(1.), np.sin(1.)) + + blob = export.serialize(exp, vjp_order=1) + rehydrated = export.deserialize(blob) + + self.assertAllClose(export.call(exp)(1.), np.sin(1.)) + self.assertAllClose(export.call_exported(exp)(1.), np.sin(1.)) def test_call_exported_lambda(self): # When we export a lambda, the exported.fun_name is not a valid MLIR function name - f = lambda x: jnp.sin(x) + f = jax.jit(lambda x: jnp.sin(x)) x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x) - f1 = export.call_exported(exp_f) - self.assertAllClose(f(x), f1(x)) + self.assertAllClose(f(x), exp_f.call(x)) def test_call_name_conflict(self): @jax.jit @@ -259,7 +265,7 @@ def inner(x): @jax.jit def outer(x): # There should be no conflict on _where - x = export.call(exp_inner)(x) + x = exp_inner.call(x) return inner(x) export.export(outer)(x) @@ -270,19 +276,18 @@ def f(x): return jnp.sin(x) @jax.jit def f1(x): - exp_f = get_exported(f)(x) - return export.call_exported(exp_f)(x) + export.call_exported(exp_f)(x) + exp_f = get_exported(jax.jit(f))(x) + return exp_f.call(x) + exp_f.call(x) self.assertAllClose(2. * f(x), f1(x)) def test_unused_args(self): - f = lambda x, y: jnp.sin(x) + f = jax.jit(lambda x, y: jnp.sin(x)) x = np.arange(4, dtype=np.float32) y = np.arange(6, dtype=np.float32) exp_f = get_exported(f)(x, y) - f1 = export.call_exported(exp_f) - self.assertAllClose(f(x, y), f1(x, y)) + self.assertAllClose(f(x, y), exp_f.call(x, y)) def test_pytree(self): a = np.arange(4, dtype=np.float32) @@ -290,43 +295,50 @@ def test_pytree(self): def f(a_b_pair, a, b): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) - exp_f = get_exported(f)((a, b), a=a, b=b) - f1 = export.call_exported(exp_f) + exp_f = get_exported(jax.jit(f))((a, b), a=a, b=b) self.assertAllClose(f((a, b), a=a, b=b), - f1((a, b), a=a, b=b)) + exp_f.call((a, b), a=a, b=b)) def test_error_wrong_intree(self): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c a = b = c = np.arange(4, dtype=np.float32) - exp_f = get_exported(f)((a, b), c=c) + exp_f = get_exported(jax.jit(f))((a, b), c=c) with self.assertRaisesRegex( ValueError, "The invocation args and kwargs must have the same pytree structure"): - export.call_exported(exp_f)(a, b, c=(a, b)) + exp_f.call(a, b, c=(a, b)) def test_error_wrong_avals(self): def f(a, *, b): # a: f32[4] and b: f32[4] return jnp.sin(a) + jnp.cos(b) f32_4 = np.arange(4, dtype=np.float32) - exp_f = get_exported(f)(f32_4, b=f32_4) + exp_f = get_exported(jax.jit(f))(f32_4, b=f32_4) with self.assertRaisesRegex(ValueError, r"Shape mismatch for args\[0\].shape\[0\]"): - export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4) + exp_f.call(np.arange(6, dtype=np.float32), b=f32_4) with self.assertRaisesRegex(ValueError, r"Shape mismatch for kwargs\['b'\].shape\[0\]"): - export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32)) + exp_f.call(f32_4, b=np.arange(6, dtype=np.float32)) with self.assertRaisesRegex(ValueError, r"Rank mismatch for args\[0\]"): - export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4) + exp_f.call(f32_4.reshape((1, 4)), b=f32_4) with self.assertRaisesRegex(ValueError, r"Dtype mismatch for args\[0\]"): - export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4) + exp_f.call(f32_4.astype(np.float16), b=f32_4) + + def test_default_export_platform(self): + test_platform = jtu.device_under_test() + if test_platform == "gpu": + test_platform = "rocm" if jtu.is_device_rocm() else "cuda" + self.assertEqual(export.default_export_platform(), test_platform) + exp = export.export(jnp.sin)(1.) + self.assertEqual(exp.platforms, (export.default_export_platform(),)) @jtu.parameterized_filterable( testcase_name=lambda kw: kw["platform"], @@ -340,14 +352,14 @@ def test_error_wrong_platform(self, platform): raise unittest.SkipTest("Uninteresting scenario") with self.assertRaisesRegex( - ValueError, "The exported function .* was lowered for platform"): - export.call_exported(exp_f)(a) + ValueError, "Function .* was exported for platform"): + exp_f.call(a) # Now try with the platform check disabled exp_f_no_platform_check = get_exported( jnp.sin, lowering_platforms=(platform,), disabled_checks=[export.DisabledSafetyCheck.platform()])(a) - res = export.call_exported(exp_f_no_platform_check)(a) + res = exp_f_no_platform_check.call(a) self.assertAllClose(res, jnp.sin(a)) @jtu.parameterized_filterable( @@ -370,30 +382,52 @@ def test_primitive_lowering(ctx, arg): with self.assertRaisesRegex(ValueError, "Cannot serialize code with custom calls whose targets .*"): get_exported( - lambda a: a + test_primitive.bind(a) + jax.jit(lambda a: a + test_primitive.bind(a)) )(a) # Now try again with the safety check disabled exp = get_exported( - lambda a: a + test_primitive.bind(a), + jax.jit(lambda a: a + test_primitive.bind(a)), disabled_checks=[export.DisabledSafetyCheck.custom_call("disallowed_call_target")] )(a) self.assertIn("disallowed_call_target", exp.mlir_module()) + def test_lowering_parameters_for_export(self): + # Test that we propagate properly the LoweringParameters.for_export + test_primitive = core.Primitive("_test_primitive_for_export") + test_primitive.def_abstract_eval(lambda in_aval: in_aval) + def test_primitive_lowering(ctx, arg): + if ctx.module_context.lowering_parameters.for_export: + raise ValueError("Lowering for export not supported") + return mlir.hlo.AddOp(arg, arg).results + + mlir.register_lowering(test_primitive, test_primitive_lowering) + self.addCleanup(lambda: mlir.register_lowering(test_primitive, None)) + + f = test_primitive.bind + a = np.arange(3, dtype=np.float32) + res = jax.jit(f)(a) # Works with JIT + self.assertAllClose(res, a + a) + jax.jit(f).lower(a) # Works with most AOT + + with self.assertRaisesRegex(ValueError, + "Lowering for export not supported"): + export.export(jax.jit(f))(a) + def test_grad(self): f = lambda x: jnp.sum(jnp.sin(x)) x = np.arange(4, dtype=np.float32) - exp_f = get_exported(f, vjp_order=1)(x) + exp_f = get_exported(jax.jit(f), vjp_order=1)(x) - f1 = export.call_exported(exp_f) + f1 = exp_f.call self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x)) def test_higher_order_grad(self): f = lambda x: x ** 3 x = np.float32(4.) - exp_f = get_exported(f, vjp_order=3)(x) + exp_f = get_exported(jax.jit(f), vjp_order=3)(x) - f1 = export.call_exported(exp_f) + f1 = exp_f.call self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x)) self.assertAllClose(jax.grad(jax.grad(f))(x), @@ -419,8 +453,8 @@ def f(xi, xf): self.assertAllClose(res, (xi_ct, xf_ct)) (f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct)) - exp = get_exported(f, vjp_order=2)(xi, xf) - fr = export.call_exported(exp) + exp = get_exported(jax.jit(f), vjp_order=2)(xi, xf) + fr = exp.call res = fr(xi, xf) self.assertAllClose(res, (f_outi, f_outf)) @@ -447,14 +481,14 @@ def f(a_b_pair, *, a, b): a = np.arange(4, dtype=np.float32) b = np.arange(6, dtype=np.float32) - exp_f = get_exported(f, vjp_order=1)((a, b), a=a, b=b) + exp_f = get_exported(jax.jit(f), vjp_order=1)((a, b), a=a, b=b) out_ct = f((a, b), a=a, b=b) # The output has the right structure as the cotangent def f1_jax(a, b): # For VJP, make a function without kwargs res = f((a, b), a=a, b=b) return res def f1_exp(a, b): # For VJP, make a function without kwargs - res = export.call_exported(exp_f)((a, b), a=a, b=b) + res = exp_f.call((a, b), a=a, b=b) return res jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct) exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct) @@ -464,15 +498,15 @@ def test_roundtrip(self): def f1(x): return jnp.sin(x) a = np.arange(4, dtype=np.float32) - exp_f1 = get_exported(f1)(a) + exp_f1 = get_exported(jax.jit(f1))(a) def f2(x): - res1 = export.call_exported(exp_f1)(x) - res2 = export.call_exported(exp_f1)(res1) + res1 = exp_f1.call(x) + res2 = exp_f1.call(res1) return jnp.cos(res2) - exp_f2 = get_exported(f2)(a) + exp_f2 = get_exported(jax.jit(f2))(a) self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))), - export.call_exported(exp_f2)(a)) + exp_f2.call(a)) def test_poly_export_only(self): a = np.arange(12, dtype=np.float32).reshape((3, 4)) @@ -480,7 +514,7 @@ def f(a, b): # a: f32[2w,h] b: f32[w,h] return jnp.concatenate([a, b], axis=0) scope = export.SymbolicScope() - exp = get_exported(f)( + exp = get_exported(jax.jit(f))( jax.ShapeDtypeStruct(export.symbolic_shape("(2*w, h)", scope=scope), a.dtype), jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)", scope=scope), a.dtype)) self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape)) @@ -489,7 +523,7 @@ def f(a, b): # a: f32[2w,h] b: f32[w,h] # Peek at the module module_str = exp.mlir_module() - self.assertEqual(config.jax_serialization_version.value >= 7, + self.assertEqual(config.jax_export_calling_convention_version.value >= 7, "shape_assertion" in module_str) self.assertIn("jax.uses_shape_polymorphism = true", module_str) wrapped_main_expected_re = ( @@ -519,7 +553,7 @@ def f(a0, a1, *, ak): return jnp.concatenate([a0, a1, ak], axis=0) a_poly_spec = jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)"), a.dtype) - exp = get_exported(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec) + exp = get_exported(jax.jit(f))(a_poly_spec, a_poly_spec, ak=a_poly_spec) self.assertEqual("(w, h)", str(exp.in_avals[0].shape)) self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape)) @@ -536,30 +570,56 @@ def f(x, y): "Invalid mixing of symbolic scopes when exporting f.*" r"Expected current \(from args\[0\]\) scope .*" r"and found for 'w' \(args\[1\]\) scope .*", re.DOTALL)): - get_exported(f)(x_poly_spec, y_poly_spec) + get_exported(jax.jit(f))(x_poly_spec, y_poly_spec) + + def test_poly_export_callable_with_no_name(self): + # This was reported by a user + class MyCallable: + def __call__(self, x): + return jnp.sin(x) + + # This makes it look like a jitted-function + def lower(self, x, _experimental_lowering_parameters=None): + return jax.jit(self.__call__).lower( + x, + _experimental_lowering_parameters=_experimental_lowering_parameters) + + def trace(self, x, _experimental_lowering_parameters=None): + return jax.jit(self.__call__).trace( + x, + _experimental_lowering_parameters=_experimental_lowering_parameters) + + a, = export.symbolic_shape("a,") + # No error + _ = get_exported(jax.jit(MyCallable()))( + jax.ShapeDtypeStruct((a, a), dtype=np.float32) + ) @jtu.parameterized_filterable( kwargs=[ dict(v=v) - for v in range(export.minimum_supported_serialization_version - 1, - export.maximum_supported_serialization_version + 2)]) + for v in range(export.minimum_supported_calling_convention_version - 1, + export.maximum_supported_calling_convention_version + 2)]) def test_poly_basic_versions(self, v: int): - self.override_serialization_version(v) - with contextlib.ExitStack() as e: - if not (export.minimum_supported_serialization_version <= v - <= export.maximum_supported_serialization_version): - e.enter_context(self.assertRaisesRegex( - ValueError, - f"The requested jax_serialization version {v} is outside the range of supported versions")) - - exp = get_exported(jnp.sin)( - jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32)) - x = np.arange(30, dtype=np.float32).reshape((5, 6)) - res = export.call_exported(exp)(x) - self.assertAllClose(res, np.sin(x)) + with config.jax_export_calling_convention_version(v): + logging.info( + "Using JAX calling convention version %s", + config.jax_export_calling_convention_version.value) + with contextlib.ExitStack() as e: + if not (export.minimum_supported_calling_convention_version <= v + <= export.maximum_supported_calling_convention_version): + e.enter_context(self.assertRaisesRegex( + ValueError, + f"The requested export calling convention version {v} is outside the range of supported versions")) + + exp = get_exported(jnp.sin)( + jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32)) + x = np.arange(30, dtype=np.float32).reshape((5, 6)) + res = exp.call(x) + self.assertAllClose(res, np.sin(x)) # A function is exported with f32[poly_spec] and is called with different arg - # shapes. We use export.call_exported and we also run the shape check + # shapes. We use export.call and we also run the shape check # module. @jtu.parameterized_filterable( testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore @@ -596,9 +656,9 @@ def f(x): # x: f32[poly_spec] return jnp.reshape(x, (-1, x.shape[1])) disabled_checks = () - exp_f = get_exported(f, disabled_checks=disabled_checks)( + exp_f = get_exported(jax.jit(f), disabled_checks=disabled_checks)( jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), np.float32)) - self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12") + self.assertEqual(exp_f.uses_global_constants, poly_spec != "3,4,12") arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12] @@ -607,7 +667,7 @@ def f(x): # x: f32[poly_spec] stack.push(self.assertRaisesRegex(Exception, expect_error)) assert core.is_constant_shape(arg.shape) - res = export.call_exported(exp_f)(arg) + res = exp_f.call(arg) if not expect_error: self.assertAllClose(res, f(arg)) @@ -698,35 +758,35 @@ def inner(x): # x: inner_poly_spec arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12] - inner_exp = get_exported(inner)( + inner_exp = get_exported(jax.jit(inner))( jax.ShapeDtypeStruct(export.symbolic_shape(inner_poly_spec), np.float32)) - self.assertEqual(inner_exp.uses_shape_polymorphism, + self.assertEqual(inner_exp.uses_global_constants, (inner_poly_spec != "3,4,12")) def outer(x): # x: outer_poly_spec # Use an addition to test that the shapes are refined properly for the # result of the call_exported. - return export.call_exported(inner_exp)(x) + inner(x) + return inner_exp.call(x) + inner(x) with contextlib.ExitStack() as stack: if expect_error_outer_exp is not None: stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp)) # Call it after exporting again, with polymorphic shapes - outer_exp = get_exported(outer)( + outer_exp = get_exported(jax.jit(outer))( jax.ShapeDtypeStruct(export.symbolic_shape(outer_poly_spec), arg.dtype)) if expect_error_outer_exp is not None: return - self.assertEqual(outer_exp.uses_shape_polymorphism, + self.assertEqual(outer_exp.uses_global_constants, (inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12")) with contextlib.ExitStack() as stack: if expect_error_run is not None: stack.push(self.assertRaisesRegex(Exception, expect_error_run)) - res = export.call_exported(outer_exp)(arg) + res = outer_exp.call(arg) if expect_error_run is not None: return @@ -748,7 +808,7 @@ def outer(x): # x: outer_poly_spec "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer poly_spec="(a + 2*b, a, a + b + c)", @@ -757,7 +817,7 @@ def outer(x): # x: outer_poly_spec "Division had remainder 1 when computing the value of 'b'. " "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency poly_spec="(a + 2*b, a, a + b)", @@ -767,7 +827,7 @@ def outer(x): # x: outer_poly_spec "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c poly_spec="(2 * a + b, a, c * c)", @@ -776,7 +836,7 @@ def outer(x): # x: outer_poly_spec "We can only solve linear uni-variate constraints. " "Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). " "Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. " - "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." )), ]) def test_shape_constraints_errors(self, *, @@ -788,20 +848,21 @@ def f_jax(x): # x: f32[a + 2*b, a, a + b + c] with contextlib.ExitStack() as stack: if expect_error is not None: stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error))) - exp = get_exported(f_jax)( + exp = get_exported(jax.jit(f_jax))( jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype)) - export.call_exported(exp)(x) + exp.call(x) def test_poly_booleans(self): # For booleans we use a special case ConvertOp to cast to and from # dynamic shapes arguments. + @jax.jit def f_jax(x): # x: bool[b] return jnp.logical_not(x) x = np.array([True, False, True, False], dtype=np.bool_) exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype)) - res = export.call_exported(exp)(x) + res = exp.call(x) self.assertAllClose(f_jax(x), res) @jtu.parameterized_filterable( @@ -816,13 +877,14 @@ def test_poly_numeric_dtypes(self, dtype=np.int32): "int4", "uint4"}: self.skipTest(f"TODO: serialization not supported for {str(dtype)}") + @jax.jit def f_jax(x): return x + x x = np.arange(6, dtype=dtype) exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype)) - res = export.call_exported(exp)(x) + res = exp.call(x) self.assertAllClose(f_jax(x), res) def test_poly_expressions(self): @@ -832,6 +894,7 @@ def output_shape(b): return (b + b, b - b, b * b, (b + 13) // b, (b + 13) % b, core.max_dim(b - 5, 0)) + @jax.jit def f(x): # x: f32[b] b = x.shape[0] return jnp.ones(output_shape(b), dtype=x.dtype) @@ -839,15 +902,26 @@ def f(x): # x: f32[b] exp = get_exported(f)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype)) # Call with static shapes - res = export.call_exported(exp)(x) + res = exp.call(x) self.assertAllClose(res, f(x)) # Now re-export with shape polymorphism x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), x.dtype) - exp2 = get_exported(export.call_exported(exp))(x_spec) + exp2 = get_exported(jax.jit(exp.call))(x_spec) a = exp2.in_avals[0].shape[0] self.assertEqual(exp2.out_avals[0].shape, output_shape(a)) + def test_with_donation(self): + f = jax.jit(jnp.sin, donate_argnums=(0,)) + x = np.arange(3, dtype=np.float32) + exp = export.export(f)(x) + + def caller(x): + y = exp.call(x) + return x + y + res = jax.jit(caller)(x) + self.assertAllClose(res, x + np.sin(x)) + def test_poly_call_pmap(self): if len(jax.devices()) < 2: self.skipTest("Need at least 2 devices") @@ -855,9 +929,9 @@ def f(x): # x: f32[a, 4] return x + jnp.arange(x.shape[0], dtype=x.dtype).reshape((x.shape[0], 1)) a, = export.symbolic_shape("a") - exp = export.export(f)( + exp = export.export(jax.jit(f))( jax.ShapeDtypeStruct((a, 4), np.float32)) - f_exp = export.call_exported(exp) + f_exp = exp.call x_jit = np.arange(12, dtype=np.float32).reshape((3, 4)) res_jit = jax.jit(f_exp)(x_jit) self.assertAllClose(res_jit, f(x_jit)) @@ -896,27 +970,52 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] # We apply the out_shardings for f_jax r".*custom_call @Sharding\(%1\).*mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*", re.DOTALL) - hlo = jax.jit(export.call_exported(exp)).lower(a_device).as_text() + hlo = jax.jit(exp.call).lower(a_device).as_text() self.assertRegex(hlo, expected_re) - res_exported = export.call_exported(exp)(a_device) + res_exported = exp.call(a_device) self.assertAllClose(res_native, res_exported) # Test error reporting with self.assertRaisesRegex( - NotImplementedError, - "Exported module .* was lowered for 2 devices and is called in a context with 1 device"): - _ = export.call_exported(exp)(a) + ValueError, + "Function .* was exported for 2 devices and is called in a context with 1 device"): + _ = exp.call(a) with self.assertRaisesRegex( - NotImplementedError, - "Exported module .* was lowered for 2 devices and is called in a context with 1 device"): + ValueError, + "Function .* was exported for 2 devices and is called in a context with 1 device"): mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",)) _ = jax.jit( - export.call_exported(exp), + exp.call, in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),) )(a) + def test_input_shardings_unused_args(self): + nr_devices = 2 + if len(jax.devices()) < nr_devices: + self.skipTest("Need at least 2 devices") + devices = jax.devices()[0:nr_devices] + export_mesh = Mesh(np.array(devices), + axis_names=("x",)) + a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4)) + + f = jax.jit(lambda x, y: jnp.sin(x), + in_shardings=(jax.sharding.NamedSharding(export_mesh, P("x", None),), + None), + out_shardings=(jax.sharding.NamedSharding(export_mesh, P("x", None),))) + exp = get_exported(f)(a, a) + + # We can use other devices and other meshes for running + run_devices = devices[::-1] + run_mesh = Mesh(run_devices, "a") + run_input_shardings = exp.in_shardings_jax(run_mesh) + a_run = jax.device_put(a, run_input_shardings[0]) + b_run = jax.device_put(a, run_input_shardings[1]) + res = exp.call(a_run, b_run) + self.assertEqual(res.addressable_shards[0].device, run_devices[0]) + self.assertEqual(res.addressable_shards[1].device, run_devices[1]) + def test_call_with_different_no_of_devices(self): if jax.local_device_count() < 2: self.skipTest("Need at least 2 devices") @@ -936,7 +1035,7 @@ def f_without_shardings(x): run_mesh = Mesh(run_devices, "i") b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) - res_exported = export.call_exported(exp)(b) + res_exported = exp.call(b) self.assertAllClose(res_native, res_exported) def test_call_with_different_no_of_devices_error_has_in_shardings(self): @@ -960,11 +1059,11 @@ def f_with_sharding(x): b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) with self.assertRaisesRegex( - NotImplementedError, - "Exported module .* was lowered for 1 devices and is called in a " - f"context with {jax.local_device_count()} devices.* module contains " + ValueError, + "Function .* was exported for 1 devices and is called in a " + f"context with {jax.local_device_count()} devices.* function contains " "non-replicated sharding annotations"): - export.call_exported(exp)(b) + exp.call(b) def test_call_with_different_no_of_devices_pmap(self): if len(jax.devices()) < 2: @@ -982,7 +1081,7 @@ def f_jax(x): b = jnp.arange(jax.device_count() * 100, dtype=jnp.float32).reshape( (-1, 1, 100) ) - res_exported = jax.pmap(export.call_exported(exp))(b) + res_exported = jax.pmap(exp.call)(b) self.assertAllClose(res_native, res_exported[0]) def test_call_with_different_no_of_devices_error_has_sharding_constraint(self): @@ -1006,11 +1105,11 @@ def f_with_sharding(x): b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) with self.assertRaisesRegex( - NotImplementedError, - "Exported module .* was lowered for 1 devices and is called in a " - f"context with {jax.local_device_count()} devices.* module contains " + ValueError, + "Function .* was exported for 1 devices and is called in a " + f"context with {jax.local_device_count()} devices.* function contains " "non-replicated sharding annotations"): - export.call_exported(exp)(b) + exp.call(b) @jtu.parameterized_filterable( kwargs=[ @@ -1036,7 +1135,7 @@ def f_jax(b): # b: f32[2, 4] perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(b, "x", perm=perm) - args_specs = export.symbolic_args_specs((a,), polymorphic_shapes=poly) + args_specs = export.symbolic_args_specs((a,), poly) exp = get_exported(f_jax)(*args_specs) # Test JAX native execution @@ -1048,10 +1147,10 @@ def f_jax(b): # b: f32[2, 4] self.assertLen(res_jax.addressable_shards, len(devices)) # Test reloaded execution. - f_r = export.call_exported(exp) + f_r = exp.call with self.assertRaisesRegex( Exception, - "Exported module .* was lowered for 2 devices and is " + "Function .* was exported for 2 devices and is " "called in a context with 1 devices"): _ = f_r(a) # A is all on the default device @@ -1181,14 +1280,14 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] self.assertEqual(exp_vjp2.nr_devices, 2) call_mesh = Mesh(jax.devices()[:2], "e") - g1 = pjit.pjit(export.call_exported(exp_vjp), + g1 = pjit.pjit(exp_vjp.call, in_shardings=(NamedSharding(call_mesh, None), NamedSharding(call_mesh, None)))(x, x.T) _, f_jax_vjp = jax.vjp(f_jax, x) xbar = f_jax_vjp(x.T) self.assertAllClose(xbar, g1) - g2 = pjit.pjit(export.call_exported(exp_vjp2), + g2 = pjit.pjit(exp_vjp2.call, in_shardings=(NamedSharding(call_mesh, None), NamedSharding(call_mesh, None), NamedSharding(call_mesh, None)))(x, x.T, x) @@ -1214,18 +1313,18 @@ def f(x): exp = export.export(pjit.pjit(f, in_shardings=shardings))(input) exp_rev = export.export(pjit.pjit(f, in_shardings=shardings_rev))(input_no_shards) - _ = export.serialize(exp, vjp_order=1) - _ = export.serialize(exp_rev, vjp_order=1) + _ = exp.serialize(vjp_order=1) + _ = exp_rev.serialize(vjp_order=1) - g = jax.grad(export.call(exp_rev))(input_rev) - g_rev = jax.grad(export.call(exp))(input) + g = jax.grad(exp_rev.call)(input_rev) + g_rev = jax.grad(exp.call)(input) self.assertAllClose(g, g_rev) def test_multi_platform(self): x = np.arange(8, dtype=np.float32) - exp = get_exported(_testing_multi_platform_func, - lowering_platforms=("tpu", "cpu", "cuda","rocm"))(x) - self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda", "rocm")) + exp = get_exported(jax.jit(_testing_multi_platform_func), + lowering_platforms=("tpu", "cpu", "cuda", "rocm"))(x) + self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "rocm")) module_str = str(exp.mlir_module()) expected_main_re = ( r"@main\(" @@ -1239,22 +1338,22 @@ def test_multi_platform(self): # Call with argument placed on different plaforms for platform in self.__class__.platforms: x_device = jax.device_put(x, jax.devices(platform)[0]) - res_exp = export.call_exported(exp)(x_device) + res_exp = exp.call(x_device) self.assertAllClose( res_exp, _testing_multi_platform_fun_expected(x, platform=platform)) def test_multi_platform_nested(self): x = np.arange(5, dtype=np.float32) - exp = get_exported(lambda x: _testing_multi_platform_func(jnp.sin(x)), - lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x) - self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda","rocm")) + exp = get_exported(jax.jit(lambda x: _testing_multi_platform_func(jnp.sin(x))), + lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(x) + self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda", "rocm")) # Now serialize the call to the exported using a different sequence of # lowering platforms, but included in the lowering platforms for the # nested exported. - exp2 = get_exported(export.call_exported(exp), - lowering_platforms=("cpu", "cuda","rocm"))(x) + exp2 = get_exported(jax.jit(exp.call), + lowering_platforms=("cpu", "cuda", "rocm"))(x) # Ensure that we do not have multiple lowerings of the exported function exp2_module_str = str(exp2.mlir_module()) @@ -1265,39 +1364,130 @@ def test_multi_platform_nested(self): for platform in self.__class__.platforms: if platform == "tpu": continue x_device = jax.device_put(x, jax.devices(platform)[0]) - res_exp = export.call_exported(exp2)(x_device) + res_exp = exp2.call(x_device) self.assertAllClose( res_exp, _testing_multi_platform_fun_expected(np.sin(x), platform=platform)) def test_multi_platform_nested_inside_single_platform_export(self): x = np.arange(5, dtype=np.float32) - exp = get_exported(_testing_multi_platform_func, - lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x) - self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm")) + exp = get_exported(jax.jit(_testing_multi_platform_func), + lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(x) + self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda", "rocm")) # Now serialize the call for the current platform. - exp2 = get_exported(export.call_exported(exp))(x) + exp2 = get_exported(jax.jit(exp.call))(x) module_str = str(exp2.mlir_module()) self.assertIn("jax.uses_shape_polymorphism = true", module_str) - res2 = export.call_exported(exp2)(x) + res2 = exp2.call(x) self.assertAllClose(res2, _testing_multi_platform_fun_expected(x)) + def test_multi_platform_mlir_lower_fun_with_platform_specific_primitives(self): + # A primitive with multiple lowering rules, which themselves involve + # tracing primitives with per-platform rules, using mlir.lower_fun. + # This situation arises for Pallas lowering. + def times_n_lowering(n: int, ctx: mlir.LoweringRuleContext, + x: mlir.ir.Value) -> Sequence[mlir.ir.Value]: + # Lowering n * x + res = x + for i in range(n - 1): + res = mlir.hlo.AddOp(res, x) + return res.results + + times_2 = core.Primitive("__testing_times_2") # x2 for cpu + times_2.def_abstract_eval(lambda x: x) + # Define lowering rules only for the relevant platforms, ensure there + # is no error about missing lowering rules + mlir.register_lowering(times_2, functools.partial(times_n_lowering, 2), + "cpu") + + times_3 = core.Primitive("__testing_times_3") # x3 for cuda and rocm + times_3.def_abstract_eval(lambda x: x) + + mlir.register_lowering(times_3, functools.partial(times_n_lowering, 3), + "rocm") + mlir.register_lowering(times_3, functools.partial(times_n_lowering, 3), + "cuda") + + times_4 = core.Primitive("__testing_times_4") # x4 for tpu + times_4.def_abstract_eval(lambda x: x) + mlir.register_lowering(times_4, functools.partial(times_n_lowering, 4), + "tpu") + + times_2_or_3 = core.Primitive("__testing_times_2_or_3") # x2 for cpu, x3 for cuda and rocm + times_2_or_3.def_abstract_eval(lambda x: x) + mlir.register_lowering(times_2_or_3, + mlir.lower_fun(times_2.bind, + multiple_results=False), "cpu") + + mlir.register_lowering(times_2_or_3, + mlir.lower_fun(times_3.bind, + multiple_results=False), "rocm") + mlir.register_lowering(times_2_or_3, + mlir.lower_fun(times_3.bind, + multiple_results=False), "cuda") + + times_2_or_3_or_4 = core.Primitive("__testing_times_2_or_3_or_4") # x2 for cpu, x3 for cuda and rocm, x4 for tpu + times_2_or_3_or_4.def_abstract_eval(lambda x: x) + times_2_or_3_or_4_lowering_cpu_gpu = mlir.lower_fun(times_2_or_3.bind, + multiple_results=False) + + for platform in ["cpu", "cuda", "rocm"]: + mlir.register_lowering(times_2_or_3_or_4, + times_2_or_3_or_4_lowering_cpu_gpu, + platform) + mlir.register_lowering(times_2_or_3_or_4, mlir.lower_fun(times_4.bind, + multiple_results=False), + "tpu") + + @jax.jit + def f(x): + return times_2_or_3_or_4.bind(x) + x = np.float32(42.) + exp = export.export(f, lowering_platforms=["cpu", "cuda", "rocm", "tpu"])(x) + expected = x * np.float32(dict(cpu=2, gpu=3, tpu=4)[jtu.device_under_test()]) + self.assertAllClose(exp.call(x), expected) + + def test_multi_platform_unknown_platform(self): + x = np.arange(8, dtype=np.float32) + exp = get_exported(jax.jit(jnp.sin), + lowering_platforms=("tpu", "cpu", "cuda", "other"))(x) + self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "other")) + + + def test_multi_platform_with_donation(self): + f = jax.jit(jnp.sin, donate_argnums=(0,)) + x = np.arange(3, dtype=np.float32) + exp = export.export(f, platforms=["cpu", "tpu"])(x) + if jtu.device_under_test() not in ["cpu", "tpu"]: + self.skipTest("other platform") + + def caller(x): + y = exp.call(x) + return x + y + res = jax.jit(caller)(x) + self.assertAllClose(res, x + np.sin(x)) + + with self.assertRaisesRegex( + NotImplementedError, + "In multi-platform lowering either all or no lowering platforms should support donation"): + export.export(f, platforms=["cpu", "tpu", "other"])(x) + def test_multi_platform_and_poly(self): if jtu.test_device_matches(["gpu"]): # The export is not applicable to GPU raise unittest.SkipTest("Not intended for running on GPU") - exp = get_exported(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,)), - lowering_platforms=("cpu", "tpu"))( + exp = get_exported(jax.jit(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,))), + lowering_platforms=("cpu", "tpu"))( jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32) ) x = np.arange(12, dtype=np.float32).reshape((3, 4)) - res = export.call_exported(exp)(x) + res = exp.call(x) self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,))) # Now serialize the call to the exported - exp2 = get_exported(export.call_exported(exp))(x) - res2 = export.call_exported(exp2)(x) + exp2 = get_exported(jax.jit(exp.call))(x) + res2 = exp2.call(x) self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,))) def test_multi_platform_and_sharding(self): @@ -1322,179 +1512,192 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] continue run_mesh = Mesh(run_devices, ("x",)) a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, None)) - res_exp = export.call_exported(exp)(a_device) + res_exp = exp.call(a_device) self.assertArraysAllClose(res_native, res_exp) @jtu.parameterized_filterable( kwargs=[ dict(v=v) - for v in range(export.minimum_supported_serialization_version, - export.maximum_supported_serialization_version + 1)]) + for v in range(export.minimum_supported_calling_convention_version, + export.maximum_supported_calling_convention_version + 1)]) def test_ordered_effects_basic(self, *, v: int): - self.override_serialization_version(v) - x = np.arange(3, dtype=np.float32) - def f_jax(x): # x: f32[3] - # Test also the calling convention for inner functions - def f_jax_inner(x): + with config.jax_export_calling_convention_version(v): + logging.info( + "Using JAX serialization version %s", + config.jax_export_calling_convention_version.value) + x = np.arange(3, dtype=np.float32) + def f_jax(x): # x: f32[3] + # Test also the calling convention for inner functions + def f_jax_inner(x): + return ( + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") + + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1")) return ( - testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") + - testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1")) - return ( - 10. + - jax.jit(f_jax_inner)(x) + - testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") + - testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") + 10. + + jax.jit(f_jax_inner)(x) + + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") + + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") + ) + + exp = get_exported(jax.jit(f_jax))(x) + self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], + sorted(str(e) for e in exp.ordered_effects)) + self.assertEqual(["ForTestingUnorderedEffect1()"], + [str(e) for e in exp.unordered_effects]) + mlir_module_str = str(exp.mlir_module()) + + # Inner functions use stablehlo.token for all versions + inner_fun_expected_re = ( + r"func.func private @f_jax_inner\(" + r"%arg0: !stablehlo.token .*jax.token = true.*" + r"%arg1: tensor<3xf32>.*->.*" + # Results + r"!stablehlo.token .*jax.token = true.*" + r"tensor<3xf32>" ) + self.assertRegex(mlir_module_str, inner_fun_expected_re) + + # The wrapped_main function takens tokens after version 9, and takes + # i1[0] before version 9. + wrapped_main_expected_re = ( + r"@_wrapped_jax_export_main\(" + r"%arg0: !stablehlo.token .*jax.token = true.*" + r"%arg1: !stablehlo.token .*jax.token = true.*->.*" + # Results + r"!stablehlo.token .*jax.token = true.*" + r"!stablehlo.token .*jax.token = true.*") + self.assertRegex(mlir_module_str, wrapped_main_expected_re) + + # The main function takes tokens and has the same type as the wrapped main + main_expected_re = wrapped_main_expected_re.replace("@_wrapped_jax_export_main", "@main") + self.assertRegex(mlir_module_str, main_expected_re) + + # Now call the exported from a function that uses its own effects + def f_outer(x): + return ( + testing_primitive_with_effect_p.bind( + x, effect_class_name="ForTestingOrderedEffect2") + + testing_primitive_with_effect_p.bind( + x, effect_class_name="ForTestingUnorderedEffect1") + + exp.call(x)) - exp = get_exported(f_jax)(x) - self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], - sorted(str(e) for e in exp.ordered_effects)) - self.assertEqual(["ForTestingUnorderedEffect1()"], - [str(e) for e in exp.unordered_effects]) - mlir_module_str = str(exp.mlir_module()) - - # Inner functions use stablehlo.token for all versions - inner_fun_expected_re = ( - r"func.func private @f_jax_inner\(" - r"%arg0: !stablehlo.token .*jax.token = true.*" - r"%arg1: tensor<3xf32>.*->.*" - # Results - r"!stablehlo.token .*jax.token = true.*" - r"tensor<3xf32>" - ) - self.assertRegex(mlir_module_str, inner_fun_expected_re) + lowered_outer = jax.jit(f_outer).lower(x) + self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], + sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"])) + self.assertEqual(["ForTestingUnorderedEffect1()"], + sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]])) - # The wrapped_main function takens tokens after version 9, and takes - # i1[0] before version 9. - wrapped_main_expected_re = ( - r"@_wrapped_jax_export_main\(" - r"%arg0: !stablehlo.token .*jax.token = true.*" - r"%arg1: !stablehlo.token .*jax.token = true.*->.*" - # Results - r"!stablehlo.token .*jax.token = true.*" - r"!stablehlo.token .*jax.token = true.*") - self.assertRegex(mlir_module_str, wrapped_main_expected_re) - - # The main function takes tokens and has the same type as the wrapped main - main_expected_re = wrapped_main_expected_re.replace("@_wrapped_jax_export_main", "@main") - self.assertRegex(mlir_module_str, main_expected_re) - - # Now call the exported from a function that uses its own effects - def f_outer(x): - return ( - testing_primitive_with_effect_p.bind( - x, effect_class_name="ForTestingOrderedEffect2") + - testing_primitive_with_effect_p.bind( - x, effect_class_name="ForTestingUnorderedEffect1") + - export.call_exported(exp)(x)) - - lowered_outer = jax.jit(f_outer).lower(x) - self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], - sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"])) - self.assertEqual(["ForTestingUnorderedEffect1()"], - sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]])) - - mlir_outer_module_str = str(lowered_outer.compiler_ir()) - self.assertRegex(mlir_outer_module_str, main_expected_re) - - res = jax.jit(f_outer)(x) - self.assertAllClose(2. * 2. * x + 10. + 4. * 2. * x, res) + mlir_outer_module_str = str(lowered_outer.compiler_ir()) + self.assertRegex(mlir_outer_module_str, main_expected_re) + + res = jax.jit(f_outer)(x) + self.assertAllClose(2. * 2. * x + 10. + 4. * 2. * x, res) @jtu.parameterized_filterable( - kwargs=[ - dict(v=v) - for v in range(export.minimum_supported_serialization_version, - export.maximum_supported_serialization_version + 1)]) + kwargs=[ + dict(v=v) + for v in range(export.minimum_supported_calling_convention_version, + export.maximum_supported_calling_convention_version + 1)]) def test_ordered_effects_poly(self, *, v: int): - self.override_serialization_version(v) - x = np.arange(12, dtype=np.float32).reshape((3, 4)) - def f_jax(x): # x: f32[b1, b2] - return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") - exp = get_exported(f_jax)(jax.ShapeDtypeStruct( - export.symbolic_shape("b2, b1"), x.dtype)) - mlir_module_str = str(exp.mlir_module()) - wrapped_main_expected_re = ( - r"@_wrapped_jax_export_main\(" - r"%arg0: tensor {jax.global_constant = \"b1\".* " - r"%arg1: tensor {jax.global_constant = \"b2\".* " - r"%arg2: !stablehlo.token {jax.token = true.* " - r"%arg3: tensor<\?x\?xf32>.*\) -> \(" - # Results - r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") - self.assertRegex(mlir_module_str, wrapped_main_expected_re) - - main_expected_re = ( - r"@main\(" - r"%arg0: !stablehlo.token {jax.token = true.*, " - r"%arg1: tensor<\?x\?xf32>.*\) -> \(" - # Results - r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") - self.assertRegex(mlir_module_str, main_expected_re) - - res = export.call_exported(exp)(x) - self.assertAllClose(10. + 2. * x, res) + with config.jax_export_calling_convention_version(v): + logging.info( + "Using JAX serialization version %s", + config.jax_export_calling_convention_version.value) + x = np.arange(12, dtype=np.float32).reshape((3, 4)) + def f_jax(x): # x: f32[b1, b2] + return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") + exp = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct( + export.symbolic_shape("b2, b1"), x.dtype)) + mlir_module_str = str(exp.mlir_module()) + wrapped_main_expected_re = ( + r"@_wrapped_jax_export_main\(" + r"%arg0: tensor {jax.global_constant = \"b1\".* " + r"%arg1: tensor {jax.global_constant = \"b2\".* " + r"%arg2: !stablehlo.token {jax.token = true.* " + r"%arg3: tensor<\?x\?xf32>.*\) -> \(" + # Results + r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") + self.assertRegex(mlir_module_str, wrapped_main_expected_re) + + main_expected_re = ( + r"@main\(" + r"%arg0: !stablehlo.token {jax.token = true.*, " + r"%arg1: tensor<\?x\?xf32>.*\) -> \(" + # Results + r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") + self.assertRegex(mlir_module_str, main_expected_re) + + res = exp.call(x) + self.assertAllClose(10. + 2. * x, res) @jtu.parameterized_filterable( kwargs=[ dict(v=v) - for v in range(export.minimum_supported_serialization_version, - export.maximum_supported_serialization_version + 1)]) + for v in range(export.minimum_supported_calling_convention_version, + export.maximum_supported_calling_convention_version + 1)]) def test_ordered_effects_multi_platform_and_poly(self, *, v: int): - self.override_serialization_version(v) - if jtu.device_under_test() == "gpu": - # The export is not applicable to GPU - raise unittest.SkipTest("Not intended for running on GPU") - x = np.ones((3, 4), dtype=np.float32) - def f_jax(x): # x: f32[b1, b2] - return 10. + _testing_multi_platform_func(x, - effect_class_name="ForTestingOrderedEffect1") - exp = get_exported( - f_jax, - lowering_platforms=("cpu", "tpu") - )(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype)) - mlir_module_str = str(exp.mlir_module()) - wrapped_main_expected_re = ( - r"@_wrapped_jax_export_main\(" - r"%arg0: tensor {jax.global_constant = \"_platform_index\".*, " - r"%arg1: tensor {jax.global_constant = \"b1\".*, " - r"%arg2: tensor {jax.global_constant = \"b2\".*, " - r"%arg3: !stablehlo.token {jax.token = true.*, " - r"%arg4: tensor<\?x\?xf32>.*\) -> \(" - # Results - r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") - self.assertRegex(mlir_module_str, wrapped_main_expected_re) - - main_expected_re = ( - r"@main\(" - r"%arg0: tensor {jax.global_constant = \"_platform_index\".*, " - r"%arg1: !stablehlo.token {jax.token = true.*, " - r"%arg2: tensor<\?x\?xf32>.*\) -> \(" - # Results - r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") - self.assertRegex(mlir_module_str, main_expected_re) - res = export.call_exported(exp)(x) - self.assertAllClose(10. + _testing_multi_platform_fun_expected(x), - res) + with config.jax_export_calling_convention_version(v): + logging.info( + "Using JAX serialization version %s", + config.jax_export_calling_convention_version.value) + if jtu.device_under_test() == "gpu": + # The export is not applicable to GPU + raise unittest.SkipTest("Not intended for running on GPU") + x = np.ones((3, 4), dtype=np.float32) + def f_jax(x): # x: f32[b1, b2] + return 10. + _testing_multi_platform_func(x, + effect_class_name="ForTestingOrderedEffect1") + exp = get_exported( + jax.jit(f_jax), + lowering_platforms=("cpu", "tpu") + )(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype)) + mlir_module_str = str(exp.mlir_module()) + wrapped_main_expected_re = ( + r"@_wrapped_jax_export_main\(" + r"%arg0: tensor {jax.global_constant = \"_platform_index\".*, " + r"%arg1: tensor {jax.global_constant = \"b1\".*, " + r"%arg2: tensor {jax.global_constant = \"b2\".*, " + r"%arg3: !stablehlo.token {jax.token = true.*, " + r"%arg4: tensor<\?x\?xf32>.*\) -> \(" + # Results + r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") + self.assertRegex(mlir_module_str, wrapped_main_expected_re) + + main_expected_re = ( + r"@main\(" + r"%arg0: tensor {jax.global_constant = \"_platform_index\".*, " + r"%arg1: !stablehlo.token {jax.token = true.*, " + r"%arg2: tensor<\?x\?xf32>.*\) -> \(" + # Results + r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") + self.assertRegex(mlir_module_str, main_expected_re) + res = exp.call(x) + self.assertAllClose(10. + _testing_multi_platform_fun_expected(x), + res) @jtu.parameterized_filterable( kwargs=[ dict(v=v) - for v in range(export.minimum_supported_serialization_version, - export.maximum_supported_serialization_version + 1)]) + for v in range(export.minimum_supported_calling_convention_version, + export.maximum_supported_calling_convention_version + 1)]) def test_ordered_effects_with_donation(self, *, v: int): - self.override_serialization_version(v) - x = np.arange(3, dtype=np.float32) + with config.jax_export_calling_convention_version(v): + logging.info( + "Using JAX serialization version %s", + config.jax_export_calling_convention_version.value) - def f_jax(x): - return testing_primitive_with_effect_p.bind( - x, effect_class_name="ForTestingOrderedEffect1" - ) + x = np.arange(3, dtype=np.float32) - f_jax = jax.jit(f_jax, donate_argnums=(0,)) - exp = export.export(f_jax)(x) - mlir_module_str = str(exp.mlir_module()) - self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1") - self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1") + def f_jax(x): + return testing_primitive_with_effect_p.bind( + x, effect_class_name="ForTestingOrderedEffect1" + ) + + f_jax = jax.jit(f_jax, donate_argnums=(0,)) + exp = export.export(f_jax)(x) + mlir_module_str = str(exp.mlir_module()) + self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1") + self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1") @jtu.parameterized_filterable( kwargs=[ @@ -1513,7 +1716,7 @@ def f_jax(x): x, effect_class_name="ForTestingOrderedEffect" + name) with self.assertRaisesRegex(Exception, expect_error): - _ = get_exported(f_jax)(jax.ShapeDtypeStruct((3, 4), x.dtype)) + _ = get_exported(jax.jit(f_jax))(jax.ShapeDtypeStruct((3, 4), x.dtype)) @jtu.parameterized_filterable( kwargs=[ @@ -1530,13 +1733,14 @@ def f_jax(x, y, gs): rhs = np.arange(num_groups * k * n, dtype=dtype).reshape((num_groups, k, n)) res_native = f_jax(lhs, rhs, group_sizes) - exp_f = get_exported(f_jax)( + exp_f = get_exported(jax.jit(f_jax))( jax.ShapeDtypeStruct(lhs.shape, dtype=lhs.dtype), jax.ShapeDtypeStruct(rhs.shape, dtype=rhs.dtype), jax.ShapeDtypeStruct(group_sizes.shape, dtype=group_sizes.dtype), ) - res_exported = export.call_exported(exp_f)(lhs, rhs, group_sizes) + res_exported = exp_f.call(lhs, rhs, group_sizes) self.assertAllClose(res_native, res_exported) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/extend_test.py b/tests/extend_test.py index a926861ebece..790b4fa2d774 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -12,17 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from absl.testing import absltest +import os + +import numpy as np +from absl.testing import absltest, parameterized import jax import jax.extend as jex import jax.numpy as jnp -from jax._src import api from jax._src import abstract_arrays +from jax._src import api +from jax._src import core from jax._src import linear_util from jax._src import prng from jax._src import test_util as jtu +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.extend import ffi jax.config.parse_flags_with_absl() @@ -82,5 +89,33 @@ def no_rule(*args, **kwargs): self.assertEqual(impl, jax.random.key_impl(k)) +class FfiTest(jtu.JaxTestCase): + + def testHeadersExist(self): + base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api") + for header in ["c_api.h", "api.h", "ffi.h"]: + self.assertTrue(os.path.exists(os.path.join(base_dir, header))) + + @parameterized.parameters( + [True, int(1), float(5.0), + np.int32(-5), np.float32(0.5)]) + def testIrAttribute(sel, value): + with mlir.make_ir_context(), ir.Location.unknown(): + const = mlir.ir_constant(value) + attr = ffi._ir_attribute(value) + assert const.type.element_type == attr.type + + @parameterized.parameters([True, 1, 5.0, "param", np.float32(0.5)]) + def testParams(self, param): + prim = core.Primitive("test_ffi") + prim.def_abstract_eval(lambda *args, **kwargs: args[0]) + mlir.register_lowering(prim, jex.ffi.ffi_lowering("test_ffi")) + + # TODO(dfm): Currently testing that lowering works with different types of + # parameters, but we should probably actually check the emitted HLO. + func = jax.jit(lambda *args: prim.bind(*args, param=param)) + func.lower(jnp.linspace(0, 5, 10)) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/ffi_test.py b/tests/ffi_test.py deleted file mode 100644 index e0e6521317a2..000000000000 --- a/tests/ffi_test.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import os -import unittest - -from jax import ffi -from jax._src import test_util as jtu -from jax._src.lib import xla_extension_version - - -class IncludeDirTest(jtu.JaxTestCase): - - @unittest.skipIf(xla_extension_version < 265, "Requires jaxlib 0.4.29") - def testHeadersExist(self): - base_dir = os.path.join(ffi.include_dir(), "xla", "ffi", "api") - for header in ["c_api.h", "api.h", "ffi.h"]: - print(os.path.join(base_dir, header)) - self.assertTrue(os.path.exists(os.path.join(base_dir, header))) diff --git a/tests/for_loop_test.py b/tests/for_loop_test.py index 725a2340a2ed..b79c233e6f2e 100644 --- a/tests/for_loop_test.py +++ b/tests/for_loop_test.py @@ -296,7 +296,7 @@ def step(x, _): A = jnp.zeros((3, 3)) # The second DUS was unnecessarily replicating A across time. # We check XLA because _scan_impl is "underneath" the jaxpr language. - s = str(jax.xla_computation(jax.grad(loss))(A).as_hlo_text()) + s = jax.jit(jax.grad(loss)).lower(A).as_text('hlo') assert s.count("dynamic-update-slice(") < 2 @_for_loop_impls diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 70873d19cdab..9cf7ba80dff5 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -14,7 +14,6 @@ from functools import partial from absl.testing import absltest -from typing import Optional import os os.environ["XLA_FLAGS"] = \ @@ -43,8 +42,8 @@ def sdpa_train(query: Array, key: Array, value: Array, grad: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, is_bnth: bool = False, @@ -63,8 +62,7 @@ def sdpa_train(query: Array, out, sdpa_vjp = jax.vjp( partial(dot_product_attention, scale=scale, mask_type=mask_type, dropout_rate=dropout_rate, - qkv_layout="BNTH" if is_bnth else "BTNH", - is_training=True), + qkv_layout="BNTH" if is_bnth else "BTNH"), query, key, value, bias, mask, q_seqlen, kv_seqlen) query_grad, key_grad, value_grad, bias_grad, _, _, _ = sdpa_vjp(grad) if bias is not None and len(bias.shape) == 3: @@ -75,8 +73,8 @@ def sdpa_train(query: Array, def sdpa_ref(query: Array, key: Array, value: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, dropout_rate: float = 0.1) -> Array: @@ -151,8 +149,8 @@ def sdpa_train_ref(query: Array, key: Array, value: Array, grad: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, + bias: Array | None = None, + mask: Array | None = None, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, dropout_rate: float = 0.1) -> Array: diff --git a/tests/gpu_memory_flags_test.py b/tests/gpu_memory_flags_test.py index d788d881fe2d..308fff257348 100644 --- a/tests/gpu_memory_flags_test.py +++ b/tests/gpu_memory_flags_test.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import sys import unittest from absl.testing import absltest @@ -27,10 +26,7 @@ class GpuMemoryAllocationTest(absltest.TestCase): # This test must be run in its own subprocess. - @unittest.skipIf( - "pytest" in sys.modules, - "Test must run in an isolated process", - ) + @jtu.skip_under_pytest("Test must run in an isolated process") @unittest.skipIf( "XLA_PYTHON_CLIENT_ALLOCATOR" in os.environ, "Test does not work if the python client allocator has been overriden", diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index f3960e0105be..1ad59103c3e7 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -14,7 +14,8 @@ from __future__ import annotations -from collections.abc import Sequence +import contextlib +from collections.abc import Callable, Sequence from functools import partial import itertools import logging @@ -22,7 +23,6 @@ import re import threading import time -from typing import Callable import unittest from unittest import skip, SkipTest @@ -196,18 +196,13 @@ def helper_log_ir(name, logging.info(f"Optimized HLO[{name}]: {jax_optimized_hlo}") -prev_xla_flags = None - +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(2) - + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase, @@ -248,14 +243,16 @@ def test_deprecated_imports(self): class HostCallbackTapTest(jtu.JaxTestCase): def setUp(self): - super().setUp() - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="The host_callback APIs are deprecated")) + # skipping here skips teardown, so do this before super().setUp(). if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") if xla_bridge.using_pjrt_c_api(): raise SkipTest("host_callback not implemented in PJRT C API") - + super().setUp() + self.enter_context(jtu.ignore_warning( + category=DeprecationWarning, message="The host_callback APIs are deprecated")) + self.enter_context(jtu.ignore_warning( + category=DeprecationWarning, message="backend and device argument")) testing_stream.reset() testing_stream._test_method_name = self._testMethodName self.old_flags = os.getenv("XLA_FLAGS", "") @@ -443,8 +440,12 @@ def func(x): logging.info("%s: %s", self._testMethodName, jax.make_jaxpr(func)(1)) - logging.info("%s: %s", self._testMethodName, - jax.xla_computation(func, backend=jtu.device_under_test())(1).as_hlo_text()) + logging.info( + "%s: %s", + self._testMethodName, + jax.jit(func) + .trace(1) + .lower(lowering_platforms=(jtu.device_under_test(),)).as_text("hlo")) self.assertEqual(2, jax.jit(func)(1)) hcb.barrier_wait() @@ -2044,13 +2045,16 @@ class HostCallbackCallTest(jtu.JaxTestCase): """Tests for hcb.call""" def setUp(self): - super().setUp() - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="The host_callback APIs are deprecated")) + # skipping here skips teardown, so do this before super().setUp(). if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") if xla_bridge.using_pjrt_c_api(): raise SkipTest("host_callback not implemented in PJRT C API") + super().setUp() + self.enter_context(jtu.ignore_warning( + category=DeprecationWarning, message="The host_callback APIs are deprecated")) + self.enter_context(jtu.ignore_warning( + category=DeprecationWarning, message="backend and device argument")) testing_stream.reset() testing_stream._test_method_name = self._testMethodName diff --git a/tests/host_callback_to_tf_test.py b/tests/host_callback_to_tf_test.py index 05a25bd33b50..fe80c90ace68 100644 --- a/tests/host_callback_to_tf_test.py +++ b/tests/host_callback_to_tf_test.py @@ -18,7 +18,7 @@ This is separate from host_callback_test because it needs a TF dependency. """ -from typing import Callable +from collections.abc import Callable import unittest from absl.testing import absltest diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index d6b62bf6defa..88fd7a334048 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools +import contextlib import threading import unittest @@ -133,18 +133,13 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out mlir.register_lowering(callback_p, callback_effect_lowering) -prev_xla_flags = None - +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(2) - + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() class JaxprEffectsTest(jtu.JaxTestCase): @@ -566,23 +561,22 @@ def log_value(x): log.append(x) return () - @functools.partial(jax.jit, device=jax.devices()[0]) + @jax.jit def f(x): # Expensive computation x = x.dot(x) x = jnp.log(x.sum()) return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) - @functools.partial(jax.jit, device=jax.devices()[1]) + @jax.jit def g(x): return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) - f(jnp.ones((500, 500))) - g(3.) - f(jnp.ones((500, 500))) - g(3.) - f(jnp.ones((500, 500))) - g(3.) + x = jax.device_put(jnp.ones((500, 500)), jax.devices()[0]) + y = jax.device_put(3., jax.devices()[1]) + for _ in range(3): + f(x) + g(y) jax.effects_barrier() f_, g_ = float(jnp.log(1.25e8)), 3. expected_log = [f_, g_, f_, g_, f_, g_] diff --git a/tests/jet_test.py b/tests/jet_test.py index 79132174e13d..b1e2ef3f8380 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -243,6 +243,8 @@ def test_floor(self): self.unary_check(jnp.floor) @jtu.skip_on_devices("tpu") def test_ceil(self): self.unary_check(jnp.ceil) @jtu.skip_on_devices("tpu") + def test_trunc(self): self.unary_check(jnp.trunc) + @jtu.skip_on_devices("tpu") def test_round(self): self.unary_check(lax.round) @jtu.skip_on_devices("tpu") def test_sign(self): self.unary_check(lax.sign) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 52a1923c92eb..040603555ff5 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -40,6 +40,7 @@ import jax.scipy as jsp from jax._src.lax import control_flow as lax_control_flow from jax._src.lax.control_flow import for_loop +from jax._src.interpreters import mlir from jax._src.maps import xmap jax.config.parse_flags_with_absl() @@ -130,6 +131,15 @@ def scan_reference(f, init, xs): ignore_jit_of_pmap_warning = partial( jtu.ignore_warning, message=".*jit-of-pmap.*") +# A JAX primitive whose lowering is a custom call to a non-existent function. +prim_non_existent_custom_call = core.Primitive("__testing_non_existent_custom_call") +prim_non_existent_custom_call.def_abstract_eval(lambda x_aval: x_aval) +mlir.register_lowering( + prim_non_existent_custom_call, + lambda ctx, x: mlir.hlo.CustomCallOp( + [x.type], [x], + call_target_name=mlir.ir.StringAttr.get("__testing_non_existent_custom_call")).results) + class LaxControlFlowTest(jtu.JaxTestCase): @@ -2322,7 +2332,7 @@ def step(x, i): A = jnp.zeros((3, 3)) # The second DUS was unnecessarily replicating A across time. # We check XLA because _scan_impl is "underneath" the jaxpr language. - s = str(jax.xla_computation(jax.grad(loss))(A).as_hlo_text()) + s = jax.jit(jax.grad(loss)).lower(A).as_text('hlo') assert s.count("dynamic-update-slice(") < 2 def testScanLengthArg(self): @@ -2417,8 +2427,8 @@ def f(c, a): # but HLO should grow due to unrolling self.assertLess( - len(str(jax.xla_computation(scan)(c, xs).as_hlo_text())), - len(str(jax.xla_computation(scan_unrolled)(c, xs).as_hlo_text()))) + len(str(jax.jit(scan).lower(c, xs).as_text('hlo'))), + len(str(jax.jit(scan_unrolled).lower(c, xs).as_text('hlo')))) def test_scan_xs_none(self): def f(h, _): @@ -2518,7 +2528,7 @@ def f(c, a): scan_fun = lambda c, xs: lax.scan(f, c, xs) def new_jaxpr(): - jaxpr = jax.make_jaxpr(scan_fun)(c, xs).jaxpr + jaxpr = jax.make_jaxpr(partial(scan_fun))(c, xs).jaxpr scan = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'scan') return jaxpr, scan @@ -2844,6 +2854,44 @@ def f(x): self.assertNotIn(" sine", hlo) self.assertIn(" cosine", hlo) + def test_platform_dependent_with_non_existent_custom_call(self): + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Only for CPU") + + def f(x): + # One use with the bad custom call on a different platform branch + x1 = lax.platform_dependent(x, + cpu=jnp.sin, + other=prim_non_existent_custom_call.bind) + # and with the bad custom call in the default branch + x2 = lax.platform_dependent(x, + cpu=jnp.sin, + default=prim_non_existent_custom_call.bind) + # and one use where the current platform is the default + x3 = lax.platform_dependent(x, + other=prim_non_existent_custom_call.bind, + default=jnp.sin) + return x1 + x2 + x3 + + x = np.arange(3, dtype=np.float32) + hlo = str(jax.jit(f).lower(x).compiler_ir()) + occurrences = re.findall(prim_non_existent_custom_call.name, hlo) + self.assertLen(occurrences, 3) + + res_eager = f(x) + self.assertAllClose(res_eager, 3. * np.sin(x)) + res_jit = jax.jit(f)(x) + self.assertAllClose(res_jit, 3 * np.sin(x)) + + res_vmap = jax.vmap(f)(x) + self.assertAllClose(res_vmap, 3. * np.sin(x)) + + _, res_jvp = jax.jvp(f, (x,), (np.full(x.shape, .1, dtype=x.dtype),)) + self.assertAllClose(res_jvp, .3 * np.cos(x)) + + res_grad = jax.grad(f)(1.) + self.assertAllClose(res_grad, 3. * np.cos(1.)) + def test_platform_dependent_multiple_identical_branches(self): x = np.arange(3, dtype=np.float32) def f(x): @@ -2928,6 +2976,31 @@ def body(carry, x): hlo_text = fn.lower(init).as_text('hlo') self.assertNotIn('4,1,2,2', hlo_text) + def test_cond_vmap_forwarding_doesnt_promote(self): + def f(x, y): + x, y = jax.lax.cond( + x < 3, + lambda x, y: (x * 2, y), + lambda x, y: (x * 3, y), + x, y + ) + return x, y + + x = jnp.arange(3) + y = jnp.array(3.) + + x2, y2 = jax.vmap(f, in_axes=(0, None), out_axes=(0, None))(x, y) # don't crash + + assert x is not x2 + assert y is y2 + + def test_cond_casting(self): + x = 1.0 + identity = lambda x: x + + y = lax.cond(True, identity, identity, x) + self.assertEqual(y, x) + self.assertIsInstance(y, jax.Array) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 5069187d2334..dab26d86c0a2 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -1679,9 +1679,6 @@ def testDeleteMaskArray(self, shape, dtype, axis): rng = jtu.rand_default(self.rng()) mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool) - if numpy_version == (1, 23, 0) and mask.shape == (1,): - # https://github.com/numpy/numpy/issues/21840 - self.skipTest("test fails for numpy v1.23.0") args_maker = lambda: [rng(shape, dtype)] np_fun = lambda arg: np.delete(arg, mask, axis=axis) jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis) @@ -1943,9 +1940,6 @@ def np_fun(x, fill_value=fill_value): @unittest.skip("jax-metal fail.") @jtu.sample_product(dtype=inexact_dtypes) def testUniqueNans(self, dtype): - if numpy_version == (1, 23, 0) and dtype == np.float16: - # https://github.com/numpy/numpy/issues/21838 - self.skipTest("Known failure on numpy 1.23.0") def args_maker(): x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] if np.issubdtype(dtype, np.complexfloating): @@ -1966,8 +1960,6 @@ def np_fun(x): @unittest.skip("jax-metal fail.") @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) def testUniqueEqualNan(self, dtype, equal_nan): - if numpy_version < (1, 24, 0): - self.skipTest("np.unique equal_nan requires NumPy 1.24 or newer.") shape = (20,) rng = jtu.rand_some_nan(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -2669,10 +2661,7 @@ def testStack(self, shape, axis, dtypes, array_input, out_dtype): else: args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - if numpy_version < (1, 24): - np_fun = jtu.promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype)) - else: - np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) + np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype) with jtu.strict_promotion_if_dtypes_match(dtypes): @@ -2699,7 +2688,7 @@ def testHVDStack(self, shape, op, dtypes, array_input, out_dtype): else: args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - if numpy_version < (1, 24) or op == "dstack": + if op == "dstack": np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype)) else: np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype, diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index 79866d8ee22f..4c31684e145f 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -423,6 +423,17 @@ def _shapes_are_equal_length(shapes): return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) +def _get_testcase_name(index, params): + dtypes = "_".join(str(dt.__name__) for dt in params['dtypes']) + name = params['op_name'] if "op_name" in params else params["name"] + return f"{index}_{name}_{dtypes}" + + +def _create_named_parameters(iter_params): + for i, params in enumerate(iter_params): + yield dict(params, **{'testcase_name': _get_testcase_name(i, params)}) + + class JaxNumpyOperatorTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy operators.""" @@ -436,7 +447,7 @@ def f(): for a in out] return f - @parameterized.parameters(itertools.chain.from_iterable( + @parameterized.named_parameters(_create_named_parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(op_name=rec.name, rng_factory=rec.rng_factory, check_dtypes=rec.check_dtypes, tolerance=rec.tolerance, @@ -449,7 +460,7 @@ def f(): *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))], ) for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, - JAX_COMPOUND_OP_RECORDS))) + JAX_COMPOUND_OP_RECORDS)))) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testOp(self, op_name, rng_factory, shapes, dtypes, check_dtypes, tolerance, inexact, kwargs, alias): @@ -477,7 +488,7 @@ def testOp(self, op_name, rng_factory, shapes, dtypes, check_dtypes, self._CompileAndCheck(jnp_op, args_maker, check_dtypes=check_dtypes, atol=tol, rtol=tol) - @parameterized.parameters(itertools.chain.from_iterable( + @parameterized.named_parameters(_create_named_parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(name=rec.name, rng_factory=rec.rng_factory, tol=rec.tolerance)], [dict(shapes=shapes, dtypes=dtypes) @@ -487,7 +498,7 @@ def testOp(self, op_name, rng_factory, shapes, dtypes, check_dtypes, for dtypes in itertools.product( *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))], ) - for rec in JAX_OPERATOR_OVERLOADS)) + for rec in JAX_OPERATOR_OVERLOADS))) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): rng = rng_factory(self.rng()) @@ -498,7 +509,7 @@ def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): with jtu.strict_promotion_if_dtypes_match(dtypes): self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol) - @parameterized.parameters(itertools.chain.from_iterable( + @parameterized.named_parameters(_create_named_parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(name=rec.name, rng_factory=rec.rng_factory, op_tolerance=rec.tolerance)], @@ -509,7 +520,7 @@ def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): for dtypes in itertools.product( *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))], ) - for rec in JAX_RIGHT_OPERATOR_OVERLOADS)) + for rec in JAX_RIGHT_OPERATOR_OVERLOADS))) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes, op_tolerance): @@ -579,7 +590,7 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype): with self.assertRaises(TypeError): op(arg, other) - @parameterized.parameters(itertools.chain.from_iterable( + @parameterized.named_parameters(_create_named_parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(name=rec.name, rng_factory=rec.rng_factory, alias=rec.alias)], shapes=filter( @@ -589,7 +600,7 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype): _dtypes_are_compatible_for_bitwise_ops, itertools.combinations_with_replacement(rec.dtypes, rec.nargs)), ) - for rec in JAX_BITWISE_OP_RECORDS)) + for rec in JAX_BITWISE_OP_RECORDS))) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testBitwiseOp(self, name, rng_factory, shapes, dtypes, alias): np_op = getattr(np, name) if hasattr(np, name) else getattr(np, alias) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 37e8410ddbfb..588368cd8553 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -507,14 +507,12 @@ def testReductionWithRepeatedAxisError(self): for weights_shape in ([None, shape] if axis is None or len(shape) == 1 or isinstance(axis, tuple) else [None, (shape[axis],), shape]) ], - keepdims=([False, True] if numpy_version >= (1, 23) else [None]), + keepdims=[False, True], returned=[False, True], ) def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims): rng = jtu.rand_default(self.rng()) - kwds = dict(returned=returned) - if keepdims is not None: - kwds['keepdims'] = keepdims + kwds = dict(returned=returned, keepdims=keepdims) if weights_shape is None: np_fun = lambda x: np.average(x, axis, **kwds) jnp_fun = lambda x: jnp.average(x, axis, **kwds) @@ -527,15 +525,11 @@ def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims): tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-12, np.complex64: 1e-5} check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - if numpy_version == (1, 23, 0) and keepdims and weights_shape is not None and axis is not None: - # Known failure: https://github.com/numpy/numpy/issues/21850 - pass - else: - try: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=check_dtypes, tol=tol) - except ZeroDivisionError: - self.skipTest("don't support checking for ZeroDivisionError") + try: + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + check_dtypes=check_dtypes, tol=tol) + except ZeroDivisionError: + self.skipTest("don't support checking for ZeroDivisionError") self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes, rtol=tol, atol=tol) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 050f39cedd15..45cc177fbfd1 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2048,9 +2048,6 @@ def np_fun(x, fill_value=fill_value): @jtu.sample_product(dtype=inexact_dtypes) def testUniqueNans(self, dtype): - if numpy_version == (1, 23, 0) and dtype == np.float16: - # https://github.com/numpy/numpy/issues/21838 - self.skipTest("Known failure on numpy 1.23.0") def args_maker(): x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] if np.issubdtype(dtype, np.complexfloating): @@ -2070,8 +2067,6 @@ def np_fun(x): @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) def testUniqueEqualNan(self, dtype, equal_nan): - if numpy_version < (1, 24, 0): - self.skipTest("np.unique equal_nan requires NumPy 1.24 or newer.") shape = (20,) rng = jtu.rand_some_nan(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -2660,13 +2655,18 @@ def np_fun(arg): side=['left', 'right'], dtype=number_dtypes, method=['sort', 'scan', 'scan_unrolled', 'compare_all'], + use_sorter=[True, False], ) - def testSearchsorted(self, ashape, vshape, side, dtype, method): + def testSearchsorted(self, ashape, vshape, side, dtype, method, use_sorter): rng = jtu.rand_default(self.rng()) - args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)] - def np_fun(a, v): - return np.searchsorted(a, v, side=side).astype('int32') - jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side, method=method) + def args_maker(): + a = rng(ashape, dtype) + v = rng(vshape, dtype) + return (a, v, np.argsort(a)) if use_sorter else (np.sort(a), v) + def np_fun(a, v, sorter=None): + return np.searchsorted(a, v, side=side, sorter=sorter).astype('int32') + def jnp_fun(a, v, sorter=None): + return jnp.searchsorted(a, v, side=side, method=method, sorter=sorter) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @@ -2779,10 +2779,7 @@ def testStack(self, shape, axis, dtypes, array_input, out_dtype): else: args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - if numpy_version < (1, 24): - np_fun = jtu.promote_like_jnp(lambda *args: np.stack(*args, axis=axis).astype(out_dtype)) - else: - np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) + np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype) with jtu.strict_promotion_if_dtypes_match(dtypes): @@ -2809,7 +2806,7 @@ def testHVDStack(self, shape, op, dtypes, array_input, out_dtype): else: args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - if numpy_version < (1, 24) or op == "dstack": + if op == "dstack": np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype)) else: np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype, @@ -2991,6 +2988,30 @@ def testArrayCreationWithSharding(self, func, shape, dtype): out = func(**kwds, shape=shape, dtype=dtype, device=sharding) self.assertEqual(out.sharding, sharding) + @jtu.sample_product( + func=[ + lambda dtype, device: jnp.arange(5, dtype=dtype, device=device), + lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device), + ], + dtype=default_dtypes, + ) + def testArangeEyeWithDevice(self, func, dtype): + device = jax.devices()[-1] + out = func(dtype=dtype, device=device) + self.assertEqual(out.devices(), {device}) + + @jtu.sample_product( + func=[ + lambda dtype, device: jnp.arange(5, dtype=dtype, device=device), + lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device), + ], + dtype=default_dtypes, + ) + def testArangeEyeWithSharding(self, func, dtype): + sharding = SingleDeviceSharding(jax.devices()[-1]) + out = func(dtype=dtype, device=sharding) + self.assertEqual(out.sharding, sharding) + @jtu.sample_product( func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], shape=array_shapes, @@ -4799,10 +4820,19 @@ def testArangeJit(self): expected = jtu.with_jax_dtype_defaults(np.arange)(5) self.assertAllClose(ans, expected) - @jtu.sample_product(args=[(5,), (0, 5)]) - def testArangeJaxpr(self, args): - jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args))() - self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + @jtu.sample_product( + args=[(5,), (0, 5)], + specify_device=[True, False], + ) + def testArangeJaxpr(self, args, specify_device): + device = jax.devices()[-1] if specify_device else None + kwargs = {"device": device} + jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args, **kwargs))() + # We have 2 statements in jaxpr: + # [a:i32[5] = iota[dimension=0 dtype=int32 shape=(5,)], + # a:i32[5] = device_put[devices=[None] srcs=[None]] b] + num_eqs = 2 if device is not None else 1 + self.assertEqual(len(jaxpr.jaxpr.eqns), num_eqs) self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) def testIssue830(self): @@ -5944,7 +5974,7 @@ def testWrappedSignaturesMatch(self): 'empty_like': ['subok', 'order'], 'einsum': ['kwargs'], 'einsum_path': ['einsum_call'], - 'eye': ['device', 'order', 'like'], + 'eye': ['order', 'like'], 'hstack': ['casting'], 'identity': ['like'], 'isin': ['kind'], @@ -5987,7 +6017,7 @@ def testWrappedSignaturesMatch(self): mismatches = {} for name, (jnp_fun, np_fun) in func_pairs.items(): - if numpy_version >= (1, 24) and name in ['histogram', 'histogram2d', 'histogramdd']: + if name in ['histogram', 'histogram2d', 'histogramdd']: # numpy 1.24 re-orders the density and weights arguments. # TODO(jakevdp): migrate histogram APIs to match newer numpy versions. continue @@ -6085,6 +6115,8 @@ def test_lax_numpy_docstrings(self): # Test that docstring wrapping & transformation didn't fail. unimplemented = ['fromfile', 'fromiter'] + aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', + 'amax', 'amin'] for name in dir(jnp): if name.startswith('_') or name in unimplemented: @@ -6107,6 +6139,8 @@ def test_lax_numpy_docstrings(self): raise Exception(f"jnp.{name} does not contain wrapped docstring.") if obj.__doc__ and "*Original docstring below.*" not in obj.__doc__: raise Exception(f"jnp.{name} does not have a wrapped docstring.") + elif name in aliases: + assert "Alias of" in obj.__doc__ else: # Other functions should have nontrivial docs including "Args" and "Returns". doc = obj.__doc__ diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index f950d1048df4..38607cae883b 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -23,6 +23,7 @@ import scipy.special as osp_special import jax +from jax._src import deprecations from jax._src import test_util as jtu from jax.scipy import special as lsp_special @@ -147,7 +148,12 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t "rel_entr", 2, float_dtypes, jtu.rand_positive, True, ), op_record("poch", 2, float_dtypes, jtu.rand_positive, True), - op_record("hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True) + op_record( + "hyp1f1", 3, float_dtypes, + functools.partial(jtu.rand_uniform, low=0.5, high=30), True + ), + op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True), + op_record("softmax", 1, float_dtypes, jtu.rand_default, True), ] @@ -237,6 +243,28 @@ def testRelEntrExtremeValues(self): self._CheckAgainstNumpy(osp_special.rel_entr, lsp_special.rel_entr, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.rel_entr, args_maker, rtol=rtol) + def testBetaParameterDeprecation(self): + with self.assertNoWarnings(): + lsp_special.beta(1, 1) + lsp_special.beta(1, b=1) + lsp_special.beta(a=1, b=1) + if deprecations.is_accelerated('jax-scipy-beta-args'): + with self.assertRaises(ValueError): + lsp_special.beta(x=1, y=1) + else: + with self.assertWarns(DeprecationWarning): + lsp_special.beta(1, y=1) + with self.assertWarns(DeprecationWarning): + lsp_special.beta(a=1, y=1) + with self.assertWarns(DeprecationWarning): + lsp_special.beta(x=1, b=1) + with self.assertWarns(DeprecationWarning): + lsp_special.beta(x=1, y=1) + with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning): + lsp_special.beta(1, x=1) + with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning): + lsp_special.beta(b=1, y=1) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_test.py b/tests/lax_test.py index fdcf71b99226..ce1a2d4ff897 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -45,7 +45,7 @@ from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal -from jax._src.util import NumpyComplexWarning +from jax._src.util import NumpyComplexWarning, safe_zip from jax._src.tree_util import tree_map config.parse_flags_with_absl() @@ -3315,14 +3315,6 @@ class FooTyRules: def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray((2,), jnp.dtype('uint32')) - @staticmethod - def logical_sharding(aval, phys_sharding): - return phys_sharding - - @staticmethod - def physical_sharding(aval, sharding): - return sharding - @staticmethod def result_handler(sticky_device, aval): def handler(_, buf): @@ -3341,14 +3333,6 @@ def handler(arr): return FooArray(aval.shape, buf) return handler - @staticmethod - def replicate_trailing_dims(ctx, val, aval): - return val - - @staticmethod - def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval): - pass - class FooTy(dtypes.ExtendedDType): type = dtypes.extended @@ -3410,11 +3394,14 @@ def __repr__(self) -> str: size = property(lambda self: self.data.size // 2) ndim = property(lambda self: self.data.ndim - 1) -def shard_foo_array_handler(x, sharding): - device, = sharding._addressable_device_assignment - aval = core.raise_to_shaped(core.get_aval(x.data)) - return pxla.batched_device_put( - aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device]) +def shard_foo_array_handler(xs, shardings): + results = [] + for x, sharding in safe_zip(xs, shardings): + device, = sharding._addressable_device_assignment + aval = core.raise_to_shaped(core.get_aval(x.data)) + results.append(pxla.batched_device_put( + aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])) + return results def foo_array_constant_handler(x): return array._array_mlir_constant_handler(x.data) diff --git a/tests/layout_test.py b/tests/layout_test.py index 44a7dde65e80..d071d6cb7149 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import math -import os import re from absl.testing import absltest import numpy as np @@ -25,30 +25,16 @@ from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip -from jax._src import xla_bridge config.parse_flags_with_absl() -prev_xla_flags = None +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() + _exit_stack.close() pattern = re.compile(r"\{(.*?):") @@ -276,8 +262,11 @@ def f(x): ' compiled with'): compiled(arr) + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def test_cpu_default_backend_layout(self): - out_cpu = jax.jit(jnp.dot, backend='cpu')(np.ones((8, 8)), np.ones((8, 8))) + inp = jax.device_put(np.ones((8, 8)), device=jax.devices('cpu')[0]) + out_cpu = jax.jit(jnp.dot)(inp, inp) jax.jit(jnp.dot, backend=jax.default_backend()).lower( out_cpu, out_cpu).compile() # doesn't crash diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 52d359ec2ffe..fe7ddc83e8b6 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -290,7 +290,6 @@ def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, for result in results: self.assertTrue(np.all(np.isnan(result))) - @jtu.sample_product( shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)], dtype=float_types + complex_types, @@ -345,13 +344,13 @@ def testEigBatching(self, shape, dtype): np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3)) @jtu.sample_product( - n=[0, 4, 5, 50, 512], - dtype=float_types + complex_types, - lower=[True, False], + n=[0, 4, 5, 50, 512], + dtype=float_types + complex_types, + lower=[True, False], ) def testEigh(self, n, dtype, lower): rng = jtu.rand_default(self.rng()) - tol = 0.5 * np.maximum(n, 80) * np.finfo(dtype).eps + eps = np.finfo(dtype).eps args_maker = lambda: [rng((n, n), dtype)] uplo = "L" if lower else "U" @@ -361,15 +360,36 @@ def testEigh(self, n, dtype, lower): w, v = jnp.linalg.eigh(np.tril(a) if lower else np.triu(a), UPLO=uplo, symmetrize_input=False) w = w.astype(v.dtype) - self.assertLessEqual( - np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), 4 * tol + tol = 2 * n * eps + self.assertAllClose( + np.eye(n, dtype=v.dtype), + np.matmul(np.conj(T(v)), v), + atol=tol, + rtol=tol, ) + with jax.numpy_rank_promotion('allow'): - self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v), - tol * np.linalg.norm(a)) + tol = 100 * eps + self.assertLessEqual( + np.linalg.norm(np.matmul(a, v) - w * v), tol * np.linalg.norm(a) + ) self._CompileAndCheck( - partial(jnp.linalg.eigh, UPLO=uplo), args_maker, rtol=tol + partial(jnp.linalg.eigh, UPLO=uplo), args_maker, rtol=eps + ) + + # Compare eigenvalues against Numpy using double precision. We do not compare + # eigenvectors because they are not uniquely defined, but the two checks above + # guarantee that that they satisfy the conditions for being eigenvectors. + double_type = dtype + if dtype == np.float32: + double_type = np.float64 + if dtype == np.complex64: + double_type = np.complex128 + w_np = np.linalg.eigvalsh(a.astype(double_type)) + tol = 8 * eps + self.assertAllClose( + w_np.astype(w.dtype), w, atol=tol * np.linalg.norm(a), rtol=tol ) @jtu.sample_product( @@ -383,7 +403,7 @@ def testEighSubsetByIndex(self, start, end): dtype = np.float32 n = 256 rng = jtu.rand_default(self.rng()) - tol = np.maximum(n, 80) * np.finfo(dtype).eps + eps = np.finfo(dtype).eps args_maker = lambda: [rng((n, n), dtype)] subset_by_index = (start, end) k = end - start @@ -397,21 +417,36 @@ def testEighSubsetByIndex(self, start, end): self.assertEqual(v.shape, (n, k)) self.assertEqual(w.shape, (k,)) - self.assertLessEqual( - np.linalg.norm(np.eye(k) - np.matmul(np.conj(T(v)), v)), 3 * tol - ) with jax.numpy_rank_promotion("allow"): + tol = 200 * eps self.assertLessEqual( np.linalg.norm(np.matmul(a, v) - w * v), tol * np.linalg.norm(a) ) + tol = 3 * n * eps + self.assertAllClose( + np.eye(k, dtype=v.dtype), + np.matmul(np.conj(T(v)), v), + atol=tol, + rtol=tol, + ) - self._CompileAndCheck(partial(jnp.linalg.eigh), args_maker, rtol=tol) + self._CompileAndCheck(partial(jnp.linalg.eigh), args_maker, rtol=eps) # Compare eigenvalues against Numpy. We do not compare eigenvectors because # they are not uniquely defined, but the two checks above guarantee that # that they satisfy the conditions for being eigenvectors. - w_np = np.linalg.eigvalsh(a)[subset_by_index[0] : subset_by_index[1]] - self.assertAllClose(w_np, w, atol=tol, rtol=tol) + double_type = dtype + if dtype == np.float32: + double_type = np.float64 + if dtype == np.complex64: + double_type = np.complex128 + w_np = np.linalg.eigvalsh(a.astype(double_type))[ + subset_by_index[0] : subset_by_index[1] + ] + tol = 20 * eps + self.assertAllClose( + w_np.astype(w.dtype), w, atol=tol * np.linalg.norm(a), rtol=tol + ) def testEighZeroDiagonal(self): a = np.array([[0., -1., -1., 1.], @@ -568,7 +603,7 @@ def testEighBatching(self, shape, dtype): ws, vs = vmap(jsp.linalg.eigh)(args) ws = ws.astype(vs.dtype) norm = np.max(np.linalg.norm(np.matmul(args, vs) - ws[..., None, :] * vs)) - self.assertLess(norm, 1e-2) + self.assertLess(norm, 1.4e-2) @jtu.sample_product( shape=[(1,), (4,), (5,)], @@ -795,7 +830,7 @@ def compute_max_backward_error(operand, reconstructed_operand): max_backward_error = np.amax(backward_error) return max_backward_error - tol = 80 * jnp.finfo(dtype).eps + tol = 100 * jnp.finfo(dtype).eps reconstruction_tol = 2 * tol unitariness_tol = 3 * tol @@ -1305,6 +1340,18 @@ def testDiagonal(self, shape, dtype, offset): self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) + def testTrace(self): + shape, dtype, offset, out_dtype = (3, 4), "float32", 0, None + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + lax_fun = partial(jnp.linalg.trace, offset=offset, dtype=out_dtype) + if jtu.numpy_version() >= (2, 0, 0): + np_fun = partial(np.linalg.trace, offset=offset) + else: + np_fun = partial(np.trace, offset=offset, axis1=-2, axis2=-1, dtype=out_dtype) + self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) + self._CompileAndCheck(lax_fun, args_maker) + class ScipyLinalgTest(jtu.JaxTestCase): @@ -1380,6 +1427,8 @@ def testLuBatching(self, shape, dtype): self.assertAllClose(us, actual_us) @jtu.skip_on_devices("cpu", "tpu") + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def testLuCPUBackendOnGPU(self): # tests running `lu` on cpu when a gpu is present. jit(jsp.linalg.lu, backend="cpu")(np.ones((2, 2))) # does not crash diff --git a/tests/logging_test.py b/tests/logging_test.py index 454b525f2e8f..5a495d47d31b 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -36,6 +36,19 @@ jax.config.parse_flags_with_absl() +@contextlib.contextmanager +def jax_debug_log_modules(value): + # jax_debug_log_modules doesn't have a context manager, because it's + # not thread-safe. But since tests are always single-threaded, we + # can define one here. + original_value = jax.config.jax_debug_log_modules + jax.config.update("jax_debug_log_modules", value) + try: + yield + finally: + jax.config.update("jax_debug_log_modules", original_value) + + @contextlib.contextmanager def capture_jax_logs(): log_output = io.StringIO() @@ -104,30 +117,30 @@ def test_debug_logging(self): self.assertEmpty(log_output.getvalue()) # Turn on all debug logging. - jax.config.update("jax_debug_log_modules", "jax") - with capture_jax_logs() as log_output: - jax.jit(lambda x: x + 1)(1) - self.assertIn("Finished tracing + transforming", log_output.getvalue()) - self.assertIn("Compiling ", log_output.getvalue()) + with jax_debug_log_modules("jax"): + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertIn("Finished tracing + transforming", log_output.getvalue()) + self.assertIn("Compiling ", log_output.getvalue()) # Turn off all debug logging. - jax.config.update("jax_debug_log_modules", None) - with capture_jax_logs() as log_output: - jax.jit(lambda x: x + 1)(1) - self.assertEmpty(log_output.getvalue()) + with jax_debug_log_modules(""): + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertEmpty(log_output.getvalue()) # Turn on one module. - jax.config.update("jax_debug_log_modules", "jax._src.dispatch") - with capture_jax_logs() as log_output: - jax.jit(lambda x: x + 1)(1) - self.assertIn("Finished tracing + transforming", log_output.getvalue()) - self.assertNotIn("Compiling ", log_output.getvalue()) + with jax_debug_log_modules("jax._src.dispatch"): + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertIn("Finished tracing + transforming", log_output.getvalue()) + self.assertNotIn("Compiling ", log_output.getvalue()) # Turn everything off again. - jax.config.update("jax_debug_log_modules", None) - with capture_jax_logs() as log_output: - jax.jit(lambda x: x + 1)(1) - self.assertEmpty(log_output.getvalue()) + with jax_debug_log_modules(""): + with capture_jax_logs() as log_output: + jax.jit(lambda x: x + 1)(1) + self.assertEmpty(log_output.getvalue()) if __name__ == "__main__": diff --git a/tests/lru_cache_test.py b/tests/lru_cache_test.py new file mode 100644 index 000000000000..fb999cbef0cf --- /dev/null +++ b/tests/lru_cache_test.py @@ -0,0 +1,155 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib.util +import tempfile +import time + +from absl.testing import absltest + +from jax._src import path as pathlib +from jax._src.lru_cache import LRUCache +import jax._src.test_util as jtu + + +class LRUCacheTestCase(jtu.JaxTestCase): + name: str | None + path: pathlib.Path | None + + def setUp(self): + if importlib.util.find_spec("filelock") is None: + self.skipTest("filelock is not installed") + + super().setUp() + tmpdir = tempfile.TemporaryDirectory() + self.enter_context(tmpdir) + self.name = tmpdir.name + self.path = pathlib.Path(self.name) + + def tearDown(self): + self.path = None + self.name = None + super().tearDown() + + +class LRUCacheTest(LRUCacheTestCase): + + def test_get_nonexistent_key(self): + cache = LRUCache(self.name, max_size=-1) + self.assertIsNone(cache.get("cache-a")) + + def test_put_and_get_key(self): + cache = LRUCache(self.name, max_size=-1) + + cache.put("cache-a", b"a") + self.assertEqual(cache.get("cache-a"), b"a") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a"}) + + cache.put("cache-b", b"b") + self.assertEqual(cache.get("cache-a"), b"a") + self.assertEqual(cache.get("cache-b"), b"b") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"}) + + def test_put_empty_value(self): + cache = LRUCache(self.name, max_size=-1) + + cache.put("cache-a", b"") + self.assertEqual(cache.get("cache-a"), b"") + + def test_put_empty_key(self): + cache = LRUCache(self.name, max_size=-1) + + with self.assertRaisesRegex(ValueError, r"key cannot be empty"): + cache.put("", b"a") + + def test_eviction(self): + cache = LRUCache(self.name, max_size=2) + + cache.put("cache-a", b"a") + cache.put("cache-b", b"b") + + # `sleep()` is necessary to guarantee that `cache-b`"s timestamp is strictly greater than `cache-a`"s + time.sleep(1) + cache.get("cache-b") + + # write `cache-c`, evict `cache-a` + cache.put("cache-c", b"c") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-c"}) + + # calling `get()` on `cache-b` makes `cache-c` least recently used + time.sleep(1) + cache.get("cache-b") + + # write `cache-d`, evict `cache-c` + cache.put("cache-d", b"d") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-d"}) + + def test_eviction_with_empty_value(self): + cache = LRUCache(self.name, max_size=1) + + cache.put("cache-a", b"a") + + # write `cache-b` with length 0 + # eviction should not happen even though the cache is full + cache.put("cache-b", b"") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"}) + + # calling `get()` on `cache-a` makes `cache-b` least recently used + time.sleep(1) + cache.get("cache-a") + + # writing `cache-c` should result in evicting the + # least recent used file (`cache-b`) first, + # but this is not sufficient to make room for `cache-c`, + # so `cache-a` should be evicted as well + cache.put("cache-c", b"c") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-c"}) + + def test_existing_cache_dir(self): + cache = LRUCache(self.name, max_size=2) + + cache.put("cache-a", b"a") + + # simulates reinitializing the cache in another process + del cache + cache = LRUCache(self.name, max_size=2) + + self.assertEqual(cache.get("cache-a"), b"a") + + # ensure that the LRU policy survives cache reinitialization + cache.put("cache-b", b"b") + + # calling `get()` on `cache-a` makes `cache-b` least recently used + time.sleep(1) + cache.get("cache-a") + + # write `cache-c`, evict `cache-b` + cache.put("cache-c", b"c") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-c"}) + + def test_max_size(self): + cache = LRUCache(self.name, max_size=1) + + msg = (r"Cache value for key .+? of size \d+ bytes exceeds the maximum " + r"cache size of \d+ bytes") + with self.assertWarnsRegex(UserWarning, msg): + cache.put("cache-a", b"aaaa") + self.assertIsNone(cache.get("cache-a")) + self.assertEqual(set(self.path.glob("cache-*")), set()) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/memories_test.py b/tests/memories_test.py index 782a2ec59725..d556a2d5bafb 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import functools import math import re @@ -25,16 +26,15 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src import config -from jax._src.lib import xla_extension_version from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp -from jax.sharding import PartitionSpec as P from jax.ad_checkpoint import Offloadable, remat, Recompute +from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import (NamedSharding, PositionalSharding, SingleDeviceSharding, GSPMDSharding, - TransferToMemoryKind, - common_devices_indices_map) + TransferToMemoryKind, PartitionSpec as P) from jax.experimental.compute_on import compute_on +from jax.experimental.shard_map import shard_map import numpy as np config.parse_flags_with_absl() @@ -64,26 +64,19 @@ def _create_inputs(shape, pspec, mem_kind=None): # * nested jit +@jtu.with_config(jax_enable_memories=True) class ShardingMemoriesTest(jtu.JaxTestCase): def setUp(self): - if xla_extension_version < 265 and not jtu.test_device_matches(["tpu"]): - self.skipTest("Memories do not work on CPU and GPU backends yet.") # TODO(b/311021572) if jtu.is_cloud_tpu(): self.skipTest("Experimental feature not yet implemented on Cloud TPU") super().setUp() - self.orig_memories_flag = config.enable_memories.value - jax.config.update('jax_enable_memories', True) if jtu.test_device_matches(["cpu"]): self._default_memory_kind = "unpinned_host" else: self._default_memory_kind = "device" - def tearDown(self): - jax.config.update('jax_enable_memories', self.orig_memories_flag) - super().tearDown() - @parameterized.named_parameters( ("named_sharding", "named_sharding"), ("positional_sharding", "positional_sharding"), @@ -206,18 +199,13 @@ def test_default_memory_kind(self): self.assertEqual(dev.default_memory().kind, self._default_memory_kind) +@jtu.with_config(jax_enable_memories=True) class DevicePutTest(jtu.JaxTestCase): def setUp(self): if not jtu.test_device_matches(["tpu"]): self.skipTest("Memories do not work on CPU and GPU backends yet.") super().setUp() - self.orig_memories_flag = config.enable_memories.value - jax.config.update('jax_enable_memories', True) - - def tearDown(self): - jax.config.update('jax_enable_memories', self.orig_memories_flag) - super().tearDown() def _check_device_put_addressable_shards( self, out, inp, expected_sharding, expected_mem_kind, index=True): @@ -409,6 +397,20 @@ def test_device_put_python_int(self, host_memory_kind: str): self._check_device_put_addressable_shards( out_host, py_inp, s_host, host_memory_kind, index=False) + def test_device_put_inside_jit(self): + _, s_host, np_inp, inp_host = _create_inputs( + (8, 2), P("x", "y"), mem_kind="pinned_host") + s_dev = s_host.with_memory_kind("device") + + @jax.jit + def f(a, b): + x, y = jax.device_put((a, b), s_dev) + return x * y + + out = f(inp_host, inp_host) + self._check_device_put_addressable_shards( + out, np_inp * np_inp, s_dev, "device") + def test_parameter_streaming(self): _, s_host, np_inp, inp_host = _create_inputs( (8, 2), P("x", "y"), mem_kind="pinned_host") @@ -456,6 +458,58 @@ def f(scalar_input): out2, 2, s_host, "pinned_host", index=False ) + def test_parameter_and_output_streaming_with_array(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + np_inp = np.arange(16).reshape(8, 2) + s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host") + inp_host = jax.device_put(np_inp, s_host) + + @functools.partial(jax.jit, out_shardings=(s_host, s_host)) + def f(x): + return (x, x) + + compiled = f.lower(inp_host).compile() # doesn't crash + compiled_text = compiled.as_text() + if compiled_text is not None: + self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}") + + out1, out2 = f(inp_host) + self._check_device_put_addressable_shards( + out1, np_inp, s_host, "pinned_host" + ) + self._check_device_put_addressable_shards( + out2, np_inp, s_host, "pinned_host" + ) + + def test_parameter_and_output_streaming_with_scalar(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") + + mesh = jax.sharding.Mesh(jax.devices(), "axis") + s_host = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(), memory_kind="pinned_host" + ) + scalar_inp = 1 + + @functools.partial(jax.jit, out_shardings=(s_host, s_host)) + def f(x): + return (x, x) + + compiled = f.lower(scalar_inp).compile() # doesn't crash + compiled_text = compiled.as_text() + if compiled_text is not None: + self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}") + + out1, out2 = f(scalar_inp) + self._check_device_put_addressable_shards( + out1, scalar_inp, s_host, "pinned_host", index=False + ) + self._check_device_put_addressable_shards( + out2, scalar_inp, s_host, "pinned_host", index=False + ) + def test_identity_jit_host_to_device_and_vice_versa(self): mesh = jtu.create_global_mesh((2, 2), ("x", "y")) np_inp = np.arange(16).reshape(8, 2) @@ -528,6 +582,8 @@ def f(x): out_host, np_inp * 2, s_host, 'pinned_host') def test_output_streaming_inside_scan(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z")) np_inp = np.arange(4096).reshape(16, 16, 16) s_hbm = NamedSharding(mesh, P(None, "y", "z"), memory_kind="device") @@ -546,19 +602,25 @@ def body(carry, x): self.assertArraysEqual(out, np_inp + 1) self.assertEqual(out.sharding.memory_kind, 'pinned_host') + def test_deepcopy(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") + mesh = jax.sharding.Mesh(jax.devices(), "x") + s_host = NamedSharding(mesh, P(), memory_kind="pinned_host") + t = jax.device_put(jnp.zeros((8, 2)), s_host) + t_copy = copy.deepcopy(t) + self.assertArraysEqual(t, t_copy) + self.assertEqual(t.shape, t_copy.shape) + + +@jtu.with_config(jax_enable_memories=True) class ComputeOffload(jtu.BufferDonationTestCase): def setUp(self): if not jtu.test_device_matches(["tpu"]): self.skipTest("Memories do not work on CPU and GPU backends yet.") super().setUp() - self.orig_memories_flag = config.enable_memories.value - jax.config.update('jax_enable_memories', True) - - def tearDown(self): - jax.config.update('jax_enable_memories', self.orig_memories_flag) - super().tearDown() def _check_mem_kind(self, executable_kind, out_sharding, expected_kind): out_kind = out_sharding.memory_kind @@ -566,6 +628,40 @@ def _check_mem_kind(self, executable_kind, out_sharding, expected_kind): self.assertEqual(out_kind, expected_kind) self.assertEqual(executable_kind, expected_kind) + def test_compute_no_inputs(self): + mesh = jtu.create_global_mesh((4,), ('data')) + + tpu_sharding = NamedSharding(mesh, P('data')) + cpu_sharding = NamedSharding(mesh, P('data'), memory_kind='pinned_host') + + @functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding)) + def init(): + tpu_array = jax.random.normal(jax.random.key(42), (16,16)) + cpu_array = jax.random.normal(jax.random.key(42), (16,16)) + return tpu_array, cpu_array + + tpu_array, cpu_array = init() + self.assertEqual(tpu_array.sharding, tpu_sharding) + self.assertEqual(cpu_array.sharding, cpu_sharding) + + def test_compute_no_inputs_host_replicated(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3: + self.skipTest("This test requires an xla_version >= 3.") + mesh = jtu.create_global_mesh((4,), ('data')) + + tpu_sharding = NamedSharding(mesh, P('data')) + cpu_sharding = NamedSharding(mesh, P(), memory_kind='pinned_host') + + @functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding)) + def init(): + tpu_array = jax.random.normal(jax.random.key(42), (16,16)) + cpu_array = jax.random.normal(jax.random.key(42), (16,16)) + return tpu_array, cpu_array + + tpu_array, cpu_array = init() + self.assertEqual(tpu_array.sharding, tpu_sharding) + self.assertEqual(cpu_array.sharding, cpu_sharding) + def test_compute_on_basic(self): out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host') @@ -709,6 +805,8 @@ def f(x): self.assertArraysEqual(out, expected_out) def test_host_offload_in_custom_vjp(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") @jax.custom_vjp def f(x): return jnp.sin(x) @@ -737,6 +835,8 @@ def f_bwd(res, tx): self.assertArraysEqual(g(x), all_true) def test_host_offload_in_custom_vjp_sharded(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") mesh = jtu.create_global_mesh((2, 2), ("x", "y")) s = NamedSharding(mesh, P('x')) @@ -767,7 +867,38 @@ def f_bwd(res, tx): all_true = jnp.ones(4) self.assertArraysEqual(g(arr), all_true) + def test_scan_offload(self): + np_inp = jnp.arange(4096).reshape(16, 16, 16) + + @jax.jit + def f(xs): + def body(carry, x): + with compute_on('device_host'): + out_tpu = x + carry + return carry, out_tpu + _, res = jax.lax.scan(body, 1, xs) + return res + + out = f(np_inp) + self.assertArraysEqual(out, np_inp + 1) + + @compute_on('device_host') + @jax.jit + def body2(carry, x): + out_tpu = x + carry + return carry, out_tpu + + @jax.jit + def f2(xs): + _, res = jax.lax.scan(body2, 1, xs) + return res + + out2 = f2(np_inp) + self.assertArraysEqual(out2, np_inp + 1) + def test_pure_host_data_and_compute(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') np_inp = np.arange(16).reshape(8, 2) @@ -813,6 +944,8 @@ def f(x): self.assertArraysAllClose(out2, np.sin(np_inp * 2)) def test_jit_host_multi_outputs(self): + if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: + self.skipTest("This test requires an xla_version >= 2.") _, s, np_inp, inp = _create_inputs((8, 2), P("x")) @jax.jit @@ -1033,19 +1166,34 @@ def f(inp1): self.assertIn("input_output_alias", lowered_text) self.assertDeleted(x) + def test_compute_offload_inside_shmap(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, s) + @compute_on('device_host') + @jax.jit + def g(x): + return x * 2 + + def f(x): + x = x * 3 + y = g(x) + return y * 4 + + out = jax.jit(shard_map(f, mesh=mesh, in_specs=P('x', 'y'), + out_specs=P('x', 'y')))(arr) + self.assertArraysEqual(out, np_inp * 24) + + +@jtu.with_config(jax_enable_memories=True) class ActivationOffloadingTest(jtu.JaxTestCase): def setUp(self): if not jtu.test_device_matches(["tpu", "gpu"]): self.skipTest("Memories do not work on CPU backend.") super().setUp() - self.orig_memories_flag = config.enable_memories.value - jax.config.update('jax_enable_memories', True) - - def tearDown(self): - jax.config.update('jax_enable_memories', self.orig_memories_flag) - super().tearDown() def test_remat_jaxpr_offloadable(self): mesh = jtu.create_global_mesh((2,), ("x",)) diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index ce077812960d..1c8893f36070 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -21,13 +21,23 @@ from absl import logging from absl.testing import absltest from absl.testing import parameterized +from jax._src import mesh as mesh_lib from jax._src import test_util +from jax._src.sharding_impls import NamedSharding, PartitionSpec, local_to_global_shape from jax.experimental import mesh_utils from jax.sharding import Mesh # pylint: disable=g-importing-member import numpy as np # pyformat: disable + +@dataclasses.dataclass(frozen=True) +class MockClient: + """Mock client for testing, everything is done as process index 0.""" + def process_index(self) -> int: + return 0 + + @dataclasses.dataclass(frozen=True) class MockTpuDevice: """Mock TPU device for testing.""" @@ -38,6 +48,7 @@ class MockTpuDevice: coords: Sequence[int] core_on_chip: int slice_index: int = 0 + client: MockClient = dataclasses.field(default_factory=MockClient) def mock_tpu_devices(x, y, z, dev_kind, one_device_per_chip, num_slices=1, @@ -207,6 +218,41 @@ def test_create_device_mesh_for_nd_torus( ) self.assertArraysEqual(assignment, expected_assignment_matrix) + @parameterized.named_parameters( + ('2x2x1', mock_2x2x1_devices,), + ('2x2x4', mock_2x2x4_devices, ), + ('4x4x4', mock_4x4x4_devices,), + ('4x4x8', mock_4x4x8_devices,), + ('4x8x8', mock_4x8x8_devices, ), + ('8x8', mock_8x8_devices), + ) + def test_create_device_mesh_has_computable_global_shape(self, devices): + def factorize(n, max_factors=3): + if max_factors == 1 or n == 1: + yield (n, ) * max_factors + return + for i in range(2, n+1): + if n % i == 0: + for remaining in factorize(n // i, max_factors=max_factors - 1): + yield (i, *remaining) + jax_devices = devices(True) + for mesh_shape in factorize(len(jax_devices), max_factors=3): + mesh = mesh_utils.create_device_mesh(mesh_shape, devices=jax_devices, + allow_split_physical_axes=True) + mesh = mesh_lib.Mesh(mesh, ('a', 'b', 'c')) + sharding = NamedSharding(mesh, PartitionSpec('a', 'b', 'c')) + computed_global_shape = local_to_global_shape(sharding, (1, 1, 1)) + self.assertFalse( + np.any([x is None for x in computed_global_shape]), + f'{mesh_shape=}, {computed_global_shape=} is not uniform') + + sharding = NamedSharding(mesh, PartitionSpec(('a', 'c',), 'b')) + computed_global_shape = local_to_global_shape(sharding, (1, 1, 1)) + self.assertFalse( + np.any([x is None for x in computed_global_shape]), + f'{mesh_shape=}, {computed_global_shape=} is not uniform') + + @parameterized.named_parameters( ('2x2x1', mock_2x2x1_devices, [1, 1, 4], [(), (), (0, 1, 2)]), ('2x2x4', mock_2x2x4_devices, [1, 4, 4], [(), (2,), (0, 1)]), diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py index ba735775beab..fd1fe560ff4b 100644 --- a/tests/mock_gpu_test.py +++ b/tests/mock_gpu_test.py @@ -24,23 +24,14 @@ import numpy as np jax.config.parse_flags_with_absl() +NUM_SHARDS = 16 +@jtu.with_config(use_mock_gpu_client=True, mock_num_gpus=NUM_SHARDS) class MockGPUTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - jax.config.update('use_mock_gpu_client', True) - - def tearDown(self): - jax.config.update('use_mock_gpu_client', False) - jax.config.update('mock_num_gpus', 1) - super().tearDown() - def testMockWithSharding(self): - num_shards = 16 - jax.config.update('mock_num_gpus', num_shards) - mesh = jtu.create_global_mesh((num_shards,), ('x',)) + mesh = jtu.create_global_mesh((NUM_SHARDS,), ('x',)) @partial( jax.jit, in_shardings=NamedSharding(mesh, P('x',)), diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index e3bfc14f15f9..d182c99be7b1 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -66,28 +66,24 @@ jax_test( ) jax_test( - name = "matmul", - srcs = ["//third_party/py/jax/experimental/mosaic/gpu/examples:matmul.py"], + name = "flash_attention", + srcs = ["//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py"], disable_backends = DISABLED_BACKENDS, disable_configs = DISABLED_CONFIGS, - main = "//third_party/py/jax/experimental/mosaic/gpu/examples:matmul.py", - tags = [ - "manual", - "notap", - ], + main = "//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py", + tags = ["notap"], deps = [ "//jax:mosaic_gpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps("numpy"), ) jax_test( - name = "flash_attention", - srcs = ["//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py"], + name = "flash_attention_test", + srcs = ["flash_attention_test.py"], disable_backends = DISABLED_BACKENDS, disable_configs = DISABLED_CONFIGS, - main = "//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py", - tags = ["notap"], deps = [ "//jax:mosaic_gpu", - ] + py_deps("numpy"), + "//jax/experimental/mosaic/gpu/examples:flash_attention", + ] + py_deps("absl/testing"), ) diff --git a/tests/mosaic/flash_attention_test.py b/tests/mosaic/flash_attention_test.py new file mode 100644 index 000000000000..1d15159ca44e --- /dev/null +++ b/tests/mosaic/flash_attention_test.py @@ -0,0 +1,78 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of FlashAttention.""" + +import os + +from absl.testing import absltest, parameterized +from jax._src import config +from jax._src import test_util as jtu + +# pylint: disable=g-import-not-at-top +try: + # We only import this to see if Mosaic is available. + import jax.experimental.mosaic.gpu # noqa: F401 +except ImportError: + flash_attention = None +else: + from jax.experimental.mosaic.gpu.examples import flash_attention + + +config.parse_flags_with_absl() +os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0") + + +@jtu.with_config(jax_traceback_filtering="off") +class FlashAttentionTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if flash_attention is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + + @parameterized.product( + batch_size=(1,), + q_seq_len=(4096,), + kv_seq_len=(4096,), + num_q_and_kv_heads=((4, 1), # MQA + (6, 3), # GQA + (4, 4),), # MHA + head_dim=(64, 128, 256), + # Provide a default value for exp_impl if 'flash_attention' is not + # available. Bypasses test failures when Mosaic is not available. + exp_impl=[*(flash_attention.ExpImplementation + if flash_attention is not None else (NotImplementedError,))], + ) + def test_flash_attention(self, batch_size, q_seq_len, kv_seq_len, + num_q_and_kv_heads, head_dim, exp_impl): + num_q_heads, num_kv_heads = num_q_and_kv_heads + flash_attention.benchmark_and_verify( + batch_size=batch_size, + q_seq_len=q_seq_len, + kv_seq_len=kv_seq_len, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + exp_impl=exp_impl, + blocks=flash_attention.BlockSizes(stages=2, q=64, kv=64) + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index cdca9821541c..0f60a69a3702 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -14,14 +14,11 @@ # ============================================================================== """Tests for Mosaic GPU DSL functions and utilities.""" -import operator from functools import partial -from typing import Optional +import operator from absl.testing import absltest, parameterized -import numpy as np import jax -import jax.numpy as jnp from jax._src import config from jax._src import test_util as jtu from jax._src.interpreters import mlir @@ -29,6 +26,8 @@ from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +import jax.numpy as jnp +import numpy as np try: import jax._src.lib.mosaic_gpu # noqa: F401 HAS_MOSAIC_GPU = True @@ -44,7 +43,6 @@ # ruff: noqa: F405 -config.update("jax_traceback_filtering", "off") config.parse_flags_with_absl() def nd_loop(bounds, body, *, _idxs = ()): @@ -66,7 +64,7 @@ def mlir_sum(elems): return total -def copy(src: ir.Value, dst: ir.Value, swizzle: Optional[int] = None): +def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None): index = ir.IndexType.get() thread_id = gpu.thread_id(gpu.Dimension.x) stride = gpu.block_dim(gpu.Dimension.x) @@ -156,23 +154,19 @@ class TestCase(parameterized.TestCase): def setUp(self): if not HAS_MOSAIC_GPU: self.skipTest("jaxlib built without Mosaic GPU") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") super().setUp() self.prng = np.random.default_rng(1234) - self.ctx = mlir.make_ir_context() - self.ctx.__enter__() - self.loc = ir.Location.unknown() - self.loc.__enter__() - - def tearDown(self): - self.loc.__exit__(None, None, None) - self.ctx.__exit__(None, None, None) - del self.loc, self.ctx - super().tearDown() + self.enter_context(jtu.global_config_context(jax_traceback_filtering="off")) + self.enter_context(mlir.make_ir_context()) + self.enter_context(ir.Location.unknown()) class TestUtilTest(TestCase): - def test_copy(self): + def test_copy_basic(self): def kernel(ctx, src, dst, _): copy(src, dst) x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3) @@ -483,42 +477,48 @@ def kernel(ctx, in_, out, smem): @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - mlir_dtype_cls=(ir.F16Type, ir.BF16Type, ir.F32Type), + in_mlir_dtype_cls=(ir.F16Type, ir.BF16Type, ir.F32Type), m=(64, 128, 192), - n=(32, 64, 128, 192), + n=(64, 128, 192), k_steps=(1, 2), tma_inputs=(False, True), + jax_out_dtype=(jnp.float16, jnp.float32), ) def test_wgmma( self, m, n, k_steps, - mlir_dtype_cls, + in_mlir_dtype_cls, lhs_transpose, rhs_transpose, tma_inputs, + jax_out_dtype, ): - mlir_dtype = mlir_dtype_cls.get() - if ir.F32Type.isinstance(mlir_dtype): # We actually use tf32 instead - jax_dtype = jnp.float32 + if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type: + raise self.skipTest("Only f16 input is supported for f16 output.") + + in_mlir_dtype = in_mlir_dtype_cls.get() + out_mlir_dtype = mlir.dtype_to_ir_type(jnp.dtype(jax_out_dtype)) + if ir.F32Type.isinstance(in_mlir_dtype): # We actually use tf32 instead + in_jax_dtype = jnp.float32 if lhs_transpose or not rhs_transpose: self.skipTest("Transpose only supported in 16-bit WGMMA") exponent_bits, mantissa_bits = 8, 10 # Use tf32 - elif bytewidth(mlir_dtype) == 2: + elif bytewidth(in_mlir_dtype) == 2: if n % 64 != 0: self.skipTest("16-bit WGMMA only supports n % 64 == 0") - if ir.F16Type.isinstance(mlir_dtype): - jax_dtype = jnp.float16 + if ir.F16Type.isinstance(in_mlir_dtype): + in_jax_dtype = jnp.float16 exponent_bits, mantissa_bits = 5, 10 - elif ir.BF16Type.isinstance(mlir_dtype): - jax_dtype = jnp.bfloat16 + elif ir.BF16Type.isinstance(in_mlir_dtype): + in_jax_dtype = jnp.bfloat16 exponent_bits, mantissa_bits = 8, 7 else: - raise NotImplementedError(mlir_dtype) + raise NotImplementedError(in_mlir_dtype) else: - raise NotImplementedError(mlir_dtype) - nk_tile = 128 // bytewidth(mlir_dtype) + raise NotImplementedError(in_mlir_dtype) + nk_tile = 128 // bytewidth(in_mlir_dtype) k = nk_tile * k_steps assert m % 64 == 0 and n % nk_tile == 0 index = ir.IndexType.get() @@ -580,7 +580,7 @@ def kernel(ctx, lhs, rhs, out, scratch): dst=memref_slice(rhs_smem, (ki, ni)), swizzle=128, ) - init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n) + init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype) acc = mgpu.wgmma( init_acc, lhs_smem, rhs_smem, a_order=lhs_order, b_order=rhs_order, @@ -594,14 +594,14 @@ def quantize(x): return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) x_shape = (k, m) if lhs_transpose else (m, k) - x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(jax_dtype) + x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) y_shape = (n, k) if rhs_transpose else (k, n) - y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(jax_dtype) - out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), jax_out_dtype) scratch_shape = [ - jax.ShapeDtypeStruct((m // 64, k // nk_tile, 64, nk_tile), jax_dtype), + jax.ShapeDtypeStruct((m // 64, k // nk_tile, 64, nk_tile), in_jax_dtype), jax.ShapeDtypeStruct( - (k // nk_tile, n // nk_tile, nk_tile, nk_tile), jax_dtype + (k // nk_tile, n // nk_tile, nk_tile, nk_tile), in_jax_dtype ), ] z = mosaic_gpu.as_gpu_kernel( @@ -609,7 +609,8 @@ def quantize(x): )(x, y) x32, y32 = x.astype(np.float32), y.astype(np.float32) ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) - np.testing.assert_allclose(z, ref, atol=5e-6) + atol = 2e-2 if jax_out_dtype == jnp.float16 else 5e-6 + np.testing.assert_allclose(z, ref, atol=atol) # TODO(apaszke): Add support for f32 @parameterized.product( @@ -939,7 +940,6 @@ def kernel(ctx, *args): )(inp) np.testing.assert_array_equal(inp, result) - def test_warp_tree_reduce(self): def kernel(ctx, out, *_): del ctx @@ -959,7 +959,6 @@ def kernel(ctx, out, *_): np.testing.assert_array_equal(result, x) - class ProfilerTest(TestCase): def test_measure(self): diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index ea285229f382..9e6f66b3a72d 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -28,18 +28,22 @@ else: from jax.experimental.mosaic.gpu.examples import matmul -config.update("jax_traceback_filtering", "off") + config.parse_flags_with_absl() os.environ["XLA_FLAGS"] = ( os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0") +@jtu.with_config(jax_traceback_filtering="off") class MatmulTestCase(jtu.JaxTestCase): def setUp(self): super().setUp() if matmul is None: self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") @parameterized.product( m=(128, 256, 512, 2048), @@ -60,12 +64,6 @@ def test_matmul(self, m, k, n, stages, tile_m, tile_n, in_dtype): if n < tile_n: self.skipTest(f"No use in running a test with {n=} < {tile_n=}.") - # TODO(bchetioui): investigate why this test case fails with error - # Illegal barrier arrive operation - # under memcheck. - if tile_m == 64 and tile_n == 64 and stages == 2: - self.skipTest("Broken test case---skipping.") - try: matmul.verify( m, @@ -102,12 +100,6 @@ def test_matmul_f32(self, m, k, n, stages, tile_m, tile_n, high_precision): if n < tile_n: self.skipTest(f"No use in running a test with {n=} < {tile_n=}.") - # TODO(bchetioui): investigate why this test case fails with error - # Illegal barrier arrive operation - # under memcheck. - if tile_m == 64 and tile_n == 64 and stages == 2: - self.skipTest("Broken test case---skipping.") - try: matmul.verify( m, diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 85386566885b..b2731f256566 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import os +import contextlib from unittest import SkipTest import tracemalloc as tm @@ -24,32 +23,17 @@ from jax import lax from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax._src import test_util as jtu -from jax._src import xla_bridge jax.config.parse_flags_with_absl() -prev_xla_flags = None - - # Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() + def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - -# Reset to previous configuration in case other test modules will be run. + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() + _exit_stack.close() class MultiDeviceTest(jtu.JaxTestCase): @@ -195,8 +179,10 @@ def f(): return lax.add(3., 4.) self.assertIsInstance(f(), jax.Array) self.assert_uncommitted_to_device(f(), devices[0]) self.assert_uncommitted_to_device(jax.jit(f)(), devices[0]) - self.assert_committed_to_device(jax.jit(f, device=devices[1])(), - devices[1]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + self.assert_committed_to_device(jax.jit(f, device=devices[1])(), + devices[1]) def test_reshape(self): devices = self.get_devices() diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index f498d788b08c..4f2e36c64f4b 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -34,6 +34,8 @@ class MultiBackendTest(jtu.JaxTestCase): """Tests jit targeting to different backends.""" @jtu.sample_product(backend=['cpu', 'gpu', 'tpu', None]) + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def testMultiBackend(self, backend): if backend not in ('cpu', jtu.device_under_test(), None): raise SkipTest("Backend is not CPU or the device under test") @@ -52,6 +54,8 @@ def fun(x, y): @jtu.sample_product( ordering=[('cpu', None), ('gpu', None), ('tpu', None), (None, None)] ) + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def testMultiBackendNestedJit(self, ordering): outer, inner = ordering if outer not in ('cpu', jtu.device_under_test(), None): @@ -78,6 +82,8 @@ def infun(x, y): (None, 'cpu'), (None, 'gpu'), (None, 'tpu'), ], ) + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def testMultiBackendNestedJitConflict(self, ordering): outer, inner = ordering if outer not in ('cpu', jtu.device_under_test(), None): @@ -105,6 +111,8 @@ def infun(x, y): self.assertRaises(ValueError, lambda: fun(x, y)) @jtu.sample_product(backend=['cpu', 'gpu', 'tpu']) + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def testGpuMultiBackendOpByOpReturn(self, backend): if backend not in ('cpu', jtu.device_under_test()): raise SkipTest("Backend is not CPU or the device under test") @@ -119,6 +127,8 @@ def fun(x, y): self.assertEqual(list(w.devices())[0].platform, backend) @jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def testJitCpu(self): @partial(jax.jit, backend='cpu') def get_arr(scale): @@ -135,6 +145,8 @@ def get_arr(scale): self.assertEqual(c.devices(), {jax.devices('cpu')[0]}) @jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def test_closed_over_values_device_placement(self): # see https://github.com/google/jax/issues/1431 def f(): return jnp.add(3., 4.) @@ -144,6 +156,8 @@ def f(): return jnp.add(3., 4.) {jax.devices('cpu')[0]}) @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def test_jit_on_nondefault_backend(self): cpus = jax.devices("cpu") self.assertNotEmpty(cpus) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index 69be62f52bab..760e340815af 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -28,7 +28,6 @@ import jax from jax._src import core from jax._src import distributed -from jax._src import maps from jax._src import test_util as jtu from jax._src import util from jax.experimental import pjit @@ -267,6 +266,7 @@ def test_gpu_mpi4py_distributed_initialize(self): os.environ.get("SLURM_JOB_NUM_NODES", None) != "2", "Slurm environment with at least two nodes needed!") @jtu.pytest_mark_if_available('SlurmMultiNodeGpuTest') +@jtu.with_config(experimental_xmap_spmd_lowering=True) class SlurmMultiNodeGpuTest(jtu.JaxTestCase): def sorted_devices(self): @@ -299,16 +299,6 @@ def create_2d_non_contiguous_mesh(self): ] == [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15] return jax.sharding.Mesh(device_mesh, ("x", "y")) - def setUp(self): - super().setUp() - self.xmap_spmd_lowering_enabled = maps.SPMD_LOWERING.value - jax.config.update("experimental_xmap_spmd_lowering", True) - - def tearDown(self): - jax.config.update("experimental_xmap_spmd_lowering", - self.xmap_spmd_lowering_enabled) - super().tearDown() - def test_gpu_multi_node_initialize_and_psum(self): # Hookup the ENV vars expected to be set already in the SLURM environment diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index 5f6dc95b9a5b..993c729f01d8 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -28,7 +28,7 @@ def _get_hlo(f): def wrapped(*args, **kwargs): - c = jax.xla_computation(f)(*args, **kwargs) + c = jax.jit(f).lower(*args, **kwargs).compiler_ir('hlo') print_opts = xla_client._xla.HloPrintOptions.short_parsable() print_opts.print_metadata = True return c.as_hlo_module().to_string(print_opts) @@ -215,7 +215,7 @@ def f(x): self.assertIn('transpose(jvp(foo))/mul', hlo_text) def test_grad_should_add_jvp_and_transpose_to_call_jaxpr(self): - @jax.grad + @jax.value_and_grad @jax.named_scope('foo') @jax.jit def f(x): @@ -240,7 +240,7 @@ def f(x): def test_nested_jit_stack(self): - @jax.grad + @jax.value_and_grad @jax.jit def f(x): @jax.jit @@ -254,7 +254,7 @@ def g(y): self.assertIn('transpose(jvp(jit(f)))/jit(g)/mul', hlo_text) def test_nested_pjit_stack(self): - @jax.grad + @jax.value_and_grad @pjit def f(x): @pjit @@ -497,16 +497,19 @@ def false_fn(x): def test_grad_of_cond_transforms_name_stack(self): - @jax.grad + @jax.value_and_grad @jax.named_scope('foo') def f(x, y): @jax.named_scope('true') def true_fn(x): return x * x * 2. + @jax.named_scope('false') def false_fn(x): return x / jnp.square(x) + return lax.cond(y, true_fn, false_fn, x) + jaxpr = jax.make_jaxpr(f)(1., True) self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'jvp(foo)') self.assertEqual(str(jaxpr.eqns[2].source_info.name_stack), @@ -529,7 +532,7 @@ def false_fn(x): def test_vmap_of_grad_of_cond_transforms_name_stack(self): @functools.partial(jax.vmap, in_axes=(0, None)) - @jax.grad + @jax.value_and_grad @jax.named_scope('foo') def f(x, y): @jax.named_scope('true') diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 441caceccc62..c64548876e01 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -58,6 +58,7 @@ jax_test( deps = [ "//jax:pallas", "//jax:pallas_gpu", + "//jax:pallas_gpu_ops", ] + py_deps("absl/testing") + py_deps("numpy"), ) @@ -209,11 +210,12 @@ jax_test( "gpu", ], main = "pallas_pipeline_tpu_test.py", + shard_count = 2, deps = [ "//jax:extend", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ], + ] + py_deps("hypothesis"), ) jax_test( @@ -329,3 +331,32 @@ jax_test( "//jax:pallas_gpu", # build_cleaner: keep ], ) + +jax_test( + name = "export_pallas_test", + srcs = ["export_pallas_test.py"], + config_tags_overrides = { + "gpu_a100_x32": { + "ondemand": False, # Include in presubmit. + }, + }, + disable_configs = [ + "gpu", + "gpu_x32", + "gpu_a100", + "gpu_h100", + "gpu_p100", + "gpu_p100_x32", + "gpu_pjrt_c_api", + ], + enable_configs = [ + "gpu_a100_x32", + ], + tags = [], + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", # build_cleaner: keep + "//jax:pallas_tpu", # build_cleaner: keep + "//jax/experimental/export", + ], +) diff --git a/tests/pallas/all_gather_test.py b/tests/pallas/all_gather_test.py index b151594d3e64..98b3e5b40135 100644 --- a/tests/pallas/all_gather_test.py +++ b/tests/pallas/all_gather_test.py @@ -84,13 +84,14 @@ def _array_dtypes(draw): class AllGatherTest(jtu.JaxTestCase): def setUp(self): - super().setUp() if not jtu.test_device_matches(["tpu"]): self.skipTest("Need TPU devices") if not jtu.is_device_tpu(version=5, variant="e"): # TODO(sharadmv,apaszke): expand support to more versions self.skipTest("Currently only supported on TPU v5e") + super().setUp() + @hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): if jax.device_count() < 2: diff --git a/tests/pallas/export_pallas_test.py b/tests/pallas/export_pallas_test.py new file mode 100644 index 000000000000..3ea293f55c3d --- /dev/null +++ b/tests/pallas/export_pallas_test.py @@ -0,0 +1,64 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test exporting Pallas kernels.""" +import sys + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax import export +# Import mosaic for flag definitions +from jax.experimental import mosaic as _ # noqa: F401 +from jax.experimental import pallas as pl +import numpy as np + + +jax.config.parse_flags_with_absl() + + +class ExportTest(jtu.JaxTestCase): + + def setUp(self): + if sys.platform == "win32": + self.skipTest("Only works on non-Windows platforms") + + super().setUp() + + def test_cross_platform(self): + def add_vectors_kernel(x_ref, y_ref, o_ref): + x, y = x_ref[...], y_ref[...] + o_ref[...] = x + y + + @jax.jit + def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: + return pl.pallas_call(add_vectors_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) + )(x, y) + + a = np.arange(8) + exp = export.export( + add_vectors, + lowering_platforms=["tpu", "cuda"], + )(a, a) + + if (jtu.device_under_test() == "tpu" or + (jtu.device_under_test() == "gpu" and + jtu.is_cuda_compute_capability_at_least("8.0"))): + res = exp.call(a, a) + self.assertAllClose(res, a + a) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/gmm_test.py b/tests/pallas/gmm_test.py index fe8802a92992..be830a6a4473 100644 --- a/tests/pallas/gmm_test.py +++ b/tests/pallas/gmm_test.py @@ -14,7 +14,7 @@ import functools import itertools -from typing import Any, Union +from typing import Any from absl.testing import absltest from absl.testing import parameterized @@ -114,7 +114,7 @@ def random_dense( shape: tuple[int, ...], key: jax.Array, dtype: jnp.dtype, - limit: Union[int, None] = None, + limit: int | None = None, ) -> jnp.ndarray: if limit is None: limit = 1 / np.prod(shape) @@ -191,12 +191,12 @@ def tolerances( class GroupedMatmulTest(jtu.JaxTestCase): def setUp(self): - super().setUp() - self.key = jax.random.PRNGKey(1234) - if not jtu.test_device_matches(["tpu"]): self.skipTest("Test requires TPU device.") + super().setUp() + self.key = jax.random.PRNGKey(1234) + def assert_allclose( self, out: jnp.ndarray, diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py index ab88de607560..95688474a099 100644 --- a/tests/pallas/gpu_attention_test.py +++ b/tests/pallas/gpu_attention_test.py @@ -30,17 +30,18 @@ # pylint: disable=no-value-for-parameter -config.update("jax_traceback_filtering", "off") config.parse_flags_with_absl() +@jtu.with_config(jax_traceback_filtering="off") class DecodeAttentionTest(jtu.JaxTestCase): def setUp(self): - super().setUp() if not jtu.is_cuda_compute_capability_at_least("8.0"): self.skipTest("Fused attention only works on GPUs with capability >= sm80") + super().setUp() + @parameterized.named_parameters(*[ ( f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}", diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 799d052dcdb0..c11fca350d0d 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -19,6 +19,7 @@ import unittest from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import test_util as jtu from jax._src import util @@ -99,25 +100,18 @@ def test_simple_ndindexer(self): def test_invalid_ndindexer(self): indices = (0, 0, 0) shape = (5, 5) - with self.assertRaises(ValueError): + with self.assertRaisesRegex( + ValueError, "`indices` must not be longer than `shape`" + ): _ = NDIndexer.from_indices_shape(indices, shape) - def test_invalid_ndindexer_oob_int(self): - indices = (4, 0) - shape = (3, 5) - with self.assertRaises(ValueError): - _ = NDIndexer.from_indices_shape(indices, shape) - - def test_invalid_ndindexer_oob_slice_start(self): - indices = (slice(3, 2), 0) - shape = (3, 5) - with self.assertRaises(ValueError): - _ = NDIndexer.from_indices_shape(indices, shape) - - def test_invalid_ndindexer_oob_slice_end(self): - indices = (Slice(2, 2), 0) - shape = (3, 5) - with self.assertRaises(ValueError): + @parameterized.parameters( + ((4, 0), (3, 5)), + ((slice(3, 2), 0), (3, 5)), + ((Slice(2, 2), 0), (3, 5)), + ) + def test_invalid_ndindexer_oob(self, indices, shape): + with self.assertRaisesRegex(ValueError, "Out of bound"): _ = NDIndexer.from_indices_shape(indices, shape) def test_ndindexer_with_padding(self): @@ -126,6 +120,12 @@ def test_ndindexer_with_padding(self): indexer = NDIndexer.from_indices_shape(indices, shape) self.assertTupleEqual(indexer.get_indexer_shape(), shape) + def test_ndindexer_with_ellipsis(self): + indices = (..., 4) + shape = (5, 5) + indexer = NDIndexer.from_indices_shape(indices, shape) + self.assertTupleEqual(indexer.get_indexer_shape(), (5,)) + def test_ndindexer_with_slices(self): indices = (slice(2, 3), slice(4, 7)) shape = (5, 6) @@ -154,6 +154,14 @@ def test_ndindexer_with_arrays_and_broadcasting(self): indexer = NDIndexer.from_indices_shape(indices, shape) self.assertTupleEqual(indexer.get_indexer_shape(), (10, 20)) + def test_ndindexer_with_arrays_and_invalid_broadcasting(self): + indices = (np.arange(10)[None], np.arange(20)[None, :]) + shape = (5, 5) + with self.assertRaisesRegex( + ValueError, "Cannot broadcast shapes for indexing" + ): + indexer = NDIndexer.from_indices_shape(indices, shape) + def test_indexer_with_all_types(self): indices = (0, slice(10), np.arange(5)) shape = (2, 3, 4) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index c6e19c39d975..57b38e8dc305 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -13,7 +13,6 @@ # limitations under the License. import functools - from absl.testing import absltest from absl.testing import parameterized import jax @@ -29,11 +28,11 @@ class PallasTest(jtu.JaxTestCase): def setUp(self): - super().setUp() - if not jtu.is_cuda_compute_capability_at_least("9.0"): self.skipTest("Only works on a GPU with capability >= sm90") + super().setUp() + class PallasCallTest(PallasTest): @@ -57,24 +56,30 @@ def test_layer_norm(self, input_factor): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), - compiler_params={"smem_scratch_bytes": 4 * 4} + compiler_params={"smem_scratch_bytes": 4 * 4}, ) def layer_norm(x_ref, o_ref): - o_ref[...] = (x_ref[...] - jnp.mean(x_ref[...], keepdims=True)) * jax.lax.rsqrt( - jnp.var(x_ref[...], keepdims=True) + eps - ) * gamma + beta + x_mean = jnp.mean(x_ref[...]) + x_centered = x_ref[...] - x_mean + o_ref[...] = ( + x_centered * jax.lax.rsqrt(jnp.mean(x_centered**2) + eps) * gamma + + beta + ) def layer_norm_np(x): - return (x - np.mean(x, keepdims=True)) / np.sqrt( - np.var(x, keepdims=True) + eps - ) * gamma + beta + x_mean = np.mean(x) + x_centered = x - x_mean + return (x_centered / np.sqrt(np.mean(x_centered**2) + eps) * gamma) + beta # Ones are always fully precise x = jnp.ones((256,)).astype(jnp.float32) * input_factor np.testing.assert_allclose(layer_norm(x), layer_norm_np(x)) # random (and anything else is not) - x = jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32) * input_factor + x = ( + jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32) + * input_factor + ) # TODO(cperivol): find out why in this particular case we have a small-ish error. rtol = 1e-07 if input_factor > 10 else 5e-5 np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=rtol) @@ -85,10 +90,14 @@ def test_print(self): out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, o_ref): + del x_ref, o_ref pl.debug_print("It works!") x = jnp.arange(256).astype(jnp.float32) - kernel(x) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + + self.assertEqual(output(), "It works!\n") def test_print_with_values(self): @functools.partial( @@ -96,6 +105,7 @@ def test_print_with_values(self): out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, o_ref): + del o_ref pl.debug_print("x[0] = {}", x_ref[0]) x = jnp.arange(256).astype(jnp.float32) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index db447fafb06a..78679a92b0fe 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -26,6 +26,9 @@ from jax._src import test_util as jtu from jax.experimental import pallas as pl +# Import mosaic for flag definitions +from jax.experimental import mosaic as _ # noqa: F401 + jax.config.parse_flags_with_absl() @@ -34,7 +37,6 @@ class OpsTest(jtu.JaxTestCase): INTERPRET = False def setUp(self): - super().setUp() if jax.config.x64_enabled: self.skipTest("Only works in 32-bit") if not self.INTERPRET: @@ -44,6 +46,8 @@ def setUp(self): not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPUs with capability >= sm80") + super().setUp() + @classmethod def pallas_call(cls, *args, **kwargs): return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) diff --git a/tests/pallas/paged_attention_kernel_test.py b/tests/pallas/paged_attention_kernel_test.py index c219c398b334..b1dcb2fab5f8 100644 --- a/tests/pallas/paged_attention_kernel_test.py +++ b/tests/pallas/paged_attention_kernel_test.py @@ -104,12 +104,8 @@ def _megacore_enabled(): ) +@jtu.with_config(jax_numpy_dtype_promotion="standard") class PagedAttentionKernelTest(jtu.JaxTestCase): - - def setUp(self): - super().setUp() - jax.config.update("jax_numpy_dtype_promotion", "standard") - @parameterized.product( dtype=(jnp.float32, jnp.bfloat16), page_size=(16, 32, 64), diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index 7c9e1baa1388..b09d4b3073c1 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -14,12 +14,16 @@ """Test TPU-specific extensions to pallas_call.""" +import contextlib import functools +import io import re +import sys from absl.testing import absltest from absl.testing import parameterized import jax from jax import lax +from jax._src import checkify from jax._src import state from jax._src import test_util as jtu from jax._src.interpreters import partial_eval as pe @@ -41,15 +45,25 @@ partial = functools.partial +@contextlib.contextmanager +def string_stdout(): + """Redirects stdout to a string.""" + initial_stdout = sys.stdout + stringio = io.StringIO() + sys.stdout = stringio + yield stringio + sys.stdout = initial_stdout + class PallasTPUTest(jtu.JaxTestCase): interpret: bool = False def setUp(self): - super().setUp() if not self.interpret and jtu.device_under_test() != 'tpu': self.skipTest('Only interpret mode supported on non-TPU') + super().setUp() + def pallas_call(self, *args, **kwargs): return pl.pallas_call(*args, **kwargs, interpret=self.interpret) @@ -285,34 +299,40 @@ def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32) - x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) + x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): s = pl.load(s_ref, (i,)) return (s, 0) s = jnp.tile(s[None], [2, 1]) - x = jnp.tile(x[None], [2, 1, 1]) - - with self.assertRaises(NotImplementedError): - jax.vmap( - pl.pallas_call( - body, - out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - in_specs=[ - pl.BlockSpec(_x_transform, (x.shape[1] // 8, x.shape[2])), - ], - out_specs=pl.BlockSpec( - lambda i, _: (i, 0), (x.shape[1] // 8, x.shape[2]) - ), - grid=8, + + @jax.jit + @jax.vmap + def kernel(s, x): + return pl.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec(_x_transform, (x.shape[0] // 8, x.shape[1])), + ], + out_specs=pl.BlockSpec( + lambda i, _: (i, 0), (x.shape[0] // 8, x.shape[1]) ), - interpret=self.interpret, - ) + grid=8, + ), + interpret=self.interpret, + compiler_params=dict(mosaic=dict(allow_input_fusion=[False, True])), )(s, x) + first = x[0, ...].reshape((1, 8, 8, -1))[:, s[0, ...]].reshape(x.shape[1:]) + second = x[1, ...].reshape((1, 8, 8, -1))[:, s[1, ...]].reshape(x.shape[1:]) + + expected = jnp.stack([first, second]) + np.testing.assert_allclose(kernel(s, x), expected) + class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest): interpret: bool = True @@ -320,6 +340,46 @@ class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest): class PallasCallDynamicGridTest(PallasTPUTest): + def test_can_query_grid_statically_via_num_programs(self): + + def kernel(_): + num_programs = pl.num_programs(0) + self.assertIsInstance(num_programs, int) + self.assertEqual(num_programs, 2) + + pl.pallas_call(kernel, out_shape=None, grid=(2,))() + + def test_can_query_grid_statically_via_num_programs_in_block_spec(self): + + def kernel(*_): + pass + + def x_index_map(_): + num_programs = pl.num_programs(0) + self.assertIsInstance(num_programs, int) + self.assertEqual(num_programs, 2) + return 0 + pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(x_index_map, (8, 128))], + out_shape=None, + grid=(2,), + )(jnp.ones((8, 128))) + + def test_dynamic_grid_has_dynamic_size(self): + + def kernel(_): + num_programs = pl.num_programs(0) + self.assertIsInstance(num_programs, int, msg=type(num_programs)) + self.assertEqual(num_programs, 2) + num_programs = pl.num_programs(1) + self.assertIsInstance(num_programs, jax.Array) + + @jax.jit + def outer(x): + pl.pallas_call(kernel, out_shape=None, grid=(2, x))() + outer(2) + def test_dynamic_grid(self): shape = (8, 128) result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) @@ -404,6 +464,7 @@ def dynamic_kernel(steps, x): return self.pallas_call( kernel, grid=(steps * 2,), + in_specs=[pl.BlockSpec(lambda i: (0, 0), shape)], out_specs=pl.BlockSpec(lambda i: (0, 0), shape), out_shape=result_ty, )(x) @@ -433,8 +494,11 @@ def dynamic_kernel(steps): out_specs=pl.BlockSpec(lambda i: (0, 0), shape), out_shape=result_ty, )() - with self.assertRaises(NotImplementedError): - dynamic_kernel(jnp.array([4, 8], jnp.int32)) + out = dynamic_kernel(jnp.array([4, 8], jnp.int32)) + first = jnp.full(shape, fill_value=8.0, dtype=jnp.float32) + second = jnp.full(shape, fill_value=16.0, dtype=jnp.float32) + expected_out = jnp.stack([first, second], axis=0) + np.testing.assert_array_equal(out, expected_out) def test_vmap_dynamic_grid(self): shape = (8, 128) @@ -473,7 +537,7 @@ def dynamic_kernel(steps): out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int32), )() - self.assertEqual(dynamic_kernel(4), 8) + self.assertEqual(dynamic_kernel(np.int32(4)), 8) @parameterized.parameters(range(1, 4)) def test_vmap_num_programs(self, num_vmaps): @@ -517,27 +581,28 @@ def dynamic_kernel(steps, x): )(x) x = np.arange(4 * 8 * 128., dtype=np.int32).reshape((4 * 8, 128)) - np.testing.assert_array_equal(dynamic_kernel(4, x), x[8:16]) + np.testing.assert_array_equal(dynamic_kernel(np.int32(4), x), x[8:16]) class PallasCallInterpretDynamicGridTest(PallasCallDynamicGridTest): interpret: bool = True -class PallasCallDMATest(parameterized.TestCase): +class PallasCallDMATest(PallasTPUTest): def setUp(self): - super().setUp() if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs not supported on TPU generations <= 3') + super().setUp() + def test_can_have_unspecified_memory_spaces(self): def kernel(x_ref, y_ref): # Just test whether things compile del x_ref, y_ref x = jnp.ones((8, 128), dtype=jnp.float32) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[pl.BlockSpec(None, None, pltpu.TPUMemorySpace.ANY)], out_specs=pl.BlockSpec(None, None, pltpu.TPUMemorySpace.ANY), @@ -572,7 +637,7 @@ def body(x_ref): pltpu.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32)) - o = pl.pallas_call( + o = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )() @@ -589,7 +654,7 @@ def inner_body(z_ref): y_ref[...] = 4 * x_ref[...] pltpu.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32)) - o = pl.pallas_call( + o = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )() @@ -601,7 +666,7 @@ def body(sem1): pass pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) - jax.block_until_ready(pl.pallas_call( + jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) @@ -613,7 +678,7 @@ def body(sem1, sem2): pltpu.run_scoped(body, pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.REGULAR) - jax.block_until_ready(pl.pallas_call( + jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) @@ -628,7 +693,7 @@ def body(dma_sems, sems): pltpu.run_scoped(body, pltpu.SemaphoreType.DMA((4,)), pltpu.SemaphoreType.REGULAR((3,))) - jax.block_until_ready(pl.pallas_call( + jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) @@ -640,6 +705,7 @@ def kernel(y_ref, dma_sems, sems): self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) + # TODO(b/345534352): Add interpret support for REGULAR semaphore. jax.block_until_ready( pl.pallas_call( kernel, @@ -674,6 +740,7 @@ def body3(sem): pltpu.semaphore_wait(sem) pltpu.run_scoped(body3, pltpu.SemaphoreType.REGULAR) + # TODO(b/345534352): Add interpret support for semaphore signal/wait. jax.block_until_ready(pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), @@ -697,6 +764,7 @@ def body(sems): pltpu.semaphore_wait(sems.at[2]) pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR((3,))) + # TODO(b/345534352): Add interpret support for semaphore signal/wait. jax.block_until_ready(pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), @@ -721,6 +789,7 @@ def body(sems): pltpu.semaphore_wait(sems.at[i, 2]) pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR((4, 3))) + # TODO(b/345534352): Add interpret support for semaphore signal/wait. jax.block_until_ready(pl.pallas_call( kernel, in_specs=[], @@ -744,6 +813,7 @@ def body(sems): pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR((m, n))) + # TODO(b/345534352): Add interpret support for semaphore signal/wait. y = jax.block_until_ready( pl.pallas_call( kernel, @@ -762,7 +832,7 @@ def body(sem): sem).wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -778,6 +848,8 @@ def body(sem): pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem).wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) + + # TODO(b/345534352): Add interpret support for nonscalar semaphores. with self.assertRaisesRegex(ValueError, 'Cannot signal'): x = jnp.arange(8 * 128.).reshape((8, 128)) pl.pallas_call( @@ -796,6 +868,8 @@ def body(sem): sem.at[0]).wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) x = jnp.arange(8 * 128.).reshape((8, 128)) + + # TODO(b/345534352): Add interpret support for nonscalar semaphores. y = pl.pallas_call( kernel, in_specs=[ @@ -817,7 +891,7 @@ def body(sem): ).wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -837,7 +911,7 @@ def body(x_ref, sem): pltpu.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -854,7 +928,7 @@ def body(y_ref, sem): pltpu.run_scoped(body, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), @@ -872,7 +946,7 @@ def body(x_ref, y_ref, sem): pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -888,7 +962,7 @@ def body(x_ref, sem): pltpu.run_scoped(body, pltpu.SMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA) x = 4 * jnp.ones((8, 128), jnp.float32) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -906,7 +980,7 @@ def body(y_ref, sem): pltpu.run_scoped(body, pltpu.SMEM((1, 2), jnp.float32), pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), @@ -923,7 +997,7 @@ def body(sem): pltpu.async_copy(x_ref, y_ref, sem).wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), @@ -946,7 +1020,7 @@ def body(sem): dma2.wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((16, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -969,7 +1043,7 @@ def body(sem): dma2.wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -995,7 +1069,7 @@ def body(sem): dma2.wait() pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(3 * 2 * 8 * 128.).reshape((3, 2, 8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -1019,7 +1093,7 @@ def body(sem): pltpu.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(2 * 8 * 128.).reshape((2, 8, 128)) with self.assertRaises(Exception): - _ = pl.pallas_call( + _ = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -1091,6 +1165,7 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): pltpu.semaphore_wait(sem) pltpu.async_copy(x_bbm_ref, y_ref, dma_sem).wait() + # TODO(b/345534352): Add interpret support for semaphore signal/wait. x = jnp.arange(8 * 128.).reshape((8, 128)) y = pl.pallas_call( kernel, @@ -1115,7 +1190,7 @@ def test_large_array_indexing(self): def kernel(index, x, y, sem): pltpu.async_copy(x.at[index[0]], y.at[:], sem).wait() - run = pl.pallas_call(kernel, + run = self.pallas_call(kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ @@ -1133,16 +1208,16 @@ def kernel(index, x, y, sem): np.testing.assert_array_equal(y, i) del y - class PallasCallRemoteDMATest(parameterized.TestCase): def setUp(self): - super().setUp() if jax.device_count() < 2: self.skipTest('Only >=2 devices are supported.') if not jtu.is_device_tpu_at_least(5): self.skipTest('Only works with TPU v5') + super().setUp() + @parameterized.named_parameters( ('vmem', pltpu.TPUMemorySpace.VMEM), ('hbm', pltpu.TPUMemorySpace.ANY), @@ -1342,10 +1417,11 @@ def body(x): class PallasCallTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + def test_cost_analysis(self): def kernel(x, y): y[:] = x[:] @@ -1411,14 +1487,36 @@ def f(x, y): compiled = jax.jit(f).lower(x, y).compile().as_text() assert re.search(r'fusion.*kind=kCustom.*fused_computation', compiled) + def test_set_internal_scratch_size(self): + shape = (128, 128) + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + + x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + requested_bytes = 128 * 4 + with self.assertRaisesRegex( + Exception, + f'Requested internal scratch size {requested_bytes} needs to be at' + ' least', + ): + pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + compiler_params=dict( + mosaic=dict(internal_scratch_in_bytes=requested_bytes) + ), + )(x) + class PallasCallUnblockedIndexingTest(PallasTPUTest): def setUp(self): - super().setUp() if not self.interpret and jtu.device_under_test() != 'tpu': self.skipTest('Only interpret mode supported on non-TPU') + super().setUp() + def test_unblocked_indexing(self): shape = (16 * 8, 128) result_ty = jax.ShapeDtypeStruct((15 * 8, 128), jnp.float32) @@ -1479,10 +1577,11 @@ class PallasCallInterpreterUnblockedIndexingTest( class PallasUXTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + def test_mlir_location(self): # Make sure that MLIR locations are correctly propagated to primitives. args = (jax.ShapeDtypeStruct((8, 128), jnp.float32),) @@ -1502,10 +1601,11 @@ def capture_as_tpu_kernel(module, *args, **kwargs): class PallasCallInputOutputAliasingTest(PallasTPUTest): def setUp(self): - super().setUp() if not self.interpret and jtu.device_under_test() != 'tpu': self.skipTest('Only interpret mode supported on non-TPU') + super().setUp() + def test_basic_input_output_aliasing(self): # Input needs to be big so it doesn't fit in VMEM x = jnp.ones((32, 1024, 1024)) @@ -1568,10 +1668,11 @@ class PallasCallInterpreterInputOutputAliasingTest(PallasTPUTest): class PallasMegacoreTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + def test_megacore_splitting(self): # We want to make sure a 3-sized dimension is split across megacore # correctly, and if we combine the (3, 3) dimensions together it is still @@ -1606,10 +1707,11 @@ def _(): class PallasCallVmapTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + def test_scratch_input_vmap(self): """Test that vmapp-ing a kernel with scratch inputs works correctly.""" @@ -1648,10 +1750,11 @@ def add_one_with_scratch(x_ref, o_ref, scratch_ref): class PallasCallControlFlowTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + def test_nested_conds(self): def kernel(y_ref): def select(pred, x, y, nesting=0): @@ -1683,10 +1786,11 @@ def _false(): class PallasCallWhileLoopTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + def test_range_while_loop(self): """Tests lowering of a while_loop which can reduce to a fori_loop.""" @@ -1766,7 +1870,7 @@ def cond(state): def body(state): i, s = state - sl = sl = jax.lax.div(i, 128) + sl = jax.lax.div(i, 128) l = jax.lax.rem(i, 128) v = pl.load(x_ref, (0, sl, l)) return i + 1, s + v @@ -1877,7 +1981,7 @@ def kernel(in_key_ref, out_segment_count, out_size_ref, key_count): def inner_cond(carry): i, prev_key = carry - sl = sl = jax.lax.div(i, 128) + sl = jax.lax.div(i, 128) l = jax.lax.rem(i, 128) key = jax.lax.cond( i < key_count, lambda i: in_key_ref[sl, l], lambda i: -1, i @@ -1894,12 +1998,12 @@ def outer_cond(carry): def outer_body(carry): i, next_out_idx = carry - sl = sl = jax.lax.div(i, 128) + sl = jax.lax.div(i, 128) l = jax.lax.rem(i, 128) key = in_key_ref[sl, l] end, _ = jax.lax.while_loop(inner_cond, inner_body, (i + 1, key)) - sl = sl = jax.lax.div(next_out_idx, 128) + sl = jax.lax.div(next_out_idx, 128) l = jax.lax.rem(next_out_idx, 128) out_size_ref[sl, l] = end - i return end, next_out_idx + 1 @@ -1949,10 +2053,11 @@ def outer_body(carry): class PallasCallReductionTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + def test_integer_sum(self): def kernel(x_ref, o_ref): x = x_ref[:] @@ -2004,10 +2109,11 @@ def kernel(x_ref, o_ref): class PallasCallDynamicDMATest(PallasTPUTest): def setUp(self): - super().setUp() if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs not supported on TPU generations <= 3') + super().setUp() + def test_simple_tile_aligned_dynamic_size_dma(self): def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): @@ -2068,10 +2174,11 @@ def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): class PallasCallComparisonTest(PallasTPUTest): def setUp(self): - super().setUp() if jtu.device_under_test() != 'tpu': self.skipTest('Test only works on TPU') + super().setUp() + @parameterized.named_parameters( ('integer_1_1', (1, 1)), ('integer_1_16', (1, 16)), @@ -2208,7 +2315,12 @@ def kernel(x_ref, o_ref): pl.debug_print('It works!') x = jnp.array([4.2, 2.4]).astype(jnp.float32) - kernel(x) + compiled_kernel = ( + jax.jit(kernel) + .lower(x) + .compile({'xla_tpu_enable_log_recorder': 'true'}) + ) + compiled_kernel(x) def test_debug_print_with_values(self): @functools.partial( @@ -2228,5 +2340,271 @@ def kernel(x_ref, o_ref): compiled_kernel(x) +class PallasCallTPUInterpretTest(PallasTPUTest): + + def test_local_dma(self): + def test_kernel(x_ref, + o_ref, + copy_sem, + ): + o_ref[...] = jnp.zeros_like(o_ref[...]) + input_to_output_copy = pltpu.make_async_copy( + src_ref=x_ref.at[0:8], + dst_ref=o_ref.at[0:8], + sem=copy_sem, + ) + input_to_output_copy.start() + input_to_output_copy.wait() + + out_shape = (jax.ShapeDtypeStruct((9, 128), jnp.float32)) + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + scratch_shapes=( + [pltpu.SemaphoreType.DMA] + ) + ) + + kernel = pl.pallas_call( + test_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + interpret=True + ) + x = jax.random.normal(jax.random.key(0), shape=(16, 128)) + result = kernel(x) + np.testing.assert_array_equal(result[0:8], x[0:8]) + np.testing.assert_array_equal(result[8:], jnp.zeros_like(result[8:])) + + @parameterized.parameters(('left',), ('right',)) + def test_remote_dma_ppermute(self, permutation): + if jax.device_count() <= 1: + self.skipTest('Test requires multiple devices.') + num_devices = jax.device_count() + if permutation == 'left': + permute_fn = lambda x: lax.rem(x + num_devices - 1, num_devices) + else: + permute_fn = lambda x: lax.rem(x + num_devices + 1, num_devices) + + # Construct a kernel which performs a ppermute based on permute_fn. + def test_kernel(x_ref, + o_ref, + copy_send_sem, + copy_recv_sem, + ): + o_ref[...] = jnp.zeros_like(o_ref[...]) + my_id = lax.axis_index('x') + dst_device = permute_fn(my_id) + input_to_output_copy = pltpu.make_async_remote_copy( + src_ref=x_ref, + dst_ref=o_ref, + send_sem=copy_send_sem, + recv_sem=copy_recv_sem, + device_id=dst_device, + device_id_type=pltpu.DeviceIdType.LOGICAL, + ) + input_to_output_copy.start() + input_to_output_copy.wait() + + out_shape = (jax.ShapeDtypeStruct((8, 128), jnp.float32)) + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + scratch_shapes=( + [pltpu.SemaphoreType.DMA] * 2 + ) + ) + + devices = mesh_utils.create_device_mesh((1, num_devices)) + mesh = jax.sharding.Mesh(devices, P(None, 'x')) + sharding = jax.sharding.NamedSharding(mesh, P(None, 'x')) + unsharded_arr = jax.random.normal( + jax.random.key(0), shape=(8, 128 * num_devices)) + sharded_arr = jax.device_put(unsharded_arr, sharding) + + kernel = pl.pallas_call( + test_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + interpret=True + ) + compiled_func = jax.jit(shard_map.shard_map( + kernel, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P(None, 'x'), + check_rep=False)) + result = compiled_func(sharded_arr) + + perm = tuple((src, permute_fn(src)) for src in range(num_devices)) + perm = jax.tree_util.tree_map(int, perm) + def lax_permute(x): + return lax.ppermute(x, 'x', perm) + expected = jax.jit(shard_map.shard_map(lax_permute, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P(None, 'x')))(sharded_arr) + np.testing.assert_array_equal(result, expected) + + +class PallasCallTraceTest(PallasTPUTest): + interpret: bool = False + + def parse_debug_string(self, debug_string): + jaxpr, mlir = debug_string.split('module') + return {'jaxpr': jaxpr, 'mlir': mlir} + + def test_trace_start_stop_match(self): + def kernel(o_ref): + with jax.named_scope('scope1'): + o_ref[...] = jnp.zeros_like(o_ref[...]) + + with string_stdout() as msg: + _ = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + debug=True, + )() + # TODO(justinfu): Add an official lowering API to get the MLIR. + mlir = self.parse_debug_string(msg.getvalue())['mlir'] + + num_start = mlir.count('tpu.trace_start') + num_stop = mlir.count('tpu.trace_stop') + self.assertEqual(num_start, 1) + self.assertEqual(num_stop, 1) + + def test_run_scoped(self): + def kernel(o_ref): + def scope1(): + with jax.named_scope('scope1'): + o_ref[...] = jnp.zeros_like(o_ref[...]) + pltpu.run_scoped(scope1) + + def scope2(): + with jax.named_scope('scope2'): + o_ref[...] = o_ref[...] + 1 + pltpu.run_scoped(scope2) + + with string_stdout() as msg: + _ = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + debug=True, + )() + # TODO(justinfu): Add an official lowering API to get the MLIR. + mlir = self.parse_debug_string(msg.getvalue())['mlir'] + + num_start = mlir.count('tpu.trace_start') + num_stop = mlir.count('tpu.trace_stop') + self.assertEqual(num_start, 2) + self.assertEqual(num_stop, 2) + + +class PallasCallTPUCheckifyTest(PallasTPUTest): + interpret: bool = True + + @parameterized.parameters((2,), (5,), (6,), (7,)) + def test_checkify_with_scalar_prefetch(self, threshold): + def body(scalar_ref, x_ref, o_ref): + scalar = scalar_ref[pl.program_id(0)] + o_ref[...] = x_ref[...] + checkify.check(scalar < threshold, 'failed on value {x}', x=scalar) + + s = jnp.array([4, 3, 2, 6, 3, 5, 2, 7], jnp.int32) + x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) + + def _x_transform(i, s_ref): + s = pl.load(s_ref, (i,)) + return (s, 0) + + pallas_call = self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec(_x_transform, (x.shape[0] // 8, x.shape[1])), + ], + out_specs=pl.BlockSpec(lambda i, _: (i, 0), + (x.shape[0] // 8, x.shape[1])), + grid=8, + ), + ) + checked_call = checkify.checkify(pallas_call) + err, out = checked_call(s, x) + expected_error_value = s[jnp.argmax(s >= threshold)] + with self.assertRaisesRegex( + checkify.JaxRuntimeError, f'failed on value {expected_error_value}'): + err.throw() + np.testing.assert_allclose(out, x.reshape((8, 8, -1))[s].reshape(x.shape)) + + def test_checkify_with_scratch(self): + def body(x_ref, o_ref, scratch_ref): + scratch_ref[...] = x_ref[...] + o_ref[...] = scratch_ref[...] + all_nequal = ~jnp.all(o_ref[...] == x_ref[...]) + checkify.check(all_nequal, 'x_ref equals o_ref id=({x}, {y})', + x=pl.program_id(0), y=pl.program_id(1)) + + x = jax.random.uniform(jax.random.key(0), (128, 128), dtype=jnp.float32) + pallas_call = self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(lambda i, j: (i, j), (32, 32)), + ], + out_specs=pl.BlockSpec(lambda i, j: (i, j), (32, 32)), + scratch_shapes=[pltpu.VMEM((32, 32), dtype=jnp.float32)], + grid=(4, 4), + ), + ) + checked_call = checkify.checkify(pallas_call) + err, out = checked_call(x) + with self.assertRaisesRegex( + checkify.JaxRuntimeError, r'x_ref equals o_ref id=\(0, 0\)'): + err.throw() + np.testing.assert_allclose(out, x) + + @parameterized.parameters((4,), (9,)) + def test_checkify_with_dynamic_grid(self, iteration): + grid_size = 4 + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) + + def kernel(y_ref): + @pl.when(pl.program_id(0) == 0) + def _init(): + y_ref[...] = jnp.zeros_like(y_ref) + y_ref[...] += 1 + @pl.when(pl.program_id(0) == iteration) + def _(): + checkify.check(False, f"error on iteration {iteration}") + + @jax.jit + def dynamic_kernel(steps): + pallas_call = self.pallas_call( + kernel, + grid=(steps * 2,), + out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_shape=result_ty, + ) + return checkify.checkify(pallas_call)() + + err, result = dynamic_kernel(jnp.int32(grid_size)) + if iteration < grid_size * 2: + with self.assertRaisesRegex( + checkify.JaxRuntimeError, f"error on iteration {iteration}"): + err.throw() + np.testing.assert_array_equal( + result, np.full(shape, grid_size * 2.0, np.float32) + ) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/pallas_pipeline_tpu_test.py b/tests/pallas/pallas_pipeline_tpu_test.py index ee73de6afaa7..f46c1742b63b 100644 --- a/tests/pallas/pallas_pipeline_tpu_test.py +++ b/tests/pallas/pallas_pipeline_tpu_test.py @@ -26,6 +26,24 @@ from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np +try: + import hypothesis as hp + import hypothesis.strategies as hps + CAN_USE_HYPOTHESIS = True +except (ModuleNotFoundError, ImportError): + CAN_USE_HYPOTHESIS = False + + +if CAN_USE_HYPOTHESIS: + hp.settings.register_profile( + 'deterministic', + database=None, + derandomize=True, + deadline=None, + max_examples=200, + print_blob=True, + ) + hp.settings.load_profile('deterministic') jax.config.parse_flags_with_absl() @@ -55,20 +73,52 @@ def basic_matmul_kernel( out_ref, acc_scratch_ref, *, - acc_steps: int, + k: int, ): - @pl.when(pl.program_id(2) == 0) + k_index = pl.program_id(2) + num_k = pl.num_programs(2) + bk = lhs_ref.shape[1] + @pl.when(k_index == 0) def _zero_acc(): acc_scratch_ref[...] = jnp.zeros( acc_scratch_ref.shape, acc_scratch_ref.dtype) - acc_scratch_ref[...] += jnp.dot( - lhs_ref[...], - rhs_ref[...], - preferred_element_type=acc_scratch_ref.dtype, - ) + divisible_k = k % bk == 0 + if divisible_k: + acc_scratch_ref[...] += jnp.dot( + lhs_ref[...], + rhs_ref[...], + preferred_element_type=acc_scratch_ref.dtype, + ) + else: + def _last_block(): + accum_dtype = acc_scratch_ref.dtype + lhs_mask = ( + k_index * bk + jax.lax.broadcasted_iota(jnp.int32, lhs_ref.shape, 1) + < k + ) + rhs_mask = ( + k_index * bk + jax.lax.broadcasted_iota(jnp.int32, rhs_ref.shape, 0) + < k + ) + dtype = lhs_ref.dtype + lhs = lhs_ref[...].astype(accum_dtype) + lhs = jnp.where(lhs_mask, lhs, 0).astype(dtype) + rhs = rhs_ref[...].astype(accum_dtype) + rhs = jnp.where(rhs_mask, rhs, 0).astype(dtype) + acc_scratch_ref[...] += jnp.dot( + lhs, rhs, preferred_element_type=acc_scratch_ref.dtype) + def _not_last_block(): + acc_scratch_ref[...] += jnp.dot( + lhs_ref[...], + rhs_ref[...], + preferred_element_type=acc_scratch_ref.dtype, + ) + jax.lax.cond( + k_index == num_k - 1, _last_block, _not_last_block + ) - @pl.when(pl.program_id(2) == acc_steps - 1) + @pl.when(k_index == num_k - 1) def _reduce_out(): out_ref[...] = acc_scratch_ref[...].astype(out_ref.dtype) @@ -76,12 +126,13 @@ def _reduce_out(): class PallasCallPipelineTest(parameterized.TestCase): def setUp(self): - super().setUp() if jax.device_count() < 2: self.skipTest('Only >=2 devices are supported.') if not jtu.is_device_tpu_at_least(5): self.skipTest('Only works with TPU v5') + super().setUp() + @parameterized.named_parameters( ('vmem', pltpu.TPUMemorySpace.VMEM), ('hbm', pltpu.TPUMemorySpace.ANY), @@ -173,15 +224,16 @@ def emit_pipeline(should_accumulate_out): np.testing.assert_allclose(z, jnp.dot(x, y) + jnp.dot(x, y)) -class PallasCallColectivePipelineTest(parameterized.TestCase): +class PallasCallCollectivePipelineTest(parameterized.TestCase): def setUp(self): - super().setUp() if jax.device_count() < 2: self.skipTest('Only >=2 devices are supported.') if not jtu.is_device_tpu_at_least(5): self.skipTest('Only works with TPU v5') + super().setUp() + @parameterized.named_parameters( ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), @@ -206,9 +258,8 @@ def test_pipeline_latency_optimized_allgather_matmul( sharded_k = k // num_devices inner_grid = (n // tn, m // tm, sharded_k // tk) - acc_steps = (sharded_k // tk) - inner_kernel = partial(basic_matmul_kernel, acc_steps=acc_steps) + inner_kernel = partial(basic_matmul_kernel, k=sharded_k) inner_allocs = [ pltpu.BufferedRef.input( @@ -504,9 +555,8 @@ def test_pipeline_throughput_optimized_allgather_matmul( sharded_k = k // num_devices half_m = m // 2 inner_grid = (n // tn, half_m // tm, sharded_k // tk) - acc_steps = (sharded_k // tk) - inner_kernel = partial(basic_matmul_kernel, acc_steps=acc_steps) + inner_kernel = partial(basic_matmul_kernel, k=sharded_k) inner_allocs = [ pltpu.BufferedRef.input( @@ -745,10 +795,9 @@ def test_pipeline_latency_optimized_matmul_reducescatter( sharded_k = k // num_devices inner_grid = (n // tn, sharded_m // tm, sharded_k // tk) outer_steps = num_devices // 2 - acc_steps = sharded_k // tk reduce_grid = (sharded_m // tm,) - inner_kernel = partial(basic_matmul_kernel, acc_steps=acc_steps) + inner_kernel = partial(basic_matmul_kernel, k=sharded_k) def reduce_kernel( out_ref, # [tm, tn] @@ -1030,9 +1079,7 @@ def test_pipeline_throughput_optimized_matmul_reducescatter( sharded_k = k // num_devices inner_grid = (n // tn, half_m // tm, sharded_k // tk) outer_steps = num_devices - acc_steps = sharded_k // tk - - inner_kernel = partial(basic_matmul_kernel, acc_steps=acc_steps) + inner_kernel = partial(basic_matmul_kernel, k=sharded_k) inner_allocs = [ pltpu.BufferedRef.input( @@ -1261,5 +1308,212 @@ def reference(x, y): ) +class PallasCallMegacoreTest(parameterized.TestCase): + + def setUp(self): + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only works with TPU v4') + + super().setUp() + + def test_can_partition_nondivisible_grid_with_dynamic_dimensions(self): + + def mul_pipeline(x_ref, y_ref): + y_ref[...] = x_ref[...] * 2 + + def mul_kernel(iters_ref, x_ref, y_ref): + pltpu.emit_pipeline( + mul_pipeline, + grid=(iters_ref[0], 5), + in_specs=[ + pl.BlockSpec(lambda i, j: (i, j), (128, 128)), + ], + out_specs=pl.BlockSpec(lambda i, j: (i, j), (128, 128)), + core_axis=0, + dimension_semantics=(pltpu.PARALLEL, pltpu.PARALLEL), + )(x_ref, y_ref) + + num_cores = jax.devices()[0].num_cores + func = pl.pallas_call( + mul_kernel, + out_shape=jax.ShapeDtypeStruct((640, 640), jnp.float32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + ), + compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + ) + x = jax.random.uniform(jax.random.key(0), (640, 640)) + np.testing.assert_allclose(func(jnp.array([5]), x), x * 2) + + def test_megacore_mul(self): + x = jax.random.uniform(jax.random.key(0), (512, 512)) + + def matmul_pipeline(x_ref, y_ref): + y_ref[...] = x_ref[...] * 2 + + def matmul_kernel(x_ref, y_ref): + pltpu.emit_pipeline( + matmul_pipeline, + grid=(4, 4), + in_specs=[ + pl.BlockSpec(lambda i, j: (i, j), (128, 128)), + ], + out_specs=pl.BlockSpec(lambda i, j: (i, j), (128, 128)), + core_axis=0, + dimension_semantics=(pltpu.ARBITRARY, pltpu.PARALLEL) + )(x_ref, y_ref) + + num_cores = jax.devices()[0].num_cores + func = pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + ) + np.testing.assert_allclose(func(x), x * 2) + + @parameterized.parameters( + (1024, 1024, 1024, 256, 512, 256), + (768, 1024, 1024, 256, 512, 256), + (1024, 1024, 768, 256, 512, 256), + (768, 1024, 768, 256, 512, 256), + ) + def test_megacore_matmul(self, m, k, n, bm, bk, bn): + k1, k2 = jax.random.split(jax.random.key(42)) + x = jax.random.uniform(k1, (m, k)) + y = jax.random.uniform(k2, (k, n)) + + def matmul_pipeline(x_ref, y_ref, z_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + z_ref[...] = jnp.zeros_like(z_ref) + z_ref[...] += x_ref[...] @ y_ref[...] + + def matmul_kernel(x_ref, y_ref, z_ref, *, bm, bk, bn): + m, k = x_ref.shape + _, n = y_ref.shape + assert k % bk == 0 + pltpu.emit_pipeline( + matmul_pipeline, + grid=(pl.cdiv(m, bm), pl.cdiv(n, bn), pl.cdiv(k, bk)), + in_specs=[ + pl.BlockSpec(lambda i, j, k: (i, k), (bm, bk)), + pl.BlockSpec(lambda i, j, k: (k, j), (bk, bn)), + ], + out_specs=pl.BlockSpec(lambda i, j, k: (i, j), (bm, bn)), + core_axis=0, + dimension_semantics=(pltpu.PARALLEL, pltpu.PARALLEL, pltpu.ARBITRARY) + )(x_ref, y_ref, z_ref) + + num_cores = jax.devices()[0].num_cores + func = pl.pallas_call( + functools.partial(matmul_kernel, bm=bm, bk=bk, bn=bn), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + ) + np.testing.assert_allclose(func(x, y), x @ y, atol=7e-5) + + +if CAN_USE_HYPOTHESIS: + + @partial(jax.jit, static_argnames=['bm', 'bk', 'bn']) + def matmul(x: jax.Array, y: jax.Array, *, bm: int, bk: int, bn: int): + + m, k = x.shape + _, n = y.shape + + def kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref): + + grid = (pl.cdiv(m, bm), pl.cdiv(n, bn), pl.cdiv(k, bk)) + + def run(acc_scratch_ref): + pltpu.emit_pipeline( + partial(basic_matmul_kernel, acc_scratch_ref=acc_scratch_ref, k=k), + in_specs=[ + pl.BlockSpec(lambda i, j, k: (i, k), (bm, bk)), + pl.BlockSpec(lambda i, j, k: (k, j), (bk, bn)), + ], + out_specs=pl.BlockSpec(lambda i, j, k: (i, j), (bm, bn)), + grid=grid, + core_axis=0, + dimension_semantics=( + pltpu.PARALLEL, + pltpu.PARALLEL, + pltpu.ARBITRARY, + ), + )(x_hbm_ref, y_hbm_ref, o_hbm_ref) + + accum_dtype = ( + jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32 + ) + pltpu.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype)) + + num_cores = jax.devices()[0].num_cores + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + )(x, y) + + class PaddedPipelineEmitterTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only TPU v4+ allowed.') + + @hp.given( + hps.sampled_from(['float32', 'bfloat16', 'int8']), + hps.integers(1, 1024), + hps.integers(1, 1024), + hps.integers(1, 1024), + hps.sampled_from([8, 16, 32, 128, 256, 512]), + hps.sampled_from([128, 256, 512]), + hps.sampled_from([128, 256, 512]), + hps.integers(0, 4), + ) + def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed): + hp.assume(bm <= m) + hp.assume(bn <= n) + hp.assume(bk <= k) + if dtype == 'bfloat16': + hp.assume(bm >= 16) + if dtype == 'int8': + hp.assume(bm >= 32) + hp.assume(jtu.is_device_tpu_at_least(5)) + k1, k2 = jax.random.split(jax.random.key(seed)) + x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype) + y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype) + + out = matmul(x, y, bm=bm, bk=bk, bn=bn) + expected = x @ y + atol = rtol = 1e-5 + if dtype == 'bfloat16': + out = out.astype('float32') + expected = expected.astype('float32') + atol = rtol = 1e-2 + np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index f2e80421a444..5c659ee08ee0 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -17,16 +17,15 @@ import itertools import os import sys -import unittest os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" from absl.testing import absltest from absl.testing import parameterized - import jax from jax import lax from jax import random +from jax._src import checkify from jax._src import config from jax._src import linear_util as lu from jax._src import state @@ -34,10 +33,10 @@ from jax._src.lax.control_flow.for_loop import for_loop from jax._src.pallas.pallas_call import _trace_to_jaxpr from jax.experimental import pallas as pl -from jax.experimental.pallas.ops import attention -from jax.experimental.pallas.ops import layer_norm -from jax.experimental.pallas.ops import rms_norm -from jax.experimental.pallas.ops import softmax +from jax.experimental.pallas.ops.gpu import attention +from jax.experimental.pallas.ops.gpu import layer_norm +from jax.experimental.pallas.ops.gpu import rms_norm +from jax.experimental.pallas.ops.gpu import softmax from jax.interpreters import partial_eval as pe import jax.numpy as jnp import numpy as np @@ -51,8 +50,6 @@ # TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs. # pylint: disable=no-value-for-parameter - -config.update("jax_traceback_filtering", "off") config.parse_flags_with_absl() @functools.partial(jax.jit, static_argnames=["bm", "bn", "gm", "bk", @@ -123,7 +120,8 @@ def body(i, acc_ref): return matmul_kernel(x, y) -class PallasTest(parameterized.TestCase): +@jtu.with_config(jax_traceback_filtering="off") +class PallasTest(jtu.JaxTestCase): INTERPRET = False def setUp(self): @@ -209,6 +207,20 @@ def index(x_ref, i_ref, o_ref): for i in range(5): np.testing.assert_allclose(index(x, i), x[i]) + def test_hoisted_consts(self): + # See https://github.com/google/jax/issues/21557. + x = jnp.zeros(32) + indices = jnp.arange(4).reshape((2, 2)) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + def kernel(src, dst): + dst[indices] = src[indices] + + jax.block_until_ready(kernel(x)) + def test_vector_slicing(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), @@ -222,100 +234,6 @@ def index(x_ref, idx_ref, o_ref): idx = jnp.arange(i, i + 2) np.testing.assert_allclose(index(x, idx), x[idx]) - def test_num_programs(self): - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), - grid=4, - ) - def kernel(o_ref): - o_ref[pl.program_id(0)] = pl.num_programs(0) - - np.testing.assert_array_equal( - kernel(), np.asarray([4, 4, 4, 4], dtype=np.int32) - ) - - def test_where_broadcasting(self): - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4, 2, 2), jnp.float32), - grid=1) - def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref): - mask = (jnp.arange(o_ref.shape[0]) == out_idx_ref[()])[:, None, None] - o_ref[...] = jnp.where(mask, x_ref[in_idx_ref[()]], 0) - - x = jnp.arange(7 * 2 * 2.).reshape(7, 2, 2) - for ii in range(7): - for oi in range(4): - out = copyitem(x, ii, oi) - self.assertEqual((4, 2, 2), out.shape) - np.testing.assert_allclose(out[:oi], jnp.zeros_like(out[:oi])) - np.testing.assert_allclose(out[oi], x[ii]) - np.testing.assert_allclose(out[oi + 1:], jnp.zeros_like(out[oi + 1:])) - - @parameterized.parameters(*[ - ((), (2,), ()), - ((1,), (2,), (0,)), - ((1, 1), (2, 2), (0, 1)), - ((), (2, 2), ()), - ]) - def test_broadcast_in_dim(self, in_shape, out_shape, dims): - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), - grid=1) - def f(x_ref, o_ref): - x = x_ref[...] - o_ref[...] = jax.lax.broadcast_in_dim(x, out_shape, dims) - - x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) - expected = jax.lax.broadcast_in_dim(x, out_shape, dims) - np.testing.assert_allclose(f(x), expected) - - @parameterized.parameters(*[ - ((2, 4), (8,)), - ((2, 4), (8, 1)), - ((2, 4), (1, 8)), - ((64,), (32, 2)), - ]) - def test_reshape(self, in_shape, out_shape): - # TODO(sharadmv): re-enable when `reshape` works again - if not self.INTERPRET: - self.skipTest("Reshape not yet supported in Triton-MLIR") - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), - grid=1) - def f(x_ref, o_ref): - o_ref[...] = x_ref[...].reshape(out_shape) - - x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) - expected = x.reshape(out_shape) - np.testing.assert_allclose(f(x), expected) - - @parameterized.parameters(*[ - ((), (1,)), - ((), (1, 1)), - ((2, 4), (2, 4)), - ((2, 4), (2, 4, 1)), - ((2, 4, 1), (2, 4)), - ((2, 4), (1, 2, 4)), - ((1, 2, 4), (2, 4)), - ((2, 4), (2, 1, 4)), - ((1, 2, 1, 4, 1), (2, 4)), - ((2, 4,), (1, 2, 1, 4)), - ((2, 4,), (1, 2, 4, 1)), - ((1, 2, 4, 1), (1, 2, 1, 4, 1)), - ]) - def test_reshape_noop_or_singleton_dims(self, in_shape, out_shape): - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), - grid=1) - def f(x_ref, o_ref): - o_ref[...] = x_ref[...].reshape(out_shape) - - x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) - expected = x.reshape(out_shape) - np.testing.assert_allclose(f(x), expected) - @parameterized.named_parameters(*[ (f"m_{m}_n_{n}_k_{k}_dtype_{dtype}_bm_{block_size_m}_" f"bn_{block_size_n}_bk_{block_size_k}_gm_{group_size_m}", m, n, k, dtype, @@ -325,7 +243,7 @@ def f(x_ref, o_ref): for n in [512, 1024] for dtype in ["float32", "float16"] for block_size_m in [64, 128] - for block_size_n in [128, 256] + for block_size_n in [64, 128] for block_size_k in [32] for group_size_m in [8] if block_size_m <= m and block_size_n <= n and block_size_k <= k @@ -347,7 +265,7 @@ def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm): for n in [512, 1024] for dtype in ["float32", "float16"] for block_size_m in [64, 128] - for block_size_n in [128, 256] + for block_size_n in [64, 128] for block_size_k in [32] if block_size_m <= m and block_size_n <= n and block_size_k <= k ]) @@ -359,33 +277,6 @@ def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk): interpret=self.INTERPRET), jnp.matmul(x, y) np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) - @parameterized.product( - size=[16, 32, 64], - dtype=["float32", "float16"], - trans_a=[False, True], - trans_b=[False, True], - ) - def test_dot(self, size, dtype, trans_a, trans_b): - if trans_a or trans_b: - # TODO(slebedev): Remove this once the problematic Triton pass is fixed. - raise unittest.SkipTest( - "Triton crashes if any of the operands are transposed") - - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((size, size), dtype), - grid=1) - def dot(x_ref, y_ref, o_ref): - x = x_ref[:, :] - y = y_ref[:, :] - o_ref[:, :] = pl.dot(x, y, trans_a, trans_b).astype(o_ref.dtype) - - k1, k2 = random.split(random.key(0)) - x = random.normal(k1, (size, size), dtype=dtype) - y = random.normal(k2, (size, size), dtype=dtype) - out, expected = dot(x, y), jnp.dot(x, y) - np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) - @parameterized.named_parameters(*( dict(testcase_name=f"{batch_size}_{size}_{block_size}_{dtype}", batch_size=batch_size, size=size, block_size=block_size, dtype=dtype) @@ -416,83 +307,6 @@ def softmax(x_ref, o_ref): np.testing.assert_allclose(softmax(x), jax.nn.softmax(x, axis=-1), atol=1e-5, rtol=1e-5) - @parameterized.parameters(*( - (size, block_size) - for size in [1, 2, 64, 129, 1021] - for block_size in [1, 2, 32, 64, 128] - )) - def test_masked_load_store(self, size, block_size): - @functools.partial(self.pallas_call, - out_shape=( - jax.ShapeDtypeStruct((size,), jnp.float32) - ), - grid=pl.cdiv(size, block_size)) - def add_one(x_ref, o_ref): - idx = pl.program_id(0) * block_size + jnp.arange(block_size) - mask = idx < x_ref.shape[0] - x = pl.load(x_ref, (idx,), mask=mask) - pl.store(o_ref, (idx,), x + 1., mask=mask) - - key = random.key(0) - x = random.normal(key, (size,)) - np.testing.assert_allclose(add_one(x), x + 1., atol=1e-5, rtol=1e-5) - - def test_broadcasted_load_store(self): - m, n = 16, 32 - @functools.partial( - self.pallas_call, - out_shape=( - jax.ShapeDtypeStruct((m, n), jnp.float32) - ), grid=1) - def load(x_ref, o_ref): - x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :])) - pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), x + 1.) - - key = random.key(0) - x = random.normal(key, (m, n)) - np.testing.assert_allclose(load(x), x + 1., atol=1e-5, rtol=1e-5) - - def test_swap(self): - m, n = 16, 32 - - @functools.partial( - self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, - grid=1, - input_output_aliases={0: 0, 1: 1}, - ) - def swap(_, _2, x_ref, y_ref): - x = x_ref[:] - y = pl.swap(y_ref, (slice(None),), x) - x_ref[:] = y - - x = random.normal(random.key(0), (m, n)) - y = random.normal(random.key(1), (m, n)) - out = swap(x, y) - np.testing.assert_array_equal(out[0], y) - np.testing.assert_array_equal(out[1], x) - - def test_masked_swap(self): - m, n = 16, 32 - - @functools.partial( - self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, - grid=1, - input_output_aliases={0: 0, 1: 1}, - ) - def masked_swap(_, _2, mask_ref, x_ref, y_ref): - x = x_ref[:] - y = pl.swap(y_ref, (slice(None),), x, mask=mask_ref[:]) - x_ref[:] = y - - x = random.normal(random.key(0), (m, n)) - y = random.normal(random.key(1), (m, n)) - mask = random.bernoulli(random.key(2), shape=(m, n)) - out = masked_swap(x, y, mask) - np.testing.assert_array_equal(out[0], jnp.where(mask, y, x)) - np.testing.assert_array_equal(out[1], jnp.where(mask, x, y)) - def test_unused_ref(self): m, n = 16, 32 @functools.partial( @@ -533,191 +347,30 @@ def add_inplace_kernel(_, o_ref, *, block_size): expected = x + 1 np.testing.assert_allclose(out, expected) - @parameterized.named_parameters(*[ - ("add_i32", pl.atomic_add, np.array([1, 2, 3, 4], np.int32), np.sum), - ("max_i", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max), - ("min_i32", pl.atomic_min, np.array([1, 2, 3, 4], np.int32), np.min), - ("add_f16", pl.atomic_add, np.array([1, 2, 3, 4], np.float16), np.sum), - ("add_f32", pl.atomic_add, np.array([1, 2, 3, 4], np.float32), np.sum), - ("max_f32", pl.atomic_max, np.array([1, 2, 3, 4], np.float32), np.max), - ("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min), - ]) - def test_scalar_atomic(self, op, value, numpy_op): - - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((), value.dtype), - grid=value.shape[0], - input_output_aliases={1: 0}) - def atomic_kernel(x_ref, _, o_ref): - pid = pl.program_id(axis=0) - op(o_ref, (), x_ref[pid]) - if op == pl.atomic_add: - neutral = np.array(0, dtype=value.dtype) - elif op == pl.atomic_max: - if np.issubdtype(value.dtype, np.integer): - neutral = np.array(np.iinfo(value.dtype).min, value.dtype) - else: - neutral = np.array(-float('inf'), value.dtype) - elif op == pl.atomic_min: - if np.issubdtype(value.dtype, np.integer): - neutral = np.array(np.iinfo(value.dtype).max, value.dtype) - else: - neutral = np.array(float('inf'), value.dtype) - elif op == pl.atomic_or: - neutral = np.array(False, value.dtype) - else: - raise NotImplementedError() - out = atomic_kernel(value, neutral) - np.testing.assert_allclose(out, numpy_op(value)) - - @parameterized.parameters(*[(0,), (1,)]) - def test_array_atomic_add(self, axis): - m, n = 32, 8 - if axis == 0: - grid = m - else: - grid = n - out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), jnp.float32) + def test_using_pallas_slice(self): + m, n = 32, 4 + out_shape = jax.ShapeDtypeStruct((4, n), jnp.float32) @functools.partial( self.pallas_call, out_shape=out_shape, - grid=grid, - input_output_aliases={1: 0}) - def reduce(x_ref, _, y_ref): - i = pl.program_id(axis=0) - if axis == 0: - idx = (i, jnp.arange(n)) - else: - idx = (jnp.arange(m), i) - x = pl.load(x_ref, idx) - pl.atomic_add(y_ref, (jnp.arange(y.shape[0]),), x) + grid=1) + def slice_kernel(x_ref, y_ref): + x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4))) + pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x) x = random.normal(random.key(0), (m, n)) - y = jnp.zeros(out_shape.shape, out_shape.dtype) - y = reduce(x, y) - y_ref = np.sum(x, axis=axis) + y = slice_kernel(x) + y_ref = x[:4] np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) - @parameterized.parameters(False, True) - def test_reduce_only_dim(self, use_store): - m = 32 - x = random.normal(random.key(0), (m,), dtype=jnp.float32) - out_shape = jax.ShapeDtypeStruct((), x.dtype) + def test_pallas_trace_cache(self): + trace_count = 0 @functools.partial( - self.pallas_call, - out_shape=out_shape, - grid=1, debug=False) - def reduce(x_ref, y_ref): - x = pl.load(x_ref, (jnp.arange(m),)) - y = jnp.sum(x, axis=-1) - if use_store: - pl.store(y_ref, (), y) - else: - y_ref[...] = y - y = reduce(x) - y_ref = jnp.sum(x, axis=-1) - np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) - - @parameterized.named_parameters(*[ - (f"{op_name}_{dtype}_{axis}", op, dtype, axis) - for op_name, op in [ - ("add", jnp.sum), - ("max", jnp.max), - ("min", jnp.min), - ("argmax", jnp.argmax), - ("argmin", jnp.argmin), - ] - for axis in [0, 1, (1,), (0, 1)] - for dtype in ["float16", "float32", "int32", "uint32"] - if isinstance(axis, int) or "arg" not in op_name - ]) - def test_array_reduce(self, op, dtype, axis): - m, n = 32, 8 - out_dtype = dtype - if op in {jnp.argmin, jnp.argmax}: - out_dtype = jnp.int32 - def make_x(key): - if jnp.issubdtype(dtype, jnp.integer): - return random.permutation( - key, jnp.arange(m * n, dtype=dtype), independent=True - ).reshape(m, n) - else: - return random.normal(key, (m, n), dtype=dtype) - out_shape = jax.ShapeDtypeStruct( - op(make_x(random.key(0)), axis=axis).shape, out_dtype) - if isinstance(axis, int): - grid = tuple(a for i, a in enumerate((m, n)) if i != axis) - else: - grid = tuple(a for i, a in enumerate((m, n)) if i not in axis) - @functools.partial( - self.pallas_call, - out_shape=out_shape, - grid=grid) - def reduce(x_ref, y_ref): - x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None])) - y = op(x, axis=axis) - pl.store(y_ref, tuple(jnp.arange(d) for d in y.shape), y) - for i, key in enumerate(random.split(random.key(0), 20)): - x = make_x(key) - y = reduce(x) - y_ref = op(x, axis=axis) - np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) - - @parameterized.named_parameters(*[ - (f"{dtype}_{axis}", dtype, axis) - for axis in [0, 1] - for dtype in ["float16", "float32", "int32", "uint32"] - if isinstance(axis, int) - ]) - def test_cumsum(self, dtype, axis): - m, n = 32, 8 - out_dtype = dtype - def make_x(key): - if jnp.issubdtype(dtype, jnp.integer): - return random.permutation( - key, jnp.arange(m * n, dtype=dtype), independent=True - ).reshape(m, n) - else: - return random.normal(key, (m, n), dtype=dtype) - out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) - grid = () - @functools.partial( - self.pallas_call, - out_shape=out_shape, - grid=grid) - def reduce(x_ref, y_ref): - x = x_ref[...] - y_ref[...] = jnp.cumsum(x, axis=axis) - for i, key in enumerate(random.split(random.key(0), 20)): - x = make_x(key) - y = reduce(x) - y_ref = jnp.cumsum(x, axis=axis) - np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) - - def test_using_pallas_slice(self): - m, n = 32, 4 - out_shape = jax.ShapeDtypeStruct((4, n), jnp.float32) - @functools.partial( - self.pallas_call, - out_shape=out_shape, - grid=1) - def slice_kernel(x_ref, y_ref): - x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4))) - pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x) - x = random.normal(random.key(0), (m, n)) - y = slice_kernel(x) - y_ref = x[:4] - np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) - - def test_pallas_trace_cache(self): - trace_count = 0 - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), - grid=1) - def add_one(x_ref, o_ref): - nonlocal trace_count - o_ref[()] = x_ref[()] + 1. - trace_count += 1 + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), + grid=1) + def add_one(x_ref, o_ref): + nonlocal trace_count + o_ref[()] = x_ref[()] + 1. + trace_count += 1 @jax.jit def f(x): @@ -727,52 +380,6 @@ def f(x): self.assertEqual(f(x), 2.) self.assertEqual(trace_count, 1) - @parameterized.parameters(*[ - (0, 0, 1), - (0, 1, 1), - (1, 0, 1), - (1, 1, 1), - (2, 1, 1), - (2, 1, 1), - ]) - def test_atomic_cas(self, init_value, cmp, new_value): - @functools.partial( - self.pallas_call, out_shape=( - jax.ShapeDtypeStruct((), jnp.int32), - jax.ShapeDtypeStruct((), jnp.int32)), - input_output_aliases={0: 0}) - def swap(_, lock_ref, out_ref): - out_ref[()] = pl.atomic_cas(lock_ref, cmp, new_value) - - lock, out = swap(init_value) - np.testing.assert_allclose(lock, new_value if cmp == init_value else - init_value) - np.testing.assert_allclose(out, init_value) - - @parameterized.parameters(*[ - 1, 2, 3, 4, 8 - ]) - def test_atomic_counter(self, num_threads): - if self.INTERPRET: - self.skipTest("While loop not supported in interpreter mode.") - - @functools.partial( - self.pallas_call, out_shape=( - jax.ShapeDtypeStruct((), jnp.int32), - jax.ShapeDtypeStruct((), jnp.int32)), - input_output_aliases={0: 0, 1: 1}, - grid=(num_threads,)) - def increment(_, __, lock_ref, counter_ref): - def _cond(_): - return pl.atomic_cas(lock_ref, 0, 1) == 1 - lax.while_loop(_cond, lambda a: a, 0) - counter_ref[...] += 1 - pl.atomic_xchg(lock_ref, (), 0) - - lock, count = increment(0, 0) - np.testing.assert_allclose(lock, 0) - np.testing.assert_allclose(count, num_threads) - def test_custom_jvp_call(self): @functools.partial(jax.custom_jvp, nondiff_argnums=(1,)) def softmax(x, axis=-1): @@ -802,10 +409,11 @@ class PallasCallInterpreterTest(PallasCallTest): class PallasControlFlowTest(PallasTest): def setUp(self): - super().setUp() if self.INTERPRET: self.skipTest("Control flow not supported in interpreter mode yet.") + super().setUp() + def test_loop_with_float64_carry(self): # Test that the jnp.zeros(f64) loop init_val is actually f64, and that # fori_loop handles i64 index variables, i.e. error: 'scf.for' op along @@ -1374,6 +982,21 @@ def add(x_ref, _, o_ref): out_ref = jnp.arange(2, 10) np.testing.assert_allclose(out, out_ref) + def test_vmap_of_kernel_with_input_output_aliases_different_axes(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), + debug=False, + input_output_aliases={0: 0}, + grid=(), + ) + def add(x_ref, o_ref): + o_ref[()] = x_ref[()] + 1 + + out = jax.vmap(add, in_axes=1)(jnp.arange(8).reshape((4, 2))) + out_ref = jnp.arange(1, 9).reshape((4, 2)).swapaxes(0, 1) + np.testing.assert_allclose(out, out_ref) + def test_vmap_of_slicing_kernel_different_axes(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), @@ -1457,7 +1080,7 @@ class PallasOpsTest(PallasTest): [jnp.abs, jnp.negative], ["int16", "int32", "int64", "float16", "float32", "float64"], ), - ([jnp.ceil, jnp.floor], ["float32", "float64"]), + ([jnp.ceil, jnp.floor], ["float32", "float64", "int32"]), ( [jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt], ["float16", "float32", "float64"], @@ -1511,6 +1134,17 @@ def kernel(x_ref, y_ref, o_ref): y = jnp.array([1, 2, 3, 4]).astype(y_dtype) np.testing.assert_allclose(kernel(x, y), lax.pow(x, y)) + @parameterized.parameters(0, 1, 2, 3, 4, 5, -1, -2, -3) + def test_integer_pow(self, y): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[:] = lax.integer_pow(x_ref[...], y) + + x = jnp.array([1, 2, 3, 4]).astype(jnp.float32) / 10 + np.testing.assert_allclose(kernel(x), lax.integer_pow(x, y)) + @parameterized.parameters("float32", "float64") def test_nextafter(self, dtype): if jtu.test_device_matches(["tpu"]) and dtype == "float64": @@ -1554,172 +1188,690 @@ def kernel(x_ref, y_ref, o_ref): y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) np.testing.assert_allclose(kernel(x, y), fn(x, y)) - def test_isnan(self): + def test_isnan(self): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), + grid=1) + def isnan(x_ref, o_ref): + o_ref[:] = jnp.isnan(x_ref[...]) + + x = jnp.arange(8.) + x = x.at[3].set(jnp.nan) + np.testing.assert_allclose(isnan(x), jnp.isnan(x)) + + @parameterized.parameters( + ("int32", "float32"), + ("float32", "float32"), + ) + def test_true_divide(self, dtype, out_dtype): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8,), out_dtype), + grid=1, + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...]) + + x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) + y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) + np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y)) + + @parameterized.parameters("float16", "bfloat16") + def test_true_divide_unsupported(self, dtype): + if self.INTERPRET: + self.skipTest("No lowering in interpreter mode") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), dtype), + grid=1, + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...]) + + x = jnp.array([2.4, 4.2]).astype(dtype) + y = jnp.array([4.2, 2.4]).astype(dtype) + with self.assertRaises(Exception): + kernel(x, y) + + BINARY_OPS = [ + ([jnp.floor_divide], ["int32", "uint32"]), + ( + [jnp.add, jnp.subtract, jnp.multiply], + ["int16", "int32", "uint32", "float16", "float32"], + ), + ([jnp.remainder], ["int32", "uint32", "float32"]), + ( + # fmt: off + [jnp.bitwise_and, jnp.bitwise_or, jnp.bitwise_xor, + jnp.bitwise_left_shift, jnp.bitwise_right_shift], + # fmt: on + ["int32", "uint32"], + ), + ] + + @parameterized.named_parameters( + (f"{fn.__name__}_{dtype}", fn, dtype) + for args in BINARY_OPS + for fn, dtype in itertools.product(*args) + ) + def test_binary(self, f, dtype): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), dtype), grid=1 + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = f(x_ref[...], y_ref[...]) + + x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) + if (f == jnp.bitwise_left_shift): + y = jnp.array([3, 1, 4, 5, 2, 2, 2, 4]).astype(dtype) + else: + y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) + + np.testing.assert_allclose(f(x, y), kernel(x, y)) + + @parameterized.parameters( + ((8, 4), jnp.int32, 0), + ((8, 16), jnp.float32, 1), + ((8, 16, 2), jnp.int8, 1), + ) + def test_broadcasted_iota(self, shape, dtype, dimension): + f = lambda: jax.lax.broadcasted_iota(dtype, shape, dimension) + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, dtype), grid=1 + ) + def kernel(o_ref): + o_ref[...] = f() + + np.testing.assert_allclose(f(), kernel()) + + @parameterized.parameters("float16", "bfloat16", "float32") + def test_approx_tanh(self, dtype): + if self.INTERPRET: + self.skipTest("approx_tanh is not supported in interpreter mode") + if (dtype == "bfloat16" and + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), grid=1 + ) + def kernel(x_ref, o_ref): + o_ref[...] = plgpu.approx_tanh(x_ref[...]) + + x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) + # We upcast to float32 because NumPy <2.0 does not handle custom dtypes + # properly. See https://github.com/google/jax/issues/11014. + np.testing.assert_allclose( + kernel(x).astype(jnp.float32), + jnp.tanh(x).astype(jnp.float32), + atol=5e-3, + rtol=5e-3, + ) + + def test_elementwise_inline_asm(self): + if self.INTERPRET: + self.skipTest( + "elementwise_inline_asm is not supported in interpreter mode" + ) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((256,), jnp.float16), + grid=1, + ) + def kernel(x_ref, o_ref): + [o_ref[...]] = plgpu.elementwise_inline_asm( + "tanh.approx.f16x2 $0, $1;", + args=[x_ref[...]], + constraints="=r,r", + pack=2, + result_shape_dtypes=[jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype)], + ) + + x = jnp.arange(256).astype(jnp.float16) + np.testing.assert_allclose(kernel(x), jnp.tanh(x), atol=5e-3, rtol=5e-3) + + def test_debug_print(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + grid=1, + compiler_params=dict(triton=dict(num_warps=1, num_stages=1)) + ) + def kernel(x_ref, o_ref): + pl.debug_print("It works!") + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + jax.effects_barrier() + + self.assertIn("It works!", output()) + + def test_debug_print_with_values(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + grid=1, + compiler_params=dict(triton=dict(num_warps=1, num_stages=1)) + ) + def kernel(x_ref, o_ref): + pl.debug_print("x[0] =", x_ref[0]) + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + jax.effects_barrier() + + self.assertIn("x[0] = 4.2", output()) + + @parameterized.parameters( + ((2, 4), (8,)), + ((2, 4), (8, 1)), + ((2, 4), (1, 8)), + ((64,), (32, 2)), + ) + def test_reshape(self, in_shape, out_shape): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), + grid=1, + ) + def f(x_ref, o_ref): + o_ref[...] = x_ref[...].reshape(out_shape) + + x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) + expected = x.reshape(out_shape) + np.testing.assert_allclose(f(x), expected) + + @parameterized.parameters( + # fmt: off + ((), (1,)), + ((), (1, 1)), + ((2, 4), (2, 4)), + ((2, 4), (2, 4, 1)), + ((2, 4, 1), (2, 4)), + ((2, 4), (1, 2, 4)), + ((1, 2, 4), (2, 4)), + ((2, 4), (2, 1, 4)), + ((1, 2, 1, 4, 1), (2, 4)), + ((2, 4,), (1, 2, 1, 4)), + ((2, 4,), (1, 2, 4, 1)), + ((1, 2, 4, 1), (1, 2, 1, 4, 1)), + # fmt: on + ) + def test_reshape_noop_or_singleton_dims(self, in_shape, out_shape): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), + grid=1, + ) + def f(x_ref, o_ref): + o_ref[...] = x_ref[...].reshape(out_shape) + + x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) + expected = x.reshape(out_shape) + np.testing.assert_allclose(f(x), expected) + + def test_num_programs(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), + grid=4, + ) + def kernel(o_ref): + o_ref[pl.program_id(0)] = pl.num_programs(0) + + np.testing.assert_array_equal( + kernel(), np.asarray([4, 4, 4, 4], dtype=np.int32) + ) + + def test_where_broadcasting(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4, 2, 2), jnp.float32), + grid=1, + ) + def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref): + mask = (jnp.arange(o_ref.shape[0]) == out_idx_ref[()])[:, None, None] + o_ref[...] = jnp.where(mask, x_ref[in_idx_ref[()]], 0) + + x = jnp.arange(7 * 2 * 2.0).reshape(7, 2, 2) + for ii in range(7): + for oi in range(4): + out = copyitem(x, ii, oi) + self.assertEqual((4, 2, 2), out.shape) + np.testing.assert_allclose(out[:oi], jnp.zeros_like(out[:oi])) + np.testing.assert_allclose(out[oi], x[ii]) + np.testing.assert_allclose(out[oi + 1 :], jnp.zeros_like(out[oi + 1 :])) + + @parameterized.parameters( + ((), (2,), ()), + ((1,), (2,), (0,)), + ((1, 1), (2, 2), (0, 1)), + ((), (2, 2), ()), + ) + def test_broadcast_in_dim(self, in_shape, out_shape, dims): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), + grid=1, + ) + def f(x_ref, o_ref): + x = x_ref[...] + o_ref[...] = jax.lax.broadcast_in_dim(x, out_shape, dims) + + x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) + expected = jax.lax.broadcast_in_dim(x, out_shape, dims) + np.testing.assert_allclose(f(x), expected) + + @parameterized.product( + size=[16, 32, 64], + dtype=["float32", "float16"], + trans_x=[False, True], + trans_y=[False, True], + ) + def test_dot(self, size, dtype, trans_x, trans_y): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((size, size), dtype), + grid=1, + ) + def dot(x_ref, y_ref, o_ref): + x = x_ref[:, :] + y = y_ref[:, :] + o_ref[:, :] = pl.dot(x, y, trans_x, trans_y).astype(o_ref.dtype) + + k1, k2 = random.split(random.key(0)) + x = random.normal(k1, (size, size), dtype=dtype) + y = random.normal(k2, (size, size), dtype=dtype) + out = dot(x, y) + expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y) + np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) + + @parameterized.product( + size=[1, 2, 64, 129, 1021], + block_size=[1, 2, 32, 64, 128], + ) + def test_masked_load_store(self, size, block_size): + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((size,), jnp.float32)), + grid=pl.cdiv(size, block_size), + ) + def kernel(x_ref, o_ref): + idx = pl.program_id(0) * block_size + jnp.arange(block_size) + mask = idx < x_ref.shape[0] + x = pl.load(x_ref, (idx,), mask=mask) + pl.store(o_ref, (idx,), x + 1.0, mask=mask) + + key = random.key(0) + x = random.normal(key, (size,)) + np.testing.assert_allclose(kernel(x), x + 1.0, atol=1e-5, rtol=1e-5) + + def test_masked_oob_load_store_slice(self): + n = 16 + + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32)), + grid=1, + ) + def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): + x = pl.load(x_ref, (pl.dslice(start_idx_ref[()], n)), + mask=mask_ref[:], other=-1.) + pl.store(o_ref, (pl.dslice(None),), x) + + x = random.normal(random.key(0), (n,)) + slice_start = random.randint(random.key(2), (), 1, n) + indices = jnp.arange(n) + slice_start + mask = indices < n + out = masked_oob_load_store_slice(x, mask, slice_start) + o_new = jnp.where(mask, x[indices], jnp.full_like(x, -1.)) + np.testing.assert_array_equal(out, o_new) + + def test_strided_load(self): + if self.INTERPRET: + # TODO(b/329733289): Remove this once the bug is fixed. + self.skipTest("Strided load not yet supported in interpreter mode") + + # Reproducer from https://github.com/google/jax/issues/20895. + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[::4] + + x = jnp.arange(16, dtype=jnp.float32) + np.testing.assert_array_equal(kernel(x), x[::4]) + + def test_broadcasted_load_store(self): + m, n = 16, 32 + + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32)), + grid=1, + ) + def load(x_ref, o_ref): + x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :])) + pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), x + 1.0) + + key = random.key(0) + x = random.normal(key, (m, n)) + np.testing.assert_allclose(load(x), x + 1.0, atol=1e-5, rtol=1e-5) + + @parameterized.parameters( + ((16, 32), (16,)), + ((16, 32), (32,)), + ((16, 32), (16, 31)), + ) + def test_invalid_broadcasted_load(self, x_shape, mask_shape): + if self.INTERPRET: + self.skipTest("No broadcasting checks in pl.load in interpreter mode") + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32) + ) + def kernel(x_ref, mask_ref, o_ref): + del o_ref # Unused. + pl.load(x_ref, slice(None), mask=mask_ref[:]) + + x = jnp.ones(x_shape, dtype=jnp.float32) + mask = jnp.ones(mask_shape, dtype=jnp.bool_) + # assertRaises* methods do not support inspecting the __cause__, so + # we have to check it manually. + try: + kernel(x, mask) + except Exception as e: + self.assertIn("Cannot broadcast", str(e.__cause__)) + else: + self.fail("Expected exception due to invalid broadcasting") + + def test_swap(self): + m, n = 16, 32 + + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, + grid=1, + input_output_aliases={0: 0, 1: 1}, + ) + def swap(_, _2, x_ref, y_ref): + x = x_ref[:] + y = pl.swap(y_ref, (slice(None),), x) + x_ref[:] = y + + x = random.normal(random.key(0), (m, n)) + y = random.normal(random.key(1), (m, n)) + out = swap(x, y) + np.testing.assert_array_equal(out[0], y) + np.testing.assert_array_equal(out[1], x) + + def test_masked_swap(self): + m, n = 16, 32 + + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, + grid=1, + input_output_aliases={0: 0, 1: 1}, + ) + def masked_swap(_, _2, mask_ref, x_ref, y_ref): + x = x_ref[:] + y = pl.swap(y_ref, (slice(None),), x, mask=mask_ref[:]) + x_ref[:] = y + + x = random.normal(random.key(0), (m, n)) + y = random.normal(random.key(1), (m, n)) + mask = random.bernoulli(random.key(2), shape=(m, n)) + out = masked_swap(x, y, mask) + np.testing.assert_array_equal(out[0], jnp.where(mask, y, x)) + np.testing.assert_array_equal(out[1], jnp.where(mask, x, y)) + + def test_masked_oob_swap_slice(self): + m, n = 32, 16 + + @functools.partial( + self.pallas_call, + out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32), + jax.ShapeDtypeStruct((m,), jnp.float32)), + grid=1, + input_output_aliases={0: 0, 1: 1}, + ) + def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref): + x, mask = x_ref[:], mask_ref[:] + y = pl.swap(y_ref, (pl.dslice(start_idx_ref[()], n)), x, mask=mask) + x_ref[:] = y + + x = random.normal(random.key(0), (n,)) + y = random.normal(random.key(1), (m,)) + slice_start = random.randint(random.key(2), (), m-n+1, m) + indices = jnp.arange(n) + slice_start + mask = indices < m + out = masked_oob_swap_slice(x, y, mask, slice_start) + + # the unjittable masked indexing equivalent + unmasked_idx = indices[mask] + x_new = x.at[mask].set(y[unmasked_idx]) + y_new = y.at[unmasked_idx].set(x[mask]) + np.testing.assert_array_equal(out[0], x_new) + np.testing.assert_array_equal(out[1], y_new) + + @parameterized.named_parameters( + ("add_i32", pl.atomic_add, np.array([1, 2, 3, 4], np.int32), np.sum), + ("max_i", pl.atomic_max, np.array([1, 2, 3, 4], np.int32), np.max), + ("min_i32", pl.atomic_min, np.array([1, 2, 3, 4], np.int32), np.min), + ("add_f16", pl.atomic_add, np.array([1, 2, 3, 4], np.float16), np.sum), + ("add_f32", pl.atomic_add, np.array([1, 2, 3, 4], np.float32), np.sum), + ("max_f32", pl.atomic_max, np.array([1, 2, 3, 4], np.float32), np.max), + ("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min), + ) + def test_scalar_atomic(self, op, value, numpy_op): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((), value.dtype), + grid=value.shape[0], + input_output_aliases={1: 0}, + ) + def atomic_kernel(x_ref, _, o_ref): + pid = pl.program_id(axis=0) + op(o_ref, (), x_ref[pid]) + + if op == pl.atomic_add: + neutral = np.array(0, dtype=value.dtype) + elif op == pl.atomic_max: + if np.issubdtype(value.dtype, np.integer): + neutral = np.array(np.iinfo(value.dtype).min, value.dtype) + else: + neutral = np.array(-float("inf"), value.dtype) + elif op == pl.atomic_min: + if np.issubdtype(value.dtype, np.integer): + neutral = np.array(np.iinfo(value.dtype).max, value.dtype) + else: + neutral = np.array(float("inf"), value.dtype) + elif op == pl.atomic_or: + neutral = np.array(False, value.dtype) + else: + raise NotImplementedError() + out = atomic_kernel(value, neutral) + np.testing.assert_allclose(out, numpy_op(value)) + + @parameterized.parameters((0,), (1,)) + def test_array_atomic_add(self, axis): + m, n = 32, 8 + if axis == 0: + grid = m + else: + grid = n + out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), jnp.float32) + @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), - grid=1) - def isnan(x_ref, o_ref): - o_ref[:] = jnp.isnan(x_ref[...]) + self.pallas_call, + out_shape=out_shape, + grid=grid, + input_output_aliases={1: 0}, + ) + def reduce(x_ref, _, y_ref): + i = pl.program_id(axis=0) + if axis == 0: + idx = (i, jnp.arange(n)) + else: + idx = (jnp.arange(m), i) + x = pl.load(x_ref, idx) + pl.atomic_add(y_ref, (jnp.arange(y.shape[0]),), x) - x = jnp.arange(8.) - x = x.at[3].set(jnp.nan) - np.testing.assert_allclose(isnan(x), jnp.isnan(x)) + x = random.normal(random.key(0), (m, n)) + y = jnp.zeros(out_shape.shape, out_shape.dtype) + y = reduce(x, y) + y_ref = np.sum(x, axis=axis) + np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) @parameterized.parameters( - ("int32", "float32"), - ("float32", "float32"), + (0, 0, 1), + (0, 1, 1), + (1, 0, 1), + (1, 1, 1), + (2, 1, 1), + (2, 1, 1), ) - def test_true_divide(self, dtype, out_dtype): + def test_atomic_cas(self, init_value, cmp, new_value): @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((8,), out_dtype), - grid=1, - ) - def kernel(x_ref, y_ref, o_ref): - o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...]) + self.pallas_call, out_shape=( + jax.ShapeDtypeStruct((), jnp.int32), + jax.ShapeDtypeStruct((), jnp.int32)), + input_output_aliases={0: 0}) + def swap(_, lock_ref, out_ref): + out_ref[()] = pl.atomic_cas(lock_ref, cmp, new_value) - x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) - y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) - np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y)) + lock, out = swap(init_value) + np.testing.assert_allclose(lock, new_value if cmp == init_value else + init_value) + np.testing.assert_allclose(out, init_value) - @parameterized.parameters("float16", "bfloat16") - def test_true_divide_unsupported(self, dtype): + @parameterized.parameters(1, 2, 3, 4, 8) + def test_atomic_counter(self, num_threads): if self.INTERPRET: - self.skipTest("No lowering in interpreter mode") + self.skipTest("While loop not supported in interpreter mode.") @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((2,), dtype), - grid=1, - ) - def kernel(x_ref, y_ref, o_ref): - o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...]) + self.pallas_call, out_shape=( + jax.ShapeDtypeStruct((), jnp.int32), + jax.ShapeDtypeStruct((), jnp.int32)), + input_output_aliases={0: 0, 1: 1}, + grid=(num_threads,)) + def increment(_, __, lock_ref, counter_ref): + def _cond(_): + return pl.atomic_cas(lock_ref, 0, 1) == 1 + lax.while_loop(_cond, lambda a: a, 0) + counter_ref[...] += 1 + pl.atomic_xchg(lock_ref, (), 0) - x = jnp.array([2.4, 4.2]).astype(dtype) - y = jnp.array([4.2, 2.4]).astype(dtype) - with self.assertRaises(Exception): - kernel(x, y) + lock, count = increment(0, 0) + np.testing.assert_allclose(lock, 0) + np.testing.assert_allclose(count, num_threads) - BINARY_OPS = [ - ([jnp.floor_divide], ["int32", "uint32"]), - ( - [jnp.add, jnp.subtract, jnp.multiply], - ["int16", "int32", "uint32", "float16", "float32"], - ), - ([jnp.remainder], ["int32", "uint32", "float32"]), - ( - # fmt: off - [jnp.bitwise_and, jnp.bitwise_or, jnp.bitwise_xor, - jnp.bitwise_left_shift, jnp.bitwise_right_shift], - # fmt: on - ["int32", "uint32"], - ), - ] + @parameterized.parameters(False, True) + def test_reduce_only_dim(self, use_store): + m = 32 + x = random.normal(random.key(0), (m,), dtype=jnp.float32) + out_shape = jax.ShapeDtypeStruct((), x.dtype) - @parameterized.named_parameters( - (f"{fn.__name__}_{dtype}", fn, dtype) - for args in BINARY_OPS - for fn, dtype in itertools.product(*args) - ) - def test_binary(self, f, dtype): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), dtype), grid=1 + self.pallas_call, out_shape=out_shape, grid=1, debug=False ) - def kernel(x_ref, y_ref, o_ref): - o_ref[...] = f(x_ref[...], y_ref[...]) - - x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) - y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) - np.testing.assert_allclose(f(x, y), kernel(x, y)) - - @parameterized.parameters( - ((8, 4), jnp.int32, 0), - ((8, 16), jnp.float32, 1), - ((8, 16, 2), jnp.int8, 1), - ) - def test_broadcasted_iota(self, shape, dtype, dimension): - f = lambda: jax.lax.broadcasted_iota(dtype, shape, dimension) + def reduce(x_ref, y_ref): + x = pl.load(x_ref, (jnp.arange(m),)) + y = jnp.sum(x, axis=-1) + if use_store: + pl.store(y_ref, (), y) + else: + y_ref[...] = y - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, dtype), grid=1 - ) - def kernel(o_ref): - o_ref[...] = f() + y = reduce(x) + y_ref = jnp.sum(x, axis=-1) + np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) - np.testing.assert_allclose(f(), kernel()) + @parameterized.named_parameters(*[ + (f"{op_name}_{dtype}_{axis}", op, dtype, axis) + for op_name, op in [ + ("add", jnp.sum), + ("max", jnp.max), + ("min", jnp.min), + ("argmax", jnp.argmax), + ("argmin", jnp.argmin), + ] + for axis in [0, 1, (1,), (0, 1)] + for dtype in ["float16", "float32", "int32", "uint32"] + if isinstance(axis, int) or "arg" not in op_name + ]) + def test_array_reduce(self, op, dtype, axis): + m, n = 32, 8 + out_dtype = dtype + if op in {jnp.argmin, jnp.argmax}: + out_dtype = jnp.int32 - @parameterized.parameters("float16", "bfloat16", "float32") - def test_approx_tanh(self, dtype): - if self.INTERPRET: - self.skipTest("approx_tanh is not supported in interpreter mode") - if (dtype == "bfloat16" and - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") + def make_x(key): + if jnp.issubdtype(dtype, jnp.integer): + return random.permutation( + key, jnp.arange(m * n, dtype=dtype), independent=True + ).reshape(m, n) + else: + return random.normal(key, (m, n), dtype=dtype) - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), grid=1 + out_shape = jax.ShapeDtypeStruct( + op(make_x(random.key(0)), axis=axis).shape, out_dtype ) - def kernel(x_ref, o_ref): - o_ref[...] = plgpu.approx_tanh(x_ref[...]) + if isinstance(axis, int): + grid = tuple(a for i, a in enumerate((m, n)) if i != axis) + else: + grid = tuple(a for i, a in enumerate((m, n)) if i not in axis) - x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) - # We upcast to float32 because NumPy <2.0 does not handle custom dtypes - # properly. See https://github.com/google/jax/issues/11014. - np.testing.assert_allclose( - kernel(x).astype(jnp.float32), - jnp.tanh(x).astype(jnp.float32), - atol=5e-3, - rtol=5e-3, - ) + @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) + def reduce(x_ref, y_ref): + x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None])) + y = op(x, axis=axis) + pl.store(y_ref, tuple(jnp.arange(d) for d in y.shape), y) - def test_elementwise_inline_asm(self): - if self.INTERPRET: - self.skipTest( - "elementwise_inline_asm is not supported in interpreter mode" - ) + for i, key in enumerate(random.split(random.key(0), 20)): + x = make_x(key) + y = reduce(x) + y_ref = op(x, axis=axis) + np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((256,), jnp.float16), - grid=1, - ) - def kernel(x_ref, o_ref): - [o_ref[...]] = plgpu.elementwise_inline_asm( - "tanh.approx.f16x2 $0, $1;", - args=[x_ref[...]], - constraints="=r,r", - pack=2, - result_shape_dtypes=[jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype)], - ) + @parameterized.product( + axis=[0, 1], + dtype=["float16", "float32", "int32", "uint32"], + ) + def test_cumsum(self, dtype, axis): + m, n = 32, 8 + out_dtype = dtype - x = jnp.arange(256).astype(jnp.float16) - np.testing.assert_allclose(kernel(x), jnp.tanh(x), atol=5e-3, rtol=5e-3) + def make_x(key): + if jnp.issubdtype(dtype, jnp.integer): + return random.permutation( + key, jnp.arange(m * n, dtype=dtype), independent=True + ).reshape(m, n) + else: + return random.normal(key, (m, n), dtype=dtype) - def test_debug_print(self): - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - grid=1, - compiler_params=dict(triton=dict(num_warps=1, num_stages=1)) - ) - def kernel(x_ref, o_ref): - pl.debug_print("It works!") + out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) + grid = () - x = jnp.array([4.2, 2.4]).astype(jnp.float32) - kernel(x) + @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) + def reduce(x_ref, y_ref): + x = x_ref[...] + y_ref[...] = jnp.cumsum(x, axis=axis) - def test_debug_print_with_values(self): - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - grid=1, - compiler_params=dict(triton=dict(num_warps=1, num_stages=1)) - ) - def kernel(x_ref, o_ref): - pl.debug_print("x[0] = ", x_ref[0]) + for i, key in enumerate(random.split(random.key(0), 20)): + x = make_x(key) + y = reduce(x) + y_ref = jnp.cumsum(x, axis=axis) + np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) - x = jnp.array([4.2, 2.4]).astype(jnp.float32) - kernel(x) class PallasOpsInterpretTest(PallasOpsTest): INTERPRET = True @@ -2022,16 +2174,11 @@ class RmsNormInterpreterTest(PallasTest): class SoftmaxTest(PallasTest): - @parameterized.parameters( - (shape, dtype) - for shape in [(1024, 125), (4, 1024, 125)] - for dtype in (jnp.bfloat16, jnp.float16, jnp.float32) + @parameterized.product( + shape=[(1024, 125), (4, 1024, 125)], + dtype=[jnp.bfloat16, jnp.float16, jnp.float32] ) def test_softmax(self, shape, dtype): - # TODO(bchetioui): add Triton bug reference when filed - if dtype == jnp.bfloat16: - raise absltest.SkipTest("Disabled due to Triton lowering bug") - x = jax.random.normal(random.key(0), shape, dtype=dtype) atol, rtol = { @@ -2040,9 +2187,11 @@ def test_softmax(self, shape, dtype): jnp.float32: (1e-7, 1e-6), }[dtype] + # We upcast to float32 because NumPy <2.0 does not handle custom dtypes + # properly. See https://github.com/google/jax/issues/11014. np.testing.assert_allclose( - softmax.softmax(x, axis=-1), - jax.nn.softmax(x, axis=-1), + softmax.softmax(x, axis=-1).astype(jnp.float32), + jax.nn.softmax(x, axis=-1).astype(jnp.float32), atol=atol, rtol=rtol, ) @@ -2135,5 +2284,114 @@ def _(): np.testing.assert_allclose(out, expected, atol=atol) +class PallasCheckifyTest(PallasTest): + # TODO(b/346651778): Support non-interpret mode checkify. + INTERPRET: bool = True + + def test_no_checkify(self,): + def kernel(y_ref): + y_ref[...] = jnp.zeros_like(y_ref[...]) + out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call) + err, result = checked_call() + err.throw() # Should not raise. + np.testing.assert_allclose(result, jnp.zeros_like(result)) + + def test_does_not_clobber_previous_error(self,): + def kernel(y_ref): + y_ref[...] = jnp.zeros_like(y_ref[...]) + checkify.check(False, "error in kernel") + out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + def error_before_call(): + checkify.check(False, "error before call") + return pallas_call() + checked_call = checkify.checkify(error_before_call) + err, result = checked_call() + with self.assertRaisesRegex( + checkify.JaxRuntimeError, "error before call"): + err.throw() + np.testing.assert_allclose(result, jnp.zeros_like(result)) + + @parameterized.parameters((False,), (True,)) + def test_trivial_check(self, assert_cond): + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + checkify.check(assert_cond, "pallas check failed") + input = jnp.arange(4, dtype=jnp.int32) + out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call) + err, result = checked_call(input) + if not assert_cond: + with self.assertRaisesRegex( + checkify.JaxRuntimeError, "pallas check failed"): + err.throw() + np.testing.assert_allclose(result, input) + + def test_nan_error(self): + def kernel(x_ref, y_ref): + y_ref[...] = jnp.log(x_ref[...]) + input = jnp.arange(4, dtype=jnp.float32) - 2 + out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call, + errors=checkify.all_checks) + err, result = checked_call(input) + with self.assertRaisesRegex( + checkify.JaxRuntimeError, "nan generated by primitive: log"): + err.throw() + is_nan = jnp.isnan(result) + np.testing.assert_allclose(is_nan, input < 0) + + def test_nan_error_with_assertion(self): + # TODO(b/346842088): Fix check asserts clobbering other errors. + self.skipTest('Known failure.') + # Test NaN error is not clobbered by an assertion failure + def kernel(x_ref, y_ref): + y_ref[...] = jnp.log(x_ref[...]) + checkify.check(False, "do not raise") + input = jnp.arange(4, dtype=jnp.float32) - 10 + out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call, + errors=checkify.all_checks) + err, _ = checked_call(input) + with self.assertRaisesRegex( + checkify.JaxRuntimeError, "nan generated by primitive: log"): + err.throw() + + @parameterized.parameters((5, 0), (8, 3), (4, 3)) + def test_checkify_returns_first_error_in_grid( + self, num_loops, fail_iteration): + # Check that checkify returns the first error that occurs + # TODO(justinfu): This test doesn't make sense on GPU, where threads run + # in parallel. Update checkify to return a grid of errors. + def kernel(x_ref, _): + value = jnp.squeeze(x_ref[...]) + checkify.check( + value < fail_iteration, "failed on loop {itr}", itr=value) + input_arr = jnp.arange(num_loops, dtype=jnp.float32) + in_specs = [pl.BlockSpec(lambda x : (x,), (1,))] + out_shape = jax.ShapeDtypeStruct((1,), dtype=jnp.float32) + pallas_call = self.pallas_call(kernel, + grid=(num_loops,), + in_specs=in_specs, + out_shape=out_shape) + + checked_call = checkify.checkify(pallas_call, + errors=checkify.all_checks) + err, _ = checked_call(input_arr) + with self.assertRaisesRegex( + checkify.JaxRuntimeError, f"failed on loop {fail_iteration}"): + err.throw() + + if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/splash_attention_kernel_test.py b/tests/pallas/splash_attention_kernel_test.py index 785c594238b8..e6132a1966a3 100644 --- a/tests/pallas/splash_attention_kernel_test.py +++ b/tests/pallas/splash_attention_kernel_test.py @@ -15,9 +15,10 @@ """Tests for splash_attention.""" from __future__ import annotations +from collections.abc import Callable import dataclasses import functools -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar import unittest from absl.testing import absltest @@ -307,13 +308,14 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]: class AttentionTest(jtu.JaxTestCase): def setUp(self): - super().setUp() if not jtu.test_device_matches(["tpu"]): self.skipTest("Need TPU devices") # TODO(b/327487669): selectively re-enable tests that works on TPU v3. if not jtu.is_device_tpu_at_least(4): self.skipTest("Not supported on TPU generations <= 3") + super().setUp() + def _assert_allclose(self, x, y, **kwargs): if x.dtype == np.dtype(jnp.bfloat16): x = x.astype(np.float32) @@ -359,7 +361,7 @@ def test_splash_attention(self, is_mqa, is_segmented, data): attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy()) masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) - mask = mask_lib.MultiHeadMask(tuple((m.get_mask() for m in masks))) + mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: @@ -420,7 +422,7 @@ def test_splash_attention_fwd( segment_ids = data.draw(segment_ids_strategy(q_seq_len)) attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy()) masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) - mask = mask_lib.MultiHeadMask(tuple((m.get_mask() for m in masks))) + mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: attn_ref = splash.make_masked_mqa_reference(mask) @@ -516,21 +518,18 @@ def test_splash_attention_custom_bwd(self, is_segmented, data): atols["dk"] = 0.09 else: raise NotImplementedError - with self.subTest("dv"): - self._assert_allclose( - dv_vanilla, dv_ref, atol=atols_v["dv"], rtol=rtols_v["dv"] - ) - self._assert_allclose(dv, dv_ref, atol=atols["dv"], rtol=rtols["dv"]) - with self.subTest("dq"): - self._assert_allclose( - dq_vanilla, dq_ref, atol=atols_v["dq"], rtol=rtols_v["dq"] - ) - self._assert_allclose(dq, dq_ref, atol=atols["dq"], rtol=rtols["dq"]) - with self.subTest("dk"): - self._assert_allclose( - dk_vanilla, dk_ref, atol=atols_v["dk"], rtol=rtols_v["dk"] - ) - self._assert_allclose(dk, dk_ref, atol=atols["dk"], rtol=rtols["dk"]) + self._assert_allclose( + dv_vanilla, dv_ref, atol=atols_v["dv"], rtol=rtols_v["dv"] + ) + self._assert_allclose(dv, dv_ref, atol=atols["dv"], rtol=rtols["dv"]) + self._assert_allclose( + dq_vanilla, dq_ref, atol=atols_v["dq"], rtol=rtols_v["dq"] + ) + self._assert_allclose(dq, dq_ref, atol=atols["dq"], rtol=rtols["dq"]) + self._assert_allclose( + dk_vanilla, dk_ref, atol=atols_v["dk"], rtol=rtols_v["dk"] + ) + self._assert_allclose(dk, dk_ref, atol=atols["dk"], rtol=rtols["dk"]) @parameterized.product( is_mqa=(False, True), diff --git a/tests/pallas/splash_attention_mask_test.py b/tests/pallas/splash_attention_mask_test.py index a408872100c4..ce7d8fd09182 100644 --- a/tests/pallas/splash_attention_mask_test.py +++ b/tests/pallas/splash_attention_mask_test.py @@ -15,7 +15,6 @@ """Tests for splash_attention_masks.""" from __future__ import annotations -from typing import List from absl.testing import absltest from absl.testing import parameterized import jax @@ -733,7 +732,7 @@ def _expected_local_mask_next(self, mask_base_index: int): _expected_local_mask_next_dkv = _expected_local_mask_next - def _stack(self, arrays: List[np.ndarray]) -> np.ndarray: + def _stack(self, arrays: list[np.ndarray]) -> np.ndarray: return np.stack(arrays, axis=0) # For each test, check both the lazy and the dense versions of the mask. diff --git a/tests/pallas/tpu/BUILD b/tests/pallas/tpu/BUILD new file mode 100644 index 000000000000..4b0ffa941510 --- /dev/null +++ b/tests/pallas/tpu/BUILD @@ -0,0 +1,47 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "jax_generate_backend_suites", + "jax_test", + "py_deps", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +jax_generate_backend_suites() + +jax_test( + name = "pallas_random_test", + srcs = [ + "pallas_random_test.py", + ], + disable_backends = [ + "cpu", + "gpu", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_tpu", + "//jax/_src/pallas/mosaic:random", + "//third_party/py/absl/testing:absltest", + "//third_party/py/absl/testing:parameterized", + ] + py_deps("numpy"), +) diff --git a/tests/pallas/tpu/pallas_random_test.py b/tests/pallas/tpu/pallas_random_test.py new file mode 100644 index 000000000000..64892b707336 --- /dev/null +++ b/tests/pallas/tpu/pallas_random_test.py @@ -0,0 +1,203 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for random ops in Pallas + Mosaic.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import random as jax_random +from jax._src import test_util as jtu +from jax._src.pallas.mosaic import random as plrandom +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() + + +class PRNGTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Need TPU devices") + super().setUp() + + def test_pallas_key_raise_not_implemented_outside_of_kernel(self): + key = jax_random.key(0, impl="rbg") + pallas_key = plrandom.to_pallas_key(key) + # Using a pallas key outside of a kernel should raise an error when + # trying to lower TPU-specific ops to XLA. + # TODO(justinfu): Make this error more specific to pallas PRNG usage. + with self.assertRaisesRegex(NotImplementedError, + "MLIR translation rule .* not found"): + jax.random.uniform( + pallas_key, shape=(1,), minval=0.0, maxval=1.0) + + def test_seeded_reproducibility(self): + # Test whether generating random bits with the same seed + # produces the same result (and different seeds produce + # different results). + def seeded_body(seed: int): + def body(o_ref): + pltpu.prng_seed(seed) + o_ref[...] = pltpu.prng_random_bits(o_ref[...].shape) + return body + + out = jax.ShapeDtypeStruct((8, 128), jnp.int32) + result_1a = pl.pallas_call(seeded_body(0), out_shape=out)() + result_1b = pl.pallas_call(seeded_body(0), out_shape=out)() + result_2 = pl.pallas_call(seeded_body(1), out_shape=out)() + with self.subTest("same_seed_same_result"): + np.testing.assert_array_equal(result_1a, result_1b) + with self.subTest("diff_seed_diff_result"): + np.testing.assert_array_compare(np.not_equal, result_1a, result_2) + + @parameterized.parameters( + ((32, 256),), + ((8, 16),), + ) + def test_prng_non_vreg_shape_output(self, shape): + # Tests that RNG generation works with output shapes + # not equal to a native-sized VREG. + # This test makes sure that vector layout tiling + # is implemented correctly. + def body(o_ref): + pltpu.prng_seed(0) + samples = pltpu.prng_random_bits(o_ref[...].shape) + o_ref[...] = samples + + o_shape = jax.ShapeDtypeStruct(shape, jnp.int32) + result = pl.pallas_call(body, out_shape=o_shape)() + # Check that random_bits generates (mostly) unique values. + unique_frac = float(len(jnp.unique(result))) / np.prod(shape) + self.assertGreater(unique_frac, 0.99) + self.assertLessEqual(jnp.max(result), np.iinfo(jnp.int32).max) + self.assertGreaterEqual(jnp.min(result), np.iinfo(jnp.int32).min) + + def test_stateful_uniform_sample(self): + # Test stateful RNG using the jax.random API wrappers. + def body(key_ref, o_ref): + plrandom.set_seed(key_ref[...]) + o_ref[...] = plrandom.uniform( + shape=o_ref[...].shape, minval=0.0, maxval=1.0) + + rbg_key = jax_random.key(0, impl="rbg") + key = plrandom.to_pallas_key(rbg_key) + o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + out_shape=o_shape, + )(key) + self.assertGreaterEqual(jnp.min(result), 0) + self.assertLessEqual(jnp.max(result), 1.0) + + def test_stateless_uniform_sample(self): + # Test keyed RNG using the jax.random API. + def body(key_ref, o_ref): + o_ref[...] = jax_random.uniform( + key_ref[...], shape=o_ref[...].shape, minval=0.0, maxval=1.0 + ) + + rbg_key = jax_random.key(0, impl="rbg") + key = plrandom.to_pallas_key(rbg_key) + o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + out_shape=o_shape, + )(key) + self.assertGreaterEqual(jnp.min(result), 0) + self.assertLessEqual(jnp.max(result), 1.0) + + def test_fold_in(self): + # Test that folding in a value results in different random numbers. + def body(key_ref, o_ref): + key = key_ref[...] + o_ref[0, ...] = jax_random.uniform( + key, shape=o_ref[0, ...].shape, minval=0.0, maxval=1.0 + ) + + key = jax_random.fold_in(key, 2) + o_ref[1, ...] = jax_random.uniform( + key, shape=o_ref[1, ...].shape, minval=0.0, maxval=1.0 + ) + + rbg_key = jax_random.key(0, impl="rbg") + key = plrandom.to_pallas_key(rbg_key) + o_shape = jax.ShapeDtypeStruct((2, 8, 128), jnp.float32) + result = pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + out_shape=o_shape, + )(key) + result_a = result[0] + result_b = result[1] + np.testing.assert_array_compare(np.not_equal, result_a, result_b) + + +class BlockInvarianceTest(parameterized.TestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Need TPU devices") + super().setUp() + + def test_block_invariance(self): + + def make_kernel_body(index_map): + def body(key_ref, o_ref): + key = key_ref[0, 0] + samples = plrandom.sample_block( + jax.random.uniform, + key, + block_size=o_ref[...].shape, + tile_size=(16, 128), + total_size=(64, 512), + block_index=index_map(pl.program_id(0), pl.program_id(1)), + minval=0.0, + maxval=1.0) + o_ref[...] = samples + return body + + global_key = jax_random.key(0, impl="pallas_tpu") + o_shape = jnp.ones((64, 512), dtype=jnp.float32) + key_spec = pl.BlockSpec(lambda i, j: (0, 0), + block_shape=(1, 1), + memory_space=pltpu.TPUMemorySpace.SMEM) + out_spec = pl.BlockSpec(lambda i, j: (i, j), block_shape=(16, 128)) + result_16x128 = pl.pallas_call( + make_kernel_body(index_map=lambda i, j: (i, j)), + out_shape=o_shape, + in_specs=[key_spec], + out_specs=out_spec, + grid=(4, 4), + )(global_key) + + out_spec = pl.BlockSpec(lambda i, j: (j, i), block_shape=(32, 256)) + result_32x256 = pl.pallas_call( + make_kernel_body(index_map=lambda i, j: (j, i)), + in_specs=[key_spec], + out_shape=o_shape, + out_specs=out_spec, + grid=(2, 2), + )(global_key) + np.testing.assert_array_equal(result_16x128, result_32x256) + + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 3dbf0232fbcf..7dc015c90bca 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -18,38 +18,250 @@ import math import os import tempfile +import unittest from absl.testing import absltest import jax +from jax._src import config +from jax._src import profiler +from jax._src import pjit +from jax._src import monitoring from jax._src import test_util as jtu -from jax.sharding import NamedSharding +from jax._src import api from jax.experimental import profiler as exp_profiler import jax.numpy as jnp -from jax.sharding import PartitionSpec as P +from jax.sharding import NamedSharding, PartitionSpec +from jax._src import compilation_cache as cc import numpy as np +from jax.experimental.serialize_executable import ( + deserialize_and_load, + serialize, +) + jax.config.parse_flags_with_absl() @jtu.pytest_mark_if_available('multiaccelerator') class PgleTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + cc.reset_cache() + + def tearDown(self): + cc.reset_cache() + super().tearDown() + + @unittest.skip("Test failing in CI") + def testPGLEProfilerGetFDOProfile(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + ) + def f(x, y): + return x @ y + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + y = x + 1 + + with config.pgle_profiling_runs(0): + f_lowered = f.lower(x, y) + compiled = f_lowered.compile() + + pgle_profiler = profiler.PGLEProfiler(1, 90) + with config.enable_pgle(False): + with profiler.PGLEProfiler.trace(pgle_profiler): + compiled(x, y) + + fdo_profile = pgle_profiler.consume_fdo_profile() + self.assertIsNotNone(fdo_profile) + self.assertIn(b'custom', fdo_profile) + + @unittest.skip("Test failing in CI") + def testPGLEProfilerGetFDOProfileLarge(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + its = 500 + + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + ) + def f(x): + agg = x + for _ in range(its): + agg = agg @ x + return agg + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + + with config.pgle_profiling_runs(0): + f_lowered = f.lower(x) + f_compiled = f_lowered.compile() + + pgle_profiler = profiler.PGLEProfiler(1, 90) + with config.enable_pgle(False): + with profiler.PGLEProfiler.trace(pgle_profiler): + f_compiled(x) + fdo_profile = pgle_profiler.consume_fdo_profile() + self.assertEqual(fdo_profile.count(b'custom'), its) + + def testAutoPgle(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + ) + def f(x): + return x * 2 + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + expected = x * 2 + + with config.pgle_profiling_runs(2), config.enable_pgle(True): + # Run 1: Module should be compiled without FDO. Two modules are expected + # One is the funtion f, the other one is multi slice module + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 2) + + # Run 2: Second PGLE run should not recompile the module + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + # Run 3: The module should be recompiled with FDO profiles + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 2) + + # Run 4: Fast-path should be used after PGLE is done + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + def testAutoPgleWithAot(self): + @jax.jit + def f(x): + return x * 2 + + x = jnp.arange(1) + expected = x * 2 + + f_lowered = f.lower(x) + serialized, in_tree, out_tree = serialize(f_lowered.compile()) + compiled = deserialize_and_load(serialized, in_tree, out_tree) + + with config.pgle_profiling_runs(1), config.enable_pgle(True): + # Run 1 + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(compiled(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + # Run 2 + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(compiled(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + @unittest.skip("Test failing in CI") + def testAutoPgleWithPersistentCache(self): + + @jax.jit + def f(x): + return x * 2 + + x = jnp.arange(1) + expected = x * 2 + + profilers_dict = ( + pjit._most_recent_pjit_call_executable.weak_pgle_profiler_dict) + with (config.enable_compilation_cache(True), + config.enable_pgle(True), + config.raise_persistent_cache_errors(True), + config.raise_persistent_cache_errors(True), + config.persistent_cache_min_entry_size_bytes(0), + config.persistent_cache_min_compile_time_secs(0), + config.pgle_profiling_runs(2), + tempfile.TemporaryDirectory() as tmpdir): + cc.set_cache_dir(tmpdir) + # Run 1: Module should be compiled without FDO + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 1) + + # Non-pgle profiled version of module should be saved + non_pgle_profiled_files = os.listdir(tmpdir) + self.assertLen(non_pgle_profiled_files, 1) + + # Run 2: Compilation should not be called + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 0) + + # Run 3: Module should be compiled with FDO and stored to persistent cache + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 1) + + for pgle_profiler in profilers_dict.values(): + self.assertTrue(pgle_profiler.is_enabled()) + self.assertTrue(pgle_profiler.is_fdo_consumed()) + # One module is PGLEd version another one is not PGLEd + self.assertLen(os.listdir(tmpdir), 2) + + # Removing non-pgle profiled module from cache to check that later pgle + # profiled version will be used. + os.remove(os.path.join(tmpdir, non_pgle_profiled_files[0])) + + api.clear_caches() + profilers_dict.clear() + + # Run 4: Persistent compilation cache should be hit PGLE profiler should + # be disabled + cache_hit = 0 + def check_if_cache_hit(event): + nonlocal cache_hit + if event == '/jax/compilation_cache/cache_hits': + cache_hit += 1 + + monitoring.register_event_listener(check_if_cache_hit) + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + monitoring._unregister_event_listener_by_callback(check_if_cache_hit) + + self.assertEqual(cache_miss_count[0], 1) + self.assertEqual(cache_hit, 1) + self.assertLen(profilers_dict, 1) + for pgle_profiler in profilers_dict.values(): + self.assertFalse(pgle_profiler.is_enabled()) + self.assertFalse(pgle_profiler.is_fdo_consumed()) def testPassingFDOProfile(self): mesh = jtu.create_global_mesh((2,), ('x',)) + @partial( jax.jit, - in_shardings=NamedSharding(mesh, P('x',)), - out_shardings=NamedSharding(mesh, P('x',)), + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), ) def f(x, y): - z = x @ y - return z @ y + return x @ y shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) y = x + 1 - f_lowered = f.lower(x, y) - compiled = f_lowered.compile() + + with config.pgle_profiling_runs(0): + f_lowered = f.lower(x, y) + compiled = f_lowered.compile() with tempfile.TemporaryDirectory() as tmpdir: jax.profiler.start_trace(tmpdir) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d59f1dcebd61..42b2297b556f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import OrderedDict, namedtuple -import os +import contextlib import re from functools import partial import logging @@ -32,7 +32,6 @@ import jax.numpy as jnp from jax._src import core from jax._src import config -from jax._src import maps from jax._src import test_util as jtu from jax import dtypes from jax import stages @@ -44,7 +43,7 @@ from jax.experimental import multihost_utils from jax.experimental.custom_partitioning import custom_partitioning from jax._src import array -from jax._src.sharding import Sharding +from jax._src.sharding import Sharding, common_devices_indices_map from jax._src import op_shardings from jax._src import sharding_impls from jax._src.sharding_impls import ( @@ -63,32 +62,15 @@ config.parse_flags_with_absl() -prev_xla_flags = None -prev_spmd_lowering_flag = None - +# Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - global prev_spmd_lowering_flag - prev_spmd_lowering_flag = maps.SPMD_LOWERING.value - config.update('experimental_xmap_spmd_lowering', True) + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + _exit_stack.enter_context(jtu.global_config_context(experimental_xmap_spmd_lowering=True)) def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() - config.update('experimental_xmap_spmd_lowering', prev_spmd_lowering_flag) - + _exit_stack.close() def create_array(global_shape, global_mesh, mesh_axes, global_data=None, dtype=np.float32): @@ -1303,6 +1285,16 @@ def f(x): """).strip(), ) + def test_with_sharding_constraint_vmap_spmd_axis_name_error(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + + def f(x): + return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('x'))) + + xs = jnp.arange(4 * 16.).reshape(4, 16) + with self.assertRaisesRegex(ValueError, "spmd_axis_name"): + jax.vmap(f, spmd_axis_name='x')(xs) + @jtu.pytest_mark_if_available('multiaccelerator') class CustomPartitionerTest(jtu.JaxTestCase): @@ -1572,7 +1564,7 @@ def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape, def test_xla_arr_sharding_mismatch(self): global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - global_input_shape = (4, 2) + global_input_shape = (6, 2) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -2199,30 +2191,6 @@ def test_fast_path_array(self): self.assertTrue(out2.sharding.is_equivalent_to(out.sharding, out.ndim)) self.assertArraysEqual(out2, inp_data) - def test_not_xlacompatible_sharding_error(self): - shape = (8, 2) - inp_data = np.arange(math.prod(shape)).reshape(shape) - ts = TempSharding(jax.devices()) - arr = array.make_array_from_callback( - shape, ts, lambda idx: inp_data[idx]) - with self.assertRaisesRegex( - ValueError, - 'One of the argument to pjit got sharding.*which is not a subclass of ' - 'XLACompatibleSharding.'): - pjit(lambda x: x)(arr) - - with self.assertRaisesRegex( - ValueError, - 'One of in_shardings leaf specifications got sharding.*which is ' - 'not a subclass of XLACompatibleSharding.'): - pjit(lambda x: x, in_shardings=ts)(arr) - - with self.assertRaisesRegex( - ValueError, - 'One of out_shardings leaf specifications got sharding.*which is ' - 'not a subclass of XLACompatibleSharding.'): - pjit(lambda x: x, out_shardings=ts)(arr) - def test_array_enabled_non_empty_mesh_with_pspec(self): arr = jnp.array([1, 2, 3]) with self.assertRaisesRegex( @@ -2368,18 +2336,18 @@ def test_out_sharding_indices_id_cache_hit(self): out1 = f(arr) self.assertIsInstance(out1.sharding, NamedSharding) out1.sharding.devices_indices_map(shape) - cache_info1 = sharding_impls.common_devices_indices_map.cache_info() + cache_info1 = common_devices_indices_map.cache_info() out2 = f(out1) self.assertIsInstance(out2.sharding, NamedSharding) out2.sharding.devices_indices_map(shape) - cache_info2 = sharding_impls.common_devices_indices_map.cache_info() + cache_info2 = common_devices_indices_map.cache_info() self.assertEqual(cache_info2.hits, cache_info1.hits + 1) out3 = f(out2) self.assertIsInstance(out3.sharding, NamedSharding) out3.sharding.devices_indices_map(shape) - cache_info3 = sharding_impls.common_devices_indices_map.cache_info() + cache_info3 = common_devices_indices_map.cache_info() self.assertEqual(cache_info3.hits, cache_info2.hits + 1) def test_aot_compile_in_tree_mismatch(self): @@ -2496,6 +2464,8 @@ def f(x, y, z): r"pjit inside pjit with device ids \[0, 1, 2, 3\].*"): my_nested_pjit(committed_inp, committed_inp, committed_inp) + @jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument") def test_jit_device_with_sharding_constraint_error(self): mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) @@ -2921,7 +2891,9 @@ def _check(out, expected_device, expected_out): mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - f = pjit(mul, device=jax.devices()[1]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + f = pjit(mul, device=jax.devices()[1]) x = jnp.arange(8).reshape(4, 2) f_out = f(x) f_out2 = f(f_out) @@ -2936,7 +2908,9 @@ def _check(out, expected_device, expected_out): self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - h = pjit(mul, device=jax.devices()[-1]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + h = pjit(mul, device=jax.devices()[-1]) h_out = h(y) cache_info3 = pjit_lib._pjit_lower_cached.cache_info() _check(h_out, jax.devices()[-1], y) @@ -2956,7 +2930,9 @@ def test_pjit_with_device_arg_input_from_another_pjit(self): out = pjit(lambda x: x * 2)(y) expected_device = jax.devices()[2] - final_out = pjit(lambda x: x * 3, device=expected_device)(out) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + final_out = pjit(lambda x: x * 3, device=expected_device)(out) self.assertEqual(final_out.devices(), {expected_device}) self.assertLen(final_out.sharding.device_set, 1) @@ -2970,7 +2946,9 @@ def _check(out, expected_device, expected_out): self.assertArraysEqual(out, expected_out) x = jnp.arange(8) - g = pjit(lambda x: x, backend='tpu') + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + g = pjit(lambda x: x, backend='tpu') g_out = g(x) _check(g_out, jax.devices()[0], x) @@ -2983,8 +2961,10 @@ def test_autodiff_with_device_arg(self): self.skipTest('Test requires more >1 device.') # Add a constant captured by the nested pjit to make things more complicated h = jnp.arange(4.) - f = pjit(lambda x: x.sum(1) * h.sum(), device=jax.devices()[1]) - g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), device=jax.devices()[1]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + f = pjit(lambda x: x.sum(1) * h.sum(), device=jax.devices()[1]) + g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), device=jax.devices()[1]) jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2) def test_pjit_device_backend_axis_resources_error(self): @@ -2993,13 +2973,17 @@ def test_pjit_device_backend_axis_resources_error(self): ValueError, 'If backend or device is specified on jit, then ' 'in_shardings should not be specified.'): - pjit(lambda x: x, in_shardings=s, backend='cpu') + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + pjit(lambda x: x, in_shardings=s, backend='cpu') with self.assertRaisesRegex( ValueError, 'If backend or device is specified on jit, then ' 'out_shardings should not be specified.'): - pjit(lambda x: x, out_shardings=s, device=jax.devices()[0]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + pjit(lambda x: x, out_shardings=s, device=jax.devices()[0]) def test_check_arg_error(self): sds = jax.ShapeDtypeStruct((4, 2), np.int32) @@ -3014,7 +2998,9 @@ def test_check_arg_error(self): def test_pjit_device_backend_both_error(self): with self.assertRaisesRegex( ValueError, "can't specify both a device and a backend for jit"): - pjit(lambda x: x, device=jax.devices()[0], backend='cpu') + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + pjit(lambda x: x, device=jax.devices()[0], backend='cpu') def test_pjit_mesh_with_device_or_backend_error(self): mesh = jtu.create_global_mesh((1,), ('x',)) @@ -3023,7 +3009,9 @@ def test_pjit_mesh_with_device_or_backend_error(self): ValueError, "Mesh context manager should not be used with jit when backend or " "device is also specified as an argument to jit."): - pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8)) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8)) def test_pjit_inline(self): @partial(pjit, inline=False) @@ -3171,7 +3159,7 @@ def test_jit_with_mesh_context_manager(self): mesh = jtu.create_global_mesh((1,), ('x',)) with self.assertRaisesRegex( RuntimeError, - "jax.jit only supports `XLACompatibleSharding`s being passed to " + "jax.jit only supports `Sharding`s being passed to " "in_shardings"): with mesh: jax.jit(lambda x: x, in_shardings=P('x'), @@ -3240,7 +3228,9 @@ def test_pjit_no_global_cache_hit_axis_resources(self): with jtu.count_pjit_cpp_cache_miss() as count: for _ in range(10): - pjit(lambda x: x * 2, device=jax.devices()[0])(inp) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + pjit(lambda x: x * 2, device=jax.devices()[0])(inp) self.assertEqual(count[0], 10) pf = pjit(lambda x: x * 2, in_shardings=s, out_shardings=s) @@ -3249,7 +3239,9 @@ def test_pjit_no_global_cache_hit_axis_resources(self): pf(inp) self.assertEqual(count[0], 1) - pf1 = pjit(lambda x: x * 2, device=jax.devices()[0]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + pf1 = pjit(lambda x: x * 2, device=jax.devices()[0]) with jtu.count_pjit_cpp_cache_miss() as count: for _ in range(10): pf1(inp) @@ -3664,7 +3656,9 @@ def test_different_named_sharding_object_replicated(self): self.assertNotEqual(x.sharding, y.sharding) def test_vmap_pjit_single_device(self): - jf = pjit(lambda x: x, device=jax.devices()[0]) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + jf = pjit(lambda x: x, device=jax.devices()[0]) out = jax.vmap(jf)(jnp.ones((3,))) # doesn't crash self.assertIsInstance(out.sharding, SingleDeviceSharding) @@ -3681,8 +3675,10 @@ def identity(x): self.assertEqual(out.devices(), {jax.devices()[0]}) self.assertArraysEqual(out, np_inp) - out2 = jax.jit(identity, device=jax.devices()[0])( - jax.device_put(np_inp, NamedSharding(mesh, P('x')))) + with jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument"): + out2 = jax.jit(identity, device=jax.devices()[0])( + jax.device_put(np_inp, NamedSharding(mesh, P('x')))) self.assertEqual(out2.devices(), {jax.devices()[0]}) self.assertArraysEqual(out2, np_inp) @@ -4137,31 +4133,92 @@ def g(x): self.assertNotEqual(f(1), g(1)) self.assertEqual(g(1), h(1)) + def test_wsc_vmap_unconstrained_spmd_axis_name(self): + def get_wsc_eqn_sharding(jaxpr): + for eqn in jaxpr.eqns: + if str(eqn.primitive) == 'sharding_constraint': + return eqn.params['sharding'], eqn.params['unconstrained_dims'] + for s in core.subjaxprs(jaxpr): + return get_wsc_eqn_sharding(s) -class TempSharding(Sharding): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + inp = jnp.ones((10, 10)) - def __init__(self, devices): - super().__init__() - self._devices = devices - self._internal_device_list = xc.DeviceList(tuple(self._devices)) + def a_function(x): + return with_sharding_constraint(x, NamedSharding(mesh, P(P.UNCONSTRAINED))) - @property - def device_set(self): - return set(self._devices) + def vmap_the_function_spmd(y): + return jax.vmap(a_function, spmd_axis_name='x')(y) - def devices_indices_map(self, global_shape): - return {d: (slice(None),) * len(global_shape) for d in self.device_set} + f1 = jax.jit(vmap_the_function_spmd) + f1(inp) # doesn't crash + jaxpr1 = jax.make_jaxpr(f1)(inp) + s1, u1 = get_wsc_eqn_sharding(jaxpr1) + self.assertEqual(s1.spec, P('x', P.UNCONSTRAINED)) + self.assertEqual(u1, {1}) - def shard_shape(self, global_shape): - return global_shape + def vmap_the_function_no_spmd(y): + return jax.vmap(a_function)(y) - @property - def memory_kind(self): - return None + f2 = jax.jit(vmap_the_function_no_spmd) + f2(inp) # doesn't crash + jaxpr2 = jax.make_jaxpr(f2)(inp) + s2, u2 = get_wsc_eqn_sharding(jaxpr2) + self.assertEqual(s2.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + self.assertEqual(u2, {0, 1}) - @property - def is_fully_replicated(self): - return True + def test_aot_sharding_dce(self): + inp = np.arange(8) + + @jax.jit + def f(x, y): + return x + + input_shardings, _ = f.lower(inp, inp).compile().input_shardings + self.assertLen(input_shardings, 2) + + def test_aot_out_info(self): + inp = np.arange(8, dtype=np.int32) + out_info = jax.jit(lambda x: x).lower((inp, inp)).out_info + self.assertEqual(out_info[0].shape, (8,)) + self.assertEqual(out_info[1].shape, (8,)) + self.assertEqual(out_info[0].dtype, np.int32) + self.assertEqual(out_info[1].dtype, np.int32) + self.assertEqual(out_info[0].sharding, None) + self.assertEqual(out_info[1].sharding, None) + + def test_jit_trace(self): + def f(x): + return x * 2 + + traced = jax.jit(f).trace(jnp.arange(8, dtype=jnp.int32)) + self.assertLen(traced.jaxpr.eqns, 1) + self.assertEqual(jax.tree.structure(traced.out_info).num_leaves, 1) + self.assertEqual(traced.out_info.shape, (8,)) + self.assertEqual(traced.out_info.dtype, jnp.int32) + # one for args, one for kwargs (though kwargs is empty) + self.assertLen(traced.in_avals, 2) + self.assertLen(traced.in_avals[0], 1) + self.assertLen(traced.in_avals[1], 0) # empty kwarg + + def test_jit_trace_lower_and_compile(self): + def f(x): + return x * 2 + + lowered = jax.jit(f).trace(jnp.arange(8)).lower() + self.assertEqual(lowered.args_info[0][0].shape, (8,)) + + compiled = lowered.compile() + out = compiled(jnp.arange(8)) + self.assertArraysEqual(out, np.arange(8) * 2) + + # fast-forward + lowered2 = jax.jit(f).lower(jnp.arange(8)) + self.assertEqual(lowered2.args_info[0][0].shape, (8,)) + + compiled2 = lowered2.compile() + out2 = compiled2(jnp.arange(8)) + self.assertArraysEqual(out2, np.arange(8) * 2) def spec_regex(s): @@ -4270,8 +4327,7 @@ def testRankTooLowConstraint(self): r".*rank at least 2, but was applied to a value of rank 1", re.M | re.S) with self.assertRaisesRegex(ValueError, error): pjit( - lambda x: with_sharding_constraint(x, spec), - in_shardings=None, + lambda x: with_sharding_constraint(x, spec), in_shardings=None, out_shardings=None, )(x) @@ -4706,19 +4762,19 @@ def test_device_indices_cache(self): ops = GSPMDSharding(devices, op1) ops.devices_indices_map(shape) - cache_info1 = sharding_impls.gspmd_sharding_devices_indices_map.cache_info() + cache_info1 = common_devices_indices_map.cache_info() ops.devices_indices_map(shape) - cache_info2 = sharding_impls.gspmd_sharding_devices_indices_map.cache_info() + cache_info2 = common_devices_indices_map.cache_info() self.assertEqual(cache_info2.hits, cache_info1.hits + 1) ops = GSPMDSharding(devices, op2) ops.devices_indices_map(shape) - cache_info3 = sharding_impls.gspmd_sharding_devices_indices_map.cache_info() + cache_info3 = common_devices_indices_map.cache_info() self.assertEqual(cache_info3.hits, cache_info2.hits + 1) ops.devices_indices_map(shape) - cache_info4 = sharding_impls.gspmd_sharding_devices_indices_map.cache_info() + cache_info4 = common_devices_indices_map.cache_info() self.assertEqual(cache_info4.hits, cache_info3.hits + 1) def test_op_sharding_semantically_replicated(self): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 6dc2e744bab6..5e279a5e6daa 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -15,11 +15,11 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor +import contextlib from functools import partial import itertools as it import gc import math -import os from random import shuffle import re from typing import Union, cast @@ -46,7 +46,6 @@ from jax._src import sharding_impls from jax._src import sharding_specs from jax._src import test_util as jtu -from jax._src import xla_bridge from jax._src.internal_test_util import lax_test_util from jax._src.interpreters import mlir from jax._src.interpreters import pxla @@ -56,7 +55,14 @@ config.parse_flags_with_absl() -prev_xla_flags = None +# Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() + +def setUpModule(): + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + +def tearDownModule(): + _exit_stack.close() compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]] @@ -85,26 +91,6 @@ def args_slicer(args, bdims): slicers = safe_map(slicer, args, bdims) return lambda i: [sl(i) for sl in slicers] -# Run all tests with 8 CPU devices. -def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - -# Reset to previous configuration in case other test modules will be run. -def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() - ignore_jit_of_pmap_warning = partial( jtu.ignore_warning, message=".*jit-of-pmap.*") @@ -1243,7 +1229,7 @@ def print_board(board): boards.append(''.join('*' if x else ' ' for x in board.ravel())) print_board(reshaped_board) - for _ in range(20): + for _ in range(9): reshaped_board = step(reshaped_board) print_board(reshaped_board) @@ -1259,17 +1245,6 @@ def print_board(board): ' ** **** ****** ', ' ** * *** * ', ' ** **** ** * *** ', - ' ** * * **** ** * ', - ' ** **** ** * * **** ', - ' ** * *** ** ** * * ', - ' ** **** ** *** *** ** *** ', - ' ** * * *** * *** * * ', - ' ** **** ** * * ***** ******* ', - ' ** * *** **** * *** * ', - ' ** **** ** *** ** ** * *** ', - ' ** * * *** * ** *** **** ** * ', - ' ** **** ** * ****** * * *** ****', - ' * * *** **** **** *** ** * ', )) print(ans) @@ -1817,8 +1792,8 @@ def matrix_vector(x, y, parallel=True): res = fv(x) return res - x = random.normal(random.PRNGKey(1), (80, 5)) - y = random.normal(random.PRNGKey(1), (10, 5)) + x = random.normal(random.PRNGKey(1), (40, 5)) + y = random.normal(random.PRNGKey(1), (5, 5)) result1 = vmap(lambda b: matrix_vector(x, b, True))(y) # vmap + pmap result2 = lax.map(lambda b: matrix_vector(x, b, False), y) # map + map @@ -2075,13 +2050,10 @@ def test_grad_of_pmap_compilation_caching(self, axis_size): def f(x): return jnp.sin(x) + # warm-up the cache x = jnp.ones(axis_size) - f(x) # warm-up any dispatching compilations - - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 - _, f_bwd = jax.vjp(f, x) - _ = f_bwd(x) - self.assertEqual(count[0], 2) # one for fwd, one for bwd + _, f_bwd = jax.vjp(f, x) + _ = f_bwd(x) with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 _, f_bwd2 = jax.vjp(f, x) @@ -2530,7 +2502,7 @@ def testOneDevice(self): f = lambda x: jnp.dot(x, x.T) f0 = pmap(f, devices=[d0]) f1 = pmap(f, devices=[d1]) - x = self.rng().rand(1, 1000, 1000) + x = self.rng().rand(1, 500, 500) r0 = f0(x) r1 = f1(x) expected = np.expand_dims(np.dot(x.squeeze(), x.squeeze().T), 0) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 5337c72e0cd0..8eccffb9b773 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -16,7 +16,6 @@ import contextlib import functools import logging -import textwrap import time import unittest @@ -41,22 +40,13 @@ config.parse_flags_with_absl() - -def _format_multiline(text): - return textwrap.dedent(text).lstrip() - -prev_xla_flags = None - +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - # This will control the CPU devices. On TPU we always have 2 devices - prev_xla_flags = jtu.set_host_platform_device_count(2) - + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - prev_xla_flags() + _exit_stack.close() map, unsafe_map = util.safe_map, map diff --git a/tests/qdwh_test.py b/tests/qdwh_test.py index 19dd5841463c..705c96f00f05 100644 --- a/tests/qdwh_test.py +++ b/tests/qdwh_test.py @@ -15,181 +15,155 @@ """Tests for the library of QDWH-based polar decomposition.""" import functools +from absl.testing import absltest import jax -import jax.numpy as jnp -import numpy as np -import scipy.linalg as osp_linalg from jax._src import config from jax._src import test_util as jtu from jax._src.lax import qdwh - -from absl.testing import absltest - +import jax.numpy as jnp +import numpy as np config.parse_flags_with_absl() -_JAX_ENABLE_X64_QDWH = config.enable_x64.value - -# Input matrix data type for QdwhTest. -_QDWH_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64_QDWH else np.float32 - -# Machine epsilon used by QdwhTest. -_QDWH_TEST_EPS = jnp.finfo(_QDWH_TEST_DTYPE).eps - -# Largest log10 value of condition numbers used by QdwhTest. -_MAX_LOG_CONDITION_NUM = np.log10(int(1 / _QDWH_TEST_EPS)) +float_types = jtu.dtypes.floating +complex_types = jtu.dtypes.complex -def _check_symmetry(x: jax.Array) -> bool: - """Check if the array is symmetric.""" - m, n = x.shape - eps = jnp.finfo(x.dtype).eps - tol = 50.0 * eps - is_hermitian = False - if m == n: - if np.linalg.norm(x - x.T.conj()) / np.linalg.norm(x) < tol: - is_hermitian = True - return is_hermitian - -def _compute_relative_diff(actual, expected): +def _compute_relative_normwise_diff(actual, expected): """Computes relative difference between two matrices.""" return np.linalg.norm(actual - expected) / np.linalg.norm(expected) -_dot = functools.partial(jnp.dot, precision="highest") +_dot = functools.partial(jnp.dot, precision='highest') -class QdwhTest(jtu.JaxTestCase): - @jtu.sample_product( - [dict(m=m, n=n) for m, n in [(8, 6), (10, 10), (20, 18)]], - log_cond=np.linspace(1, _MAX_LOG_CONDITION_NUM, 4), - ) - def testQdwhUnconvergedAfterMaxNumberIterations( - self, m, n, log_cond): - """Tests unconvergence after maximum number of iterations.""" - a = jnp.triu(jnp.ones((m, n))) - u, s, v = jnp.linalg.svd(a, full_matrices=False) - cond = 10**log_cond - s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) - with jax.numpy_dtype_promotion('standard'): - a = (u * s) @ v - is_hermitian = _check_symmetry(a) - max_iterations = 2 +class QdwhTest(jtu.JaxTestCase): - _, _, actual_num_iterations, is_converged = qdwh.qdwh( - a, is_hermitian=is_hermitian, max_iterations=max_iterations) + def _testReconstruction(self, a, u, h, tol): + """Tests that a = u*p.""" + with self.subTest('Test reconstruction'): + diff = _compute_relative_normwise_diff(_dot(u, h), a) + self.assertLessEqual(diff, tol) - with self.subTest('Number of iterations.'): - self.assertEqual(max_iterations, actual_num_iterations) + def _testUnitary(self, u, tol): + """Tests that u is unitary.""" + with self.subTest('Test unitary'): + m, n = u.shape + self.assertAllClose( + _dot(u.conj().T, u), np.eye(n, dtype=u.dtype), atol=tol, rtol=tol + ) - with self.subTest('Converged.'): - self.assertFalse(is_converged) + def _testHermitian(self, h, tol): + """Tests that h is Hermitian.""" + with self.subTest('Test hermitian'): + self.assertAllClose(h, h.conj().T, atol=tol, rtol=tol) + + def _testPolarDecomposition(self, a, u, h, tol): + """Tests that u*h is the polar decomposition of a""" + self._testReconstruction(a, u, h, tol) + self._testUnitary(u, tol) + self._testHermitian(h, tol) + + def _testQdwh(self, a, dynamic_shape=None): + """Computes the polar decomposition and tests its basic properties.""" + eps = jnp.finfo(a.dtype).eps + u, h, iters, conv = qdwh.qdwh(a, dynamic_shape=dynamic_shape) + tol = 13 * eps + if dynamic_shape is not None: + m, n = dynamic_shape + a = a[:m, :n] + u = u[:m, :n] + h = h[:n, :n] + self._testPolarDecomposition(a, u, h, tol=tol) @jtu.sample_product( - [dict(m=m, n=n) for m, n in [(8, 6), (10, 10), (20, 18)]], - log_cond=np.linspace(1, _MAX_LOG_CONDITION_NUM, 4), + shape=[(8, 6), (10, 10), (20, 18)], + dtype=float_types + complex_types, ) - def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond): + def testQdwhWithUpperTriangularInputAllOnes(self, shape, dtype): """Tests qdwh with upper triangular input of all ones.""" - a = jnp.triu(jnp.ones((m, n))).astype(_QDWH_TEST_DTYPE) - u, s, v = jnp.linalg.svd(a, full_matrices=False) - cond = 10**log_cond - s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) - a = (u * s) @ v - is_hermitian = _check_symmetry(a) - max_iterations = 10 + eps = jnp.finfo(dtype).eps + m, n = shape + a = jnp.triu(jnp.ones((m, n))).astype(dtype) + self._testQdwh(a) - actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=is_hermitian, - max_iterations=max_iterations) - expected_u, expected_h = osp_linalg.polar(a) - - # Sets the test tolerance. - rtol = 1E6 * _QDWH_TEST_EPS - - with self.subTest('Test u.'): - relative_diff_u = _compute_relative_diff(actual_u, expected_u) - np.testing.assert_almost_equal(relative_diff_u, 1E-6, decimal=5) - - with self.subTest('Test h.'): - relative_diff_h = _compute_relative_diff(actual_h, expected_h) - np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5) - - with self.subTest('Test u.dot(h).'): - a_round_trip = _dot(actual_u, actual_h) - relative_diff_a = _compute_relative_diff(a_round_trip, a) - np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5) - - with self.subTest('Test orthogonality.'): - actual_results = _dot(actual_u.T, actual_u) - expected_results = np.eye(n) - self.assertAllClose( - actual_results, expected_results, rtol=rtol, atol=1E-5) + @jtu.sample_product( + shape=[(2, 2), (5, 5), (8, 5), (10, 10)], + dtype=float_types + complex_types, + ) + def testQdwhWithDynamicShape(self, shape, dtype): + """Tests qdwh with dynamic shapes.""" + rng = jtu.rand_uniform(self.rng()) + a = rng((10, 10), dtype) + self._testQdwh(a, dynamic_shape=shape) @jtu.sample_product( - [dict(m=m, n=n) for m, n in [(6, 6), (8, 4)]], - padding=(None, (3, 2)), - log_cond=np.linspace(1, 4, 4), + shape=[(8, 6), (10, 10), (20, 18), (300, 300)], + log_cond=np.linspace(0, 1, 4), + dtype=float_types + complex_types, ) - def testQdwhWithRandomMatrix(self, m, n, log_cond, padding): - """Tests qdwh with random input.""" - rng = jtu.rand_uniform(self.rng(), low=0.3, high=0.9) - a = rng((m, n), _QDWH_TEST_DTYPE) - u, s, v = jnp.linalg.svd(a, full_matrices=False) + def testQdwhWithRandomMatrix(self, shape, log_cond, dtype): + """Tests qdwh with upper triangular input of all ones.""" + eps = jnp.finfo(dtype).eps + m, n = shape + max_cond = np.log10(1.0 / eps) + log_cond = log_cond * max_cond cond = 10**log_cond + + # Generates input matrix with prescribed condition number. + rng = jtu.rand_uniform(self.rng()) + a = rng((m, n), dtype) + u, _, v = jnp.linalg.svd(a, full_matrices=False) s = jnp.expand_dims(jnp.linspace(cond, 1, min(m, n)), range(u.ndim - 1)) - a = (u * s) @ v - is_hermitian = _check_symmetry(a) - max_iterations = 10 + a = (u * s.astype(u.dtype)) @ v + self._testQdwh(a) + @jtu.sample_product( + [dict(m=m, n=n) for m, n in [(6, 6), (8, 4)]], + padding=(None, (3, 2)), + dtype=float_types + complex_types, + ) + def testQdwhJitCompatibility(self, m, n, padding, dtype): + """Tests JIT compilation of QDWH with and without dynamic shape.""" + rng = jtu.rand_uniform(self.rng()) + a = rng((m, n), dtype) def lsp_linalg_fn(a): if padding is not None: pm, pn = padding a = jnp.pad(a, [(0, pm), (0, pn)], constant_values=jnp.nan) - u, h, _, _ = qdwh.qdwh( - a, is_hermitian=is_hermitian, max_iterations=max_iterations, - dynamic_shape=(m, n) if padding else None) + u, h, _, _ = qdwh.qdwh(a, dynamic_shape=(m, n) if padding else None) if padding is not None: u = u[:m, :n] h = h[:n, :n] return u, h args_maker = lambda: [a] - - # Sets the test tolerance. - rtol = 1E6 * _QDWH_TEST_EPS - with self.subTest('Test JIT compatibility'): self._CompileAndCheck(lsp_linalg_fn, args_maker) - with self.subTest('Test against numpy.'): - self._CheckAgainstNumpy(osp_linalg.polar, lsp_linalg_fn, args_maker, - rtol=rtol, atol=1E-3) - @jtu.sample_product( - [dict(m=m, n=n, r=r) for m, n, r in [(10, 10, 8), (8, 8, 7), (12, 8, 5)]], - log_cond=np.linspace(1, 4, 4), + [dict(m=m, n=n, r=r) for m, n, r in [(10, 10, 8), (8, 8, 7), (12, 8, 5)]], + log_cond=np.linspace(0, 1, 4), + dtype=float_types + complex_types, ) - def testQdwhOnRankDeficientInput(self, m, n, r, log_cond): + def testQdwhOnRankDeficientInput(self, m, n, r, log_cond, dtype): """Tests qdwh on rank-deficient input.""" - a = np.triu(np.ones((m, n))).astype(_QDWH_TEST_DTYPE) + eps = jnp.finfo(dtype).eps + a = np.triu(np.ones((m, n))).astype(dtype) - # Generates a rank-deficient input. + # Generates a rank-deficient input with prescribed condition number. + max_cond = np.log10(1.0 / eps) + log_cond = log_cond * max_cond u, _, vh = np.linalg.svd(a, full_matrices=False) s = 10**jnp.linspace(log_cond, 0, min(m, n)) + print(s) s = jnp.expand_dims(s.at[r:].set(0), range(u.ndim - 1)) - a = (u * s) @ vh + a = (u * s.astype(u.dtype)) @ vh - actual_u, actual_h, _, _ = qdwh.qdwh(a, is_hermitian=_check_symmetry(a)) - _, expected_h = osp_linalg.polar(a) + actual_u, actual_h, _, _ = qdwh.qdwh(a) - with self.subTest('Test h.'): - relative_diff_h = _compute_relative_diff(actual_h, expected_h) - np.testing.assert_almost_equal(relative_diff_h, 1E-6, decimal=5) - - with self.subTest('Test u.dot(h).'): - a_round_trip = _dot(actual_u, actual_h) - relative_diff_a = _compute_relative_diff(a_round_trip, a) - np.testing.assert_almost_equal(relative_diff_a, 1E-6, decimal=5) + self._testHermitian(actual_h, 10 * eps) + self._testReconstruction(a, actual_u, actual_h, 60 * eps) # QDWH gives U_p = U Σₖ V* for input A with SVD A = U Σ V*. For full rank # input, we expect convergence Σₖ → I, giving the correct polar factor @@ -202,34 +176,31 @@ def testQdwhOnRankDeficientInput(self, m, n, r, log_cond): vr = vh.conj().T[:, :r] uvr = _dot(actual_u, vr) actual_results = _dot(uvr.T.conj(), uvr) - expected_results = np.eye(r) + expected_results = np.eye(r, dtype=actual_u.dtype) self.assertAllClose( - actual_results, expected_results, rtol=_QDWH_TEST_EPS, atol=1e-6 + actual_results, expected_results, atol=25 * eps, rtol=25 * eps ) @jtu.sample_product( - [dict(m=m, n=n, r=r, c=c) for m, n, r, c in [(4, 3, 1, 1), (5, 2, 0, 0)]], - dtype=jtu.dtypes.floating, + [dict(m=m, n=n, r=r, c=c) for m, n, r, c in [(4, 3, 1, 1), (5, 2, 0, 0)]], + dtype=float_types + complex_types, ) def testQdwhWithTinyElement(self, m, n, r, c, dtype): """Tests qdwh on matrix with zeros and close-to-zero entries.""" a = jnp.zeros((m, n), dtype=dtype) - tiny_elem = jnp.finfo(a.dtype).tiny + one = dtype(1.0) + tiny_elem = dtype(jnp.finfo(a.dtype).tiny) a = a.at[r, c].set(tiny_elem) - is_hermitian = _check_symmetry(a) - max_iterations = 10 - @jax.jit def lsp_linalg_fn(a): - u, h, _, _ = qdwh.qdwh( - a, is_hermitian=is_hermitian, max_iterations=max_iterations) + u, h, _, _ = qdwh.qdwh(a) return u, h actual_u, actual_h = lsp_linalg_fn(a) expected_u = jnp.zeros((m, n), dtype=dtype) - expected_u = expected_u.at[r, c].set(1.0) + expected_u = expected_u.at[r, c].set(one) with self.subTest('Test u.'): np.testing.assert_array_equal(expected_u, actual_u) diff --git a/tests/random_test.py b/tests/random_test.py index 74659f2f0a9c..2c45d60cc64d 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -45,6 +45,9 @@ PRNG_IMPLS = list(prng_internal.prngs.items()) +# Remove Pallas keys from this test, which do not run in XLA. +PRNG_IMPLS = [ + (name, impl) for (name, impl) in PRNG_IMPLS if "pallas" not in name] class OnX64(enum.Enum): @@ -390,10 +393,16 @@ def test_threefry_gpu_kernel_lowering(self): f = lambda key: jax.random.uniform(key, (1,)) with jax._src.config.threefry_gpu_kernel_lowering(False): hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text() - self.assertNotIn("cu_threefry2x32", hlo_text) + if jtu.is_device_rocm(): + self.assertNotIn("hip_threefry2x32", hlo_text) + else: + self.assertNotIn("cu_threefry2x32", hlo_text) with jax._src.config.threefry_gpu_kernel_lowering(True): hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text() - self.assertIn("cu_threefry2x32", hlo_text) + if jtu.is_device_rocm(): + self.assertIn("hip_threefry2x32", hlo_text) + else: + self.assertIn("cu_threefry2x32", hlo_text) @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_random_seed_offset(self, make_key): diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index a0fe16d0849c..501f4cbe5e5f 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -1633,21 +1633,25 @@ def testRankData(self, shape, dtype, axis, method): self._CompileAndCheck(lax_fun, args_maker, rtol=tol) @jtu.sample_product( - [dict(shape=shape, axis=axis, ddof=ddof, nan_policy=nan_policy) + [dict(shape=shape, axis=axis, ddof=ddof, nan_policy=nan_policy, keepdims=keepdims) for shape in [(5,), (5, 6), (5, 6, 7)] for axis in [None, *range(len(shape))] for ddof in [0, 1, 2, 3] for nan_policy in ["propagate", "omit"] + for keepdims in [True, False] ], dtype=jtu.dtypes.integer + jtu.dtypes.floating, ) - def testSEM(self, shape, dtype, axis, ddof, nan_policy): + def testSEM(self, shape, dtype, axis, ddof, nan_policy, keepdims): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy) - lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy) + kwds = {} if scipy_version < (1, 11) else {'keepdims': keepdims} + scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy, + **kwds) + lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy, + **kwds) tol_spec = {np.float32: 2e-4, np.float64: 5e-6} tol = jtu.tolerance(dtype, tol_spec) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 700a25b5db04..f079d6753edd 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -16,13 +16,13 @@ from __future__ import annotations import enum -from collections.abc import Sequence +from collections.abc import Callable, Sequence import cProfile import itertools import math import os from pstats import Stats -from typing import Any, Callable +from typing import Any import unittest from absl import logging @@ -35,9 +35,7 @@ import re import jax -from jax.experimental import export -from jax.experimental.export import _shape_poly as shape_poly -from jax.experimental.export import _shape_poly_decision as shape_poly_decision +from jax import export from jax.experimental import pjit from jax import lax import jax.numpy as jnp @@ -46,6 +44,8 @@ from jax._src import config from jax._src import core from jax._src import test_util as jtu +from jax._src.export import shape_poly +from jax._src.export import shape_poly_decision from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow from jax._src.lib import xla_client @@ -1267,11 +1267,11 @@ def log_message(extra: str): tst.assertEqual(getattr(jax.config, fname), fvalue, ( f"Flag {fname} current value {getattr(jax.config, fname)} != {fvalue}")) - f_jax = self.dyn_fun + f_jax = jax.jit(self.dyn_fun) args = self.dyn_args_maker(tst.rng()) args = jax.tree.map(jnp.array, args) args_specs = export.symbolic_args_specs(args, self.polymorphic_shapes, - symbolic_constraints=self.symbolic_constraints) + constraints=self.symbolic_constraints) if self.expect_error is not None: with tst.assertRaisesRegex(self.expect_error[0], self.expect_error[1]): @@ -1283,7 +1283,7 @@ def log_message(extra: str): return None # Run the JAX natively and then the exported function and compare res_jax_native = f_jax(*args) - res_jax_exported = export.call_exported(exp)(*args) + res_jax_exported = exp.call(*args) custom_assert_lims = [ l for l in self.limitations if l.custom_assert is not None] assert len(custom_assert_lims) <= 1, custom_assert_lims @@ -1315,7 +1315,7 @@ def check_shape_poly(tst, f_jax: Callable, *, symbolic_constraints: Sequence[str] = (), expect_error=None) -> jax.Array | None: # Builds a PolyHarness and runs the test. See PolyHarness documentation. - h = PolyHarness("", "", f_jax, + h = PolyHarness("", "", jax.jit(f_jax), arg_descriptors=arg_descriptors, polymorphic_shapes=polymorphic_shapes, symbolic_constraints=symbolic_constraints, @@ -1408,11 +1408,10 @@ def test_kwargs(self): def f_jax(x, *, y): return x + jnp.sin(y) - f_exported = export.call_exported( - export.export(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), - x.dtype), - y=jax.ShapeDtypeStruct(y.shape, y.dtype))) - self.assertAllClose(f_jax(x, y=y), f_exported(x, y=y)) + exp = export.export(jax.jit(f_jax))( + jax.ShapeDtypeStruct(export.symbolic_shape("b"), x.dtype), + y=jax.ShapeDtypeStruct(y.shape, y.dtype)) + self.assertAllClose(f_jax(x, y=y), exp.call(x, y=y)) def test_arg_avals_errors(self): """Test error reporting for shape polymorphism.""" @@ -1617,8 +1616,8 @@ def f(x): # x: i32[a, b] acc += jnp.sum(slice, axis=0) return acc - _ = export.export(f)(jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), - np.int32)) + _ = export.export(jax.jit(f))( + jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), np.int32)) def test_constraints_compile_time_check(self): @@ -1630,29 +1629,30 @@ def f(x): # x: i32[a] x_spec = jax.ShapeDtypeStruct( export.symbolic_shape("a", constraints=["a >= 2", "a <= 4"]), np.int32) - exp = export.export(f)(x_spec) + exp = export.export(jax.jit(f))(x_spec) x_2 = np.arange(2, dtype=np.int32) - res_2 = export.call_exported(exp)(x_2) + res_2 = exp.call(x_2) self.assertAllClose(x_2[0:2], res_2) x_4 = np.arange(4, dtype=np.int32) - res_4 = export.call_exported(exp)(x_4) + res_4 = exp.call(x_4) self.assertAllClose(x_4[1:3], res_4) with self.assertRaisesRegex( ValueError, re.escape("Expected 'a - 2' to be greater or equal to 0, but found -1")): - export.call_exported(exp)(np.arange(1, dtype=np.int32)) + exp.call(np.arange(1, dtype=np.int32)) with self.assertRaisesRegex( ValueError, re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")): - export.call_exported(exp)(np.arange(5, dtype=np.int32)) + exp.call(np.arange(5, dtype=np.int32)) def test_caching_with_scopes(self): f_tracing_count = 0 expected_a_bounds = (1, np.inf) + @jax.jit def f(x): # x: i32[a] nonlocal f_tracing_count f_tracing_count += 1 @@ -1997,6 +1997,30 @@ def test_vmap_error(self): + jnp.sin(x))), arg_descriptors=[RandArg((3, 4), _f32)], polymorphic_shapes=["b, ..."]), + [ # approx_max_k + # x: f32[b, {n}, 32] with n being either 8 or the symbol "n" + # we reduce on dim=1, with size n + # k is either the constant 4 or the symbol "k" + PolyHarness("approx_max_k", f"n_{n}_k_{k}_agg={agg}", + lambda x, x_k, agg: lax.approx_max_k( + x, k=x_k.shape[0], reduction_dimension=1, + aggregate_to_topk=agg), + arg_descriptors=[RandArg((3, 8, 32), _f32), + RandArg((4,), _f32), + StaticArg(agg)], + polymorphic_shapes=[f"b, {n}, 32", f"{k},"], + # k must be at most the reduction dimension size + symbolic_constraints=[f"{k} <= {n}"], + expect_error=( + (NotImplementedError, "aggregate_to_topk=False") if ( + not agg and (isinstance(k, str) or + isinstance(n, str))) else + None + )) + for n in [8, "n"] + for k in [4, "k"] + for agg in [True, False] + ], [ # arange PolyHarness("arange", name, f_jax, @@ -3071,6 +3095,12 @@ def test_vmap_error(self): lambda x: jnp.tri(x.shape[0], M=x.shape[0] + 2) + x, arg_descriptors=[RandArg((3, 1), _f32)], polymorphic_shapes=["b, ..."]), + PolyHarness("tril", "", + lambda x: jnp.tril(jnp.ones((x.shape[0], x.shape[0] + x.shape[1]), + dtype=_f32), + k=x.shape[1]), + arg_descriptors=[RandArg((3, 4), _f32)], + polymorphic_shapes=["m, n"]), [ PolyHarness("triangular_solve", f"shape={jtu.format_shape_dtype_string(a_shape, dtype)}_{left_side=}_{a_poly=}_{b_poly=}", @@ -3259,7 +3289,7 @@ def test_harness(self, harness: PolyHarness): if harness.group_name == "vmap_eigh" and jtu.test_device_matches(["gpu"]): # For eigh on GPU with shape polymorphism under native serialization, - # we use a different lowering for small matrices. See README.md. + # we use a different lowering for small matrices. shape = harness.original_harness.params["shape"] if 0 < shape[-1] <= 32: harness.check_result = False @@ -3296,14 +3326,13 @@ def test_harness(self, harness: PolyHarness): if "random_gamma" in harness.group_name: config_flags = {**config_flags, "jax_debug_key_reuse": False} - prev_jax_config_flags = {fname: getattr(jax.config, fname) for fname in config_flags} - try: - for fname, fvalue in config_flags.items(): - jax.config.update(fname, fvalue) + # TPU precision is a little lower since we swap the order of matmul operands. + if "cholesky" in harness.group_name and jtu.test_device_matches(["tpu"]): + harness.tol = 5e-5 + + with jtu.global_config_context(**config_flags): harness.run_test(self) - finally: - for fname, _ in config_flags.items(): - jax.config.update(fname, prev_jax_config_flags[fname]) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 52e1c551d061..11305e937a08 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +import contextlib import jax import jax.numpy as jnp import numpy as np from absl.testing import absltest -from jax._src import xla_bridge from jax._src import test_util as jtu from jax.sharding import NamedSharding, PartitionSpec as P from jax.experimental.shard_alike import shard_alike @@ -26,26 +25,14 @@ jax.config.parse_flags_with_absl() -prev_xla_flags = None - +# Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() + _exit_stack.close() class ShardAlikeDownstreamTest(jtu.JaxTestCase): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 9b3076434a56..ca9d813e2571 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -14,14 +14,14 @@ from __future__ import annotations -from collections.abc import Sequence, Iterable, Iterator, Generator +from collections.abc import Callable, Generator, Iterable, Iterator, Sequence +import contextlib from functools import partial import itertools as it import math import operator as op -import os from types import SimpleNamespace -from typing import Any, NamedTuple, Callable, TypeVar +from typing import Any, NamedTuple, TypeVar import unittest from absl.testing import absltest @@ -36,14 +36,12 @@ from jax._src import config from jax._src import core from jax._src import test_util as jtu -from jax._src import xla_bridge from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src import linear_util as lu from jax._src import tree_util -from jax._src.lib import xla_extension_version import jax.numpy as jnp from jax.experimental.custom_partitioning import custom_partitioning @@ -68,31 +66,17 @@ def create_inputs(a_sharding, b_sharding): jax.sharding.NamedSharding(mesh, b_sharding)) return mesh, m1, m2 -# Run all tests with 8 CPU devices. -prev_xla_flags = None # Run all tests with 8 CPU devices. -def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() +_exit_stack = contextlib.ExitStack() +def setUpModule(): + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) if len(jax.devices()) < 8: raise unittest.SkipTest("tests require 8 devices") -# Reset to previous configuration in case other test modules will be run. def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() + _exit_stack.close() class ShardMapTest(jtu.JaxTestCase): @@ -1646,6 +1630,10 @@ def f(x): v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + self.assertIn( + 'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}}', + f.lower(v).as_text('hlo'), + ) self.assertAllClose(v*v, f(v), check_dtypes=False) def test_partial_auto_error_wsc_manual(self): @@ -1717,7 +1705,6 @@ def f(x): with self.assertRaisesRegex(ValueError, "in_specs refers to 'j'"): f(v) - @unittest.skipIf(xla_extension_version < 262, "Requires jaxlib 0.4.28") def test_nested_partial_auto(self): mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) @@ -1798,6 +1785,132 @@ def f(x): ir.as_text() ) + def test_vmap_spmd_axis_name_error(self): + mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + + @partial( + shard_map, + mesh=mesh, + in_specs=P('i'), + out_specs=P('i'), + ) + def f(x): + return jnp.sin(x) + + xs = jnp.arange(4 * 16.).reshape(4, 16) + with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"): + jax.vmap(f, spmd_axis_name='i')(xs) + + @partial( + shard_map, + mesh=mesh, + in_specs=P('j'), + out_specs=P(('i', 'j')), + check_rep=False, + ) + def g(x): + return jnp.sin(x) + + xs = jnp.arange(4 * 16.).reshape(4, 16) + with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"): + jax.vmap(g, spmd_axis_name='i')(xs) + + def test_in_spec_none(self): + mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + + x = jnp.arange(8).reshape(4, 2) + + def f(o, x): + self.assertIs(o, obj) + return jnp.sin(x) + + obj = object() + y = shard_map(f, mesh, (None, P('i')), P('i'))(obj, x) + self.assertAllClose(y, jnp.sin(x), check_dtypes=False) + + obj = None + y = shard_map(f, mesh, (None, P('i')), P('i'))(None, x) + self.assertAllClose(y, jnp.sin(x), check_dtypes=False) + + def f2(o, x): + self.assertIsInstance(o, dict) + self.assertIs(o['a'], obj['a']) + return jnp.sin(x) + + obj = {'a': object()} + y = shard_map(f2, mesh, ({'a': None}, P('i')), P('i'))(obj, x) + self.assertAllClose(y, jnp.sin(x), check_dtypes=False) + + def f3(x, o): + self.assertIs(o, obj) + return jnp.sin(x) + + obj = object() + y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj) + self.assertAllClose(y, jnp.sin(x), check_dtypes=False) + + obj = None + y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj) + self.assertAllClose(y, jnp.sin(x), check_dtypes=False) + + def f4(o1, o2, x, o3): + self.assertIs(o1, obj1) + self.assertIs(o2[0], obj2[0]) + self.assertIs(o2[1], obj2[1]) + self.assertIs(o3, obj3) + return jnp.sin(x) + + obj1 = object() + obj2 = (object(), object()) + obj3 = object() + y = shard_map(f4, mesh, (None, None, P('i'), None), P('i'))(obj1, obj2, x, obj3) + self.assertAllClose(y, jnp.sin(x), check_dtypes=False) + + def test_in_spec_none_divisibility_errors(self): + mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + x = jnp.arange(4).reshape(2, 2) + + with self.assertRaisesRegex(ValueError, 'divisible'): + shard_map(lambda *_: None, mesh, (None, P('i')), None)(object(), x) + + with self.assertRaisesRegex(ValueError, 'divisible'): + shard_map(lambda *_: None, mesh, (P('i'), None), None)(x, object()) + + with self.assertRaisesRegex(ValueError, 'divisible'): + shard_map(lambda *_: None, mesh, (P('i'), None), None + )(x, (object(), object())) + + with self.assertRaisesRegex(ValueError, 'divisible'): + shard_map(lambda *_: None, mesh, (P('i'), (None, None)), None, + )(x, (object(), object())) + + with self.assertRaisesRegex(ValueError, 'divisible'): + shard_map(lambda *_: None, mesh, ((None, None), P('i')), None, + )((object(), object()), x) + + def test_in_spec_none_rank_errors(self): + mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + x = jnp.arange(4) + + with self.assertRaisesRegex(ValueError, 'rank'): + shard_map(lambda *_: None, mesh, (None, P('i', 'j')), None)(object(), x) + + with self.assertRaisesRegex(ValueError, 'rank'): + shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None)(x, object()) + + with self.assertRaisesRegex(ValueError, 'rank'): + shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None + )(x, (object(), object())) + + with self.assertRaisesRegex(ValueError, 'rank'): + shard_map(lambda *_: None, mesh, (P('i', 'j'), (None, None)), None, + )(x, (object(), object())) + + with self.assertRaisesRegex(ValueError, 'rank'): + shard_map(lambda *_: None, mesh, ((None, None), P('i', 'j')), None, + )((object(), object()), x) + + class FunSpec(NamedTuple): name: str num_inputs: int diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index ba0ad5cb02c5..680bdda5675a 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -1939,6 +1939,18 @@ def test_bcsr_concatenate(self, shape, dtype, n_batch, n_dense, dimension): if jnp.issubdtype(dtype, jnp.floating): self._CheckGradsSparse(dense_func, sparse_func, args_maker) + def test_bcoo_spdot_abstract_eval_bug(self): + # Regression test for https://github.com/google/jax/issues/21921 + lhs = sparse.BCOO( + (jnp.float32([[1]]), lax.broadcasted_iota(jnp.int32, (10, 1, 1), 0)), + shape=(10, 10)) + rhs = sparse.BCOO( + (jnp.float32([1]), jnp.int32([[3]])), + shape=(10,)) + args_maker = lambda: [lhs, rhs] + def func(lhs, rhs): + return (lhs @ rhs).todense() + self._CompileAndCheck(func, args_maker) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/state_test.py b/tests/state_test.py index 5573049e76f0..b6dbb490b794 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -14,10 +14,10 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial import itertools as it -from typing import Any, Callable, NamedTuple, Union +from typing import Any, NamedTuple, Union from absl.testing import absltest from absl.testing import parameterized diff --git a/tests/transfer_guard_test.py b/tests/transfer_guard_test.py index b6d9058db385..6a255b0a1b09 100644 --- a/tests/transfer_guard_test.py +++ b/tests/transfer_guard_test.py @@ -99,12 +99,16 @@ def _all_funcs(): ] +# TransferGuardTest disables `--jax_enable_checks` because it +# can prematurely fetch the value of device arrays and make +# device-to-host tests to incur no transfers unexpectedly. +@jtu.with_config(jax_enable_checks=False) class TransferGuardTest(jtu.JaxTestCase): - # `_default_config` is used by `jtu.JaxTestCase` to update the JAX config for - # every test case. TransferGuardTest disables `--jax_enable_checks` because it - # can prematurely fetch the value of device arrays and make device-to-host - # tests to incur no transfers unexpectedly. - _default_config = {"jax_enable_checks": False} + def setUp(self): + super().setUp() + # Nearly all test methods use the deprecated device argument to JIT. + self.enter_context(jtu.ignore_warning(category=DeprecationWarning, + message="backend and device argument")) @contextlib.contextmanager def assertAllows(self, func_name): diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 87121fc0eb11..23ddf73904b5 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -1144,20 +1144,117 @@ def test_different_structure_no_children(self): class TreeAliasTest(jtu.JaxTestCase): - @parameterized.parameters( - ('all', 'tree_all'), - ('flatten', 'tree_flatten'), - ('leaves', 'tree_leaves'), - ('map', 'tree_map'), - ('reduce', 'tree_reduce'), - ('structure', 'tree_structure'), - ('transpose', 'tree_transpose'), - ('unflatten', 'tree_unflatten'), - ) - def test_tree_aliases(self, tree_name, tree_util_name): - wrapper = getattr(jax.tree, tree_name) - original = getattr(jax.tree_util, tree_util_name) - self.assertIs(wrapper.__wrapped__, original) + """Simple smoke-tests for tree_util aliases under jax.tree""" + + def test_tree_all(self): + obj = [True, True, (True, False)] + self.assertEqual( + jax.tree.all(obj), + tree_util.tree_all(obj), + ) + + def test_tree_all_is_leaf(self): + obj = [True, True, (True, False)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.all(obj, is_leaf=is_leaf), + tree_util.tree_all(obj, is_leaf=is_leaf), + ) + + def test_tree_flatten(self): + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.flatten(obj), + tree_util.tree_flatten(obj), + ) + + def test_tree_flatten_is_leaf(self): + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.flatten(obj, is_leaf=is_leaf), + tree_util.tree_flatten(obj, is_leaf=is_leaf), + ) + + def test_tree_leaves(self): + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.leaves(obj), + tree_util.tree_leaves(obj), + ) + + def test_tree_leaves_is_leaf(self): + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.leaves(obj, is_leaf=is_leaf), + tree_util.tree_leaves(obj, is_leaf=is_leaf), + ) + + def test_tree_map(self): + func = lambda x: x * 2 + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.map(func, obj), + tree_util.tree_map(func, obj), + ) + + def test_tree_map_is_leaf(self): + func = lambda x: x * 2 + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.map(func, obj, is_leaf=is_leaf), + tree_util.tree_map(func, obj, is_leaf=is_leaf), + ) + + def test_tree_reduce(self): + func = lambda a, b: a + b + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.reduce(func, obj), + tree_util.tree_reduce(func, obj), + ) + + def test_tree_reduce_is_leaf(self): + func = lambda a, b: a + b + obj = [(1, 2), (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.reduce(func, obj, is_leaf=is_leaf), + tree_util.tree_reduce(func, obj, is_leaf=is_leaf), + ) + + def test_tree_structure(self): + obj = [1, 2, (3, 4)] + self.assertEqual( + jax.tree.structure(obj), + tree_util.tree_structure(obj), + ) + + def test_tree_structure_is_leaf(self): + obj = [1, 2, (3, 4)] + is_leaf = lambda x: isinstance(x, tuple) + self.assertEqual( + jax.tree.structure(obj, is_leaf=is_leaf), + tree_util.tree_structure(obj, is_leaf=is_leaf), + ) + + def test_tree_transpose(self): + obj = [(1, 2), (3, 4), (5, 6)] + outer_treedef = tree_util.tree_structure(['*', '*', '*']) + inner_treedef = tree_util.tree_structure(('*', '*')) + self.assertEqual( + jax.tree.transpose(outer_treedef, inner_treedef, obj), + tree_util.tree_transpose(outer_treedef, inner_treedef, obj) + ) + + def test_tree_unflatten(self): + leaves, treedef = jax.tree.flatten([1, 2, (3, 4)]) + self.assertEqual( + jax.tree.unflatten(treedef, leaves), + tree_util.tree_unflatten(treedef, leaves) + ) if __name__ == "__main__": diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index d118d0e6454b..7e778cc99d2c 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -20,13 +20,13 @@ from absl import logging from absl.testing import absltest +from jax import version from jax._src import compiler from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.interpreters import xla from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -161,14 +161,9 @@ def _mock_tpu_client_with_options(library_path=None, options=None): def _mock_tpu_client(library_path=None): _mock_tpu_client_with_options(library_path=library_path, options=None) - if xla_extension_version >= 267: - with mock.patch.object(xc, "make_tpu_client", - side_effect=_mock_tpu_client_with_options): - xb.tpu_client_timer_callback(0.01) - else: - with mock.patch.object(xc, "make_tpu_client", - side_effect=_mock_tpu_client): - xb.tpu_client_timer_callback(0.01) + with mock.patch.object(xc, "make_tpu_client", + side_effect=_mock_tpu_client_with_options): + xb.tpu_client_timer_callback(0.01) def test_register_plugin(self): with self.assertLogs(level="WARNING") as log_output: @@ -202,7 +197,12 @@ def test_register_plugin(self): self.assertIn("name2", xb._backend_factories) self.assertEqual(registration.priority, 400) self.assertTrue(registration.experimental) - mock_make.assert_called_once_with("name1", {}, None) + + options = {} + if xb.get_backend().platform == 'tpu': + options["ml_framework_name"] = "JAX" + options["ml_framework_version"] = version.__version__ + mock_make.assert_called_once_with("name1", options, None) def test_register_plugin_with_config(self): test_json_file_path = os.path.join( @@ -229,16 +229,19 @@ def test_register_plugin_with_config(self): self.assertIn("name1", xb._backend_factories) self.assertEqual(registration.priority, 400) self.assertTrue(registration.experimental) - mock_make.assert_called_once_with( - "name1", - { - "int_option": 64, - "int_list_option": [32, 64], - "string_option": "string", - "float_option": 1.0, - }, - None, - ) + + # The expectation is specified in example_pjrt_plugin_config.json. + options = { + "int_option": 64, + "int_list_option": [32, 64], + "string_option": "string", + "float_option": 1.0, + } + if xb.get_backend().platform == 'tpu': + options["ml_framework_name"] = "JAX" + options["ml_framework_version"] = version.__version__ + + mock_make.assert_called_once_with("name1", options, None) class GetBackendTest(jtu.JaxTestCase): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 0d11bb878d55..428c7fc66801 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -15,11 +15,11 @@ from __future__ import annotations from collections.abc import Generator, Iterator +import contextlib import functools import itertools as it from itertools import product, permutations import math -import os import re from unittest import SkipTest from typing import Union @@ -55,27 +55,14 @@ jax.config.parse_flags_with_absl() - -# TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py # Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() + def setUpModule(): - global prev_xla_flags - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - " --xla_force_host_platform_device_count=8") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - -# Reset to previous configuration in case other test modules will be run. + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + def tearDownModule(): - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags - xla_bridge.get_backend.cache_clear() + _exit_stack.close() def create_array(global_shape, global_mesh, mesh_axes, global_data=None): diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a1ff12724d2e..f7a921e00286 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "940e3a27542b7ce76666173e7b287aa2a9263916" -XLA_SHA256 = "bcdc778e5a456839869dea796117b723bdea488075bd9555fe118fd8d6fcf25e" +XLA_COMMIT = "c53e9f4e48f11d0da3469c85a0692db7ebd2a5d3" +XLA_SHA256 = "2bbed1d978ea5715676d7efbeffda10ac80817acb2cbf7a0df7cbceb4c75eab7" def repo(): tf_http_archive(