Skip to content

Commit

Permalink
Merge branch 'google:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
coreyjadams committed Jun 27, 2024
2 parents 5a91ac3 + 98b8754 commit 6701bd1
Show file tree
Hide file tree
Showing 503 changed files with 23,727 additions and 10,438 deletions.
11 changes: 3 additions & 8 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ build:win_clang --compiler=clang-cl
build:cuda_plugin --@xla//xla/python:enable_gpu=false
build:cuda_plugin --define=xla_python_enable_gpu=false

build:rocm_plugin --@xla//xla/python:enable_gpu=false
build:rocm_plugin --define=xla_python_enable_gpu=false

# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA
Expand All @@ -116,10 +119,6 @@ build:cuda_clang --copt=-Wno-gnu-offsetof-extensions
# Disable clang extention that rejects unknown arguments.
build:cuda_clang --copt=-Qunused-arguments

build:mosaic_gpu --@llvm-project//mlir:enable_cuda=true
build:mosaic_gpu --copt=-DLLVM_HAS_NVPTX_TARGET=1
build:mosaic_gpu --//jax:build_mosaic_gpu=true

build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --@xla//xla/python:enable_gpu=true
Expand Down Expand Up @@ -224,8 +223,6 @@ build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylin
build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"

build:rbe_cpu_linux_py3.9 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.9"
build:rbe_cpu_linux_py3.9 --repo_env HERMETIC_PYTHON_VERSION="3.9"
build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10"
build:rbe_cpu_linux_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10"
build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11"
Expand Down Expand Up @@ -254,8 +251,6 @@ build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-
build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_nccl"
# RBE machines have an older CUDA driver version, so we have to enable driver forward compatibility
build:rbe_linux_cuda12.3_nvcc_base --test_env=LD_LIBRARY_PATH=/usr/local/cuda/compat
build:rbe_linux_cuda12.3_nvcc_py3.9 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.9"
build:rbe_linux_cuda12.3_nvcc_py3.9 --repo_env HERMETIC_PYTHON_VERSION="3.9"
build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10"
build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10"
build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11"
Expand Down
15 changes: 8 additions & 7 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
name: CI

# We test all supported Python versions as follows:
# - 3.9 : Documentation build
# - 3.9 : Part of Matrix with NumPy dispatch
# - 3.10 : Documentation build
# - 3.10 : Part of Matrix with NumPy dispatch
# - 3.10 : Part of Matrix
# - 3.11 : Part of Matrix

Expand Down Expand Up @@ -45,8 +45,8 @@ jobs:
matrix:
# Test the oldest and newest supported Python versions here.
include:
- name-prefix: "with 3.9"
python-version: "3.9"
- name-prefix: "with 3.10"
python-version: "3.10"
os: ubuntu-20.04-16core
enable-x64: 1
prng-upgrade: 1
Expand Down Expand Up @@ -108,7 +108,7 @@ jobs:
timeout-minutes: 10
strategy:
matrix:
python-version: [3.9]
python-version: ['3.10']
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
Expand Down Expand Up @@ -136,10 +136,11 @@ jobs:
- name: Test documentation
env:
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
JAX_TRACEBACK_FILTERING: "off"
JAX_ARRAY: 1
PY_COLORS: 1
run: |
pytest -n auto --tb=short docs
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/maps.py
Expand All @@ -149,7 +150,7 @@ jobs:
timeout-minutes: 10
strategy:
matrix:
python-version: [3.9]
python-version: ['3.10']
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
with:
repository: data-apis/array-api-tests
# TODO(jakevdp) update this to a stable release/tag when available.
ref: 'e38ce3466e596ed2b8fa4638f161f5563ded81a8' # Latest commit as of 2024-04-15
ref: '33f2d2ea2f3dd2b3ceeeb4519d55e08096184149' # Latest commit as of 2024-05-28
submodules: 'true'
path: 'array-api-tests'
- name: Set up Python ${{ matrix.python-version }}
Expand Down
7 changes: 3 additions & 4 deletions .github/workflows/wheel_win_x64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
os: [windows-2019-32core]
arch: [AMD64]
pyver: ['3.9', '3.10', '3.11', '3.12']
pyver: ['3.10', '3.11', '3.12']
name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build
runs-on: ${{ matrix.os }}

Expand All @@ -24,7 +24,7 @@ jobs:
access_token: ${{ github.token }}

- name: Install LLVM/Clang
run: choco install llvm --version=18.1.4 --yes
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade

- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4

Expand All @@ -39,8 +39,7 @@ jobs:
JAXLIB_RELEASE: true
run: |
python -m pip install -r build/test-requirements.txt
python -m pip uninstall -y matplotlib
python -m pip install --pre --upgrade numpy==2.0.0rc2 scipy==1.13.0
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py `
--bazel_options=--color=yes `
Expand Down
18 changes: 7 additions & 11 deletions .github/workflows/windows_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@ on:
schedule:
- cron: "0 12 * * *" # Daily at 12:00 UTC
workflow_dispatch: # allows triggering the workflow run manually
pull_request:
types: [ labeled ] # allow force-windows-run label

env:
DISTUTILS_USE_SDK: 1
MSSdk: 1

jobs:
win-wheels:
if: ${{ (github.event.action != 'labeled') || (github.event.label.name == 'windows:force-run')}}
strategy:
fail-fast: true
matrix:
os: [windows-2019-32core]
arch: [AMD64]
pyver: ['3.9']
name: ${{ matrix.os }} CI build
pyver: ['3.10']
name: Windows CI build
runs-on: ${{ matrix.os }}

steps:
Expand All @@ -26,17 +29,12 @@ jobs:
access_token: ${{ github.token }}

- name: Install LLVM/Clang
run: choco install llvm --version=18.1.4 --yes
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade

- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
with:
path: jax

- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
with:
repository: openxla/xla
path: xla

- uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5
with:
python-version: ${{ matrix.pyver }}
Expand All @@ -49,11 +47,9 @@ jobs:
run: |
cd jax
python -m pip install -r build/test-requirements.txt
python -m pip uninstall -y matplotlib
python -m pip install --pre --upgrade numpy==2.0.0rc2 scipy==1.13.0
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py `
('--bazel_options=--override_repository=xla=${{ github.workspace }}\xla' -replace '\\','\\') `
--bazel_options=--color=yes `
--bazel_options=--config=win_clang
Expand Down
101 changes: 97 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,88 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
Remember to align the itemized text with the first line of an item within a list.
-->

## jax 0.4.29
## jax 0.4.31

* Breaking changes
* Changes
* The minimum Python version is now 3.10. 3.10 will remain the minimum
supported version until July 2025.
* The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum
supported version until December 2024.
* {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output
of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
* `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be
installed either as a part of local CUDA installation, or via NVIDIA's CUDA
pip wheels.

## jaxlib 0.4.31

* Bug fixes
* Fixed a bug that meant that negative static_argnums to a jit were mishandled
by the jit dispatch fast path.

## jax 0.4.30 (June 18, 2024)

* Changes
* JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was
bumped to 0.4.0 but this has been rolled back in this release to give users
of both TensorFlow and JAX more time to migrate to a newer TensorFlow
release.
* `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e.
* jax now depends on jaxlib directly. This change was enabled by the CUDA
plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with `pip install jax`, no extras required.
* Added an API for exporting and serializing JAX functions. This used
to exist in `jax.experimental.export` (which is being deprecated),
and will now live in `jax.export`.
See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html).

* Deprecations
* Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
in a future release.
* Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX
release. This previously was the case, but there was an inadvertent regression in
the last several JAX releases.
* `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead.
See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export).
* Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays
`x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`.
* `jax.xla_computation` is deprecated and will be removed in a future release.
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
* `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
`jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
* You can also use `.out_info` property of `jax.stages.Lowered` to get the
output information (like tree structure, shape and dtype).
* For cross-backend lowering, you can replace
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.


## jaxlib 0.4.30 (June 18, 2024)

* Support for monolithic CUDA jaxlibs has been dropped. You must use the
plugin-based installation (`pip install jax[cuda12]` or
`pip install jax[cuda12_local]`).

## jax 0.4.29 (June 10, 2024)

* Changes
* We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g. `pip install jax[cuda12]`).
* JAX now requires ml_dtypes version 0.4.0 or newer.
* Removed backwards-compatibility support for old usage of the
`jax.experimental.export` API. It is not possible anymore to use
`from jax.experimental.export import export`, and instead you should use
`from jax.experimental import export`.
The removed functionality has been deprecated since 0.4.24.
* Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`.

* Deprecations
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use
`jax.sharding.Sharding`.
* `jax.experimental.Exported.in_shardings` has been renamed as
`jax.experimental.Exported.in_shardings_hlo`. Same for `out_shardings`.
The old names will be removed after 3 months.
* Removed a number of previously-deprecated APIs:
* from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
* from {mod}`jax.lax`: `tie_in`
Expand All @@ -28,12 +104,29 @@ Remember to align the itemized text with the first line of an item within a list
* {mod}`jax.random` APIs no longer accept batched keys, where previously
some did unintentionally. Going forward, we recommend explicit use of
{func}`jax.vmap` in such cases.
* In {func}`jax.scipy.special.beta`, the `x` and `y` parameters have been
renamed to `a` and `b` for consistency with other `beta` APIs.

## jaxlib 0.4.29
* New Functionality
* Added {func}`jax.experimental.Exported.in_shardings_jax` to construct
shardings that can be used with the JAX APIs from the HloShardings
that are stored in the `Exported` objects.

## jaxlib 0.4.29 (June 10, 2024)

* Bug fixes
* Fixes a bug where XLA sharded some concatenation operations incorrectly,
* Fixed a bug where XLA sharded some concatenation operations incorrectly,
which manifested as an incorrect output for cumulative reductions (#21403).
* Fixed a bug where XLA:CPU miscompiled certain matmul fusions
(https://github.com/openxla/xla/pull/13301).
* Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396).

* Deprecations
* `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will
raise an error in a future version of jax. `None` is only a tree-prefix of
itself. To preserve the current behavior, you can ask `jax.tree.map` to
treat `None` as a leaf value by writing:
`jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.

## jax 0.4.28 (May 9, 2024)

Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ Some standouts:

| Hardware | Instructions |
|------------|-----------------------------------------------------------------------------------------------------------------|
| CPU | `pip install -U "jax[cpu]"` |
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12]"` |
| CPU | `pip install -U jax` |
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
Expand All @@ -414,7 +414,8 @@ community-supported conda build, and answers to some frequently-asked questions.
Multiple Google research groups develop and share libraries for training neural
networks in JAX. If you want a fully featured library for neural network
training with examples and how-to guides, try
[Flax](https://github.com/google/flax).
[Flax](https://github.com/google/flax). Check out the new [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) API for a
simplified development experience.

Google X maintains the neural network library
[Equinox](https://github.com/patrick-kidger/equinox). This is used as the
Expand Down
4 changes: 3 additions & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ python_init_rules()
load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")
python_init_repositories(
requirements = {
"3.9": "//build:requirements_lock_3_9.txt",
"3.10": "//build:requirements_lock_3_10.txt",
"3.11": "//build:requirements_lock_3_11.txt",
"3.12": "//build:requirements_lock_3_12.txt",
"3.13": "//build:requirements_lock_3_13.txt",
},
local_wheel_workspaces = ["//jaxlib:jax.bzl"],
local_wheel_dist_folder = "../dist",
default_python_version = "system",
)

load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")
Expand Down
Loading

0 comments on commit 6701bd1

Please sign in to comment.