Skip to content

Commit

Permalink
Merge branch 'main' into add-device-kwarg-to-jnp-ops
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Jun 26, 2024
2 parents cb26958 + 66287cd commit 5b64bc5
Show file tree
Hide file tree
Showing 71 changed files with 1,673 additions and 1,172 deletions.
7 changes: 3 additions & 4 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ build:win_clang --compiler=clang-cl
build:cuda_plugin --@xla//xla/python:enable_gpu=false
build:cuda_plugin --define=xla_python_enable_gpu=false

build:rocm_plugin --@xla//xla/python:enable_gpu=false
build:rocm_plugin --define=xla_python_enable_gpu=false

# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA
Expand Down Expand Up @@ -220,8 +223,6 @@ build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylin
build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"

build:rbe_cpu_linux_py3.9 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.9"
build:rbe_cpu_linux_py3.9 --repo_env HERMETIC_PYTHON_VERSION="3.9"
build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10"
build:rbe_cpu_linux_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10"
build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11"
Expand Down Expand Up @@ -250,8 +251,6 @@ build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-
build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_nccl"
# RBE machines have an older CUDA driver version, so we have to enable driver forward compatibility
build:rbe_linux_cuda12.3_nvcc_base --test_env=LD_LIBRARY_PATH=/usr/local/cuda/compat
build:rbe_linux_cuda12.3_nvcc_py3.9 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.9"
build:rbe_linux_cuda12.3_nvcc_py3.9 --repo_env HERMETIC_PYTHON_VERSION="3.9"
build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10"
build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10"
build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11"
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
name: CI

# We test all supported Python versions as follows:
# - 3.9 : Documentation build
# - 3.9 : Part of Matrix with NumPy dispatch
# - 3.10 : Documentation build
# - 3.10 : Part of Matrix with NumPy dispatch
# - 3.10 : Part of Matrix
# - 3.11 : Part of Matrix

Expand Down Expand Up @@ -45,8 +45,8 @@ jobs:
matrix:
# Test the oldest and newest supported Python versions here.
include:
- name-prefix: "with 3.9"
python-version: "3.9"
- name-prefix: "with 3.10"
python-version: "3.10"
os: ubuntu-20.04-16core
enable-x64: 1
prng-upgrade: 1
Expand Down Expand Up @@ -108,7 +108,7 @@ jobs:
timeout-minutes: 10
strategy:
matrix:
python-version: [3.9]
python-version: ['3.10']
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
Expand Down Expand Up @@ -150,7 +150,7 @@ jobs:
timeout-minutes: 10
strategy:
matrix:
python-version: [3.9]
python-version: ['3.10']
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/wheel_win_x64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
os: [windows-2019-32core]
arch: [AMD64]
pyver: ['3.9', '3.10', '3.11', '3.12']
pyver: ['3.10', '3.11', '3.12']
name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build
runs-on: ${{ matrix.os }}

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/windows_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
matrix:
os: [windows-2019-32core]
arch: [AMD64]
pyver: ['3.9']
pyver: ['3.10']
name: Windows CI build
runs-on: ${{ matrix.os }}

Expand Down
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.31

* Changes
* The minimum NumPy version is now 1.24.
* The minimum Python version is now 3.10. 3.10 will remain the minimum
supported version until July 2025.
* The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum
supported version until December 2024.
* {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output
of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.

## jaxlib 0.4.31

Expand Down
1 change: 0 additions & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ python_init_rules()
load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")
python_init_repositories(
requirements = {
"3.9": "//build:requirements_lock_3_9.txt",
"3.10": "//build:requirements_lock_3_10.txt",
"3.11": "//build:requirements_lock_3_11.txt",
"3.12": "//build:requirements_lock_3_12.txt",
Expand Down
67 changes: 45 additions & 22 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def get_python_version(python_bin_path):
return major, minor

def check_python_version(python_version):
if python_version < (3, 9):
print("ERROR: JAX requires Python 3.9 or newer, found ", python_version)
if python_version < (3, 10):
print("ERROR: JAX requires Python 3.10 or newer, found ", python_version)
sys.exit(-1)


Expand Down Expand Up @@ -318,7 +318,10 @@ def write_bazelrc(*, remote_build,
if not enable_nccl:
f.write("build --config=nonccl\n")
if build_gpu_plugin:
f.write("build --config=cuda_plugin\n")
if enable_cuda:
f.write("build --config=cuda_plugin\n")
elif enable_rocm:
f.write("build --config=rocm_plugin\n")
if python_version:
f.write(
"build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format(
Expand Down Expand Up @@ -431,21 +434,21 @@ def main():
"plugin is still experimental and is not ready for use yet."
),
)
add_boolean_argument(
parser,
"build_cuda_kernel_plugin",
default=False,
help_str=(
"Are we building the cuda kernel plugin? jaxlib will not be built "
"when this flag is True."
parser.add_argument(
"--build_gpu_kernel_plugin",
choices=["cuda", "rocm"],
default="",
help=(
"Specify 'cuda' or 'rocm' to build the respective kernel plugin."
" When this flag is set, jaxlib will not be built."
),
)
add_boolean_argument(
parser,
"build_cuda_pjrt_plugin",
"build_gpu_pjrt_plugin",
default=False,
help_str=(
"Are we building the cuda pjrt plugin? jaxlib will not be built "
"Are we building the cuda/rocm pjrt plugin? jaxlib will not be built "
"when this flag is True."
),
)
Expand All @@ -454,6 +457,11 @@ def main():
choices=["11", "12"],
default="12",
help="Which CUDA major version the gpu plugin is for.")
parser.add_argument(
"--gpu_plugin_rocm_version",
choices=["60"],
default="60",
help="Which ROCM major version the gpu plugin is for.")
add_boolean_argument(
parser,
"enable_rocm",
Expand Down Expand Up @@ -676,7 +684,7 @@ def main():
*args.bazel_options,
)

if not args.build_cuda_kernel_plugin and not args.build_cuda_pjrt_plugin:
if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin:
build_cpu_wheel_command = [
*command_base,
"//jaxlib/tools:build_wheel", "--",
Expand All @@ -691,29 +699,44 @@ def main():
print(" ".join(build_cpu_wheel_command))
shell(build_cpu_wheel_command)

if args.build_gpu_plugin or args.build_cuda_kernel_plugin:
build_cuda_kernels_command = [
if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \
(args.build_gpu_kernel_plugin == "rocm"):
build_gpu_kernels_command = [
*command_base,
"//jaxlib/tools:build_cuda_kernels_wheel", "--",
"//jaxlib/tools:build_gpu_kernels_wheel", "--",
f"--output_path={output_path}",
f"--jaxlib_git_hash={get_githash()}",
f"--cpu={wheel_cpu}",
f"--cuda_version={args.gpu_plugin_cuda_version}"
]
if args.enable_cuda:
build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}")
build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_cuda_version}")
elif args.enable_rocm:
build_gpu_kernels_command.append(f"--enable-rocm={args.enable_rocm}")
build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_rocm_version}")
else:
raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.")
if args.editable:
build_cuda_kernels_command.append("--editable")
print(" ".join(build_cuda_kernels_command))
shell(build_cuda_kernels_command)
build_gpu_kernels_command.append("--editable")
print(" ".join(build_gpu_kernels_command))
shell(build_gpu_kernels_command)

if args.build_gpu_plugin or args.build_cuda_pjrt_plugin:
if args.build_gpu_plugin or args.build_gpu_pjrt_plugin:
build_pjrt_plugin_command = [
*command_base,
"//jaxlib/tools:build_gpu_plugin_wheel", "--",
f"--output_path={output_path}",
f"--jaxlib_git_hash={get_githash()}",
f"--cpu={wheel_cpu}",
f"--cuda_version={args.gpu_plugin_cuda_version}"
]
if args.enable_cuda:
build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}")
build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_cuda_version}")
elif args.enable_rocm:
build_pjrt_plugin_command.append(f"--enable-rocm={args.enable_rocm}")
build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_rocm_version}")
else:
raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.")
if args.editable:
build_pjrt_plugin_command.append("--editable")
print(" ".join(build_pjrt_plugin_command))
Expand Down
1 change: 0 additions & 1 deletion build/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ numpy~=2.0.0
#
scipy~=1.13.1

importlib_metadata; python_version<"3.10"
ml_dtypes>=0.4.0
opt_einsum
zstandard
Expand Down
Loading

0 comments on commit 5b64bc5

Please sign in to comment.