diff --git a/.bazelrc b/.bazelrc index fcb6be479a96..a08df1017f91 100644 --- a/.bazelrc +++ b/.bazelrc @@ -90,6 +90,9 @@ build:win_clang --compiler=clang-cl build:cuda_plugin --@xla//xla/python:enable_gpu=false build:cuda_plugin --define=xla_python_enable_gpu=false +build:rocm_plugin --@xla//xla/python:enable_gpu=false +build:rocm_plugin --define=xla_python_enable_gpu=false + # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to # point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA @@ -220,8 +223,6 @@ build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylin build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_cpu_linux_py3.9 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.9" -build:rbe_cpu_linux_py3.9 --repo_env HERMETIC_PYTHON_VERSION="3.9" build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10" build:rbe_cpu_linux_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11" @@ -250,8 +251,6 @@ build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04- build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_nccl" # RBE machines have an older CUDA driver version, so we have to enable driver forward compatibility build:rbe_linux_cuda12.3_nvcc_base --test_env=LD_LIBRARY_PATH=/usr/local/cuda/compat -build:rbe_linux_cuda12.3_nvcc_py3.9 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.9" -build:rbe_linux_cuda12.3_nvcc_py3.9 --repo_env HERMETIC_PYTHON_VERSION="3.9" build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10" build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11" diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 3efdc0c2bf12..9c5dee48fae9 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -1,8 +1,8 @@ name: CI # We test all supported Python versions as follows: -# - 3.9 : Documentation build -# - 3.9 : Part of Matrix with NumPy dispatch +# - 3.10 : Documentation build +# - 3.10 : Part of Matrix with NumPy dispatch # - 3.10 : Part of Matrix # - 3.11 : Part of Matrix @@ -45,8 +45,8 @@ jobs: matrix: # Test the oldest and newest supported Python versions here. include: - - name-prefix: "with 3.9" - python-version: "3.9" + - name-prefix: "with 3.10" + python-version: "3.10" os: ubuntu-20.04-16core enable-x64: 1 prng-upgrade: 1 @@ -108,7 +108,7 @@ jobs: timeout-minutes: 10 strategy: matrix: - python-version: [3.9] + python-version: ['3.10'] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 @@ -150,7 +150,7 @@ jobs: timeout-minutes: 10 strategy: matrix: - python-version: [3.9] + python-version: ['3.10'] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 393be393f0be..6433fb66039e 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -13,7 +13,7 @@ jobs: matrix: os: [windows-2019-32core] arch: [AMD64] - pyver: ['3.9', '3.10', '3.11', '3.12'] + pyver: ['3.10', '3.11', '3.12'] name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build runs-on: ${{ matrix.os }} diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index d0dac6519203..92f9355ae200 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -18,7 +18,7 @@ jobs: matrix: os: [windows-2019-32core] arch: [AMD64] - pyver: ['3.9'] + pyver: ['3.10'] name: Windows CI build runs-on: ${{ matrix.os }} diff --git a/CHANGELOG.md b/CHANGELOG.md index aa15035fe5e8..85ec11685874 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,12 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.31 * Changes - * The minimum NumPy version is now 1.24. + * The minimum Python version is now 3.10. 3.10 will remain the minimum + supported version until July 2025. + * The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum + supported version until December 2024. + * {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output + of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point. ## jaxlib 0.4.31 diff --git a/WORKSPACE b/WORKSPACE index 4b99b07ef814..e574bd9f9611 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -9,7 +9,6 @@ python_init_rules() load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") python_init_repositories( requirements = { - "3.9": "//build:requirements_lock_3_9.txt", "3.10": "//build:requirements_lock_3_10.txt", "3.11": "//build:requirements_lock_3_11.txt", "3.12": "//build:requirements_lock_3_12.txt", diff --git a/build/build.py b/build/build.py index d525c36fbcc7..2f68222816ac 100755 --- a/build/build.py +++ b/build/build.py @@ -70,8 +70,8 @@ def get_python_version(python_bin_path): return major, minor def check_python_version(python_version): - if python_version < (3, 9): - print("ERROR: JAX requires Python 3.9 or newer, found ", python_version) + if python_version < (3, 10): + print("ERROR: JAX requires Python 3.10 or newer, found ", python_version) sys.exit(-1) @@ -318,7 +318,10 @@ def write_bazelrc(*, remote_build, if not enable_nccl: f.write("build --config=nonccl\n") if build_gpu_plugin: - f.write("build --config=cuda_plugin\n") + if enable_cuda: + f.write("build --config=cuda_plugin\n") + elif enable_rocm: + f.write("build --config=rocm_plugin\n") if python_version: f.write( "build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format( @@ -431,21 +434,21 @@ def main(): "plugin is still experimental and is not ready for use yet." ), ) - add_boolean_argument( - parser, - "build_cuda_kernel_plugin", - default=False, - help_str=( - "Are we building the cuda kernel plugin? jaxlib will not be built " - "when this flag is True." + parser.add_argument( + "--build_gpu_kernel_plugin", + choices=["cuda", "rocm"], + default="", + help=( + "Specify 'cuda' or 'rocm' to build the respective kernel plugin." + " When this flag is set, jaxlib will not be built." ), ) add_boolean_argument( parser, - "build_cuda_pjrt_plugin", + "build_gpu_pjrt_plugin", default=False, help_str=( - "Are we building the cuda pjrt plugin? jaxlib will not be built " + "Are we building the cuda/rocm pjrt plugin? jaxlib will not be built " "when this flag is True." ), ) @@ -454,6 +457,11 @@ def main(): choices=["11", "12"], default="12", help="Which CUDA major version the gpu plugin is for.") + parser.add_argument( + "--gpu_plugin_rocm_version", + choices=["60"], + default="60", + help="Which ROCM major version the gpu plugin is for.") add_boolean_argument( parser, "enable_rocm", @@ -676,7 +684,7 @@ def main(): *args.bazel_options, ) - if not args.build_cuda_kernel_plugin and not args.build_cuda_pjrt_plugin: + if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin: build_cpu_wheel_command = [ *command_base, "//jaxlib/tools:build_wheel", "--", @@ -691,29 +699,44 @@ def main(): print(" ".join(build_cpu_wheel_command)) shell(build_cpu_wheel_command) - if args.build_gpu_plugin or args.build_cuda_kernel_plugin: - build_cuda_kernels_command = [ + if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \ + (args.build_gpu_kernel_plugin == "rocm"): + build_gpu_kernels_command = [ *command_base, - "//jaxlib/tools:build_cuda_kernels_wheel", "--", + "//jaxlib/tools:build_gpu_kernels_wheel", "--", f"--output_path={output_path}", f"--jaxlib_git_hash={get_githash()}", f"--cpu={wheel_cpu}", - f"--cuda_version={args.gpu_plugin_cuda_version}" ] + if args.enable_cuda: + build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}") + build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") + elif args.enable_rocm: + build_gpu_kernels_command.append(f"--enable-rocm={args.enable_rocm}") + build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + else: + raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") if args.editable: - build_cuda_kernels_command.append("--editable") - print(" ".join(build_cuda_kernels_command)) - shell(build_cuda_kernels_command) + build_gpu_kernels_command.append("--editable") + print(" ".join(build_gpu_kernels_command)) + shell(build_gpu_kernels_command) - if args.build_gpu_plugin or args.build_cuda_pjrt_plugin: + if args.build_gpu_plugin or args.build_gpu_pjrt_plugin: build_pjrt_plugin_command = [ *command_base, "//jaxlib/tools:build_gpu_plugin_wheel", "--", f"--output_path={output_path}", f"--jaxlib_git_hash={get_githash()}", f"--cpu={wheel_cpu}", - f"--cuda_version={args.gpu_plugin_cuda_version}" ] + if args.enable_cuda: + build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}") + build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") + elif args.enable_rocm: + build_pjrt_plugin_command.append(f"--enable-rocm={args.enable_rocm}") + build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + else: + raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") if args.editable: build_pjrt_plugin_command.append("--editable") print(" ".join(build_pjrt_plugin_command)) diff --git a/build/requirements.in b/build/requirements.in index 5e2d6a86b146..add6b8577350 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -18,7 +18,6 @@ numpy~=2.0.0 # scipy~=1.13.1 -importlib_metadata; python_version<"3.10" ml_dtypes>=0.4.0 opt_einsum zstandard diff --git a/build/requirements_lock_3_9.txt b/build/requirements_lock_3_9.txt deleted file mode 100644 index 3216e833e535..000000000000 --- a/build/requirements_lock_3_9.txt +++ /dev/null @@ -1,635 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.9 -# by the following command: -# -# bazel run //build:requirements.update -# -absl-py==2.1.0 \ - --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ - --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff - # via -r build/test-requirements.txt -attrs==23.2.0 \ - --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ - --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 - # via hypothesis -build==1.2.1 \ - --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ - --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 - # via -r build/test-requirements.txt -cloudpickle==3.0.0 \ - --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ - --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 - # via -r build/test-requirements.txt -colorama==0.4.6 \ - --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ - --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt -contourpy==1.2.1 \ - --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ - --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ - --hash=sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9 \ - --hash=sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4 \ - --hash=sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce \ - --hash=sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7 \ - --hash=sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f \ - --hash=sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922 \ - --hash=sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4 \ - --hash=sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e \ - --hash=sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b \ - --hash=sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619 \ - --hash=sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205 \ - --hash=sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480 \ - --hash=sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965 \ - --hash=sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c \ - --hash=sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd \ - --hash=sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5 \ - --hash=sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f \ - --hash=sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc \ - --hash=sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec \ - --hash=sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd \ - --hash=sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b \ - --hash=sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9 \ - --hash=sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe \ - --hash=sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce \ - --hash=sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609 \ - --hash=sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8 \ - --hash=sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0 \ - --hash=sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f \ - --hash=sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8 \ - --hash=sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b \ - --hash=sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364 \ - --hash=sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040 \ - --hash=sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f \ - --hash=sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083 \ - --hash=sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df \ - --hash=sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba \ - --hash=sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445 \ - --hash=sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da \ - --hash=sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3 \ - --hash=sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72 \ - --hash=sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02 \ - --hash=sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985 - # via matplotlib -cycler==0.12.1 \ - --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ - --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c - # via matplotlib -etils[epath,epy]==1.5.2 \ - --hash=sha256:6dc882d355e1e98a5d1a148d6323679dc47c9a5792939b9de72615aa4737eb0b \ - --hash=sha256:ba6a3e1aff95c769130776aa176c11540637f5dd881f3b79172a5149b6b1c446 - # via -r build/requirements.in -exceptiongroup==1.2.1 \ - --hash=sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad \ - --hash=sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16 - # via - # hypothesis - # pytest -execnet==2.1.1 \ - --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ - --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 - # via pytest-xdist -filelock==3.14.0 \ - --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ - --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a - # via -r build/test-requirements.txt -flatbuffers==24.3.25 \ - --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ - --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 - # via -r build/test-requirements.txt -fonttools==4.51.0 \ - --hash=sha256:0118ef998a0699a96c7b28457f15546815015a2710a1b23a7bf6c1be60c01636 \ - --hash=sha256:0d145976194a5242fdd22df18a1b451481a88071feadf251221af110ca8f00ce \ - --hash=sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f \ - --hash=sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1 \ - --hash=sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc \ - --hash=sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f \ - --hash=sha256:1621ee57da887c17312acc4b0e7ac30d3a4fb0fec6174b2e3754a74c26bbed1e \ - --hash=sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716 \ - --hash=sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15 \ - --hash=sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77 \ - --hash=sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034 \ - --hash=sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba \ - --hash=sha256:54dcf21a2f2d06ded676e3c3f9f74b2bafded3a8ff12f0983160b13e9f2fb4a7 \ - --hash=sha256:56fc244f2585d6c00b9bcc59e6593e646cf095a96fe68d62cd4da53dd1287b55 \ - --hash=sha256:599bdb75e220241cedc6faebfafedd7670335d2e29620d207dd0378a4e9ccc5a \ - --hash=sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0 \ - --hash=sha256:60a3409c9112aec02d5fb546f557bca6efa773dcb32ac147c6baf5f742e6258b \ - --hash=sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671 \ - --hash=sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a \ - --hash=sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039 \ - --hash=sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74 \ - --hash=sha256:865a58b6e60b0938874af0968cd0553bcd88e0b2cb6e588727117bd099eef836 \ - --hash=sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2 \ - --hash=sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308 \ - --hash=sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2 \ - --hash=sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5 \ - --hash=sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1 \ - --hash=sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438 \ - --hash=sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74 \ - --hash=sha256:aefa011207ed36cd280babfaa8510b8176f1a77261833e895a9d96e57e44802f \ - --hash=sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097 \ - --hash=sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e \ - --hash=sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037 \ - --hash=sha256:c18b49adc721a7d0b8dfe7c3130c89b8704baf599fb396396d07d4aa69b824a1 \ - --hash=sha256:c5b8cab0c137ca229433570151b5c1fc6af212680b58b15abd797dcdd9dd5051 \ - --hash=sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b \ - --hash=sha256:cadf4e12a608ef1d13e039864f484c8a968840afa0258b0b843a0556497ea9ed \ - --hash=sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68 \ - --hash=sha256:de7c29bdbdd35811f14493ffd2534b88f0ce1b9065316433b22d63ca1cd21f14 \ - --hash=sha256:e9d9298be7a05bb4801f558522adbe2feea1b0b103d5294ebf24a92dd49b78e5 \ - --hash=sha256:ee1af4be1c5afe4c96ca23badd368d8dc75f611887fb0c0dac9f71ee5d6f110e \ - --hash=sha256:f7e89853d8bea103c8e3514b9f9dc86b5b4120afb4583b57eb10dfa5afbe0936 - # via matplotlib -fsspec==2024.5.0 \ - --hash=sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a \ - --hash=sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c - # via etils -hypothesis==6.102.4 \ - --hash=sha256:013df31b04a4daede13756f497e60e451963d86f426395a79f99c5d692919bbd \ - --hash=sha256:59b4d144346d5cffb482cc1bafbd21b13ff31608e8c4b3e4630339aee3e87763 - # via -r build/test-requirements.txt -importlib-metadata==7.1.0 ; python_version < "3.10" \ - --hash=sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570 \ - --hash=sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2 - # via - # -r build/requirements.in - # build -importlib-resources==6.4.0 \ - --hash=sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c \ - --hash=sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145 - # via - # etils - # matplotlib -iniconfig==2.0.0 \ - --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ - --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 - # via pytest -kiwisolver==1.4.5 \ - --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ - --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ - --hash=sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af \ - --hash=sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f \ - --hash=sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046 \ - --hash=sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3 \ - --hash=sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5 \ - --hash=sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71 \ - --hash=sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee \ - --hash=sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3 \ - --hash=sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9 \ - --hash=sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b \ - --hash=sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985 \ - --hash=sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea \ - --hash=sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16 \ - --hash=sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89 \ - --hash=sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c \ - --hash=sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9 \ - --hash=sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712 \ - --hash=sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342 \ - --hash=sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a \ - --hash=sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958 \ - --hash=sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d \ - --hash=sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a \ - --hash=sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130 \ - --hash=sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff \ - --hash=sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898 \ - --hash=sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b \ - --hash=sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f \ - --hash=sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265 \ - --hash=sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93 \ - --hash=sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929 \ - --hash=sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635 \ - --hash=sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709 \ - --hash=sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b \ - --hash=sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb \ - --hash=sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a \ - --hash=sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920 \ - --hash=sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e \ - --hash=sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544 \ - --hash=sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45 \ - --hash=sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390 \ - --hash=sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77 \ - --hash=sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355 \ - --hash=sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff \ - --hash=sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4 \ - --hash=sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7 \ - --hash=sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20 \ - --hash=sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c \ - --hash=sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162 \ - --hash=sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228 \ - --hash=sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437 \ - --hash=sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc \ - --hash=sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a \ - --hash=sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901 \ - --hash=sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4 \ - --hash=sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770 \ - --hash=sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525 \ - --hash=sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad \ - --hash=sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a \ - --hash=sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29 \ - --hash=sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90 \ - --hash=sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250 \ - --hash=sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d \ - --hash=sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3 \ - --hash=sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54 \ - --hash=sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f \ - --hash=sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1 \ - --hash=sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da \ - --hash=sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238 \ - --hash=sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa \ - --hash=sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523 \ - --hash=sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0 \ - --hash=sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205 \ - --hash=sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3 \ - --hash=sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4 \ - --hash=sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac \ - --hash=sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9 \ - --hash=sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb \ - --hash=sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced \ - --hash=sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd \ - --hash=sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0 \ - --hash=sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da \ - --hash=sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18 \ - --hash=sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9 \ - --hash=sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276 \ - --hash=sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333 \ - --hash=sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b \ - --hash=sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db \ - --hash=sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126 \ - --hash=sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9 \ - --hash=sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09 \ - --hash=sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0 \ - --hash=sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec \ - --hash=sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7 \ - --hash=sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff \ - --hash=sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9 \ - --hash=sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192 \ - --hash=sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8 \ - --hash=sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d \ - --hash=sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6 \ - --hash=sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797 \ - --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ - --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f - # via matplotlib -markdown-it-py==3.0.0 \ - --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ - --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb - # via rich -matplotlib==3.8.4 ; python_version <= "3.10" \ - --hash=sha256:1c13f041a7178f9780fb61cc3a2b10423d5e125480e4be51beaf62b172413b67 \ - --hash=sha256:232ce322bfd020a434caaffbd9a95333f7c2491e59cfc014041d95e38ab90d1c \ - --hash=sha256:493e9f6aa5819156b58fce42b296ea31969f2aab71c5b680b4ea7a3cb5c07d94 \ - --hash=sha256:50bac6e4d77e4262c4340d7a985c30912054745ec99756ce213bfbc3cb3808eb \ - --hash=sha256:606e3b90897554c989b1e38a258c626d46c873523de432b1462f295db13de6f9 \ - --hash=sha256:6209e5c9aaccc056e63b547a8152661324404dd92340a6e479b3a7f24b42a5d0 \ - --hash=sha256:6485ac1f2e84676cff22e693eaa4fbed50ef5dc37173ce1f023daef4687df616 \ - --hash=sha256:6addbd5b488aedb7f9bc19f91cd87ea476206f45d7116fcfe3d31416702a82fa \ - --hash=sha256:72f9322712e4562e792b2961971891b9fbbb0e525011e09ea0d1f416c4645661 \ - --hash=sha256:7a6769f58ce51791b4cb8b4d7642489df347697cd3e23d88266aaaee93b41d9a \ - --hash=sha256:8080d5081a86e690d7688ffa542532e87f224c38a6ed71f8fbed34dd1d9fedae \ - --hash=sha256:843cbde2f0946dadd8c5c11c6d91847abd18ec76859dc319362a0964493f0ba6 \ - --hash=sha256:8aac397d5e9ec158960e31c381c5ffc52ddd52bd9a47717e2a694038167dffea \ - --hash=sha256:8f65c9f002d281a6e904976007b2d46a1ee2bcea3a68a8c12dda24709ddc9106 \ - --hash=sha256:90df07db7b599fe7035d2f74ab7e438b656528c68ba6bb59b7dc46af39ee48ef \ - --hash=sha256:9bb0189011785ea794ee827b68777db3ca3f93f3e339ea4d920315a0e5a78d54 \ - --hash=sha256:a0e47eda4eb2614300fc7bb4657fced3e83d6334d03da2173b09e447418d499f \ - --hash=sha256:abc9d838f93583650c35eca41cfcec65b2e7cb50fd486da6f0c49b5e1ed23014 \ - --hash=sha256:ac24233e8f2939ac4fd2919eed1e9c0871eac8057666070e94cbf0b33dd9c338 \ - --hash=sha256:b12ba985837e4899b762b81f5b2845bd1a28f4fdd1a126d9ace64e9c4eb2fb25 \ - --hash=sha256:b7a2a253d3b36d90c8993b4620183b55665a429da8357a4f621e78cd48b2b30b \ - --hash=sha256:c7064120a59ce6f64103c9cefba8ffe6fba87f2c61d67c401186423c9a20fd35 \ - --hash=sha256:c89ee9314ef48c72fe92ce55c4e95f2f39d70208f9f1d9db4e64079420d8d732 \ - --hash=sha256:cc4ccdc64e3039fc303defd119658148f2349239871db72cd74e2eeaa9b80b71 \ - --hash=sha256:ce1edd9f5383b504dbc26eeea404ed0a00656c526638129028b758fd43fc5f10 \ - --hash=sha256:ecd79298550cba13a43c340581a3ec9c707bd895a6a061a78fa2524660482fc0 \ - --hash=sha256:f51c4c869d4b60d769f7b4406eec39596648d9d70246428745a681c327a8ad30 \ - --hash=sha256:fb44f53af0a62dc80bba4443d9b27f2fde6acfdac281d95bc872dc148a6509cc - # via -r build/requirements.in -mdurl==0.1.2 \ - --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ - --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba - # via markdown-it-py -ml-dtypes==0.4.0 \ - --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ - --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ - --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ - --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ - --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ - --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ - --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ - --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ - --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ - --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ - --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ - --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ - --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ - --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ - --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ - --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ - --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 - # via -r build/requirements.in -mpmath==1.4.0a1 \ - --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ - --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 - # via -r build/test-requirements.txt -numpy==2.0.0 \ - --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ - --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ - --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ - --hash=sha256:0ec84b9ba0654f3b962802edc91424331f423dcf5d5f926676e0150789cb3d95 \ - --hash=sha256:17067d097ed036636fa79f6a869ac26df7db1ba22039d962422506640314933a \ - --hash=sha256:1cde1753efe513705a0c6d28f5884e22bdc30438bf0085c5c486cdaff40cd67a \ - --hash=sha256:1e72728e7501a450288fc8e1f9ebc73d90cfd4671ebbd631f3e7857c39bd16f2 \ - --hash=sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2 \ - --hash=sha256:2ce46fd0b8a0c947ae047d222f7136fc4d55538741373107574271bc00e20e8f \ - --hash=sha256:34f003cb88b1ba38cb9a9a4a3161c1604973d7f9d5552c38bc2f04f829536609 \ - --hash=sha256:354f373279768fa5a584bac997de6a6c9bc535c482592d7a813bb0c09be6c76f \ - --hash=sha256:38ecb5b0582cd125f67a629072fed6f83562d9dd04d7e03256c9829bdec027ad \ - --hash=sha256:3e8e01233d57639b2e30966c63d36fcea099d17c53bf424d77f088b0f4babd86 \ - --hash=sha256:3f6bed7f840d44c08ebdb73b1825282b801799e325bcbdfa6bc5c370e5aecc65 \ - --hash=sha256:4554eb96f0fd263041baf16cf0881b3f5dafae7a59b1049acb9540c4d57bc8cb \ - --hash=sha256:46e161722e0f619749d1cd892167039015b2c2817296104487cd03ed4a955995 \ - --hash=sha256:49d9f7d256fbc804391a7f72d4a617302b1afac1112fac19b6c6cec63fe7fe8a \ - --hash=sha256:4d2f62e55a4cd9c58c1d9a1c9edaedcd857a73cb6fda875bf79093f9d9086f85 \ - --hash=sha256:5f64641b42b2429f56ee08b4f427a4d2daf916ec59686061de751a55aafa22e4 \ - --hash=sha256:63b92c512d9dbcc37f9d81b123dec99fdb318ba38c8059afc78086fe73820275 \ - --hash=sha256:6d7696c615765091cc5093f76fd1fa069870304beaccfd58b5dcc69e55ef49c1 \ - --hash=sha256:79e843d186c8fb1b102bef3e2bc35ef81160ffef3194646a7fdd6a73c6b97196 \ - --hash=sha256:821eedb7165ead9eebdb569986968b541f9908979c2da8a4967ecac4439bae3d \ - --hash=sha256:84554fc53daa8f6abf8e8a66e076aff6ece62de68523d9f665f32d2fc50fd66e \ - --hash=sha256:8d83bb187fb647643bd56e1ae43f273c7f4dbcdf94550d7938cfc32566756514 \ - --hash=sha256:903703372d46bce88b6920a0cd86c3ad82dae2dbef157b5fc01b70ea1cfc430f \ - --hash=sha256:9416a5c2e92ace094e9f0082c5fd473502c91651fb896bc17690d6fc475128d6 \ - --hash=sha256:9a1712c015831da583b21c5bfe15e8684137097969c6d22e8316ba66b5baabe4 \ - --hash=sha256:9c27f0946a3536403efb0e1c28def1ae6730a72cd0d5878db38824855e3afc44 \ - --hash=sha256:a356364941fb0593bb899a1076b92dfa2029f6f5b8ba88a14fd0984aaf76d0df \ - --hash=sha256:a7039a136017eaa92c1848152827e1424701532ca8e8967fe480fe1569dae581 \ - --hash=sha256:acd3a644e4807e73b4e1867b769fbf1ce8c5d80e7caaef0d90dcdc640dfc9787 \ - --hash=sha256:ad0c86f3455fbd0de6c31a3056eb822fc939f81b1618f10ff3406971893b62a5 \ - --hash=sha256:b4c76e3d4c56f145d41b7b6751255feefae92edbc9a61e1758a98204200f30fc \ - --hash=sha256:b6f6a8f45d0313db07d6d1d37bd0b112f887e1369758a5419c0370ba915b3871 \ - --hash=sha256:c5a59996dc61835133b56a32ebe4ef3740ea5bc19b3983ac60cc32be5a665d54 \ - --hash=sha256:c73aafd1afca80afecb22718f8700b40ac7cab927b8abab3c3e337d70e10e5a2 \ - --hash=sha256:cee6cc0584f71adefe2c908856ccc98702baf95ff80092e4ca46061538a2ba98 \ - --hash=sha256:cef04d068f5fb0518a77857953193b6bb94809a806bd0a14983a8f12ada060c9 \ - --hash=sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864 \ - --hash=sha256:e61155fae27570692ad1d327e81c6cf27d535a5d7ef97648a17d922224b216de \ - --hash=sha256:e7f387600d424f91576af20518334df3d97bc76a300a755f9a8d6e4f5cadd289 \ - --hash=sha256:ed08d2703b5972ec736451b818c2eb9da80d66c3e84aed1deeb0c345fefe461b \ - --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ - --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 - # via - # -r build/requirements.in - # -r build/test-requirements.txt - # contourpy - # matplotlib - # ml-dtypes - # opt-einsum - # scipy -opt-einsum==3.3.0 \ - --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ - --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via -r build/requirements.in -packaging==24.0 \ - --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ - --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 - # via - # build - # matplotlib - # pytest -pillow==10.3.0 \ - --hash=sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c \ - --hash=sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2 \ - --hash=sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb \ - --hash=sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d \ - --hash=sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa \ - --hash=sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3 \ - --hash=sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1 \ - --hash=sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a \ - --hash=sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd \ - --hash=sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8 \ - --hash=sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999 \ - --hash=sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599 \ - --hash=sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936 \ - --hash=sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375 \ - --hash=sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d \ - --hash=sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b \ - --hash=sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60 \ - --hash=sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572 \ - --hash=sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3 \ - --hash=sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced \ - --hash=sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f \ - --hash=sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b \ - --hash=sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19 \ - --hash=sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f \ - --hash=sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d \ - --hash=sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383 \ - --hash=sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795 \ - --hash=sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355 \ - --hash=sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57 \ - --hash=sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09 \ - --hash=sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b \ - --hash=sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462 \ - --hash=sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf \ - --hash=sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f \ - --hash=sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a \ - --hash=sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad \ - --hash=sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9 \ - --hash=sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d \ - --hash=sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45 \ - --hash=sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994 \ - --hash=sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d \ - --hash=sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338 \ - --hash=sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463 \ - --hash=sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451 \ - --hash=sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591 \ - --hash=sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c \ - --hash=sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd \ - --hash=sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32 \ - --hash=sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9 \ - --hash=sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf \ - --hash=sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5 \ - --hash=sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828 \ - --hash=sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3 \ - --hash=sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5 \ - --hash=sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2 \ - --hash=sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b \ - --hash=sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2 \ - --hash=sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475 \ - --hash=sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3 \ - --hash=sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb \ - --hash=sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef \ - --hash=sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015 \ - --hash=sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002 \ - --hash=sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170 \ - --hash=sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84 \ - --hash=sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57 \ - --hash=sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f \ - --hash=sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27 \ - --hash=sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a - # via - # -r build/test-requirements.txt - # matplotlib -pluggy==1.5.0 \ - --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ - --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 - # via pytest -portpicker==1.6.0 \ - --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ - --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt -psutil==5.9.8 \ - --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ - --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ - --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ - --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ - --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ - --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ - --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ - --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ - --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ - --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ - --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ - --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ - --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ - --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ - --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ - --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 - # via portpicker -pygments==2.18.0 \ - --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ - --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a - # via rich -pyparsing==3.1.2 \ - --hash=sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad \ - --hash=sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742 - # via matplotlib -pyproject-hooks==1.1.0 \ - --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ - --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 - # via build -pytest==8.2.0 \ - --hash=sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233 \ - --hash=sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f - # via pytest-xdist -pytest-xdist==3.6.1 \ - --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ - --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d - # via -r build/test-requirements.txt -python-dateutil==2.9.0.post0 \ - --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ - --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 - # via matplotlib -rich==13.7.1 \ - --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ - --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 - # via -r build/test-requirements.txt -scipy==1.13.1 \ - --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ - --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ - --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ - --hash=sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9 \ - --hash=sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54 \ - --hash=sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16 \ - --hash=sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2 \ - --hash=sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5 \ - --hash=sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59 \ - --hash=sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326 \ - --hash=sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b \ - --hash=sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1 \ - --hash=sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d \ - --hash=sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24 \ - --hash=sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627 \ - --hash=sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c \ - --hash=sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa \ - --hash=sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949 \ - --hash=sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989 \ - --hash=sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004 \ - --hash=sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f \ - --hash=sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884 \ - --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ - --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ - --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f - # via -r build/requirements.in -six==1.16.0 \ - --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ - --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 - # via python-dateutil -sortedcontainers==2.4.0 \ - --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ - --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 - # via hypothesis -tomli==2.0.1 \ - --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \ - --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f - # via - # build - # pytest -typing-extensions==4.12.0rc1 \ - --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ - --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe - # via etils -wheel==0.43.0 \ - --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ - --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 - # via -r build/test-requirements.txt -zipp==3.18.2 \ - --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ - --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e - # via - # etils - # importlib-metadata - # importlib-resources -zstandard==0.22.0 \ - --hash=sha256:11f0d1aab9516a497137b41e3d3ed4bbf7b2ee2abc79e5c8b010ad286d7464bd \ - --hash=sha256:1958100b8a1cc3f27fa21071a55cb2ed32e9e5df4c3c6e661c193437f171cba2 \ - --hash=sha256:1a90ba9a4c9c884bb876a14be2b1d216609385efb180393df40e5172e7ecf356 \ - --hash=sha256:1d43501f5f31e22baf822720d82b5547f8a08f5386a883b32584a185675c8fbf \ - --hash=sha256:23d2b3c2b8e7e5a6cb7922f7c27d73a9a615f0a5ab5d0e03dd533c477de23004 \ - --hash=sha256:2612e9bb4977381184bb2463150336d0f7e014d6bb5d4a370f9a372d21916f69 \ - --hash=sha256:275df437ab03f8c033b8a2c181e51716c32d831082d93ce48002a5227ec93019 \ - --hash=sha256:2ac9957bc6d2403c4772c890916bf181b2653640da98f32e04b96e4d6fb3252a \ - --hash=sha256:2b11ea433db22e720758cba584c9d661077121fcf60ab43351950ded20283440 \ - --hash=sha256:2fdd53b806786bd6112d97c1f1e7841e5e4daa06810ab4b284026a1a0e484c0b \ - --hash=sha256:33591d59f4956c9812f8063eff2e2c0065bc02050837f152574069f5f9f17775 \ - --hash=sha256:36a47636c3de227cd765e25a21dc5dace00539b82ddd99ee36abae38178eff9e \ - --hash=sha256:39b2853efc9403927f9065cc48c9980649462acbdf81cd4f0cb773af2fd734bc \ - --hash=sha256:3db41c5e49ef73641d5111554e1d1d3af106410a6c1fb52cf68912ba7a343a0d \ - --hash=sha256:445b47bc32de69d990ad0f34da0e20f535914623d1e506e74d6bc5c9dc40bb09 \ - --hash=sha256:466e6ad8caefb589ed281c076deb6f0cd330e8bc13c5035854ffb9c2014b118c \ - --hash=sha256:48f260e4c7294ef275744210a4010f116048e0c95857befb7462e033f09442fe \ - --hash=sha256:4ac59d5d6910b220141c1737b79d4a5aa9e57466e7469a012ed42ce2d3995e88 \ - --hash=sha256:53866a9d8ab363271c9e80c7c2e9441814961d47f88c9bc3b248142c32141d94 \ - --hash=sha256:589402548251056878d2e7c8859286eb91bd841af117dbe4ab000e6450987e08 \ - --hash=sha256:68953dc84b244b053c0d5f137a21ae8287ecf51b20872eccf8eaac0302d3e3b0 \ - --hash=sha256:6c25b8eb733d4e741246151d895dd0308137532737f337411160ff69ca24f93a \ - --hash=sha256:7034d381789f45576ec3f1fa0e15d741828146439228dc3f7c59856c5bcd3292 \ - --hash=sha256:73a1d6bd01961e9fd447162e137ed949c01bdb830dfca487c4a14e9742dccc93 \ - --hash=sha256:8226a33c542bcb54cd6bd0a366067b610b41713b64c9abec1bc4533d69f51e70 \ - --hash=sha256:888196c9c8893a1e8ff5e89b8f894e7f4f0e64a5af4d8f3c410f0319128bb2f8 \ - --hash=sha256:88c5b4b47a8a138338a07fc94e2ba3b1535f69247670abfe422de4e0b344aae2 \ - --hash=sha256:8a1b2effa96a5f019e72874969394edd393e2fbd6414a8208fea363a22803b45 \ - --hash=sha256:93e1856c8313bc688d5df069e106a4bc962eef3d13372020cc6e3ebf5e045202 \ - --hash=sha256:9501f36fac6b875c124243a379267d879262480bf85b1dbda61f5ad4d01b75a3 \ - --hash=sha256:959665072bd60f45c5b6b5d711f15bdefc9849dd5da9fb6c873e35f5d34d8cfb \ - --hash=sha256:a1d67d0d53d2a138f9e29d8acdabe11310c185e36f0a848efa104d4e40b808e4 \ - --hash=sha256:a493d470183ee620a3df1e6e55b3e4de8143c0ba1b16f3ded83208ea8ddfd91d \ - --hash=sha256:a7ccf5825fd71d4542c8ab28d4d482aace885f5ebe4b40faaa290eed8e095a4c \ - --hash=sha256:a88b7df61a292603e7cd662d92565d915796b094ffb3d206579aaebac6b85d5f \ - --hash=sha256:a97079b955b00b732c6f280d5023e0eefe359045e8b83b08cf0333af9ec78f26 \ - --hash=sha256:d22fdef58976457c65e2796e6730a3ea4a254f3ba83777ecfc8592ff8d77d303 \ - --hash=sha256:d75f693bb4e92c335e0645e8845e553cd09dc91616412d1d4650da835b5449df \ - --hash=sha256:d8593f8464fb64d58e8cb0b905b272d40184eac9a18d83cf8c10749c3eafcd7e \ - --hash=sha256:d8fff0f0c1d8bc5d866762ae95bd99d53282337af1be9dc0d88506b340e74b73 \ - --hash=sha256:de20a212ef3d00d609d0b22eb7cc798d5a69035e81839f549b538eff4105d01c \ - --hash=sha256:e9e9d4e2e336c529d4c435baad846a181e39a982f823f7e4495ec0b0ec8538d2 \ - --hash=sha256:f058a77ef0ece4e210bb0450e68408d4223f728b109764676e1a13537d056bb0 \ - --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ - --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ - --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e - # via -r build/requirements.in - -# The following packages are considered to be unsafe in a requirements file: -setuptools==69.5.1 \ - --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ - --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 - # via -r build/test-requirements.txt diff --git a/build/rocm/build_rocm.sh b/build/rocm/build_rocm.sh index 64e166239fc8..6374a2a18929 100755 --- a/build/rocm/build_rocm.sh +++ b/build/rocm/build_rocm.sh @@ -57,7 +57,7 @@ rocm_version=$(cat /opt/rocm/.info/version | cut -d "-" -f 1) export JAX_ROCM_VERSION=${rocm_version//./} #Build and install wheel -python3 ./build/build.py --enable_rocm --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR} +python3 ./build/build.py --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR} JAX_RELEASE=1 python -m build pip3 install --force-reinstall dist/*.whl # installs jaxlib (includes XLA) diff --git a/build/rocm/run_single_gpu.py b/build/rocm/run_single_gpu.py index bf37a49ee61e..13dac5ebbb48 100755 --- a/build/rocm/run_single_gpu.py +++ b/build/rocm/run_single_gpu.py @@ -22,7 +22,7 @@ GPU_LOCK = threading.Lock() LAST_CODE = 0 -base_dir="./logs" +base_dir = "./logs" def extract_filename(path): base_name = os.path.basename(path) diff --git a/docs/contributing.md b/docs/contributing.md index 2d1331bf233a..cad7cfc1ea64 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -36,7 +36,7 @@ Follow these steps to contribute code: [repository page](http://www.github.com/google/jax). This creates a copy of the JAX repository in your own account. -3. Install Python >= 3.9 locally in order to run tests. +3. Install Python >= 3.10 locally in order to run tests. 4. `pip` installing your fork from source. This allows you to modify the code and immediately test it out: diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index 40a5d2f0d9f5..1f5cc0727605 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -66,10 +66,6 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta ### Communication flags -* **--xla_gpu_enable_async_collectives** This flag enables the collective ops - such as `AllReduce`, `AllGather`, `ReduceScatter` and `CollectivePermute` to - be asynchronous. Asynchronous communication can overlap cross-core - communication with computation. The default value is False. * **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False. diff --git a/jax/BUILD b/jax/BUILD index b3d85106b031..2f7480e31b1a 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -963,7 +963,7 @@ pytype_strict_library( ":traceback_util", ":util", "//jax/_src/lib", - ] + py_deps("importlib_metadata"), + ], ) # Public JAX libraries below this point. diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 6f37c16e6715..fbdd4894843e 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import Any, Callable, Optional, Sequence, Union +from collections.abc import Callable, Sequence +from typing import Any, Union import numpy as np from jax._src.sharding import Sharding @@ -113,24 +114,24 @@ class Array(abc.ABC): def __release_buffer__(self, view: memoryview) -> None: ... # np.ndarray methods: - def all(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, - keepdims=None, *, where: Optional[ArrayLike] = ...) -> Array: ... - def any(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, - keepdims=None, *, where: Optional[ArrayLike] = ...) -> Array: ... - def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ... - def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Array: ... + def all(self, axis: int | Sequence[int] | None = None, out=None, + keepdims=None, *, where: ArrayLike | None = ...) -> Array: ... + def any(self, axis: int | Sequence[int] | None = None, out=None, + keepdims=None, *, where: ArrayLike | None = ...) -> Array: ... + def argmax(self, axis: int | None = None, out=None, keepdims=None) -> Array: ... + def argmin(self, axis: int | None = None, out=None, keepdims=None) -> Array: ... def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ... - def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ... + def argsort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ... def astype(self, dtype) -> Array: ... def choose(self, choices, out=None, mode='raise') -> Array: ... def clip(self, min=None, max=None, out=None) -> Array: ... - def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ... + def compress(self, condition, axis: int | None = None, out=None) -> Array: ... def conj(self) -> Array: ... def conjugate(self) -> Array: ... def copy(self) -> Array: ... - def cumprod(self, axis: Optional[Union[int, Sequence[int]]] = None, + def cumprod(self, axis: int | Sequence[int] | None = None, dtype=None, out=None) -> Array: ... - def cumsum(self, axis: Optional[Union[int, Sequence[int]]] = None, + def cumsum(self, axis: int | Sequence[int] | None = None, dtype=None, out=None) -> Array: ... def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ... def dot(self, b, *, precision=None) -> Array: ... @@ -138,35 +139,35 @@ class Array(abc.ABC): @property def imag(self) -> Array: ... def item(self, *args) -> Any: ... - def max(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, + def max(self, axis: int | Sequence[int] | None = None, out=None, keepdims=None, initial=None, where=None) -> Array: ... - def mean(self, axis: Optional[Union[int, Sequence[int]]] = None, dtype=None, + def mean(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, keepdims=False, *, where=None,) -> Array: ... - def min(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, + def min(self, axis: int | Sequence[int] | None = None, out=None, keepdims=None, initial=None, where=None) -> Array: ... @property def nbytes(self) -> int: ... def nonzero(self, *, size=None, fill_value=None) -> Array: ... - def prod(self, axis: Optional[Union[int, Sequence[int]]] = None, dtype=None, + def prod(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, keepdims=None, initial=None, where=None) -> Array: ... - def ptp(self, axis: Optional[Union[int, Sequence[int]]] = None, out=None, + def ptp(self, axis: int | Sequence[int] | None = None, out=None, keepdims=False,) -> Array: ... def ravel(self, order='C') -> Array: ... @property def real(self) -> Array: ... - def repeat(self, repeats, axis: Optional[int] = None, *, + def repeat(self, repeats, axis: int | None = None, *, total_repeat_length=None) -> Array: ... def reshape(self, *args, order='C') -> Array: ... def round(self, decimals=0, out=None) -> Array: ... def searchsorted(self, v, side='left', sorter=None) -> Array: ... - def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ... - def squeeze(self, axis: Optional[Union[int, Sequence[int]]] = None) -> Array: ... - def std(self, axis: Optional[Union[int, Sequence[int]]] = None, + def sort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ... + def squeeze(self, axis: int | Sequence[int] | None = None) -> Array: ... + def std(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ... - def sum(self, axis: Optional[Union[int, Sequence[int]]] = None, dtype=None, + def sum(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, keepdims=None, initial=None, where=None) -> Array: ... def swapaxes(self, axis1: int, axis2: int) -> Array: ... - def take(self, indices, axis: Optional[int] = None, out=None, + def take(self, indices, axis: int | None = None, out=None, mode=None) -> Array: ... def tobytes(self, order='C') -> bytes: ... def tolist(self) -> list[Any]: ... @@ -177,15 +178,15 @@ class Array(abc.ABC): def T(self) -> Array: ... @property def mT(self) -> Array: ... - def var(self, axis: Optional[Union[int, Sequence[int]]] = None, + def var(self, axis: int | Sequence[int] | None = None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ... def view(self, dtype=None, type=None) -> Array: ... # Even though we don't always support the NumPy array protocol, e.g., for # tracer types, for type checking purposes we must declare support so we # implement the NumPy ArrayLike protocol. - def __array__(self, dtype: Optional[np.dtype] = ..., - copy: Optional[bool] = ...) -> np.ndarray: ... + def __array__(self, dtype: np.dtype | None = ..., + copy: bool | None = ...) -> np.ndarray: ... def __dlpack__(self) -> Any: ... # JAX extensions @@ -237,23 +238,23 @@ class _IndexUpdateHelper: class _IndexUpdateRef: def get(self, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[str] = None, fill_value: Optional[StaticScalar] = None) -> Array: ... + mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... def set(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: Optional[str] = None, fill_value: Optional[StaticScalar] = None) -> Array: ... + mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... def add(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def mul(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def multiply(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def divide(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def power(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def min(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def max(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... def apply(self, func: Callable[[ArrayLike], ArrayLike], indices_are_sorted: bool = False, - unique_indices: bool = False, mode: Optional[str] = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None) -> Array: ... diff --git a/jax/_src/config.py b/jax/_src/config.py index 8eadb26095fa..a0d575b37d6d 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -276,6 +276,8 @@ def __bool__(self) -> NoReturn: type(self).__name__)) def _set(self, value: _T) -> None: + if self._validator: + self._validator(value) self._value = value if self._update_global_hook: self._update_global_hook(value) @@ -1414,7 +1416,7 @@ def _update_disable_jit_thread_local(val): default_matmul_precision = optional_enum_state( name='jax_default_matmul_precision', - enum_values=['bfloat16', 'tensorfloat32', 'float32'], + enum_values=['default', 'high', 'highest', 'bfloat16', 'tensorfloat32', 'float32'], default=None, help=('Control the default matmul and conv precision for 32bit inputs.\n\n' diff --git a/jax/_src/core.py b/jax/_src/core.py index 812da70fde17..b0f4c97e6c3d 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -686,7 +686,7 @@ def __init__(self, trace: Trace): def _error_repr(self): if self.aval is None: return f"traced array with aval {self.aval}" - return f"traced array with shape {raise_to_shaped(self.aval).str_short()}." + return f"traced array with shape {raise_to_shaped(self.aval).str_short()}" def __array__(self, *args, **kw): raise TracerArrayConversionError(self) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 29ace0925e39..d4b51731dc8d 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1857,10 +1857,11 @@ def find_progenitors(self, tracer): produced = set(eqn.outvars) & active_vars if produced: active_vars.difference_update(produced) - active_vars.update(eqn.invars) + active_vars.update({v for v in eqn.invars if type(v) is Var}) invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars] constvars = active_vars & set(self.constvar_to_val) - const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars] + const_eqns = [eqn for eqn in self.eqns + if {v for v in eqn.invars if type(v) is Var} & constvars] return invar_positions, const_eqns def _const_folding_and_forwarding( diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index bf2d2c6b3da4..3150f972ce0d 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -23,6 +23,7 @@ import operator from typing import Any, Callable, TypeVar +import jax from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util from jax._src import config @@ -247,6 +248,7 @@ def cond(pred, true_fun, false_fun, *operands): if any(isinstance(op_aval, AbstractRef) for op_aval in ops_avals): raise ValueError("Cannot pass `Ref`s into `cond`.") true_jaxpr, false_jaxpr = jaxprs + out_tree, false_out_tree = out_trees if any(isinstance(out_aval, AbstractRef) for out_aval in true_jaxpr.out_avals + false_jaxpr.out_avals): @@ -255,6 +257,14 @@ def cond(pred, true_fun, false_fun, *operands): _check_tree_and_avals("true_fun and false_fun output", out_tree, true_jaxpr.out_avals, false_out_tree, false_jaxpr.out_avals) + # prune passhtrough outputs + true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr) + false_fwds = pe._jaxpr_forwarding(false_jaxpr.jaxpr) + in_fwd = [i if i == j else None for i, j in zip(true_fwds, false_fwds)] + keep = [f is None for f in in_fwd] + true_jaxpr = pe.prune_closed_jaxpr_outputs(true_jaxpr, keep) + false_jaxpr = pe.prune_closed_jaxpr_outputs(false_jaxpr, keep) + joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: @@ -273,6 +283,18 @@ def cond(pred, true_fun, false_fun, *operands): out = cond_p.bind( index, *consts, *ops, branches=(false_jaxpr, true_jaxpr), linear=tuple(linear)) + num_consts = len(consts) + out_ = iter(out) + + def _cast_to_array(x): + _copy = isinstance(x, np.bool_) + return jax.numpy.asarray(x, copy=_copy) + + out = [ + next(out_) if fwd is None else _cast_to_array(ops[fwd - num_consts]) + for fwd in in_fwd + ] + assert next(out_, None) is None return tree_unflatten(out_tree, out) @api_boundary diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 2fb00a99ec8e..11af8366b44c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2959,16 +2959,6 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): core.ShapedArray(rhs_aval.shape, aval_out.dtype)) lhs_dtype = rhs_dtype = aval_out.dtype - # TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul - if platform == "cpu": - if lhs_dtype == np.float16: - lhs = mlir.convert_hlo(ctx, lhs, lhs_aval, - core.ShapedArray(lhs_aval.shape, np.float32)) - - if rhs_dtype == np.float16: - rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, - core.ShapedArray(rhs_aval.shape, np.float32)) - dot_dnums = hlo.DotDimensionNumbers.get( lhs_batching_dimensions=list(lhs_batch), diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ffcfc617c1f0..692067bb6072 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -376,6 +376,8 @@ def result_type(*args: Any) -> DType: @jit def trunc(x: ArrayLike) -> Array: util.check_arraylike('trunc', x) + if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + return lax_internal.asarray(x) return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x)) @@ -1024,9 +1026,26 @@ def iscomplex(x: ArrayLike) -> Array: i = ufuncs.imag(x) return lax.ne(i, _lax_const(i, 0)) -@util.implements(np.isreal) @jit def isreal(x: ArrayLike) -> Array: + """Return boolean array showing where the input is real. + + JAX implementation of :func:`numpy.isreal`. + + Args: + x: input array to check. + + Returns: + A new array containing boolean values indicating real elements. + + See Also: + - :func:`jax.numpy.iscomplex` + - :func:`jax.numpy.isrealobj` + + Examples: + >>> jnp.isreal(jnp.array([False, 0j, 1, 2.1, 1+2j])) + Array([ True, True, True, True, False], dtype=bool) + """ i = ufuncs.imag(x) return lax.eq(i, _lax_const(i, 0)) @@ -1203,8 +1222,37 @@ def gradient_along_axis(a, h, axis): return a_grad[0] if len(axis_tuple) == 1 else a_grad -@util.implements(np.isrealobj) def isrealobj(x: Any) -> bool: + """Check if the input is not a complex number or an array containing complex elements. + + JAX implementation of :func:`numpy.isrealobj`. + + The function evaluates based on input type rather than value. + Inputs with zero imaginary parts are still considered complex. + + Args: + x: input object to check. + + Returns: + False if ``x`` is a complex number or an array containing at least one complex element, + True otherwise. + + See Also: + - :func:`jax.numpy.iscomplexobj` + - :func:`jax.numpy.isreal` + + Examples: + >>> jnp.isrealobj(0) + True + >>> jnp.isrealobj(1.2) + True + >>> jnp.isrealobj(jnp.array([1, 2])) + True + >>> jnp.isrealobj(1+2j) + False + >>> jnp.isrealobj(jnp.array([0, 1+2j])) + False + """ return not iscomplexobj(x) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 29c36ab3b59c..45595c4387a2 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -99,6 +99,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: """ check_arraylike("roots", p) p_arr = atleast_1d(promote_dtypes_inexact(p)[0]) + del p if p_arr.ndim != 1: raise ValueError("Input must be a rank-1 array.") if p_arr.size < 2: @@ -116,8 +117,8 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: @partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov')) -def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, - full: bool = False, w: Array | None = None, cov: bool = False +def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, + full: bool = False, w: ArrayLike | None = None, cov: bool = False ) -> Array | tuple[Array, ...]: r"""Least squares polynomial fit to data. @@ -217,42 +218,47 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, >>> p.shape, C.shape ((3, 3), (3, 3, 1)) """ - check_arraylike("polyfit", x, y) + if w is None: + check_arraylike("polyfit", x, y) + else: + check_arraylike("polyfit", x, y, w) deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 # check arguments + x_arr, y_arr = asarray(x), asarray(y) + del x, y if deg < 0: raise ValueError("expected deg >= 0") - if x.ndim != 1: + if x_arr.ndim != 1: raise TypeError("expected 1D vector for x") - if x.size == 0: + if x_arr.size == 0: raise TypeError("expected non-empty vector for x") - if y.ndim < 1 or y.ndim > 2: + if y_arr.ndim < 1 or y_arr.ndim > 2: raise TypeError("expected 1D or 2D array for y") - if x.shape[0] != y.shape[0]: + if x_arr.shape[0] != y_arr.shape[0]: raise TypeError("expected x and y to have same length") # set rcond if rcond is None: - rcond = len(x) * finfo(x.dtype).eps + rcond = len(x_arr) * finfo(x_arr.dtype).eps rcond = core.concrete_or_error(float, rcond, "rcond must be float") # set up least squares equation for powers of x - lhs = vander(x, order) - rhs = y + lhs = vander(x_arr, order) + rhs = y_arr # apply weighting if w is not None: - check_arraylike("polyfit", w) w, = promote_dtypes_inexact(w) - if w.ndim != 1: + w_arr = asarray(w) + if w_arr.ndim != 1: raise TypeError("expected a 1-d array for weights") - if w.shape[0] != y.shape[0]: + if w_arr.shape[0] != y_arr.shape[0]: raise TypeError("expected w and y to have the same length") - lhs *= w[:, np.newaxis] + lhs *= w_arr[:, np.newaxis] if rhs.ndim == 2: - rhs *= w[:, np.newaxis] + rhs *= w_arr[:, np.newaxis] else: - rhs *= w + rhs *= w_arr # scale lhs to improve condition number and solve scale = sqrt((lhs*lhs).sum(axis=0)) @@ -268,12 +274,12 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, if cov == "unscaled": fac = 1 else: - if len(x) <= order: + if len(x_arr) <= order: raise ValueError("the number of data points must exceed order " "to scale the covariance matrix") - fac = resids / (len(x) - order) + fac = resids / (len(x_arr) - order) fac = fac[0] #making np.array() of shape (1,) to int - if y.ndim == 1: + if y_arr.ndim == 1: return c, Vbase * fac else: return c, Vbase[:, :, np.newaxis] * fac @@ -282,7 +288,7 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, @jit -def poly(seq_of_zeros: Array) -> Array: +def poly(seq_of_zeros: ArrayLike) -> Array: r"""Returns the coefficients of a polynomial for the given sequence of roots. JAX implementation of :func:`numpy.poly`. @@ -340,30 +346,31 @@ def poly(seq_of_zeros: Array) -> Array: """ check_arraylike('poly', seq_of_zeros) seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros) - seq_of_zeros = atleast_1d(seq_of_zeros) + seq_of_zeros_arr = atleast_1d(seq_of_zeros) + del seq_of_zeros - sh = seq_of_zeros.shape + sh = seq_of_zeros_arr.shape if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0: # import at runtime to avoid circular import from jax._src.numpy import linalg - seq_of_zeros = linalg.eigvals(seq_of_zeros) + seq_of_zeros_arr = linalg.eigvals(seq_of_zeros_arr) - if seq_of_zeros.ndim != 1: + if seq_of_zeros_arr.ndim != 1: raise ValueError("input must be 1d or non-empty square 2d array.") - dt = seq_of_zeros.dtype - if len(seq_of_zeros) == 0: + dt = seq_of_zeros_arr.dtype + if len(seq_of_zeros_arr) == 0: return ones((), dtype=dt) a = ones((1,), dtype=dt) - for k in range(len(seq_of_zeros)): - a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full') + for k in range(len(seq_of_zeros_arr)): + a = convolve(a, array([1, -seq_of_zeros_arr[k]], dtype=dt), mode='full') return a @partial(jit, static_argnames=['unroll']) -def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array: +def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: r"""Evaluates the polynomial at specific values. JAX implementations of :func:`numpy.polyval`. @@ -417,25 +424,27 @@ def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array: [ 8., 34., 76.]], dtype=float32) """ check_arraylike("polyval", p, x) - p, x = promote_dtypes_inexact(p, x) - shape = lax.broadcast_shapes(p.shape[1:], x.shape) - y = lax.full_like(x, 0, shape=shape, dtype=x.dtype) - y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll) + p_arr, x_arr = promote_dtypes_inexact(p, x) + del p, x + shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape) + y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype) + y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) return y @implements(np.polyadd) @jit -def polyadd(a1: Array, a2: Array) -> Array: +def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: check_arraylike("polyadd", a1, a2) - a1, a2 = promote_dtypes(a1, a2) - if a2.shape[0] <= a1.shape[0]: - return a1.at[-a2.shape[0]:].add(a2) + a1_arr, a2_arr = promote_dtypes(a1, a2) + del a1, a2 + if a2_arr.shape[0] <= a1_arr.shape[0]: + return a1_arr.at[-a2_arr.shape[0]:].add(a2_arr) else: - return a2.at[-a1.shape[0]:].add(a1) + return a2_arr.at[-a1_arr.shape[0]:].add(a1_arr) @partial(jit, static_argnames=('m',)) -def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: +def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array: r"""Returns the coefficients of the integration of specified order of a polynomial. JAX implementation of :func:`numpy.polyint`. @@ -484,7 +493,8 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k check_arraylike("polyint", p, k) - p, k_arr = promote_dtypes_inexact(p, k) + p_arr, k_arr = promote_dtypes_inexact(p, k) + del p, k if m < 0: raise ValueError("Order of integral must be positive (see polyder)") k_arr = atleast_1d(k_arr) @@ -493,16 +503,16 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: if k_arr.shape != (m,): raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.") if m == 0: - return p + return p_arr else: - grid = (arange(len(p) + m, dtype=p.dtype)[np.newaxis] - - arange(m, dtype=p.dtype)[:, np.newaxis]) + grid = (arange(len(p_arr) + m, dtype=p_arr.dtype)[np.newaxis] + - arange(m, dtype=p_arr.dtype)[:, np.newaxis]) coeff = maximum(1, grid).prod(0)[::-1] - return true_divide(concatenate((p, k_arr)), coeff) + return true_divide(concatenate((p_arr, k_arr)), coeff) @partial(jit, static_argnames=('m',)) -def polyder(p: Array, m: int = 1) -> Array: +def polyder(p: ArrayLike, m: int = 1) -> Array: r"""Returns the coefficients of the derivative of specified order of a polynomial. JAX implementation of :func:`numpy.polyder`. @@ -540,14 +550,15 @@ def polyder(p: Array, m: int = 1) -> Array: """ check_arraylike("polyder", p) m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder") - p, = promote_dtypes_inexact(p) + p_arr, = promote_dtypes_inexact(p) + del p if m < 0: raise ValueError("Order of derivative must be positive") if m == 0: - return p - coeff = (arange(m, len(p), dtype=p.dtype)[np.newaxis] - - arange(m, dtype=p.dtype)[:, np.newaxis]).prod(0) - return p[:-m] * coeff[::-1] + return p_arr + coeff = (arange(m, len(p_arr), dtype=p_arr.dtype)[np.newaxis] + - arange(m, dtype=p_arr.dtype)[:, np.newaxis]).prod(0) + return p_arr[:-m] * coeff[::-1] _LEADING_ZEROS_DOC = """\ @@ -562,6 +573,7 @@ def polyder(p: Array, m: int = 1) -> Array: def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array: check_arraylike("polymul", a1, a2) a1_arr, a2_arr = promote_dtypes_inexact(a1, a2) + del a1, a2 if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1): a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f') if len(a1_arr) == 0: @@ -574,6 +586,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]: check_arraylike("polydiv", u, v) u_arr, v_arr = promote_dtypes_inexact(u, v) + del u, v m = len(u_arr) - 1 n = len(v_arr) - 1 scale = 1. / v_arr[0] @@ -589,7 +602,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> @implements(np.polysub) @jit -def polysub(a1: Array, a2: Array) -> Array: +def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: check_arraylike("polysub", a1, a2) a1, a2 = promote_dtypes(a1, a2) return polyadd(a1, -a2) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 23432b45b102..7fa667836e31 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -204,15 +204,6 @@ def force(x): force, x, "The axis argument must be known statically.") -# TODO(jakevdp) change promote_integers default to False -_PROMOTE_INTEGERS_DOC = """ -promote_integers : bool, default=True - If True, then integer inputs will be promoted to the widest available integer - dtype, following numpy's behavior. If False, the result will have the same dtype - as the input. ``promote_integers`` is ignored if ``dtype`` is specified. -""" - - @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, @@ -224,10 +215,76 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, initial=initial, where_=where, parallel_reduce=lax.psum, promote_integers=promote_integers) -@implements(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) + def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: + r"""Sum of the elements of the array over a given axis. + + JAX implementation of :func:`numpy.sum`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which the sum to be computed. + If None, the sum is computed along all the axes. + dtype: The type of the output array. Default=None. + out: Unused by JAX + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + initial: int or array, Default=None. Initial value for the sum. + where: int or array, default=None. The elements to be used in the sum. Array + should be broadcast compatible to the input. + promote_integers : bool, default=True. If True, then integer inputs will be + promoted to the widest available integer dtype, following numpy's behavior. + If False, the result will have the same dtype as the input. + ``promote_integers`` is ignored if ``dtype`` is specified. + + Returns: + An array of the sum along the given axis. + + See also: + - :func:`jax.numpy.prod`: Compute the product of array elements over a given + axis. + - :func:`jax.numpy.max`: Compute the maximum of array elements over given axis. + - :func:`jax.numpy.min`: Compute the minimum of array elements over given axis. + + Examples: + + By default, the sum is computed along all the axes. + + >>> x = jnp.array([[1, 3, 4, 2], + ... [5, 2, 6, 3], + ... [8, 1, 3, 9]]) + >>> jnp.sum(x) + Array(47, dtype=int32) + + If ``axis=1``, the sum is computed along axis 1. + + >>> jnp.sum(x, axis=1) + Array([10, 16, 21], dtype=int32) + + If ``keepdims=True``, ``ndim`` of the output is equal to that of the input. + + >>> jnp.sum(x, axis=1, keepdims=True) + Array([[10], + [16], + [21]], dtype=int32) + + To include only specific elements in the sum, you can use a``where``. + + >>> where=jnp.array([[0, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.sum(x, axis=1, keepdims=True, where=where) + Array([[ 4], + [ 9], + [12]], dtype=int32) + >>> where=jnp.array([[False], + ... [False], + ... [False]]) + >>> jnp.sum(x, axis=0, keepdims=True, where=where) + Array([[0, 0, 0, 0]], dtype=int32) + """ return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers) @@ -243,11 +300,77 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, promote_integers=promote_integers) -@implements(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) + def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: + r"""Return product of the array elements over a given axis. + + JAX implementation of :func:`numpy.prod`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which the product to be computed. + If None, the product is computed along all the axes. + dtype: The type of the output array. Default=None. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + initial: int or array, Default=None. Initial value for the product. + where: int or array, default=None. The elements to be used in the product. + Array should be broadcast compatible to the input. + promote_integers : bool, default=True. If True, then integer inputs will be + promoted to the widest available integer dtype, following numpy's behavior. + If False, the result will have the same dtype as the input. + ``promote_integers`` is ignored if ``dtype`` is specified. + out: Unused by JAX. + + Returns: + An array of the product along the given axis. + + See also: + - :func:`jax.numpy.sum`: Compute the sum of array elements over a given axis. + - :func:`jax.numpy.max`: Compute the maximum of array elements over given axis. + - :func:`jax.numpy.min`: Compute the minimum of array elements over given axis. + + Examples: + By default, ``jnp.prod`` computes along all the axes. + + >>> x = jnp.array([[1, 3, 4, 2], + ... [5, 2, 1, 3], + ... [2, 1, 3, 1]]) + >>> jnp.prod(x) + Array(4320, dtype=int32) + + If ``axis=1``, product is computed along axis 1. + + >>> jnp.prod(x, axis=1) + Array([24, 30, 6], dtype=int32) + + If ``keepdims=True``, ``ndim`` of the output is equal to that of the input. + + >>> jnp.prod(x, axis=1, keepdims=True) + Array([[24], + [30], + [ 6]], dtype=int32) + + To include only specific elements in the sum, you can use a``where``. + + >>> where=jnp.array([[1, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.prod(x, axis=1, keepdims=True, where=where) + Array([[4], + [3], + [6]], dtype=int32) + >>> where = jnp.array([[False], + ... [False], + ... [False]]) + >>> jnp.prod(x, axis=1, keepdims=True, where=where) + Array([[1], + [1], + [1]], dtype=int32) + """ return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers) @@ -261,10 +384,77 @@ def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax) -@implements(np.max, skip_params=['out']) + def max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + r"""Return the maximum of the array elements along a given axis. + + JAX implementation of :func:`numpy.max`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which the maximum to be computed. + If None, the maximum is computed along all the axes. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + initial: int or array, default=None. Initial value for the maximum. + where: int or array of boolean dtype, default=None. The elements to be used + in the maximum. Array should be broadcast compatible to the input. + ``initial`` must be specified when ``where`` is used. + out: Unused by JAX. + + Returns: + An array of maximum values along the given axis. + + See also: + - :func:`jax.numpy.min`: Compute the minimum of array elements along a given + axis. + - :func:`jax.numpy.sum`: Compute the sum of array elements along a given axis. + - :func:`jax.numpy.prod`: Compute the product of array elements along a given + axis. + + Examples: + + By default, ``jnp.max`` computes the maximum of elements along all the axes. + + >>> x = jnp.array([[9, 3, 4, 5], + ... [5, 2, 7, 4], + ... [8, 1, 3, 6]]) + >>> jnp.max(x) + Array(9, dtype=int32) + + If ``axis=1``, the maximum will be computed along axis 1. + + >>> jnp.max(x, axis=1) + Array([9, 7, 8], dtype=int32) + + If ``keepdims=True``, ``ndim`` of the output will be same of that of the input. + + >>> jnp.max(x, axis=1, keepdims=True) + Array([[9], + [7], + [8]], dtype=int32) + + To include only specific elements in computing the maximum, you can use + ``where``. It can either have same dimension as input + + >>> where=jnp.array([[0, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.max(x, axis=1, keepdims=True, initial=0, where=where) + Array([[4], + [7], + [8]], dtype=int32) + + or must be broadcast compatible with input. + + >>> where = jnp.array([[False], + ... [False], + ... [False]]) + >>> jnp.max(x, axis=0, keepdims=True, initial=0, where=where) + Array([[0, 0, 0, 0]], dtype=int32) + """ return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) @@ -276,10 +466,76 @@ def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin) -@implements(np.min, skip_params=['out']) + def min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + r"""Return the minimum of array elements along a given axis. + + JAX implementation of :func:`numpy.min`. + + Args: + a: Input array. + axis: int or array, default=None. Axis along which the minimum to be computed. + If None, the minimum is computed along all the axes. + keepdims: bool, default=False. If true, reduced axes are left in the result + with size 1. + initial: int or array, Default=None. Initial value for the minimum. + where: int or array, default=None. The elements to be used in the minimum. + Array should be broadcast compatible to the input. ``initial`` must be + specified when ``where`` is used. + out: Unused by JAX. + + Returns: + An array of minimum values along the given axis. + + See also: + - :func:`jax.numpy.max`: Compute the maximum of array elements along a given + axis. + - :func:`jax.numpy.sum`: Compute the sum of array elements along a given axis. + - :func:`jax.numpy.prod`: Compute the product of array elements along a given + axis. + + Examples: + By default, the minimum is computed along all the axes. + + >>> x = jnp.array([[2, 5, 1, 6], + ... [3, -7, -2, 4], + ... [8, -4, 1, -3]]) + >>> jnp.min(x) + Array(-7, dtype=int32) + + If ``axis=1``, the minimum is computed along axis 1. + + >>> jnp.min(x, axis=1) + Array([ 1, -7, -4], dtype=int32) + + If ``keepdims=True``, ``ndim`` of the output will be same of that of the input. + + >>> jnp.min(x, axis=1, keepdims=True) + Array([[ 1], + [-7], + [-4]], dtype=int32) + + To include only specific elements in computing the minimum, you can use + ``where``. ``where`` can either have same dimension as input. + + >>> where=jnp.array([[1, 0, 1, 0], + ... [0, 0, 1, 1], + ... [1, 1, 1, 0]], dtype=bool) + >>> jnp.min(x, axis=1, keepdims=True, initial=0, where=where) + Array([[ 0], + [-2], + [-4]], dtype=int32) + + or must be broadcast compatible with input. + + >>> where = jnp.array([[False], + ... [False], + ... [False]]) + >>> jnp.min(x, axis=0, keepdims=True, initial=0, where=where) + Array([[0, 0, 0, 0]], dtype=int32) + """ return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) @@ -307,8 +563,19 @@ def any(a: ArrayLike, axis: Axis = None, out: None = None, return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) -amin = min -amax = max +def amin(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: + """Alias of :func:`jax.numpy.min`.""" + return min(a, axis=axis, out=out, keepdims=keepdims, + initial=initial, where=where) + +def amax(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: + """Alias of :func:`jax.numpy.max`.""" + return max(a, axis=axis, out=out, keepdims=keepdims, + initial=initial, where=where) def _axis_size(a: ArrayLike, axis: int | Sequence[int]): if not isinstance(axis, (tuple, list)): diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index ccc448e56060..673ff2c4d11d 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -92,11 +92,17 @@ def sign(x: ArrayLike, /) -> Array: @implements(np.floor, module='numpy') @partial(jit, inline=True) def floor(x: ArrayLike, /) -> Array: + check_arraylike('floor', x) + if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + return lax.asarray(x) return lax.floor(*promote_args_inexact('floor', x)) @implements(np.ceil, module='numpy') @partial(jit, inline=True) def ceil(x: ArrayLike, /) -> Array: + check_arraylike('ceil', x) + if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + return lax.asarray(x) return lax.ceil(*promote_args_inexact('ceil', x)) @implements(np.exp, module='numpy') @@ -579,9 +585,43 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: divide = true_divide -@implements(np.floor_divide, module='numpy') @jit def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Calculates the floor division of x1 by x2 element-wise + + LAX-backend implementation of :func:`numpy.floor_divide`. + + Args: + x1: Input array, the dividend + x2: Input array, the divisor + + Returns: + An array-like object containing each of the quotients rounded down + to the nearest integer towards negative infinity. This is equivalent + to ``x1 // x2`` in Python. + + Examples: + >>> x1 = jnp.array([10, 20, 30]) + >>> x2 = jnp.array([3, 4, 7]) + >>> jnp.floor_divide(x1, x2) + Array([3, 5, 4], dtype=int32) + + >>> x1 = jnp.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]) + >>> x2 = 3 + >>> jnp.floor_divide(x1, x2) + Array([-2, -2, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=int32) + + >>> x1 = jnp.array([6, 6, 6], dtype=jnp.int32) + >>> x2 = jnp.array([2.0, 2.5, 3.0], dtype=jnp.float32) + >>> jnp.floor_divide(x1, x2) + Array([3., 2., 2.], dtype=float32) + + Note: + ``x1 // x2`` is equivalent to ``jnp.floor_divide(x1, x2)`` for arrays ``x1`` and ``x2`` + + See Also: + :func:`jnp.divide` and :func:`jnp.true_divide` for floating point division + """ x1, x2 = promote_args_numeric("floor_divide", x1, x2) dtype = dtypes.dtype(x1) if dtypes.issubdtype(dtype, np.unsignedinteger): diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 667c2707d68e..b8efde364fc8 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -17,7 +17,6 @@ from __future__ import annotations from collections.abc import Sequence -import sys from typing import Callable, Union import warnings @@ -36,11 +35,8 @@ from jax._src.typing import Array, ArrayLike -if sys.version_info >= (3, 10): - from types import EllipsisType - SingleIndex = int | slice | Sequence[int] | Array | EllipsisType | None -else: - SingleIndex = Union[int, slice, Sequence[int], Array, None] +from types import EllipsisType +SingleIndex = int | slice | Sequence[int] | Array | EllipsisType | None Index = Union[SingleIndex, tuple[SingleIndex, ...]] Scalar = Union[complex, float, int, np.number] diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 29770cfd4417..aca5ce6e67db 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -415,7 +415,7 @@ def supported_dtypes(): return types def is_device_rocm(): - return xla_bridge.get_backend().platform_version.startswith('rocm') + return 'rocm' in xla_bridge.get_backend().platform_version def is_device_cuda(): return 'cuda' in xla_bridge.get_backend().platform_version diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index f2600db6187e..87925c142e50 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -30,7 +30,6 @@ import os import pkgutil import platform as py_platform -import sys import threading import traceback from typing import Any, Callable, Union @@ -570,12 +569,7 @@ def discover_pjrt_plugins() -> None: logger.debug("No jax_plugins namespace packages available") # Augment with advertised entrypoints. - if sys.version_info < (3, 10): - # Use the backport library because it provides a forward-compatible - # implementation. - from importlib_metadata import entry_points - else: - from importlib.metadata import entry_points + from importlib.metadata import entry_points for entry_point in entry_points(group="jax_plugins"): logger.debug("Discovered entry-point based JAX plugin: %s", diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 24d18027b115..e0d8c4ee67f5 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -67,6 +67,7 @@ broadcast_arrays as broadcast_arrays, broadcast_to as broadcast_to, can_cast as can_cast, + ceil as ceil, complex128 as complex128, complex64 as complex64, concat as concat, @@ -87,6 +88,7 @@ flip as flip, float32 as float32, float64 as float64, + floor as floor, floor_divide as floor_divide, from_dlpack as from_dlpack, full as full, @@ -162,6 +164,7 @@ tile as tile, tril as tril, triu as triu, + trunc as trunc, uint16 as uint16, uint32 as uint32, uint64 as uint64, @@ -192,11 +195,8 @@ ) from jax.experimental.array_api._elementwise_functions import ( - ceil as ceil, clip as clip, - floor as floor, hypot as hypot, - trunc as trunc, ) from jax.experimental.array_api._statistical_functions import ( diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py index 5587b9a60f1a..103f8ab7d1ef 100644 --- a/jax/experimental/array_api/_elementwise_functions.py +++ b/jax/experimental/array_api/_elementwise_functions.py @@ -18,15 +18,6 @@ from jax._src.numpy.util import promote_args -# TODO(micky774): Update jnp.ceil to preserve integral dtype -def ceil(x, /): - """Rounds each element x_i of the input array x to the smallest (i.e., closest to -infinity) integer-valued number that is not less than x_i.""" - x, = promote_args("ceil", x) - if isdtype(x.dtype, "integral"): - return x - return jax.numpy.ceil(x) - - # TODO(micky774): Remove when jnp.clip deprecation is completed # (began 2024-4-2) and default behavior is Array API 2023 compliant def clip(x, /, min=None, max=None): @@ -43,15 +34,6 @@ def clip(x, /, min=None, max=None): return jax.numpy.clip(x, min=min, max=max) -# TODO(micky774): Update jnp.floor to preserve integral dtype -def floor(x, /): - """Rounds each element x_i of the input array x to the greatest (i.e., closest to +infinity) integer-valued number that is not greater than x_i.""" - x, = promote_args("floor", x) - if isdtype(x.dtype, "integral"): - return x - return jax.numpy.floor(x) - - # TODO(micky774): Remove when jnp.hypot deprecation is completed # (began 2024-4-14) and default behavior is Array API 2023 compliant def hypot(x1, x2, /): @@ -64,12 +46,3 @@ def hypot(x1, x2, /): "values first, such as by using jnp.real or jnp.imag to take the real " "or imaginary components respectively.") return jax.numpy.hypot(x1, x2) - - -# TODO(micky774): Update jnp.trunc to preserve integral dtype -def trunc(x, /): - """Rounds each element x_i of the input array x to the nearest integer-valued number that is closer to zero than x_i.""" - x, = promote_args("trunc", x) - if isdtype(x.dtype, "integral"): - return x - return jax.numpy.trunc(x) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index c942c4867f40..4cbe52383751 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -1224,15 +1224,24 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic out_n_batch = lhs.n_batch + rhs.n_batch - len(lhs_batch) out_nse = min(out_nse, math.prod(out_aval.shape[out_n_batch:])) + lhs_batch_shape = np.broadcast_shapes( + tuple(lhs_data.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch), + tuple(lhs_indices.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch), + ) + rhs_batch_shape = np.broadcast_shapes( + tuple(rhs_data.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch), + tuple(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch), + ) + data_shape = ( *(lhs_shape[dim] for dim in lhs_batch), - *(lhs_data.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch), - *(rhs_data.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch), + *lhs_batch_shape, + *rhs_batch_shape, out_nse) indices_shape = ( *(lhs_shape[dim] for dim in lhs_batch), - *(lhs_indices.shape[dim] for dim in range(lhs.n_batch) if dim not in lhs_batch), - *(rhs_indices.shape[dim] for dim in range(rhs.n_batch) if dim not in rhs_batch), + *lhs_batch_shape, + *rhs_batch_shape, out_nse, lhs.n_sparse + rhs.n_sparse - 2 * len(lhs_contracting)) data_aval = core.ShapedArray(data_shape, out_aval.dtype) diff --git a/jax/experimental/sparse/nm.py b/jax/experimental/sparse/nm.py index 344e8c3bae0e..6c827325befc 100644 --- a/jax/experimental/sparse/nm.py +++ b/jax/experimental/sparse/nm.py @@ -182,7 +182,7 @@ def _nm_spmm_abstract_eval( mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="cuda") if gpu_sparse.rocm_is_supported: - mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="ROCM") + mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="rocm") # -------------------------------------------------------------------- # nm_pack diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 6d7f48408b57..089f69996d54 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -2,7 +2,8 @@ from __future__ import annotations import builtins -from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, TypeVar, Union, overload +from collections.abc import Callable, Sequence +from typing import Any, Literal, NamedTuple, TypeVar, Union, overload from jax._src import core as _core from jax._src import dtypes as _dtypes @@ -33,21 +34,21 @@ def acos(x: ArrayLike, /) -> Array: ... def acosh(x: ArrayLike, /) -> Array: ... def add(x: ArrayLike, y: ArrayLike, /) -> Array: ... def amax(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def amin(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def all(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., *, where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., *, where: ArrayLike | None = ...) -> Array: ... def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ..., atol: ArrayLike = ..., equal_nan: builtins.bool = ...) -> Array: ... alltrue = all def angle(z: ArrayLike, deg: builtins.bool = ...) -> Array: ... def any(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., *, where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., *, where: ArrayLike | None = ...) -> Array: ... def append( - arr: ArrayLike, values: ArrayLike, axis: Optional[int] = ... + arr: ArrayLike, values: ArrayLike, axis: int | None = ... ) -> Array: ... def apply_along_axis(func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs) -> Array: ... @@ -56,9 +57,9 @@ def apply_over_axes( ) -> Array: ... def arange( start: DimSize, - stop: Optional[DimSize] = ..., - step: Optional[DimSize] = ..., - dtype: Optional[DTypeLike] = ..., *, + stop: DimSize | None = ..., + step: DimSize | None = ..., + dtype: DTypeLike | None = ..., device: _Device | _Sharding | None = ..., ) -> Array: ... def arccos(x: ArrayLike, /) -> Array: ... @@ -70,20 +71,20 @@ def arctan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... def arctanh(x: ArrayLike, /) -> Array: ... def argmax( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - keepdims: Optional[builtins.bool] = ..., + keepdims: builtins.bool | None = ..., ) -> Array: ... def argmin( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - keepdims: Optional[builtins.bool] = ..., + keepdims: builtins.bool | None = ..., ) -> Array: ... def argpartition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ... def argsort( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., *, stable: builtins.bool = ..., descending: builtins.bool = ..., @@ -93,8 +94,8 @@ def argsort( def argwhere( a: ArrayLike, *, - size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ..., + size: int | None = ..., + fill_value: ArrayLike | None = ..., ) -> Array: ... around = round def array(object: Any, dtype: DTypeLike | None = ..., copy: builtins.bool = True, @@ -106,17 +107,17 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: ... array_repr = _np.array_repr def array_split( ary: ArrayLike, - indices_or_sections: Union[int, Sequence[int], ArrayLike], + indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = ..., ) -> list[Array]: ... array_str = _np.array_str def asarray( - a: Any, dtype: Optional[DTypeLike] = ..., order: Optional[str] = ..., - *, copy: Optional[builtins.bool] = ... + a: Any, dtype: DTypeLike | None = ..., order: str | None = ..., + *, copy: builtins.bool | None = ... ) -> Array: ... def asin(x: ArrayLike, /) -> Array: ... def asinh(x: ArrayLike, /) -> Array: ... -def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: _Device | _Sharding | None = ...) -> Array: ... +def astype(a: ArrayLike, dtype: DTypeLike | None, /, *, copy: builtins.bool = ..., device: _Device | _Sharding | None = ...) -> Array: ... def atan(x: ArrayLike, /) -> Array: ... def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... def atanh(x: ArrayLike, /) -> Array: ... @@ -142,19 +143,19 @@ def atleast_3d(x: ArrayLike, /) -> Array: ... def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... @overload -def average(a: ArrayLike, axis: _Axis = ..., weights: Optional[ArrayLike] = ..., +def average(a: ArrayLike, axis: _Axis = ..., weights: ArrayLike | None = ..., returned: Literal[False] = False, keepdims: builtins.bool = False) -> Array: ... @overload -def average(a: ArrayLike, axis: _Axis = ..., weights: Optional[ArrayLike] = ..., *, +def average(a: ArrayLike, axis: _Axis = ..., weights: ArrayLike | None = ..., *, returned: Literal[True], keepdims: builtins.bool = False) -> tuple[Array, Array]: ... @overload -def average(a: ArrayLike, axis: _Axis = ..., weights: Optional[ArrayLike] = ..., - returned: builtins.bool = False, keepdims: builtins.bool = False) -> Union[Array, tuple[Array, Array]]: ... +def average(a: ArrayLike, axis: _Axis = ..., weights: ArrayLike | None = ..., + returned: builtins.bool = False, keepdims: builtins.bool = False) -> Array | tuple[Array, Array]: ... def bartlett(M: int) -> Array: ... bfloat16: Any -def bincount(x: ArrayLike, weights: Optional[ArrayLike] = ..., - minlength: int = ..., *, length: Optional[int] = ...) -> Array: ... +def bincount(x: ArrayLike, weights: ArrayLike | None = ..., + minlength: int = ..., *, length: int | None = ...) -> Array: ... def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_count(x: ArrayLike, /) -> Array: ... def bitwise_invert(x: ArrayLike, /) -> Array: ... @@ -164,7 +165,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... def blackman(M: int) -> Array: ... -def block(arrays: Union[ArrayLike, Sequence[ArrayLike], Sequence[Sequence[ArrayLike]]]) -> Array: ... +def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ... bool: Any bool_: Any def broadcast_arrays(*args: ArrayLike) -> list[Array]: ... @@ -173,8 +174,8 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]: ... def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... @overload -def broadcast_shapes(*shapes: Sequence[Union[int, _core.Tracer]] - ) -> tuple[Union[int, _core.Tracer], ...]: ... +def broadcast_shapes(*shapes: Sequence[int | _core.Tracer] + ) -> tuple[int | _core.Tracer, ...]: ... def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: ... c_: _CClass @@ -188,53 +189,53 @@ def choose(a: ArrayLike, choices: Sequence[ArrayLike], def clip( x: ArrayLike | None = ..., /, - min: Optional[ArrayLike] = ..., - max: Optional[ArrayLike] = ..., + min: ArrayLike | None = ..., + max: ArrayLike | None = ..., a: ArrayLike | DeprecatedArg | None = ..., a_min: ArrayLike | DeprecatedArg | None = ..., a_max: ArrayLike | DeprecatedArg | None = ... ) -> Array: ... def column_stack( - tup: Union[_np.ndarray, Array, Sequence[ArrayLike]] + tup: _np.ndarray | Array | Sequence[ArrayLike] ) -> Array: ... complex128: Any complex64: Any complex_: Any complexfloating = _np.complexfloating -def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = ..., +def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = ..., out: None = ...) -> Array: ... def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: ... def concatenate( - arrays: Union[_np.ndarray, Array, Sequence[ArrayLike]], - axis: Optional[int] = ..., - dtype: Optional[DTypeLike] = ..., + arrays: _np.ndarray | Array | Sequence[ArrayLike], + axis: int | None = ..., + dtype: DTypeLike | None = ..., ) -> Array: ... def conjugate(x: ArrayLike, /) -> Array: ... conj = conjugate def convolve(a: ArrayLike, v: ArrayLike, mode: str = ..., *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... -def copy(a: ArrayLike, order: Optional[str] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... +def copy(a: ArrayLike, order: str | None = ...) -> Array: ... def copysign(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = ..., rowvar: builtins.bool = ...) -> Array: ... +def corrcoef(x: ArrayLike, y: ArrayLike | None = ..., rowvar: builtins.bool = ...) -> Array: ... def correlate(a: ArrayLike, v: ArrayLike, mode: str = ..., *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... def cos(x: ArrayLike, /) -> Array: ... def cosh(x: ArrayLike, /) -> Array: ... def count_nonzero(a: ArrayLike, axis: _Axis = ..., keepdims: builtins.bool = ...) -> Array: ... -def cov(m: ArrayLike, y: Optional[ArrayLike] = ..., rowvar: builtins.bool = ..., - bias: builtins.bool = ..., ddof: Optional[int] = ..., - fweights: Optional[ArrayLike] = ..., - aweights: Optional[ArrayLike] = ...) -> Array: ... +def cov(m: ArrayLike, y: ArrayLike | None = ..., rowvar: builtins.bool = ..., + bias: builtins.bool = ..., ddof: int | None = ..., + fweights: ArrayLike | None = ..., + aweights: ArrayLike | None = ...) -> Array: ... def cross( a: ArrayLike, b: ArrayLike, axisa: int = -1, axisb: int = -1, axisc: int = -1, - axis: Optional[int] = ..., + axis: int | None = ..., ) -> Array: ... csingle: Any def cumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., @@ -250,8 +251,8 @@ def deg2rad(x: ArrayLike, /) -> Array: ... degrees = rad2deg def delete( arr: ArrayLike, - obj: Union[ArrayLike, slice], - axis: Optional[int] = ..., + obj: ArrayLike | slice, + axis: int | None = ..., *, assume_unique_indices: builtins.bool = ..., ) -> Array: ... @@ -263,33 +264,33 @@ def diagonal( a: ArrayLike, offset: ArrayLike = ..., axis1: int = ..., axis2: int = ... ): ... def diff(a: ArrayLike, n: int = ..., axis: int = ..., - prepend: Optional[ArrayLike] = ..., - append: Optional[ArrayLike] = ...) -> Array: ... + prepend: ArrayLike | None = ..., + append: ArrayLike | None = ...) -> Array: ... def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ...) -> Array: ... divide = true_divide def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ... def dot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... double: Any def dsplit( - ary: ArrayLike, indices_or_sections: Union[int, ArrayLike] + ary: ArrayLike, indices_or_sections: int | ArrayLike ) -> list[Array]: ... -def dstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]], - dtype: Optional[DTypeLike] = ...) -> Array: ... +def dstack(tup: _np.ndarray | Array | Sequence[ArrayLike], + dtype: DTypeLike | None = ...) -> Array: ... dtype = _np.dtype e: float -def ediff1d(ary: ArrayLike, to_end: Optional[ArrayLike] = ..., - to_begin: Optional[ArrayLike] = ...) -> Array: ... +def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = ..., + to_begin: ArrayLike | None = ...) -> Array: ... @overload def einsum( subscript: str, /, *operands: ArrayLike, out: None = ..., - optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ..., + preferred_element_type: DTypeLike | None = ..., _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., ) -> Array: ... @@ -298,11 +299,11 @@ def einsum( def einsum( arr: ArrayLike, axes: Sequence[Any], /, - *operands: Union[ArrayLike, Sequence[Any]], + *operands: ArrayLike | Sequence[Any], out: None = ..., - optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ..., + preferred_element_type: DTypeLike | None = ..., _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., ) -> Array: ... @@ -311,9 +312,9 @@ def einsum( subscripts, /, *operands, out: None = ..., - optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ..., + preferred_element_type: DTypeLike | None = ..., _use_xeinsum: builtins.bool = ..., _dot_general: Callable[..., Array] = ..., ) -> Array: ... @@ -322,38 +323,38 @@ def einsum( def einsum_path( subscripts: str, /, *operands: ArrayLike, - optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... @overload def einsum_path( arr: ArrayLike, axes: Sequence[Any], /, - *operands: Union[ArrayLike, Sequence[Any]], - optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., + *operands: ArrayLike | Sequence[Any], + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... @overload def einsum_path( subscripts, /, *operands: ArrayLike, - optimize: Union[str, builtins.bool, list[tuple[int, ...]]] = ..., + optimize: str | builtins.bool | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... -def empty(shape: Any, dtype: Optional[DTypeLike] = ..., - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... -def empty_like(prototype: Union[ArrayLike, DuckTypedArray], - dtype: Optional[DTypeLike] = ..., +def empty(shape: Any, dtype: DTypeLike | None = ..., + device: _Device | _Sharding | None = ...) -> Array: ... +def empty_like(prototype: ArrayLike | DuckTypedArray, + dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... + device: _Device | _Sharding | None = ...) -> Array: ... def equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... euler_gamma: float def exp(x: ArrayLike, /) -> Array: ... def exp2(x: ArrayLike, /) -> Array: ... -def expand_dims(a: ArrayLike, axis: Union[int, Sequence[int]]) -> Array: ... +def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: ... def expm1(x: ArrayLike, /) -> Array: ... def extract(condition: ArrayLike, arr: ArrayLike, *, size: int | None = None, fill_value: ArrayLike = 0) -> Array: ... -def eye(N: DimSize, M: Optional[DimSize] = ..., k: int | ArrayLike = ..., - dtype: Optional[DTypeLike] = ..., *, +def eye(N: DimSize, M: DimSize | None = ..., k: int | ArrayLike = ..., + dtype: DTypeLike | None = ..., *, device: _Device | _Sharding | None = ...) -> Array: ... def fabs(x: ArrayLike, /) -> Array: ... finfo = _dtypes.finfo @@ -361,12 +362,12 @@ def fix(x: ArrayLike, out: None = ...) -> Array: ... def flatnonzero( a: ArrayLike, *, - size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike]] = ..., + size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike] = ..., ) -> Array: ... flexible = _np.flexible def flip( - m: ArrayLike, axis: Optional[Union[int, Sequence[int]]] = ... + m: ArrayLike, axis: int | Sequence[int] | None = ... ) -> Array: ... def fliplr(m: ArrayLike) -> Array: ... @@ -390,7 +391,7 @@ def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ... def from_dlpack(x: Any, /, *, device: _Device | _Sharding | None = None, copy: builtins.bool | None = None) -> Array: ... -def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = ..., +def frombuffer(buffer: bytes | Any, dtype: DTypeLike = ..., count: int = ..., offset: int = ...) -> Array: ... def fromfile(*args, **kwargs): ... def fromfunction(function: Callable[..., Array], shape: Any, @@ -400,12 +401,12 @@ def fromstring( string: str, dtype: DTypeLike = ..., count: int = ..., *, sep: str ) -> Array: ... def full(shape: Any, fill_value: ArrayLike, - dtype: Optional[DTypeLike] = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... -def full_like(a: Union[ArrayLike, DuckTypedArray], - fill_value: ArrayLike, dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., *, + device: _Device | _Sharding | None = ...) -> Array: ... +def full_like(a: ArrayLike | DuckTypedArray, + fill_value: ArrayLike, dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... + device: _Device | _Sharding | None = ...) -> Array: ... def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: ... generic = _np.generic def geomspace( @@ -413,48 +414,48 @@ def geomspace( stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., axis: int = ..., ) -> Array: ... get_printoptions = _np.get_printoptions def gradient(f: ArrayLike, *varargs: ArrayLike, - axis: Optional[Union[int, Sequence[int]]] = ..., - edge_order: Optional[int] = ...) -> Union[Array, list[Array]]: ... + axis: int | Sequence[int] | None = ..., + edge_order: int | None = ...) -> Array | list[Array]: ... def greater(x: ArrayLike, y: ArrayLike, /) -> Array: ... def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... def hamming(M: int) -> Array: ... def hanning(M: int) -> Array: ... def heaviside(x: ArrayLike, y: ArrayLike, /) -> Array: ... def histogram(a: ArrayLike, bins: ArrayLike = ..., - range: Optional[Sequence[ArrayLike]] = ..., - weights: Optional[ArrayLike] = ..., - density: Optional[builtins.bool] = ...) -> tuple[Array, Array]: ... + range: Sequence[ArrayLike] | None = ..., + weights: ArrayLike | None = ..., + density: builtins.bool | None = ...) -> tuple[Array, Array]: ... def histogram2d( x: ArrayLike, y: ArrayLike, - bins: Union[ArrayLike, Sequence[ArrayLike]] = ..., - range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]] = ..., - weights: Optional[ArrayLike] = ..., - density: Optional[builtins.bool] = ..., + bins: ArrayLike | Sequence[ArrayLike] = ..., + range: Sequence[None | Array | Sequence[ArrayLike]] | None = ..., + weights: ArrayLike | None = ..., + density: builtins.bool | None = ..., ) -> tuple[Array, Array, Array]: ... def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = ..., - range: Union[None, Array, Sequence[ArrayLike]] = ..., - weights: Optional[ArrayLike] = ...) -> Array: ... + range: None | Array | Sequence[ArrayLike] = ..., + weights: ArrayLike | None = ...) -> Array: ... def histogramdd( sample: ArrayLike, - bins: Union[ArrayLike, Sequence[ArrayLike]] = ..., - range: Optional[Sequence[Union[None, Array, Sequence[ArrayLike]]]] = ..., - weights: Optional[ArrayLike] = ..., - density: Optional[builtins.bool] = ..., + bins: ArrayLike | Sequence[ArrayLike] = ..., + range: Sequence[None | Array | Sequence[ArrayLike]] | None = ..., + weights: ArrayLike | None = ..., + density: builtins.bool | None = ..., ) -> tuple[Array, list[Array]]: ... def hsplit( - ary: ArrayLike, indices_or_sections: Union[int, ArrayLike] + ary: ArrayLike, indices_or_sections: int | ArrayLike ) -> list[Array]: ... -def hstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]], - dtype: Optional[DTypeLike] = ...) -> Array: ... +def hstack(tup: _np.ndarray | Array | Sequence[ArrayLike], + dtype: DTypeLike | None = ...) -> Array: ... def hypot(x: ArrayLike, y: ArrayLike, /) -> Array: ... def i0(x: ArrayLike) -> Array: ... -def identity(n: DimSize, dtype: Optional[DTypeLike] = ...) -> Array: ... +def identity(n: DimSize, dtype: DTypeLike | None = ...) -> Array: ... iinfo = _dtypes.iinfo def imag(x: ArrayLike, /) -> Array: ... index_exp = _np.index_exp @@ -467,15 +468,15 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, *, sparse: Literal[True]) -> tuple[Array, ...]: ... @overload def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, - sparse: builtins.bool = False) -> Union[Array, tuple[Array, ...]]: ... + sparse: builtins.bool = False) -> Array | tuple[Array, ...]: ... inexact = _np.inexact inf: float def inner( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... -def insert(arr: ArrayLike, obj: Union[ArrayLike, slice], values: ArrayLike, - axis: Optional[int] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... +def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, + axis: int | None = ...) -> Array: ... int16: Any int32: Any int4: Any @@ -484,17 +485,17 @@ int8: Any int_: Any integer = _np.integer def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, - left: Union[ArrayLike, str, None] = ..., - right: Union[ArrayLike, str, None] = ..., - period: Optional[ArrayLike] = ...) -> Array: ... + left: ArrayLike | str | None = ..., + right: ArrayLike | str | None = ..., + period: ArrayLike | None = ...) -> Array: ... def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ..., - return_indices: builtins.bool = ...) -> Union[Array, tuple[Array, Array, Array]]: ... + return_indices: builtins.bool = ...) -> Array | tuple[Array, Array, Array]: ... def invert(x: ArrayLike, /) -> Array: ... def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ..., atol: ArrayLike = ..., equal_nan: builtins.bool = ...) -> Array: ... def iscomplex(m: ArrayLike) -> Array: ... def iscomplexobj(x: Any) -> builtins.bool: ... -def isdtype(dtype: DTypeLike, kind: Union[DType, str, tuple[Union[DType, str], ...]]) -> builtins.bool: ... +def isdtype(dtype: DTypeLike, kind: DType | str | tuple[DType | str, ...]) -> builtins.bool: ... def isfinite(x: ArrayLike, /) -> Array: ... def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: builtins.bool = ..., invert: builtins.bool = ...) -> Array: ... @@ -520,23 +521,23 @@ def lexsort(keys: Sequence[ArrayLike], axis: int = ...) -> Array: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, retstep: Literal[False] = False, - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., axis: int = 0) -> Array: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int, endpoint: builtins.bool, retstep: Literal[True], - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., axis: int = 0) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, *, retstep: Literal[True], - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., axis: int = 0) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, retstep: builtins.bool = False, - dtype: Optional[DTypeLike] = ..., - axis: int = 0) -> Union[Array, tuple[Array, Array]]: ... + dtype: DTypeLike | None = ..., + axis: int = 0) -> Array | tuple[Array, Array]: ... def load(*args: Any, **kwargs: Any) -> Array: ... def log(x: ArrayLike, /) -> Array: ... @@ -551,20 +552,20 @@ def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... def logspace(start: ArrayLike, stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., base: ArrayLike = ..., - dtype: Optional[DTypeLike] = ..., axis: int = ...) -> Array: ... + dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ... def mask_indices( n: int, mask_func: Callable, k: int = ... ) -> tuple[Array, ...]: ... def matmul( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... def matrix_transpose(x: ArrayLike, /) -> Array: ... max = amax def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: ... def mean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ...) -> Array: ... -def median(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ..., + where: ArrayLike | None = ...) -> Array: ... +def median(a: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., keepdims: builtins.bool = ...) -> Array: ... def meshgrid(*xi: ArrayLike, copy: builtins.bool = ..., sparse: builtins.bool = ..., @@ -574,121 +575,121 @@ min = amin def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: ... def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ... -def moveaxis(a: ArrayLike, source: Union[int, Sequence[int]], - destination: Union[int, Sequence[int]]) -> Array: ... +def moveaxis(a: ArrayLike, source: int | Sequence[int], + destination: int | Sequence[int]) -> Array: ... def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: ... nan: float def nan_to_num(x: ArrayLike, copy: builtins.bool = ..., nan: ArrayLike = ..., - posinf: Optional[ArrayLike] = ..., - neginf: Optional[ArrayLike] = ...) -> Array: ... + posinf: ArrayLike | None = ..., + neginf: ArrayLike | None = ...) -> Array: ... def nanargmax( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - keepdims: Optional[builtins.bool] = ..., + keepdims: builtins.bool | None = ..., ) -> Array: ... def nanargmin( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - keepdims: Optional[builtins.bool] = ..., + keepdims: builtins.bool | None = ..., ) -> Array: ... def nancumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def nancumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def nanmax(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def nanmean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., - where: Optional[ArrayLike] = ...) -> Array: ... -def nanmedian(a: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ..., + where: ArrayLike | None = ...) -> Array: ... +def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., keepdims: builtins.bool = ...) -> Array: ... def nanmin(a: ArrayLike, axis: _Axis = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def nanpercentile(a: ArrayLike, q: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = ..., + axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... def nanprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., - keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... -def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ..., + keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... +def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... def nanstd(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + where: ArrayLike | None = ...) -> Array: ... def nansum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., - initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ...) -> Array: ... + initial: ArrayLike | None = ..., + where: ArrayLike | None = ...) -> Array: ... def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = 0, keepdims: builtins.bool = False, - where: Optional[ArrayLike] = ...) -> Array: ... + where: ArrayLike | None = ...) -> Array: ... ndarray = Array ndim = _np.ndim def negative(x: ArrayLike, /) -> Array: ... newaxis = ... def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def nonzero(a: ArrayLike, *, size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... +def nonzero(a: ArrayLike, *, size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... ) -> tuple[Array, ...]: ... def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... number = _np.number object_ = _np.object_ ogrid: _Ogrid -def ones(shape: Any, dtype: Optional[DTypeLike] = ..., - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... -def ones_like(a: Union[ArrayLike, DuckTypedArray], - dtype: Optional[DTypeLike] = ..., +def ones(shape: Any, dtype: DTypeLike | None = ..., + device: _Device | _Sharding | None = ...) -> Array: ... +def ones_like(a: ArrayLike | DuckTypedArray, + dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... + device: _Device | _Sharding | None = ...) -> Array: ... def outer(a: ArrayLike, b: Array, out: None = ...) -> Array: ... def packbits( - a: ArrayLike, axis: Optional[int] = ..., bitorder: str = ... + a: ArrayLike, axis: int | None = ..., bitorder: str = ... ) -> Array: ... PadValueLike = Union[_T, Sequence[_T], Sequence[Sequence[_T]]] def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | _np.ndarray], - mode: Union[str, Callable[..., Any]] = ..., **kwargs) -> Array: ... + mode: str | Callable[..., Any] = ..., **kwargs) -> Array: ... def partition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ... def percentile(a: ArrayLike, q: ArrayLike, - axis: Optional[Union[int, tuple[int, ...]]] = ..., + axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... def permute_dims(x: ArrayLike, /, axes: tuple[int, ...]) -> Array: ... pi: float -def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]], - funclist: Sequence[Union[ArrayLike, Callable[..., Array]]], +def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], + funclist: Sequence[ArrayLike | Callable[..., Array]], *args, **kw) -> Array: ... def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, inplace: builtins.bool = ...) -> Array: ... -def poly(seq_of_zeros: Array) -> Array: ... -def polyadd(a1: Array, a2: Array) -> Array: ... -def polyder(p: Array, m: int = ...) -> Array: ... +def poly(seq_of_zeros: ArrayLike) -> Array: ... +def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: ... +def polyder(p: ArrayLike, m: int = ...) -> Array: ... def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: builtins.bool = ...) -> tuple[Array, Array]: ... -def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = ..., - full: builtins.bool = ..., w: Optional[Array] = ..., cov: builtins.bool = ... - ) -> Union[Array, tuple[Array, ...]]: ... -def polyint(p: Array, m: int = ..., k: Optional[int] = ...) -> Array: ... +def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = ..., + full: builtins.bool = ..., w: ArrayLike | None = ..., cov: builtins.bool = ... + ) -> Array | tuple[Array, ...]: ... +def polyint(p: ArrayLike, m: int = ..., k: int | ArrayLike | None = ...) -> Array: ... def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: builtins.bool = ...) -> Array: ... -def polysub(a1: Array, a2: Array) -> Array: ... -def polyval(p: Array, x: Array, *, unroll: int = ...) -> Array: ... +def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: ... +def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = ...) -> Array: ... def positive(x: ArrayLike, /) -> Array: ... def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ... def power(x: ArrayLike, y: ArrayLike, /) -> Array: ... printoptions = _np.printoptions def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., - initial: Optional[ArrayLike] = ..., where: Optional[ArrayLike] = ..., + initial: ArrayLike | None = ..., where: ArrayLike | None = ..., promote_integers: builtins.bool = ...) -> Array: ... product = prod promote_types = _np.promote_types @@ -696,7 +697,7 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ...) -> Array: ... def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ... -def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ..., +def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... r_: _RClass @@ -709,19 +710,19 @@ def real(x: ArrayLike, /) -> Array: ... def reciprocal(x: ArrayLike, /) -> Array: ... register_jax_array_methods: Any def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = ..., *, - total_repeat_length: Optional[int] = ...) -> Array: ... +def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *, + total_repeat_length: int | None = ...) -> Array: ... def reshape( - a: ArrayLike, shape: Union[DimSize, Shape] = ..., - newshape: Union[DimSize, Shape] | None = ..., order: str = ... + a: ArrayLike, shape: DimSize | Shape = ..., + newshape: DimSize | Shape | None = ..., order: str = ... ) -> Array: ... def resize(a: ArrayLike, new_shape: Shape) -> Array: ... def result_type(*args: Any) -> DType: ... def right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def rint(x: ArrayLike, /) -> Array: ... -def roll(a: ArrayLike, shift: Union[ArrayLike, Sequence[int]], - axis: Optional[Union[int, Sequence[int]]] = ...) -> Array: ... +def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], + axis: int | Sequence[int] | None = ...) -> Array: ... def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: ... def roots(p: ArrayLike, *, strip_zeros: builtins.bool = ...) -> Array: ... def rot90(m: ArrayLike, k: int = ..., axes: tuple[int, int] = ...) -> Array: ... @@ -743,8 +744,8 @@ def setdiff1d( ar2: ArrayLike, assume_unique: builtins.bool = ..., *, - size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ..., + size: int | None = ..., + fill_value: ArrayLike | None = ..., ) -> Array: ... def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ... shape = _np.shape @@ -759,7 +760,7 @@ size = _np.size sometrue = any def sort( a: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., *, stable: builtins.bool = ..., descending: builtins.bool = ..., @@ -769,24 +770,24 @@ def sort( def sort_complex(a: ArrayLike) -> Array: ... def split( ary: ArrayLike, - indices_or_sections: Union[int, Sequence[int], ArrayLike], + indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = ..., ) -> list[Array]: ... def sqrt(x: ArrayLike, /) -> Array: ... def square(x: ArrayLike, /) -> Array: ... def squeeze( - a: ArrayLike, axis: Optional[Union[int, Sequence[int]]] = ... + a: ArrayLike, axis: int | Sequence[int] | None = ... ) -> Array: ... def stack( - arrays: Union[_np.ndarray, Array, Sequence[ArrayLike]], + arrays: _np.ndarray | Array | Sequence[ArrayLike], axis: int = ..., out: None = ..., - dtype: Optional[DTypeLike] = ..., + dtype: DTypeLike | None = ..., ) -> Array: ... def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ..., correction: int | float | None = ...) -> Array: ... + where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ... def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ... def sum( a: ArrayLike, @@ -794,53 +795,53 @@ def sum( dtype: DTypeLike = ..., out: None = ..., keepdims: builtins.bool = ..., - initial: Optional[ArrayLike] = ..., - where: Optional[ArrayLike] = ..., + initial: ArrayLike | None = ..., + where: ArrayLike | None = ..., promote_integers: builtins.bool = ..., ) -> Array: ... def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: ... def take( a: ArrayLike, indices: ArrayLike, - axis: Optional[int] = ..., + axis: int | None = ..., out: None = ..., - mode: Optional[str] = ..., + mode: str | None = ..., unique_indices: builtins.bool = ..., indices_are_sorted: builtins.bool = ..., - fill_value: Optional[StaticScalar] = ..., + fill_value: StaticScalar | None = ..., ) -> Array: ... def take_along_axis( arr: ArrayLike, indices: ArrayLike, - axis: Optional[int], - mode: Optional[Union[str, GatherScatterMode]] = ..., - fill_value: Optional[StaticScalar] = None, + axis: int | None, + mode: str | GatherScatterMode | None = ..., + fill_value: StaticScalar | None = None, ) -> Array: ... def tan(x: ArrayLike, /) -> Array: ... def tanh(x: ArrayLike, /) -> Array: ... def tensordot(a: ArrayLike, b: ArrayLike, - axes: Union[int, Sequence[int], Sequence[Sequence[int]]] = ..., + axes: int | Sequence[int] | Sequence[Sequence[int]] = ..., *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... -def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... +def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: ... def trace(a: ArrayLike, offset: int | ArrayLike = ..., axis1: int = ..., axis2: int = ..., - dtype: Optional[DTypeLike] = ..., out: None = ...) -> Array: ... -def transpose(a: ArrayLike, axes: Optional[Sequence[int]] = ...) -> Array: ... + dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... +def transpose(a: ArrayLike, axes: Sequence[int] | None = ...) -> Array: ... def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = ..., axis: int = ...) -> Array: ... def tri( - N: int, M: Optional[int] = ..., k: int = ..., dtype: DTypeLike = ... + N: int, M: int | None = ..., k: int = ..., dtype: DTypeLike = ... ) -> Array: ... def tril(m: ArrayLike, k: int = ...) -> Array: ... def tril_indices( - n: int, k: int = ..., m: Optional[int] = ... + n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: builtins.bool = ..., *, inplace: builtins.bool = ...) -> Array: ... def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ... def triu(m: ArrayLike, k: int = ...) -> Array: ... def triu_indices( - n: int, k: int = ..., m: Optional[int] = ... + n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... def triu_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... def true_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... @@ -855,8 +856,8 @@ def union1d( ar1: ArrayLike, ar2: ArrayLike, *, - size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ..., + size: int | None = ..., + fill_value: ArrayLike | None = ..., ) -> Array: ... class _UniqueAllResult(NamedTuple): values: Array @@ -870,71 +871,71 @@ class _UniqueInverseResult(NamedTuple): values: Array inverse_indices: Array def unique(ar: ArrayLike, return_index: builtins.bool = ..., return_inverse: builtins.bool = ..., - return_counts: builtins.bool = ..., axis: Optional[int] = ..., - *, equal_nan: builtins.bool = ..., size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ... + return_counts: builtins.bool = ..., axis: int | None = ..., + *, equal_nan: builtins.bool = ..., size: int | None = ..., + fill_value: ArrayLike | None = ... ): ... -def unique_all(x: ArrayLike, /, *, size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ...) -> _UniqueAllResult: ... -def unique_counts(x: ArrayLike, /, *, size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ...) -> _UniqueCountsResult: ... -def unique_inverse(x: ArrayLike, /, *, size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ...) -> _UniqueInverseResult: ... -def unique_values(x: ArrayLike, /, *, size: Optional[int] = ..., - fill_value: Optional[ArrayLike] = ...) -> Array: ... +def unique_all(x: ArrayLike, /, *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> _UniqueAllResult: ... +def unique_counts(x: ArrayLike, /, *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> _UniqueCountsResult: ... +def unique_inverse(x: ArrayLike, /, *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> _UniqueInverseResult: ... +def unique_values(x: ArrayLike, /, *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> Array: ... def unpackbits( a: ArrayLike, - axis: Optional[int] = ..., - count: Optional[ArrayLike] = ..., + axis: int | None = ..., + count: ArrayLike | None = ..., bitorder: str = ..., ) -> Array: ... def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: ... unsignedinteger = _np.unsignedinteger def unstack(x: ArrayLike , /, *, axis: int = ...) -> tuple[Array, ...]: ... -def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = ..., +def unwrap(p: ArrayLike, discont: ArrayLike | None = ..., axis: int = ..., period: ArrayLike = ...) -> Array: ... def vander( - x: ArrayLike, N: Optional[int] = ..., increasing: builtins.bool = ... + x: ArrayLike, N: int | None = ..., increasing: builtins.bool = ... ) -> Array: ... def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ..., correction: int | float | None = ...) -> Array: ... + where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ... def vdot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = ..., precision: PrecisionLike = ..., - preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ...) -> Array: ... def vsplit( - ary: ArrayLike, indices_or_sections: Union[int, ArrayLike] + ary: ArrayLike, indices_or_sections: int | ArrayLike ) -> list[Array]: ... -def vstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]], - dtype: Optional[DTypeLike] = ...) -> Array: ... +def vstack(tup: _np.ndarray | Array | Sequence[ArrayLike], + dtype: DTypeLike | None = ...) -> Array: ... @overload def where(condition: ArrayLike, x: Literal[None] = ..., y: Literal[None] = ..., - /, *, size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... + /, *, size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... ) -> tuple[Array, ...]: ... @overload def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /, *, - size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... + size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... ) -> Array: ... @overload -def where(condition: ArrayLike, x: Optional[ArrayLike] = ..., - y: Optional[ArrayLike] = ..., /, *, size: Optional[int] = ..., - fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... - ) -> Union[Array, tuple[Array, ...]]: ... - -def zeros(shape: Any, dtype: Optional[DTypeLike] = ..., - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... -def zeros_like(a: Union[ArrayLike, DuckTypedArray], - dtype: Optional[DTypeLike] = ..., +def where(condition: ArrayLike, x: ArrayLike | None = ..., + y: ArrayLike | None = ..., /, *, size: int | None = ..., + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... + ) -> Array | tuple[Array, ...]: ... + +def zeros(shape: Any, dtype: DTypeLike | None = ..., + device: _Device | _Sharding | None = ...) -> Array: ... +def zeros_like(a: ArrayLike | DuckTypedArray, + dtype: DTypeLike | None = ..., shape: Any = ..., *, - device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... + device: _Device | _Sharding | None = ...) -> Array: ... def vectorize(pyfunc, *, excluded = ..., signature = ...) -> Callable: ... diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py index d3a84c384932..84cc697d1894 100644 --- a/jax/tools/build_utils.py +++ b/jax/tools/build_utils.py @@ -107,3 +107,13 @@ def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str): ) with open(src_file, "w") as f: f.write(content) + +def update_setup_with_rocm_version(file_dir: pathlib.Path, rocm_version: str): + src_file = file_dir / "setup.py" + with open(src_file) as f: + content = f.read() + content = content.replace( + "rocm_version = 0 # placeholder", f"rocm_version = {rocm_version}" + ) + with open(src_file, "w") as f: + f.write(content) diff --git a/jax_plugins/BUILD.bazel b/jax_plugins/BUILD.bazel index 2102c6404c5a..6e2cf6aadbaf 100644 --- a/jax_plugins/BUILD.bazel +++ b/jax_plugins/BUILD.bazel @@ -17,6 +17,7 @@ licenses(["notice"]) load( "//jaxlib:jax.bzl", "if_cuda_is_configured", + "if_rocm_is_configured", "py_library_providing_imports_info", ) @@ -30,5 +31,7 @@ py_library( ":jax_plugins", ] + if_cuda_is_configured([ "//jax_plugins/cuda:cuda_plugin", + ]) + if_rocm_is_configured([ + "//jax_plugins/rocm:rocm_plugin", ]), -) \ No newline at end of file +) diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index 5215f88832b9..cd26731aa629 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -49,7 +49,7 @@ def has_ext_modules(self): author="JAX team", author_email="jax-dev@google.com", packages=[package_name], - python_requires=">=3.9", + python_requires=">=3.10", install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"], extras_require={ 'with_cuda': [ @@ -76,7 +76,6 @@ def has_ext_modules(self): license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", diff --git a/jax_plugins/rocm/BUILD.bazel b/jax_plugins/rocm/BUILD.bazel new file mode 100644 index 000000000000..08a61c786262 --- /dev/null +++ b/jax_plugins/rocm/BUILD.bazel @@ -0,0 +1,55 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +load("//jaxlib:symlink_files.bzl", "symlink_files") +load( + "//jaxlib:jax.bzl", + "if_windows", + "py_library_providing_imports_info", + "pytype_library", +) + +package( + default_applicable_licenses = [], + default_visibility = ["//:__subpackages__"], +) + +exports_files([ + "__init__.py", + "plugin_pyproject.toml", + "plugin_setup.py", + "pyproject.toml", + "setup.py", +]) + +symlink_files( + name = "pjrt_c_api_gpu_plugin", + srcs = if_windows( + ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], + ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], + ), + dst = ".", + flatten = True, +) + +py_library_providing_imports_info( + name = "rocm_plugin", + srcs = [ + "__init__.py", + ], + data = [":pjrt_c_api_gpu_plugin"], + lib_rule = pytype_library, +) diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py new file mode 100644 index 000000000000..4535f1b3bbc8 --- /dev/null +++ b/jax_plugins/rocm/__init__.py @@ -0,0 +1,91 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import importlib +import logging +import pathlib +import platform + +from jax._src.lib import xla_client +import jax._src.xla_bridge as xb + +# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without +# preinstalled jax rocm plugin packages. +for pkg_name in ['jax_rocm60_plugin', 'jaxlib']: + try: + rocm_plugin_extension = importlib.import_module( + f'{pkg_name}.rocm_plugin_extension' + ) + except ImportError: + rocm_plugin_extension = None + else: + break + +logger = logging.getLogger(__name__) + + +def _get_library_path(): + base_path = pathlib.Path(__file__).resolve().parent + installed_path = ( + base_path / 'xla_rocm_plugin.so' + ) + if installed_path.exists(): + return installed_path + + local_path = ( + base_path / 'pjrt_c_api_gpu_plugin.so' + ) + if local_path.exists(): + logger.debug( + 'Native library %s does not exist. This most likely indicates an issue' + ' with how %s was built or installed. Fallback to local test' + ' library %s', + installed_path, + __package__, + local_path, + ) + return local_path + + logger.debug( + 'WARNING: Native library %s and local test library path %s do not' + ' exist. This most likely indicates an issue with how %s was built or' + ' installed or missing src files.', + installed_path, + local_path, + __package__, + ) + return None + + +def initialize(): + path = _get_library_path() + if path is None: + return + options = xla_client.generate_pjrt_gpu_plugin_options() + options["platform_name"] = "ROCM" + c_api = xb.register_plugin( + 'rocm', priority=500, library_path=str(path), options=options + ) + if rocm_plugin_extension: + xla_client.register_custom_call_handler( + "ROCM", + functools.partial( + rocm_plugin_extension.register_custom_call_target, c_api + ), + ) + for _name, _value in rocm_plugin_extension.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM") + else: + logger.warning('rocm_plugin_extension is not found.') diff --git a/jax_plugins/rocm/plugin_pyproject.toml b/jax_plugins/rocm/plugin_pyproject.toml new file mode 100644 index 000000000000..8fe2f47af9a1 --- /dev/null +++ b/jax_plugins/rocm/plugin_pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py new file mode 100644 index 000000000000..9ccf3bf44339 --- /dev/null +++ b/jax_plugins/rocm/plugin_setup.py @@ -0,0 +1,70 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from setuptools import setup +from setuptools.dist import Distribution + +__version__ = None +rocm_version = 0 # placeholder +project_name = f"jax-rocm{rocm_version}-plugin" +package_name = f"jax_rocm{rocm_version}_plugin" + +def load_version_module(pkg_path): + spec = importlib.util.spec_from_file_location( + 'version', os.path.join(pkg_path, 'version.py')) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +_version_module = load_version_module(package_name) +__version__ = _version_module._get_version_for_build() +_cmdclass = _version_module._get_cmdclass(package_name) + +class BinaryDistribution(Distribution): + """This class makes 'bdist_wheel' include an ABI tag on the wheel.""" + + def has_ext_modules(self): + return True + +setup( + name=project_name, + version=__version__, + cmdclass=_cmdclass, + description="JAX Plugin for AMD GPUs", + long_description="", + long_description_content_type="text/markdown", + author="Ruturaj4", + author_email="Ruturaj.Vaidya@amd.com", + packages=[package_name], + python_requires=">=3.9", + install_requires=[f"jax-rocm{rocm_version}-pjrt=={__version__}"], + url="https://github.com/google/jax", + license="Apache-2.0", + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + ], + package_data={ + package_name: [ + "*", + ], + }, + zip_safe=False, + distclass=BinaryDistribution, +) diff --git a/jax_plugins/rocm/pyproject.toml b/jax_plugins/rocm/pyproject.toml new file mode 100644 index 000000000000..8fe2f47af9a1 --- /dev/null +++ b/jax_plugins/rocm/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/jax_plugins/rocm/setup.py b/jax_plugins/rocm/setup.py new file mode 100644 index 000000000000..8782676ce9a2 --- /dev/null +++ b/jax_plugins/rocm/setup.py @@ -0,0 +1,66 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from setuptools import setup, find_namespace_packages + +__version__ = None +rocm_version = 0 # placeholder +project_name = f"jax-rocm{rocm_version}-pjrt" +package_name = f"jax_plugins.xla_rocm{rocm_version}" + +def load_version_module(pkg_path): + spec = importlib.util.spec_from_file_location( + 'version', os.path.join(pkg_path, 'version.py')) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +_version_module = load_version_module(f"jax_plugins/xla_rocm{rocm_version}") +__version__ = _version_module._get_version_for_build() + +packages = find_namespace_packages( + include=[ + package_name, + f"{package_name}.*", + ] +) + +setup( + name=project_name, + version=__version__, + description="JAX XLA PJRT Plugin for AMD GPUs", + long_description="", + long_description_content_type="text/markdown", + author="Ruturaj4", + author_email="Ruturaj.Vaidya@amd.com", + packages=packages, + install_requires=[], + url="https://github.com/google/jax", + license="Apache-2.0", + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + ], + package_data={ + package_name: ["xla_rocm_plugin.so"], + }, + zip_safe=False, + entry_points={ + "jax_plugins": [ + f"xla_rocm{rocm_version} = {package_name}", + ], + }, +) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 606f9b0735c7..dc8b5148ca93 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -215,6 +215,28 @@ pybind_extension( ], ) +pybind_extension( + name = "rocm_plugin_extension", + srcs = ["rocm_plugin_extension.cc"], + module_name = "rocm_plugin_extension", + deps = [ + "//jaxlib:kernel_nanobind_helpers", + "@xla//third_party/python_runtime:headers", + "@xla//xla:status", + "@xla//xla:util", + "@xla//xla/ffi/api:c_api", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/python:py_client_gpu", + "@xla//xla/tsl/python/lib/core:numpy", + "@com_google_absl//absl/status", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + ], +) + # CPU kernels # TODO(phawkins): Remove this forwarding target. diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 87fd6954d12e..a1ce5fa4d2f7 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -42,14 +42,21 @@ _name, _value, platform="CUDA", api_version=api_version ) -try: - from .rocm import _linalg as _hip_linalg # pytype: disable=import-error +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hip_linalg = importlib.import_module( + f"{rocm_module_name}._linalg", package="jaxlib" + ) + except ImportError: + _hip_linalg = None + else: + break + +if _hip_linalg: for _name, _value in _hip_linalg.registrations().items(): xla_client.register_custom_call_target( _name, _value, platform="ROCM", api_version=1 ) -except ImportError: - _hip_linalg = None _prod = lambda xs: functools.reduce(operator.mul, xs, 1) diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index a03dc8a8a4ca..12dcacebfa46 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -41,12 +41,19 @@ for _name, _value in _cuda_prng.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") -try: - from .rocm import _prng as _hip_prng # pytype: disable=import-error +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hip_prng = importlib.import_module( + f"{rocm_module_name}._prng", package="jaxlib" + ) + except ImportError: + _hip_prng = None + else: + break + +if _hip_prng: for _name, _value in _hip_prng.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") -except ImportError: - _hip_prng = None _prod = lambda xs: functools.reduce(operator.mul, xs, 1) diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index e804068e5e6d..f9a704a79f3d 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -59,21 +59,34 @@ for _name, _value in _cusolver.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") - try: from .rocm import _blas as _hipblas # pytype: disable=import-error - for _name, _value in _hipblas.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") except ImportError: - _hipblas = None + for rocm_module_name in ["jax_rocm60_plugin"]: + try: + _hipblas = importlib.import_module(f"{rocm_module_name}._blas") + except: + _hipblas = None + else: + break -try: - from .rocm import _solver as _hipsolver # pytype: disable=import-error - for _name, _value in _hipsolver.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") -except ImportError: - _hipsolver = None +if _hipblas: + for _name, _value in _hipblas.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM") + +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hipsolver = importlib.import_module( + f"{rocm_module_name}._solver", package="jaxlib" + ) + except ImportError: + _hipsolver = None + else: + break +if _hipsolver: + for _name, _value in _hipsolver.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM") def _real_type(dtype): """Returns the real equivalent of 'dtype'.""" diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index 7b840e9595ef..84192d4d0286 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -41,11 +41,17 @@ for _name, _value in _cusparse.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") -try: - from .rocm import _sparse as _hipsparse # pytype: disable=import-error -except ImportError: - _hipsparse = None -else: +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hipsparse = importlib.import_module( + f"{rocm_module_name}._sparse", package="jaxlib" + ) + except ImportError: + _hipsparse = None + else: + break + +if _hipsparse: for _name, _value in _hipsparse.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") diff --git a/jaxlib/gpu_triton.py b/jaxlib/gpu_triton.py index e38a7feddefc..f2d37bfec03d 100644 --- a/jaxlib/gpu_triton.py +++ b/jaxlib/gpu_triton.py @@ -38,8 +38,17 @@ get_custom_call = _cuda_triton.get_custom_call get_serialized_metadata = _cuda_triton.get_serialized_metadata -try: - from .rocm import _triton as _hip_triton # pytype: disable=import-error +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hip_triton = importlib.import_module( + f"{rocm_module_name}._triton", package="jaxlib" + ) + except ImportError: + _hip_triton = None + else: + break + +if _hip_triton: xla_client.register_custom_call_target( "triton_kernel_call", _hip_triton.get_custom_call(), platform='ROCM') @@ -51,5 +60,3 @@ get_compute_capability = _hip_triton.get_compute_capability get_custom_call = _hip_triton.get_custom_call get_serialized_metadata = _hip_triton.get_serialized_metadata -except ImportError: - _hip_triton = None diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 55d58ba96092..a4da463a3447 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -48,11 +48,6 @@ jax_internal_test_harnesses_visibility = [] jax_test_util_visibility = [] loops_visibility = [] -def get_importlib_metadata(): - if HERMETIC_PYTHON_VERSION == "3.9": - return ["@pypi_importlib_metadata//:pkg"] - return [] - # TODO(vam): remove this once zstandard builds against Python 3.13 def get_zstandard(): if HERMETIC_PYTHON_VERSION == "3.13": @@ -69,7 +64,6 @@ _py_deps = { "filelock": ["@pypi_filelock//:pkg"], "flatbuffers": ["@pypi_flatbuffers//:pkg"], "hypothesis": ["@pypi_hypothesis//:pkg"], - "importlib_metadata": get_importlib_metadata(), "matplotlib": ["@pypi_matplotlib//:pkg"], "opt_einsum": ["@pypi_opt_einsum//:pkg"], "pil": ["@pypi_pillow//:pkg"], diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index 6e952b546866..74d5ef30bf09 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -398,3 +398,12 @@ py_library( ":_triton", ], ) + +py_library( + name = "gpu_only_test_deps", + # `if_rocm_is_configured` will default to `[]`. + deps = if_rocm_is_configured([ + ":rocm_gpu_support", + "//jaxlib:rocm_plugin_extension", + ]), +) diff --git a/jaxlib/rocm_plugin_extension.cc b/jaxlib/rocm_plugin_extension.cc new file mode 100644 index 000000000000..9b0743d27cd9 --- /dev/null +++ b/jaxlib/rocm_plugin_extension.cc @@ -0,0 +1,152 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include + +#include "nanobind/nanobind.h" +#include "absl/status/status.h" +#include "rocm/include/hip/hip_runtime.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/ffi/api/c_api.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/py_client_gpu.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" + +namespace nb = nanobind; + +namespace xla { +namespace { +Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, + nb::capsule fn, int api_version, + XLA_FFI_Handler_Traits traits) { + if (c_api->extension_start == nullptr) { + return Unimplemented("The plugin does not have extension."); + } + const PJRT_Extension_Base* next = + reinterpret_cast(c_api->extension_start); + while (next != nullptr && + next->type != + PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { + next = next->next; + } + if (next == nullptr) { + return Unimplemented("The plugin does not have a custom call extension."); + } + + if (traits != 0) { + return Unimplemented("The plugin does not support custom call traits."); + } + + PJRT_Gpu_Register_Custom_Call_Args args; + args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; + args.function_name = fn_name.c_str(); + args.function_name_size = nb::len(fn_name); +#if PJRT_API_GPU_EXTENSION_VERSION >= 1 + args.api_version = api_version; +#endif + args.custom_call_function = static_cast(fn.data()); + RETURN_STATUS_IF_PJRT_ERROR( + reinterpret_cast(next)->custom_call(&args), + c_api); + return OkStatus(); +} + +nb::dict Registrations() { + nb::dict dict; + dict["xla_python_gpu_callback"] = + jax::EncapsulateFunction(xla::XlaPythonGpuCallback); + return dict; +} + +std::string ToString(hipError_t result) { +#define OSTREAM_ROCM_ERROR(__name) \ + case hipError##__name: \ + return "HIP_ERROR_" #__name; + + switch (result) { + OSTREAM_ROCM_ERROR(InvalidValue) + OSTREAM_ROCM_ERROR(OutOfMemory) + OSTREAM_ROCM_ERROR(NotInitialized) + OSTREAM_ROCM_ERROR(Deinitialized) + OSTREAM_ROCM_ERROR(NoDevice) + OSTREAM_ROCM_ERROR(InvalidDevice) + OSTREAM_ROCM_ERROR(InvalidImage) + OSTREAM_ROCM_ERROR(InvalidContext) + OSTREAM_ROCM_ERROR(InvalidHandle) + OSTREAM_ROCM_ERROR(NotFound) + OSTREAM_ROCM_ERROR(NotReady) + OSTREAM_ROCM_ERROR(NoBinaryForGpu) + + // Encountered an uncorrectable ECC error during execution. + OSTREAM_ROCM_ERROR(ECCNotCorrectable) + + // Load/store on an invalid address. Must reboot all context. + case 700: + return "ROCM_ERROR_ILLEGAL_ADDRESS"; + // Passed too many / wrong arguments, too many threads for register count. + case 701: + return "ROCM_ERROR_LAUNCH_OUT_OF_RESOURCES"; + + OSTREAM_ROCM_ERROR(ContextAlreadyInUse) + OSTREAM_ROCM_ERROR(PeerAccessUnsupported) + OSTREAM_ROCM_ERROR(Unknown) // Unknown internal error to ROCM. + default: + return absl::StrCat("hipError_t(", static_cast(result), ")"); + } +} +} // namespace + +NB_MODULE(rocm_plugin_extension, m) { + tsl::ImportNumpy(); + m.def( + "register_custom_call_target", + [](nb::capsule c_api, nb::str fn_name, nb::capsule fn, + nb::str xla_platform_name, int api_version, + XLA_FFI_Handler_Traits traits) { + xla::ThrowIfError(RegisterCustomCallTarget( + static_cast(c_api.data()), fn_name, std::move(fn), + api_version, traits)); + }, + nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), + nb::arg("xla_platform_name"), nb::arg("api_version") = 0, + nb::arg("traits") = 0); + m.def("registrations", &Registrations); + m.def( + "get_device_ordinal", + [](std::intptr_t data_value) { + if (data_value == 0) { + return 0; + } + int device_ordinal; + void* data_ptr = reinterpret_cast(data_value); + hipError_t result = + hipPointerGetAttribute(static_cast(&device_ordinal), + HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + reinterpret_cast(data_ptr)); + if (result != hipSuccess) { + LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << data_ptr + << ". Error: " << ToString(result); + } + return device_ordinal; + }, + nb::arg("data_value")); +} +} // namespace xla diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 0e4f422be67a..adc3ba452111 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -59,7 +59,7 @@ def has_ext_modules(self): author='JAX team', author_email='jax-dev@google.com', packages=['jaxlib', 'jaxlib.xla_extension'], - python_requires='>=3.9', + python_requires='>=3.10', install_requires=[ 'scipy>=1.9', "scipy>=1.11.1; python_version>='3.12'", @@ -69,7 +69,6 @@ def has_ext_modules(self): url='https://github.com/google/jax', license='Apache-2.0', classifiers=[ - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 4297e1b39394..089cba21dc7b 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -93,6 +93,12 @@ py_binary( "//jax_plugins/cuda:setup.py", "//jax_plugins/cuda:__init__.py", "@local_config_cuda//cuda:cuda-nvvm", + ]) + if_rocm([ + "//jaxlib:version", + "//jaxlib/rocm:rocm_gpu_support", + "//jax_plugins/rocm:pyproject.toml", + "//jax_plugins/rocm:setup.py", + "//jax_plugins/rocm:__init__.py", ]), deps = [ "//jax/tools:build_utils", @@ -104,8 +110,8 @@ py_binary( ) py_binary( - name = "build_cuda_kernels_wheel", - srcs = ["build_cuda_kernels_wheel.py"], + name = "build_gpu_kernels_wheel", + srcs = ["build_gpu_kernels_wheel.py"], data = [ "LICENSE.txt", ] + if_cuda([ @@ -116,6 +122,12 @@ py_binary( "//jax_plugins/cuda:plugin_pyproject.toml", "//jax_plugins/cuda:plugin_setup.py", "@local_config_cuda//cuda:cuda-nvvm", + ]) + if_rocm([ + "//jaxlib:rocm_plugin_extension", + "//jaxlib:version", + "//jaxlib/rocm:rocm_gpu_support", + "//jax_plugins/rocm:plugin_pyproject.toml", + "//jax_plugins/rocm:plugin_setup.py", ]), deps = [ "//jax/tools:build_utils", diff --git a/jaxlib/tools/build_cuda_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py similarity index 63% rename from jaxlib/tools/build_cuda_kernels_wheel.py rename to jaxlib/tools/build_gpu_kernels_wheel.py index b95169eb8589..f0da3d2530b6 100644 --- a/jaxlib/tools/build_cuda_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -43,16 +43,24 @@ "--cpu", default=None, required=True, help="Target CPU architecture. Required." ) parser.add_argument( - "--cuda_version", + "--platform_version", default=None, required=True, - help="Target CUDA version. Required.", + help="Target CUDA/ROCM version. Required.", ) parser.add_argument( "--editable", action="store_true", - help="Create an 'editable' jax cuda plugin build instead of a wheel.", + help="Create an 'editable' jax cuda/rocm plugin build instead of a wheel.", ) +parser.add_argument( + "--enable-cuda", + default=False, + help="Should we build with CUDA enabled? Requires CUDA and CuDNN.") +parser.add_argument( + "--enable-rocm", + default=False, + help="Should we build with ROCM enabled?") args = parser.parse_args() r = runfiles.Create() @@ -70,7 +78,7 @@ def write_setup_cfg(sources_path, cpu): """) -def prepare_wheel( +def prepare_wheel_cuda( sources_path: pathlib.Path, *, cpu, cuda_version ): """Assembles a source tree for the cuda kernel wheel in `sources_path`.""" @@ -112,15 +120,58 @@ def prepare_wheel( ], ) +def prepare_wheel_rocm( + sources_path: pathlib.Path, *, cpu, rocm_version +): + """Assembles a source tree for the rocm kernel wheel in `sources_path`.""" + copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + + copy_runfiles( + "__main__/jax_plugins/rocm/plugin_pyproject.toml", + dst_dir=sources_path, + dst_filename="pyproject.toml", + ) + copy_runfiles( + "__main__/jax_plugins/rocm/plugin_setup.py", + dst_dir=sources_path, + dst_filename="setup.py", + ) + build_utils.update_setup_with_rocm_version(sources_path, rocm_version) + write_setup_cfg(sources_path, cpu) + + plugin_dir = sources_path / f"jax_rocm{rocm_version}_plugin" + copy_runfiles( + dst_dir=plugin_dir, + src_files=[ + f"__main__/jaxlib/rocm/_solver.{pyext}", + f"__main__/jaxlib/rocm/_blas.{pyext}", + f"__main__/jaxlib/rocm/_linalg.{pyext}", + f"__main__/jaxlib/rocm/_prng.{pyext}", + f"__main__/jaxlib/rocm/_sparse.{pyext}", + f"__main__/jaxlib/rocm/_triton.{pyext}", + f"__main__/jaxlib/rocm_plugin_extension.{pyext}", + "__main__/jaxlib/version.py", + ], + ) + # Build wheel for cuda kernels -tmpdir = tempfile.TemporaryDirectory(prefix="jax_cuda_plugin") +if args.enable_rocm: + tmpdir = tempfile.TemporaryDirectory(prefix="jax_rocm_plugin") +else: + tmpdir = tempfile.TemporaryDirectory(prefix="jax_cuda_plugin") sources_path = tmpdir.name try: os.makedirs(args.output_path, exist_ok=True) - prepare_wheel( - pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.cuda_version - ) - package_name = f"jax cuda{args.cuda_version} plugin" + if args.enable_cuda: + prepare_wheel_cuda( + pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.platform_version + ) + package_name = f"jax cuda{args.platform_version} plugin" + elif args.enable_rocm: + prepare_wheel_rocm( + pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version + ) + package_name = f"jax rocm{args.platform_version} plugin" if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index aebf199b71e4..73cb8a9e020d 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Script that builds a jax cuda plugin wheel, intended to be run via bazel run -# as part of the jax cuda plugin build process. +# Script that builds a jax cuda/rocm plugin wheel, intended to be run via bazel run +# as part of the jax cuda/rocm plugin build process. # Most users should not run this script directly; use build.py instead. @@ -49,16 +49,24 @@ "--cpu", default=None, required=True, help="Target CPU architecture. Required." ) parser.add_argument( - "--cuda_version", + "--platform_version", default=None, required=True, - help="Target CUDA version. Required.", + help="Target CUDA/ROCM version. Required.", ) parser.add_argument( "--editable", action="store_true", - help="Create an 'editable' jax cuda plugin build instead of a wheel.", + help="Create an 'editable' jax cuda/rocm plugin build instead of a wheel.", ) +parser.add_argument( + "--enable-cuda", + default=False, + help="Should we build with CUDA enabled? Requires CUDA and CuDNN.") +parser.add_argument( + "--enable-rocm", + default=False, + help="Should we build with ROCM enabled?") args = parser.parse_args() r = runfiles.Create() @@ -106,18 +114,56 @@ def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version): ) +def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): + """Assembles a source tree for the ROCm wheel in `sources_path`.""" + copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + + plugin_dir = sources_path / "jax_plugins" / f"xla_rocm{rocm_version}" + copy_runfiles( + dst_dir=sources_path, + src_files=[ + "__main__/jax_plugins/rocm/pyproject.toml", + "__main__/jax_plugins/rocm/setup.py", + ], + ) + build_utils.update_setup_with_rocm_version(sources_path, rocm_version) + write_setup_cfg(sources_path, cpu) + copy_runfiles( + dst_dir=plugin_dir, + src_files=[ + "__main__/jax_plugins/rocm/__init__.py", + "__main__/jaxlib/version.py", + ], + ) + copy_runfiles( + "__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so", + dst_dir=plugin_dir, + dst_filename="xla_rocm_plugin.so", + ) + + tmpdir = None sources_path = args.sources_path if sources_path is None: - tmpdir = tempfile.TemporaryDirectory(prefix="jaxcudapjrt") + tmpdir = tempfile.TemporaryDirectory(prefix="jaxgpupjrt") sources_path = tmpdir.name try: os.makedirs(args.output_path, exist_ok=True) - prepare_cuda_plugin_wheel( - pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.cuda_version - ) - package_name = "jax cuda plugin" + + if args.enable_cuda: + prepare_cuda_plugin_wheel( + pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.platform_version + ) + package_name = "jax cuda plugin" + elif args.enable_rocm: + prepare_rocm_plugin_wheel( + pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version + ) + package_name = "jax rocm plugin" + else: + raise ValueError("Unsupported backend. Choose either 'cuda' or 'rocm'.") + if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 9a45c047783e..68a0e0aba380 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -238,7 +238,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/cuda/_versions.{pyext}", ], ) - if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"): + if exists(f"__main__/jaxlib/rocm/_solver.{pyext}") and not skip_gpu_kernels: copy_runfiles( dst_dir=jaxlib_dir / "rocm", src_files=[ diff --git a/pyproject.toml b/pyproject.toml index be6c89d3cce8..74110060a669 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ module = [ "absl.*", "colorama.*", "filelock.*", - "importlib_metadata.*", "IPython.*", "numpy.*", "opt_einsum.*", diff --git a/setup.py b/setup.py index 82b163ee5d4f..cc2c75ab7ff4 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ def load_version_module(pkg_path): author_email='jax-dev@google.com', packages=find_packages(exclude=["examples", "jax/src/internal_test_util"]), package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]}, - python_requires='>=3.9', + python_requires='>=3.10', install_requires=[ f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', 'ml_dtypes>=0.2.0', @@ -60,10 +60,6 @@ def load_version_module(pkg_path): 'opt_einsum', 'scipy>=1.9', "scipy>=1.11.1; python_version>='3.12'", - # Required by xla_bridge.discover_pjrt_plugins for forwards compat with - # Python versions < 3.10. Can be dropped when 3.10 is the minimum - # required Python version. - 'importlib_metadata>=4.6;python_version<"3.10"', ], extras_require={ # Minimum jaxlib version; used in testing. @@ -111,7 +107,6 @@ def load_version_module(pkg_path): url='https://github.com/google/jax', license='Apache-2.0', classifiers=[ - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", diff --git a/tests/BUILD b/tests/BUILD index 4a3c05f1cac9..036946529de1 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -75,6 +75,11 @@ jax_test( }, ) +jax_test( + name = "config_test", + srcs = ["config_test.py"], +) + jax_test( name = "core_test", srcs = ["core_test.py"], diff --git a/tests/api_test.py b/tests/api_test.py index b7eb015d7b87..fb1d6f4cd0c8 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4765,6 +4765,27 @@ def f(x): f.trace(jnp.arange(8)).lower(lowering_platforms=('tpu',)) # doesn't crash + def test_no_double_dots_in_error_message(self): + @jax.jit + def f(x): + return 1 if x > 0 else 0 + + with self.assertRaisesRegex(TracerBoolConversionError, r"with shape bool\[\]\.[^\.]"): + f(0) + + def test_inlined_literals_with_error(self): + @jax.jit + def f(): + @partial(jax.jit, inline=True) + def g(): + return jnp.sin(1.) + if g() > 0: + return 1. + return 0. + + with self.assertRaisesRegex(TracerBoolConversionError, "Attempted boolean"): + f() + class RematTest(jtu.JaxTestCase): diff --git a/tests/config_test.py b/tests/config_test.py new file mode 100644 index 000000000000..0f49d988a46c --- /dev/null +++ b/tests/config_test.py @@ -0,0 +1,73 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +from jax._src import test_util as jtu +from jax._src import config + +jax.config.parse_flags_with_absl() + + +jax_test_enum_config = config.enum_state( + name='jax_test_enum_config', + enum_values=['default', 'xxx', 'yyy'], + default='default', + help='Configuration only used for tests.') + + +class ConfigTest(jtu.JaxTestCase): + def test_config_setting_via_update(self): + self.assertEqual(jax_test_enum_config.value, 'default') + + jax.config.update('jax_test_enum_config', 'xxx') + self.assertEqual(jax_test_enum_config.value, 'xxx') + + jax.config.update('jax_test_enum_config', 'yyy') + self.assertEqual(jax_test_enum_config.value, 'yyy') + + jax.config.update('jax_test_enum_config', 'default') + self.assertEqual(jax_test_enum_config.value, 'default') + + def test_config_setting_via_context(self): + self.assertEqual(jax_test_enum_config.value, 'default') + + with jax_test_enum_config('xxx'): + self.assertEqual(jax_test_enum_config.value, 'xxx') + + with jax_test_enum_config('yyy'): + self.assertEqual(jax_test_enum_config.value, 'yyy') + + self.assertEqual(jax_test_enum_config.value, 'xxx') + + self.assertEqual(jax_test_enum_config.value, 'default') + + def test_config_update_validation(self): + self.assertEqual(jax_test_enum_config.value, 'default') + with self.assertRaisesRegex(ValueError, 'new enum value must be in.*'): + jax.config.update('jax_test_enum_config', 'invalid') + # Error should raise before changing the value + self.assertEqual(jax_test_enum_config.value, 'default') + + def test_config_context_validation(self): + self.assertEqual(jax_test_enum_config.value, 'default') + with self.assertRaisesRegex(ValueError, 'new enum value must be in.*'): + with jax_test_enum_config('invalid'): + pass + self.assertEqual(jax_test_enum_config.value, 'default') + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/export_test.py b/tests/export_test.py index d93e07f1f906..7875f82b099b 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -335,7 +335,8 @@ def f(a, *, b): # a: f32[4] and b: f32[4] def test_default_export_platform(self): test_platform = jtu.device_under_test() - if test_platform == "gpu": test_platform = "cuda" + if test_platform == "gpu": + test_platform = "rocm" if jtu.is_device_rocm() else "cuda" self.assertEqual(export.default_export_platform(), test_platform) exp = export.export(jnp.sin)(1.) self.assertEqual(exp.platforms, (export.default_export_platform(),)) @@ -1402,8 +1403,11 @@ def times_n_lowering(n: int, ctx: mlir.LoweringRuleContext, mlir.register_lowering(times_2, functools.partial(times_n_lowering, 2), "cpu") - times_3 = core.Primitive("__testing_times_3") # x3 for cuda + times_3 = core.Primitive("__testing_times_3") # x3 for cuda and rocm times_3.def_abstract_eval(lambda x: x) + + mlir.register_lowering(times_3, functools.partial(times_n_lowering, 3), + "rocm") mlir.register_lowering(times_3, functools.partial(times_n_lowering, 3), "cuda") @@ -1412,22 +1416,27 @@ def times_n_lowering(n: int, ctx: mlir.LoweringRuleContext, mlir.register_lowering(times_4, functools.partial(times_n_lowering, 4), "tpu") - times_2_or_3 = core.Primitive("__testing_times_2_or_3") # x2 for cpu, x3 for cuda + times_2_or_3 = core.Primitive("__testing_times_2_or_3") # x2 for cpu, x3 for cuda and rocm times_2_or_3.def_abstract_eval(lambda x: x) mlir.register_lowering(times_2_or_3, mlir.lower_fun(times_2.bind, multiple_results=False), "cpu") + + mlir.register_lowering(times_2_or_3, + mlir.lower_fun(times_3.bind, + multiple_results=False), "rocm") mlir.register_lowering(times_2_or_3, mlir.lower_fun(times_3.bind, multiple_results=False), "cuda") - times_2_or_3_or_4 = core.Primitive("__testing_times_2_or_3_or_4") # x2 for cpu, x3 for cuda, x4 for tpu + times_2_or_3_or_4 = core.Primitive("__testing_times_2_or_3_or_4") # x2 for cpu, x3 for cuda and rocm, x4 for tpu times_2_or_3_or_4.def_abstract_eval(lambda x: x) - times_2_or_3_or_4_lowering_cpu_cuda = mlir.lower_fun(times_2_or_3.bind, + times_2_or_3_or_4_lowering_cpu_gpu = mlir.lower_fun(times_2_or_3.bind, multiple_results=False) - for platform in ["cpu", "cuda"]: + + for platform in ["cpu", "cuda", "rocm"]: mlir.register_lowering(times_2_or_3_or_4, - times_2_or_3_or_4_lowering_cpu_cuda, + times_2_or_3_or_4_lowering_cpu_gpu, platform) mlir.register_lowering(times_2_or_3_or_4, mlir.lower_fun(times_4.bind, multiple_results=False), @@ -1437,7 +1446,7 @@ def times_n_lowering(n: int, ctx: mlir.LoweringRuleContext, def f(x): return times_2_or_3_or_4.bind(x) x = np.float32(42.) - exp = export.export(f, lowering_platforms=["cpu", "cuda", "tpu"])(x) + exp = export.export(f, lowering_platforms=["cpu", "cuda", "rocm", "tpu"])(x) expected = x * np.float32(dict(cpu=2, gpu=3, tpu=4)[jtu.device_under_test()]) self.assertAllClose(exp.call(x), expected) diff --git a/tests/jet_test.py b/tests/jet_test.py index 79132174e13d..b1e2ef3f8380 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -243,6 +243,8 @@ def test_floor(self): self.unary_check(jnp.floor) @jtu.skip_on_devices("tpu") def test_ceil(self): self.unary_check(jnp.ceil) @jtu.skip_on_devices("tpu") + def test_trunc(self): self.unary_check(jnp.trunc) + @jtu.skip_on_devices("tpu") def test_round(self): self.unary_check(lax.round) @jtu.skip_on_devices("tpu") def test_sign(self): self.unary_check(lax.sign) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 4d03e46c1b32..040603555ff5 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2976,6 +2976,31 @@ def body(carry, x): hlo_text = fn.lower(init).as_text('hlo') self.assertNotIn('4,1,2,2', hlo_text) + def test_cond_vmap_forwarding_doesnt_promote(self): + def f(x, y): + x, y = jax.lax.cond( + x < 3, + lambda x, y: (x * 2, y), + lambda x, y: (x * 3, y), + x, y + ) + return x, y + + x = jnp.arange(3) + y = jnp.array(3.) + + x2, y2 = jax.vmap(f, in_axes=(0, None), out_axes=(0, None))(x, y) # don't crash + + assert x is not x2 + assert y is y2 + + def test_cond_casting(self): + x = 1.0 + identity = lambda x: x + + y = lax.cond(True, identity, identity, x) + self.assertEqual(y, x) + self.assertIsInstance(y, jax.Array) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f9493e5199b1..45cc177fbfd1 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6115,7 +6115,8 @@ def test_lax_numpy_docstrings(self): # Test that docstring wrapping & transformation didn't fail. unimplemented = ['fromfile', 'fromiter'] - aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2'] + aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', + 'amax', 'amin'] for name in dir(jnp): if name.startswith('_') or name in unimplemented: diff --git a/tests/logging_test.py b/tests/logging_test.py index 29eb5ce3559f..5a495d47d31b 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -124,7 +124,7 @@ def test_debug_logging(self): self.assertIn("Compiling ", log_output.getvalue()) # Turn off all debug logging. - with jax_debug_log_modules(None): + with jax_debug_log_modules(""): with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertEmpty(log_output.getvalue()) @@ -137,7 +137,7 @@ def test_debug_logging(self): self.assertNotIn("Compiling ", log_output.getvalue()) # Turn everything off again. - with jax_debug_log_modules(None): + with jax_debug_log_modules(""): with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertEmpty(log_output.getvalue()) diff --git a/tests/mosaic/flash_attention_test.py b/tests/mosaic/flash_attention_test.py index c93997920f27..1d15159ca44e 100644 --- a/tests/mosaic/flash_attention_test.py +++ b/tests/mosaic/flash_attention_test.py @@ -15,7 +15,6 @@ """Test different parameterizations of FlashAttention.""" import os -import sys from absl.testing import absltest, parameterized from jax._src import config @@ -23,8 +22,6 @@ # pylint: disable=g-import-not-at-top try: - if sys.version_info < (3, 10): - raise ImportError("Mosaic GPU requires Python 3.10+") # We only import this to see if Mosaic is available. import jax.experimental.mosaic.gpu # noqa: F401 except ImportError: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 91bb7b5f5195..6835c644bed7 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -16,7 +16,6 @@ from functools import partial import operator -import sys from typing import Optional from absl.testing import absltest, parameterized @@ -31,8 +30,6 @@ import jax.numpy as jnp import numpy as np try: - if sys.version_info < (3, 10): - raise ImportError("Mosaic requires Python 3.10") import jax._src.lib.mosaic_gpu # noqa: F401 HAS_MOSAIC_GPU = True except ImportError: diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index e53ef5bcc8e8..9e6f66b3a72d 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -15,15 +15,12 @@ """Test different parameterizations of a matmul.""" import os -import sys from absl.testing import absltest, parameterized from jax._src import config from jax._src import test_util as jtu import jax.numpy as jnp try: - if sys.version_info < (3, 10): - raise ImportError("Mosaic GPU requires Python 3.10+") # We only import this to see if Mosaic is available. import jax.experimental.mosaic.gpu # noqa: F401 except ImportError: diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 39d638caf071..5c659ee08ee0 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1080,7 +1080,7 @@ class PallasOpsTest(PallasTest): [jnp.abs, jnp.negative], ["int16", "int32", "int64", "float16", "float32", "float64"], ), - ([jnp.ceil, jnp.floor], ["float32", "float64"]), + ([jnp.ceil, jnp.floor], ["float32", "float64", "int32"]), ( [jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt], ["float16", "float32", "float64"], diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index ba0ad5cb02c5..680bdda5675a 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -1939,6 +1939,18 @@ def test_bcsr_concatenate(self, shape, dtype, n_batch, n_dense, dimension): if jnp.issubdtype(dtype, jnp.floating): self._CheckGradsSparse(dense_func, sparse_func, args_maker) + def test_bcoo_spdot_abstract_eval_bug(self): + # Regression test for https://github.com/google/jax/issues/21921 + lhs = sparse.BCOO( + (jnp.float32([[1]]), lax.broadcasted_iota(jnp.int32, (10, 1, 1), 0)), + shape=(10, 10)) + rhs = sparse.BCOO( + (jnp.float32([1]), jnp.int32([[3]])), + shape=(10,)) + args_maker = lambda: [lhs, rhs] + def func(lhs, rhs): + return (lhs @ rhs).todense() + self._CompileAndCheck(func, args_maker) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 277897f6366a..f7a921e00286 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "41b06ce46a795dddc5176b8fbd28cba449077349" -XLA_SHA256 = "0d4e82ddf354994405cde8fec8ee752e795bd4545bc63005dad93d89a53cb74c" +XLA_COMMIT = "c53e9f4e48f11d0da3469c85a0692db7ebd2a5d3" +XLA_SHA256 = "2bbed1d978ea5715676d7efbeffda10ac80817acb2cbf7a0df7cbceb4c75eab7" def repo(): tf_http_archive(