Skip to content

Commit

Permalink
Merge branch 'google:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
coreyjadams committed May 28, 2024
2 parents e4fd97e + ff3db9b commit 3a96e73
Show file tree
Hide file tree
Showing 258 changed files with 11,550 additions and 3,558 deletions.
76 changes: 32 additions & 44 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -217,21 +217,21 @@ build:rbe_linux --host_linkopt=-lm
build:rbe_cpu_linux_base --config=rbe_linux
build:rbe_cpu_linux_base --config=cuda_clang
build:rbe_cpu_linux_base --action_env=TF_NVCC_CLANG="1"
build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_platform//:platform"
build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_platform//:platform"
build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_platform//:platform"

build:rbe_cpu_linux_py3.9 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.9"
build:rbe_cpu_linux_py3.9 --python_path="/usr/local/bin/python3.9"
build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.10"
build:rbe_cpu_linux_py3.10 --python_path="/usr/local/bin/python3.10"
build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.11"
build:rbe_cpu_linux_py3.11 --python_path="/usr/local/bin/python3.11"
build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.12"
build:rbe_cpu_linux_py3.12 --python_path="/usr/local/bin/python3.12"
build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"

build:rbe_cpu_linux_py3.9 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.9"
build:rbe_cpu_linux_py3.9 --repo_env HERMETIC_PYTHON_VERSION="3.9"
build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10"
build:rbe_cpu_linux_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10"
build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11"
build:rbe_cpu_linux_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11"
build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.12"
build:rbe_cpu_linux_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12"

build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --config=cuda
Expand All @@ -241,27 +241,27 @@ build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda12.3_nvcc_base --config=cuda_clang
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_NVCC_CLANG="1"
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDA_VERSION=12
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDNN_VERSION=8
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDNN_VERSION=9
build:rbe_linux_cuda12.3_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12"
build:rbe_linux_cuda12.3_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda12.3_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_platform//:platform"
build:rbe_linux_cuda12.3_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_platform//:platform"
build:rbe_linux_cuda12.3_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_platform//:platform"
build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_cuda"
build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_nccl"
build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda12.3_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
build:rbe_linux_cuda12.3_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
build:rbe_linux_cuda12.3_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda"
build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_nccl"
# RBE machines have an older CUDA driver version, so we have to enable driver forward compatibility
build:rbe_linux_cuda12.3_nvcc_base --test_env=LD_LIBRARY_PATH=/usr/local/cuda/compat
build:rbe_linux_cuda12.3_nvcc_py3.9 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.9"
build:rbe_linux_cuda12.3_nvcc_py3.9 --python_path="/usr/local/bin/python3.9"
build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.10"
build:rbe_linux_cuda12.3_nvcc_py3.10 --python_path="/usr/local/bin/python3.10"
build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.11"
build:rbe_linux_cuda12.3_nvcc_py3.11 --python_path="/usr/local/bin/python3.11"
build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9_config_python3.12"
build:rbe_linux_cuda12.3_nvcc_py3.12 --python_path="/usr/local/bin/python3.12"
build:rbe_linux_cuda12.3_nvcc_py3.9 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.9"
build:rbe_linux_cuda12.3_nvcc_py3.9 --repo_env HERMETIC_PYTHON_VERSION="3.9"
build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10"
build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10"
build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11"
build:rbe_linux_cuda12.3_nvcc_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11"
build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.12"
build:rbe_linux_cuda12.3_nvcc_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12"

# These you may need to change for your own GCP project.
build:tensorflow_testing_rbe --project_id=tensorflow-testing
Expand All @@ -288,18 +288,6 @@ build:cross_compile_linux_arm64 --cpu=aarch64
build:cross_compile_linux_arm64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite

build:rbe_cross_compile_base --config=rbe
# JAX depends on some local Python headers that are configured as Genrule. They
# are present on the local host machine but not on the remote execution machine,
# leading to build failures. To resolve the issue, the following line is added
# to make sure all Genrule targets are excuted locally.
build:rbe_cross_compile_base --strategy=Genrule=standalone
# Due to the above strategy, all Genrule commands are executed locally, but the
# following actions invoke tools (E.g `flatc`, `llvm-tblgen`, etc.) that are
# only executabe on the RBE (x86) machine, so the strategy_regexp options are
# added to override and run the actions using remote strategy.
build:rbe_cross_compile_base --strategy_regexp='Generating code from table.*=remote'
build:rbe_cross_compile_base --strategy_regexp='Generating flatbuffer files.*=remote'
build:rbe_cross_compile_base --strategy_regexp='Executing genrule @llvm-project.*=remote'

# RBE cross-compile configs for Linux Aarch64
build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ jobs:
PY_COLORS: 1
run: |
pytest -n auto --tb=short docs
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/maps.py
documentation_render:
Expand Down
14 changes: 9 additions & 5 deletions .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ jobs:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu-type: ["v3-8", "v4-8", "v5e-4"]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu-type }})"
tpu: [
{type: "v3-8", cores: "4"},
{type: "v4-8", cores: "4"},
{type: "v5e-8", cores: "8"}
]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20240228
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"]
timeout-minutes: 120
defaults:
run:
Expand Down Expand Up @@ -84,7 +88,7 @@ jobs:
PY_COLORS: 1
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=4 --tb=short \
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
--maxfail=20 -m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
Expand All @@ -95,5 +99,5 @@ jobs:
curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \
--header 'Content-Type: application/json' \
--data-raw "{
'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu-type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu.type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
}"
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ repos:
files: \.py$

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.5
rev: v0.4.4
hooks:
- id: ruff

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.9.0'
rev: 'v1.10.0'
hooks:
- id: mypy
files: (jax/|tests/typing_test\.py)
exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead
additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.23, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4]
additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.27, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4]
args: [--config=pyproject.toml]

- repo: https://github.com/mwouts/jupytext
Expand Down
37 changes: 28 additions & 9 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,33 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.29

* Changes
* Computations on the CPU backend may now be dispatched asynchronously in
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* Breaking changes
* JAX now requires ml_dtypes version 0.4.0 or newer.

* Deprecations
* Removed a number of previously-deprecated APIs:
* from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
* from {mod}`jax.lax`: `tie_in`
* from {mod}`jax.nn`: `normalize`
* from {mod}`jax.interpreters.xla`: `backend_specific_translations`,
`translations`, `register_translation`, `xla_destructure`,
`TranslationRule`, `TranslationContext`, `XlaOp`.
* The ``tol`` argument of {func}`jax.numpy.linalg.matrix_rank` is being
deprecated and will soon be removed. Use `rtol` instead.
* The ``rcond`` argument of {func}`jax.numpy.linalg.pinv` is being
deprecated and will soon be removed. Use `rtol` instead.
* The deprecated `jax.config` submodule has been removed. To configure JAX
use `import jax` and then reference the config object via `jax.config`.
* {mod}`jax.random` APIs no longer accept batched keys, where previously
some did unintentionally. Going forward, we recommend explicit use of
{func}`jax.vmap` in such cases.

## jaxlib 0.4.29

* Bug fixes
* Fixes a bug where XLA sharded some concatenation operations incorrectly,
which manifested as an incorrect output for cumulative reductions (#21403).

## jax 0.4.28 (May 9, 2024)

* Bug fixes
Expand Down Expand Up @@ -114,7 +133,7 @@ Remember to align the itemized text with the first line of an item within a list
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
Previously, no copy would be made when the output array would have the same
dtype as the input array. This may result in some increased memory usage.
The default value is set to `copy=False` to preserve backwards compatability.
The default value is set to `copy=False` to preserve backwards compatibility.

## jaxlib 0.4.27 (May 7, 2024)

Expand Down Expand Up @@ -219,7 +238,7 @@ Remember to align the itemized text with the first line of an item within a list
* Changes

* JAX lowering to StableHLO does not depend on physical devices anymore.
If your primitive wraps custom_paritioning or JAX callbacks in the lowering
If your primitive wraps custom_partitioning or JAX callbacks in the lowering
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
primitive to `jax._src.dispatch.prim_requires_devices_during_lowering` set.
This is needed because custom_partitioning and JAX callbacks need physical
Expand Down Expand Up @@ -1276,7 +1295,7 @@ Changes:
* Added {func}`jax.random.ball`.
* Added {func}`jax.default_device`.
* Added a `python -m jax.collect_profile` script to manually capture program
traces as an alternative to the Tensorboard UI.
traces as an alternative to the TensorBoard UI.
* Added a `jax.named_scope` context manager that adds profiler metadata to
Python programs (similar to `jax.named_call`).
* In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit
Expand Down Expand Up @@ -2442,7 +2461,7 @@ Changes:
* Added several new rules for `jax.experimental.jet` {jax-issue}`#2537`.
* Fixed `jax.experimental.stax.BatchNorm` when `scale`/`center` isn't provided.
* Fix some missing cases of broadcasting in `jax.numpy.einsum` {jax-issue}`#2512`.
* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitray order {jax-issue}`#2597`.
* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitrary order {jax-issue}`#2597`.
* Add `batch_group_count` to `conv_general_dilated` {jax-issue}`#2635`.
* Add docstring for `test_util.check_grads` {jax-issue}`#2656`.
* Add `callback_transform` {jax-issue}`#2665`.
Expand Down
33 changes: 33 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,39 @@
load("//third_party/xla:workspace.bzl", jax_xla_workspace = "repo")
jax_xla_workspace()

# Initialize hermetic Python
load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules")
python_init_rules()

load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")
python_init_repositories(
requirements = {
"3.9": "//build:requirements_lock_3_9.txt",
"3.10": "//build:requirements_lock_3_10.txt",
"3.11": "//build:requirements_lock_3_11.txt",
"3.12": "//build:requirements_lock_3_12.txt",
"3.13": "//build:requirements_lock_3_13.txt",
},
)

load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")
python_init_toolchains()

load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip")
python_init_pip()

load("@pypi//:requirements.bzl", "install_deps")
install_deps()

# Optional, to facilitate testing against newest versions of Python
load("@xla//third_party/py:python_repo.bzl", "custom_python_interpreter")
custom_python_interpreter(
name = "python_dev",
urls = ["https://www.python.org/ftp/python/3.13.0/Python-{version}.tgz"],
strip_prefix = "Python-{version}",
version = "3.13.0a6",
)

load("@xla//:workspace4.bzl", "xla_workspace4")
xla_workspace4()

Expand Down
Loading

0 comments on commit 3a96e73

Please sign in to comment.