From 5a732ad89f4e3a61b0b65374581392285a682045 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Fri, 1 Mar 2024 09:14:11 -0800 Subject: [PATCH 01/35] Adding script to convert NVIDIA nsys profiles to pbtxt --- jax/tools/pgo_nsys_converter.py | 45 +++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 jax/tools/pgo_nsys_converter.py diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py new file mode 100644 index 000000000000..95d860343991 --- /dev/null +++ b/jax/tools/pgo_nsys_converter.py @@ -0,0 +1,45 @@ +import csv +import re +import sys +import argparse +import psutil +import os +import shutil +import subprocess + +nsys_path = shutil.which("nsys") + +parser = argparse.ArgumentParser(description='Tool to sweep for optimal collective combiner threshold') +parser.add_argument("--profile_path", type=str, help="path to nsys profile") +parser.add_argument("--post_process", help="post process pbtxt to get minimum cost value for each instruction", action="store_true") +parser.add_argument("--pgle_output_path", type=str, help="output directory", default="/opt/paxml/workspace/lhs_pbtxt/temp.pbtxt") + +args = parser.parse_args() + +pgle_filename = os.path.basename(args.pgle_output_path).partition('.')[0] +pgle_folder = os.path.join(os.path.split(args.pgle_output_path)[0], '') +profile_folder = os.path.join(os.path.split(args.profile_path)[0], '') + +stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", "nvtxkernsum", f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] + +print(f""" + ******Starting stats command****** + {stats_command}.""") + +proc = subprocess.Popen(stats_command, stdout=sys.stdout, stderr=sys.stderr) +proc.wait() + +thunk_re = re.compile("hlo_op=(.*)#") +cost_dictionary = dict() +with open(f"{args.pgle_output_path}", 'w', newline='') as protofile: + with open(f"{pgle_folder}{pgle_filename}.pbtxt_nvtxkernsum.csv", newline='') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + name = row['NVTX Range'] + time_ns = float(row['Avg (ns)']) + m = thunk_re.search(name) + if m is not None: + protofile.write(f'costs {{ name: "{m.group(1)}" cost_us: {time_ns / 1000.0} }}\n') + +clean_command = f"rm {profile_folder}/*.sqlite; rm {pgle_folder}/*.csv" +subprocess.call(clean_command, shell=True) From 2480ca383ee27bef2be064d6688f6fd6612db644 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Mon, 4 Mar 2024 11:42:01 -0800 Subject: [PATCH 02/35] respond to reviewer's comments --- jax/tools/pgo_nsys_converter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 95d860343991..67b2b5f62629 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -7,9 +7,11 @@ import shutil import subprocess +print("Script to convert NVIDIA Nsys Profiles to the .pbtxt format. This format is readable by XLA's Profile Guided Latency Estimator. Usage: pgo_nsys_converter.py --profile_path --pgle_output_path ") + nsys_path = shutil.which("nsys") -parser = argparse.ArgumentParser(description='Tool to sweep for optimal collective combiner threshold') +parser = argparse.ArgumentParser(description='Tool to convert NVIDIA Nsys Profiles to the .pbtxt format') parser.add_argument("--profile_path", type=str, help="path to nsys profile") parser.add_argument("--post_process", help="post process pbtxt to get minimum cost value for each instruction", action="store_true") parser.add_argument("--pgle_output_path", type=str, help="output directory", default="/opt/paxml/workspace/lhs_pbtxt/temp.pbtxt") From a7a9f85535cb565ea43a8a371f8d360dee9f5e63 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Thu, 7 Mar 2024 11:55:55 -0800 Subject: [PATCH 03/35] Added license information --- jax/tools/pgo_nsys_converter.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 67b2b5f62629..4d3dfd3db23c 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -1,3 +1,17 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import csv import re import sys From 98f790f5d524372d45ca86bc8da07ca498d72b3b Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Mon, 14 Aug 2023 12:59:09 -0700 Subject: [PATCH 04/35] update package/API reference docs to new-style typed PRNG keys --- docs/jax.nn.initializers.rst | 4 ++-- jax/_src/api.py | 2 +- jax/_src/core.py | 4 ++-- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/nn/initializers.py | 26 ++++++++++++------------ jax/_src/random.py | 4 ++-- jax/random.py | 32 +++++++++++++++++------------- 7 files changed, 39 insertions(+), 35 deletions(-) diff --git a/docs/jax.nn.initializers.rst b/docs/jax.nn.initializers.rst index d96ba43f0d49..246e0cdbe9a1 100644 --- a/docs/jax.nn.initializers.rst +++ b/docs/jax.nn.initializers.rst @@ -14,8 +14,8 @@ consistent with definitions used in Keras and Sonnet. An initializer is a function that takes three arguments: ``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and -data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random -key used when generating random numbers to initialize the array. +data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from +:func:`jax.random.key`), used to generate random numbers to initialize the array. .. autosummary:: :toctree: _autosummary diff --git a/jax/_src/api.py b/jax/_src/api.py index 005e0ceca3a1..2fecc4fd78db 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -280,7 +280,7 @@ def jit( ... def selu(x, alpha=1.67, lmbda=1.05): ... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha) >>> - >>> key = jax.random.PRNGKey(0) + >>> key = jax.random.key(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) # doctest: +SKIP [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 diff --git a/jax/_src/core.py b/jax/_src/core.py index c351e6980bf7..b207b39e5376 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1277,7 +1277,7 @@ def f(x): @jax.jit def jax_fn(x): with jax.ensure_compile_time_eval(): - y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100) + y = random.randint(random.key(0), (1000,1000), 0, 100) y2 = y @ y x2 = jnp.sum(y2) * x return x2 @@ -1285,7 +1285,7 @@ def jax_fn(x): A similar behavior can often be achieved simply by 'hoisting' the constant expression out of the corresponding staging API:: - y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100) + y = random.randint(random.key(0), (1000,1000), 0, 100) @jax.jit def jax_fn(x): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index e939f2170f8a..0c6ae2fe022a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2101,7 +2101,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): Example 2: partial products of an array of matrices - >>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2)) + >>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2)) >>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods.shape (4, 2, 2) diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 5468fd663181..d7353c396ae6 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -62,7 +62,7 @@ def zeros(key: KeyArray, The ``key`` argument is ignored. >>> import jax, jax.numpy as jnp - >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32) + >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32) """ @@ -77,7 +77,7 @@ def ones(key: KeyArray, The ``key`` argument is ignored. >>> import jax, jax.numpy as jnp - >>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32) + >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32) @@ -96,7 +96,7 @@ def constant(value: ArrayLike, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.constant(-7) - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[-7., -7., -7.], [-7., -7., -7.]], dtype=float32) """ @@ -122,7 +122,7 @@ def uniform(scale: RealNumeric = 1e-2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.uniform(10.0) - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[7.298188 , 8.691938 , 8.7230015], [2.0818567, 1.8662417, 5.5022564]], dtype=float32) """ @@ -148,7 +148,7 @@ def normal(stddev: RealNumeric = 1e-2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.normal(5.0) - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 3.0613258 , 5.6129413 , 5.6866574 ], [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) """ @@ -376,7 +376,7 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_uniform() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.50350785, 0.8088631 , 0.81566876], [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32) @@ -414,7 +414,7 @@ def glorot_normal(in_axis: int | Sequence[int] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_normal() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.41770416, 0.75262755, 0.7619329 ], [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32) @@ -452,7 +452,7 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_uniform() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.56293887, 0.90433645, 0.9119454 ], [-0.71479625, -0.7676109 , 0.12302713]], dtype=float32) @@ -488,7 +488,7 @@ def lecun_normal(in_axis: int | Sequence[int] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_normal() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.46700746, 0.8414632 , 0.8518669 ], [-0.61677957, -0.67402434, 0.09683388]], dtype=float32) @@ -524,7 +524,7 @@ def he_uniform(in_axis: int | Sequence[int] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_uniform() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.79611576, 1.2789248 , 1.2896855 ], [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32) @@ -562,7 +562,7 @@ def he_normal(in_axis: int | Sequence[int] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_normal() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 0.6604483 , 1.1900088 , 1.2047218 ], [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32) @@ -595,7 +595,7 @@ def orthogonal(scale: RealNumeric = 1.0, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.orthogonal() - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32) """ @@ -638,7 +638,7 @@ def delta_orthogonal( >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.delta_orthogonal() - >>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32) # doctest: +SKIP Array([[[ 0. , 0. , 0. ], [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]], diff --git a/jax/_src/random.py b/jax/_src/random.py index f9045ebf3135..2696b488ab40 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -244,7 +244,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: """Folds in data to a PRNG key to form a new PRNG key. Args: - key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``). + key: a PRNG key (from ``key``, ``split``, ``fold_in``). data: a 32bit integer representing data to be folded in to the key. Returns: @@ -274,7 +274,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: """Splits a PRNG key into `num` new keys by adding a leading axis. Args: - key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``). + key: a PRNG key (from ``key``, ``split``, ``fold_in``). num: optional, a positive integer (or tuple of integers) indicating the number (or shape) of keys to produce. Defaults to 2. diff --git a/jax/random.py b/jax/random.py index c06f48b3583b..f65ec58580a4 100644 --- a/jax/random.py +++ b/jax/random.py @@ -22,24 +22,25 @@ >>> seed = 1701 >>> num_steps = 100 ->>> key = jax.random.PRNGKey(seed) +>>> key = jax.random.key(seed) >>> for i in range(num_steps): ... key, subkey = jax.random.split(key) ... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP -PRNG Keys +PRNG keys --------- Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to be passed as a first argument. -The random state is described by two unsigned 32-bit integers that we call a **key**, -usually generated by the :py:func:`jax.random.PRNGKey` function:: +The random state is described by a special array element type that we call a **key**, +usually generated by the :py:func:`jax.random.key` function:: >>> from jax import random - >>> key = random.PRNGKey(0) + >>> key = random.key(0) >>> key - Array([0, 0], dtype=uint32) + Array((), dtype=key) overlaying: + [0 0] This key can then be used in any of JAX's random number generation routines:: @@ -60,8 +61,8 @@ Advanced -------- -Design and Context -================== +Design and background +===================== **TLDR**: JAX PRNG = `Threefry counter PRNG `_ + a functional array-oriented `splitting model `_ @@ -79,16 +80,19 @@ Advanced RNG configuration ========================== -JAX provides several PRNG implementations (controlled by the -`jax_default_prng_impl` flag). +JAX provides several PRNG implementations. A specific one can be +selected with the optional `impl` keyword argument to +`jax.random.key`. When no `impl` option is passed to the `key` +constructor, the implementation is determined by the global +`jax_default_prng_impl` configuration flag. -- **default** +- **default**, `"threefry2x32"`: `A counter-based PRNG built around the Threefry hash function `_. - *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See `TF doc `_. - - "rbg" uses ThreeFry for splitting, and XLA RBG for data generation. - - "unsafe_rbg" exists only for demonstration purposes, using RBG both for + - `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation. + - `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for splitting (using an untested made up algorithm) and generating. The random streams generated by these experimental implementations haven't @@ -126,7 +130,7 @@ less safe in the sense that the quality of random streams it generates from different keys is less well understood. -For more about jax_threefry_partitionable, see +For more about `jax_threefry_partitionable`, see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers """ From 0bdbe763aac157a1d36b69f2d7ad5c2aceeb51c3 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Mon, 14 Aug 2023 16:31:17 -0700 Subject: [PATCH 05/35] update documentation notes to new-style typed keys --- docs/Custom_Operation_for_GPUs.md | 4 ++-- docs/async_dispatch.rst | 2 +- docs/device_memory_profiling.md | 6 +++--- docs/profiling.md | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index 44cdaf1f15e8..38be1683449c 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -304,7 +304,7 @@ per_core_batch_size=4 seq_len=512 emb_dim=512 x = jax.random.normal( - jax.random.PRNGKey(0), + jax.random.key(0), shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim), dtype=jnp.bfloat16, ) @@ -1049,7 +1049,7 @@ per_core_batch_size=4 seq_len=512 emb_dim=512 x = jax.random.normal( - jax.random.PRNGKey(0), + jax.random.key(0), shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim), dtype=jnp.bfloat16, ) diff --git a/docs/async_dispatch.rst b/docs/async_dispatch.rst index 421a27b0b67b..00fb98a5185d 100644 --- a/docs/async_dispatch.rst +++ b/docs/async_dispatch.rst @@ -9,7 +9,7 @@ program: >>> import numpy as np >>> import jax.numpy as jnp >>> from jax import random ->>> x = random.uniform(random.PRNGKey(0), (1000, 1000)) +>>> x = random.uniform(random.key(0), (1000, 1000)) >>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`) >>> # will block until the value is ready. >>> jnp.dot(x, x) + 3. # doctest: +SKIP diff --git a/docs/device_memory_profiling.md b/docs/device_memory_profiling.md index 57b2894a53a0..a6f27e9e9710 100644 --- a/docs/device_memory_profiling.md +++ b/docs/device_memory_profiling.md @@ -59,7 +59,7 @@ def func2(x): y = func1(x) return y, jnp.tile(x, 10) + 1 -x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000)) +x = jax.random.normal(jax.random.key(42), (1000, 1000)) y, z = func2(x) z.block_until_ready() @@ -107,14 +107,14 @@ import jax.numpy as jnp import jax.profiler def afunction(): - return jax.random.normal(jax.random.PRNGKey(77), (1000000,)) + return jax.random.normal(jax.random.key(77), (1000000,)) z = afunction() def anotherfunc(): arrays = [] for i in range(1, 10): - x = jax.random.normal(jax.random.PRNGKey(42), (i, 10000)) + x = jax.random.normal(jax.random.key(42), (i, 10000)) arrays.append(x) x.block_until_ready() jax.profiler.save_device_memory_profile(f"memory{i}.prof") diff --git a/docs/profiling.md b/docs/profiling.md index 86b539c0a7fb..24567850d22c 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -11,7 +11,7 @@ check out the Tensorboard profiler below. ```python with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): # Run the operations to be profiled - key = jax.random.PRNGKey(0) + key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() @@ -107,7 +107,7 @@ import jax jax.profiler.start_trace("/tmp/tensorboard") # Run the operations to be profiled -key = jax.random.PRNGKey(0) +key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() @@ -126,7 +126,7 @@ alternative to `start_trace` and `stop_trace`: import jax with jax.profiler.trace("/tmp/tensorboard"): - key = jax.random.PRNGKey(0) + key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() From 75a53f46e072d16be95f4a432ae7d8e9791e21a8 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 17 Aug 2023 17:33:42 -0700 Subject: [PATCH 06/35] update notebooks to new-style typed PRNG keys --- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 8 ++++---- docs/notebooks/Common_Gotchas_in_JAX.md | 8 ++++---- ...tributed_arrays_and_automatic_parallelization.ipynb | 10 +++++----- ...Distributed_arrays_and_automatic_parallelization.md | 10 +++++----- docs/notebooks/Neural_Network_and_Data_Loading.ipynb | 6 +++--- docs/notebooks/Neural_Network_and_Data_Loading.md | 6 +++--- .../notebooks/Writing_custom_interpreters_in_Jax.ipynb | 2 +- docs/notebooks/Writing_custom_interpreters_in_Jax.md | 2 +- docs/notebooks/autodiff_cookbook.ipynb | 8 ++++---- docs/notebooks/autodiff_cookbook.md | 8 ++++---- docs/notebooks/convolutions.ipynb | 4 ++-- docs/notebooks/convolutions.md | 4 ++-- docs/notebooks/neural_network_with_tfds_data.ipynb | 6 +++--- docs/notebooks/neural_network_with_tfds_data.md | 6 +++--- docs/notebooks/quickstart.ipynb | 2 +- docs/notebooks/quickstart.md | 2 +- docs/notebooks/vmapped_log_probs.ipynb | 2 +- docs/notebooks/vmapped_log_probs.md | 2 +- 18 files changed, 48 insertions(+), 48 deletions(-) diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index cd89e425956d..6aa5944d6cc2 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -1006,7 +1006,7 @@ "source": [ "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", "\n", - "The random state is described by two unsigned-int32s that we call a __key__:" + "The random state is described by a special array element that we call a __key__:" ] }, { @@ -1030,7 +1030,7 @@ ], "source": [ "from jax import random\n", - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "key" ] }, @@ -2121,7 +2121,7 @@ } ], "source": [ - "x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n", + "x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n", "x.dtype" ] }, @@ -2188,7 +2188,7 @@ "source": [ "import jax.numpy as jnp\n", "from jax import random\n", - "x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n", + "x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n", "x.dtype # --> dtype('float64')" ] }, diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 93695cb65c75..4fba1c4448ec 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -463,14 +463,14 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. -The random state is described by two unsigned-int32s that we call a __key__: +The random state is described by a special array element that we call a __key__: ```{code-cell} ipython3 :id: yPHE7KTWgAWs :outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3 from jax import random -key = random.PRNGKey(0) +key = random.key(0) key ``` @@ -1071,7 +1071,7 @@ At the moment, JAX by default enforces single-precision numbers to mitigate the :id: CNNGtzM3NDkO :outputId: b422bb23-a784-44dc-f8c9-57f3b6c861b8 -x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64) +x = random.uniform(random.key(0), (1000,), dtype=jnp.float64) x.dtype ``` @@ -1117,7 +1117,7 @@ We can then confirm that `x64` mode is enabled: import jax.numpy as jnp from jax import random -x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64) +x = random.uniform(random.key(0), (1000,), dtype=jnp.float64) x.dtype # --> dtype('float64') ``` diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index b6ea0ffe14e2..d54b0153508a 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -131,7 +131,7 @@ ], "source": [ "# Create an array of random values:\n", - "x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n", + "x = jax.random.normal(jax.random.key(0), (8192, 8192))\n", "# and use jax.device_put to distribute it across devices:\n", "y = jax.device_put(x, sharding.reshape(4, 2))\n", "jax.debug.visualize_array_sharding(y)" @@ -272,7 +272,7 @@ "outputs": [], "source": [ "import jax\n", - "x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))" + "x = jax.random.normal(jax.random.key(0), (8192, 8192))" ] }, { @@ -1513,7 +1513,7 @@ }, "outputs": [], "source": [ - "x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n", + "x = jax.random.normal(jax.random.key(0), (8192, 8192))\n", "x = jax.device_put(x, sharding.reshape(4, 2))" ] }, @@ -1738,7 +1738,7 @@ "layer_sizes = [784, 8192, 8192, 8192, 10]\n", "batch_size = 8192\n", "\n", - "params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)" + "params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)" ] }, { @@ -2184,7 +2184,7 @@ " numbers = jax.random.uniform(key, x.shape)\n", " return x + numbers\n", "\n", - "key = jax.random.PRNGKey(42)\n", + "key = jax.random.key(42)\n", "x_sharding = jax.sharding.PositionalSharding(jax.devices())\n", "x = jax.device_put(jnp.arange(24), x_sharding)" ] diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index 256e2622410f..ba7ec1a372af 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -81,7 +81,7 @@ sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) :outputId: 3b518df8-5c29-4848-acc3-e41df939f30b # Create an array of random values: -x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192)) +x = jax.random.normal(jax.random.key(0), (8192, 8192)) # and use jax.device_put to distribute it across devices: y = jax.device_put(x, sharding.reshape(4, 2)) jax.debug.visualize_array_sharding(y) @@ -144,7 +144,7 @@ For example, here's a value with a single-device `Sharding`: :id: VmoX4SUp3vGJ import jax -x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192)) +x = jax.random.normal(jax.random.key(0), (8192, 8192)) ``` ```{code-cell} @@ -609,7 +609,7 @@ sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) ```{code-cell} :id: Q1wuDp-L3vGT -x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192)) +x = jax.random.normal(jax.random.key(0), (8192, 8192)) x = jax.device_put(x, sharding.reshape(4, 2)) ``` @@ -720,7 +720,7 @@ def init_model(key, layer_sizes, batch_size): layer_sizes = [784, 8192, 8192, 8192, 10] batch_size = 8192 -params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size) +params, batch = init_model(jax.random.key(0), layer_sizes, batch_size) ``` +++ {"id": "sJv_h0AS2drh"} @@ -902,7 +902,7 @@ def f(key, x): numbers = jax.random.uniform(key, x.shape) return x + numbers -key = jax.random.PRNGKey(42) +key = jax.random.key(42) x_sharding = jax.sharding.PositionalSharding(jax.devices()) x = jax.device_put(jnp.arange(24), x_sharding) ``` diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index fb0ac165be16..1a920c45c45f 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -84,7 +84,7 @@ "num_epochs = 8\n", "batch_size = 128\n", "n_targets = 10\n", - "params = init_network_params(layer_sizes, random.PRNGKey(0))" + "params = init_network_params(layer_sizes, random.key(0))" ] }, { @@ -150,7 +150,7 @@ ], "source": [ "# This works on single examples\n", - "random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n", + "random_flattened_image = random.normal(random.key(1), (28 * 28,))\n", "preds = predict(params, random_flattened_image)\n", "print(preds.shape)" ] @@ -173,7 +173,7 @@ ], "source": [ "# Doesn't work with a batch\n", - "random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n", + "random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n", "try:\n", " preds = predict(params, random_flattened_images)\n", "except TypeError:\n", diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index ebbb6da3d107..aef4dd3982fb 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -71,7 +71,7 @@ step_size = 0.01 num_epochs = 8 batch_size = 128 n_targets = 10 -params = init_network_params(layer_sizes, random.PRNGKey(0)) +params = init_network_params(layer_sizes, random.key(0)) ``` +++ {"id": "BtoNk_yxWtIw"} @@ -109,7 +109,7 @@ Let's check that our prediction function only works on single images. :outputId: 9d3b29e8-fab3-4ecb-9f63-bc8c092f9006 # This works on single examples -random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,)) +random_flattened_image = random.normal(random.key(1), (28 * 28,)) preds = predict(params, random_flattened_image) print(preds.shape) ``` @@ -119,7 +119,7 @@ print(preds.shape) :outputId: d5d20211-b6da-44e9-f71e-946f2a9d0fc4 # Doesn't work with a batch -random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28)) +random_flattened_images = random.normal(random.key(1), (10, 28 * 28)) try: preds = predict(params, random_flattened_images) except TypeError: diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index aac86fd8b710..1a1a77eb9ee3 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -66,7 +66,7 @@ }, "outputs": [], "source": [ - "x = random.normal(random.PRNGKey(0), (5000, 5000))\n", + "x = random.normal(random.key(0), (5000, 5000))\n", "def f(w, b, x):\n", " return jnp.tanh(jnp.dot(x, w) + b)\n", "fast_f = jit(f)" diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index af4379b03802..49be088e2db2 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -48,7 +48,7 @@ JAX provides a NumPy-like API for numerical computing which can be used as is, b ```{code-cell} ipython3 :id: HmlMcICOcSXR -x = random.normal(random.PRNGKey(0), (5000, 5000)) +x = random.normal(random.key(0), (5000, 5000)) def f(w, b, x): return jnp.tanh(jnp.dot(x, w) + b) fast_f = jit(f) diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index 39aad749de47..d0b8fe0c0c23 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -27,7 +27,7 @@ "from jax import grad, jit, vmap\n", "from jax import random\n", "\n", - "key = random.PRNGKey(0)" + "key = random.key(0)" ] }, { @@ -1055,7 +1055,7 @@ " outs, = vmap(vjp_fun)(M)\n", " return outs\n", "\n", - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "num_covecs = 128\n", "U = random.normal(key, (num_covecs,) + y.shape)\n", "\n", @@ -1306,7 +1306,7 @@ "outputs": [], "source": [ "def check(seed):\n", - " key = random.PRNGKey(seed)\n", + " key = random.key(seed)\n", "\n", " # random coeffs for u and v\n", " key, subkey = random.split(key)\n", @@ -1399,7 +1399,7 @@ "outputs": [], "source": [ "def check(seed):\n", - " key = random.PRNGKey(seed)\n", + " key = random.key(seed)\n", "\n", " # random coeffs for u and v\n", " key, subkey = random.split(key)\n", diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index fa50db4f2194..59b2a382914a 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -29,7 +29,7 @@ import jax.numpy as jnp from jax import grad, jit, vmap from jax import random -key = random.PRNGKey(0) +key = random.key(0) ``` +++ {"id": "YxnjtAGN6vu2"} @@ -614,7 +614,7 @@ def vmap_mjp(f, x, M): outs, = vmap(vjp_fun)(M) return outs -key = random.PRNGKey(0) +key = random.key(0) num_covecs = 128 U = random.normal(key, (num_covecs,) + y.shape) @@ -770,7 +770,7 @@ Here's a check: :id: BGZV__zupIMS def check(seed): - key = random.PRNGKey(seed) + key = random.key(seed) # random coeffs for u and v key, subkey = random.split(key) @@ -833,7 +833,7 @@ Here's a check of the VJP rules: :id: 4J7edvIBttcU def check(seed): - key = random.PRNGKey(seed) + key = random.key(seed) # random coeffs for u and v key, subkey = random.split(key) diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb index f8dcaa3685ad..c4ef1961bbd2 100644 --- a/docs/notebooks/convolutions.ipynb +++ b/docs/notebooks/convolutions.ipynb @@ -60,7 +60,7 @@ "import jax.numpy as jnp\n", "import numpy as np\n", "\n", - "key = random.PRNGKey(1701)\n", + "key = random.key(1701)\n", "\n", "x = jnp.linspace(0, 10, 500)\n", "y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))\n", @@ -130,7 +130,7 @@ "ax[0].set_title('original')\n", "\n", "# Create a noisy version by adding random Gaussian noise\n", - "key = random.PRNGKey(1701)\n", + "key = random.key(1701)\n", "noisy_image = image + 50 * random.normal(key, image.shape)\n", "ax[1].imshow(noisy_image, cmap='binary_r')\n", "ax[1].set_title('noisy')\n", diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 5d34ef950021..2dda3610eb14 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -43,7 +43,7 @@ from jax import random import jax.numpy as jnp import numpy as np -key = random.PRNGKey(1701) +key = random.key(1701) x = jnp.linspace(0, 10, 500) y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,)) @@ -84,7 +84,7 @@ ax[0].imshow(image, cmap='binary_r') ax[0].set_title('original') # Create a noisy version by adding random Gaussian noise -key = random.PRNGKey(1701) +key = random.key(1701) noisy_image = image + 50 * random.normal(key, image.shape) ax[1].imshow(noisy_image, cmap='binary_r') ax[1].set_title('noisy') diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index 8a14bb7bbec9..ef06e26b3048 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -97,7 +97,7 @@ "num_epochs = 10\n", "batch_size = 128\n", "n_targets = 10\n", - "params = init_network_params(layer_sizes, random.PRNGKey(0))" + "params = init_network_params(layer_sizes, random.key(0))" ] }, { @@ -163,7 +163,7 @@ ], "source": [ "# This works on single examples\n", - "random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n", + "random_flattened_image = random.normal(random.key(1), (28 * 28,))\n", "preds = predict(params, random_flattened_image)\n", "print(preds.shape)" ] @@ -186,7 +186,7 @@ ], "source": [ "# Doesn't work with a batch\n", - "random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n", + "random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n", "try:\n", " preds = predict(params, random_flattened_images)\n", "except TypeError:\n", diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index b3c2be06bffb..197c02ad4beb 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -79,7 +79,7 @@ step_size = 0.01 num_epochs = 10 batch_size = 128 n_targets = 10 -params = init_network_params(layer_sizes, random.PRNGKey(0)) +params = init_network_params(layer_sizes, random.key(0)) ``` +++ {"id": "BtoNk_yxWtIw"} @@ -117,7 +117,7 @@ Let's check that our prediction function only works on single images. :outputId: ce9d86ed-a830-4832-e04d-10d1abb1fb8a # This works on single examples -random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,)) +random_flattened_image = random.normal(random.key(1), (28 * 28,)) preds = predict(params, random_flattened_image) print(preds.shape) ``` @@ -127,7 +127,7 @@ print(preds.shape) :outputId: f43bbc9d-bc8f-4168-ee7b-79ee9d33f245 # Doesn't work with a batch -random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28)) +random_flattened_images = random.normal(random.key(1), (10, 28 * 28)) try: preds = predict(params, random_flattened_images) except TypeError: diff --git a/docs/notebooks/quickstart.ipynb b/docs/notebooks/quickstart.ipynb index 722047adadbd..7987a8bcf5f4 100644 --- a/docs/notebooks/quickstart.ipynb +++ b/docs/notebooks/quickstart.ipynb @@ -81,7 +81,7 @@ } ], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, (10,))\n", "print(x)" ] diff --git a/docs/notebooks/quickstart.md b/docs/notebooks/quickstart.md index 46c37f208a8f..6f2ecf03102b 100644 --- a/docs/notebooks/quickstart.md +++ b/docs/notebooks/quickstart.md @@ -59,7 +59,7 @@ We'll be generating random data in the following examples. One big difference be :id: u0nseKZNqOoH :outputId: 03e20e21-376c-41bb-a6bb-57431823691b -key = random.PRNGKey(0) +key = random.key(0) x = random.normal(key, (10,)) print(x) ``` diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb index 4ee2e4924d53..833c3a40d145 100644 --- a/docs/notebooks/vmapped_log_probs.ipynb +++ b/docs/notebooks/vmapped_log_probs.ipynb @@ -483,7 +483,7 @@ "\n", "normal_sample = jax.jit(normal_sample, static_argnums=(1,))\n", "\n", - "key = random.PRNGKey(10003)\n", + "key = random.key(10003)\n", "\n", "beta_loc = jnp.zeros(num_features, jnp.float32)\n", "beta_log_scale = jnp.zeros(num_features, jnp.float32)\n", diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md index 8b6d7ceeb61a..19e498db4362 100644 --- a/docs/notebooks/vmapped_log_probs.md +++ b/docs/notebooks/vmapped_log_probs.md @@ -210,7 +210,7 @@ def normal_sample(key, shape): normal_sample = jax.jit(normal_sample, static_argnums=(1,)) -key = random.PRNGKey(10003) +key = random.key(10003) beta_loc = jnp.zeros(num_features, jnp.float32) beta_log_scale = jnp.zeros(num_features, jnp.float32) From 7fe44c1228dfc704150492205bd06babf45708b8 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 17 Aug 2023 17:33:42 -0700 Subject: [PATCH 07/35] update stax to new-style typed keys --- jax/example_libraries/README.md | 2 +- jax/example_libraries/stax.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/example_libraries/README.md b/jax/example_libraries/README.md index edcfcb506f8e..349b7e12e56b 100644 --- a/jax/example_libraries/README.md +++ b/jax/example_libraries/README.md @@ -44,7 +44,7 @@ net_init, net_apply = stax.serial( ) # Initialize parameters, not committing to a batch shape -rng = random.PRNGKey(0) +rng = random.key(0) in_shape = (-1, 28, 28, 1) out_shape, net_params = net_init(rng, in_shape) diff --git a/jax/example_libraries/stax.py b/jax/example_libraries/stax.py index e5bf38c6a69f..476252d92d5d 100644 --- a/jax/example_libraries/stax.py +++ b/jax/example_libraries/stax.py @@ -268,7 +268,7 @@ def apply_fun(params, inputs, **kwargs): msg = ("Dropout layer requires apply_fun to be called with a PRNG key " "argument. That is, instead of `apply_fun(params, inputs)`, call " "it like `apply_fun(params, inputs, rng)` where `rng` is a " - "jax.random.PRNGKey value.") + "PRNG key (e.g. from `jax.random.key`).") raise ValueError(msg) if mode == 'train': keep = random.bernoulli(rng, rate, inputs.shape) From 753a6b2b0c9439e9e89c20669cdd16eaa0d9dadd Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 17 Aug 2023 17:33:43 -0700 Subject: [PATCH 08/35] update `experimental.sparse` to new-style typed keys --- jax/experimental/sparse/random.py | 2 +- jax/experimental/sparse/transform.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/experimental/sparse/random.py b/jax/experimental/sparse/random.py index b22d466dcf5a..f90c2572d282 100644 --- a/jax/experimental/sparse/random.py +++ b/jax/experimental/sparse/random.py @@ -29,7 +29,7 @@ def random_bcoo(key, shape, *, dtype=jnp.float_, indices_dtype=None, """Generate a random BCOO matrix. Args: - key : random.PRNGKey to be passed to ``generator`` function. + key : PRNG key to be passed to ``generator`` function. shape : tuple specifying the shape of the array to be generated. dtype : dtype of the array to be generated. indices_dtype: dtype of the BCOO indices. diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 55460118de81..19f4ca736ec8 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -27,9 +27,9 @@ >>> from jax import random >>> from jax.experimental.sparse import BCOO, sparsify ->>> mat = random.uniform(random.PRNGKey(1701), (5, 5)) +>>> mat = random.uniform(random.key(1701), (5, 5)) >>> mat = mat.at[mat < 0.5].set(0) ->>> vec = random.uniform(random.PRNGKey(42), (5,)) +>>> vec = random.uniform(random.key(42), (5,)) >>> def f(mat, vec): ... return -(jnp.sin(mat) @ vec) From 78fd4f1664fd11e04828409c5e7e864310ef8523 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 17 Aug 2023 17:33:43 -0700 Subject: [PATCH 09/35] update top-level examples to use new-style typed keys --- examples/advi.py | 4 ++-- examples/differentially_private_sgd.py | 2 +- examples/examples_test.py | 2 +- examples/gaussian_process_regression.py | 2 +- examples/mnist_classifier.py | 2 +- examples/mnist_vae.py | 6 +++--- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/advi.py b/examples/advi.py index 35ee94a58d2f..68092b2cf74a 100644 --- a/examples/advi.py +++ b/examples/advi.py @@ -78,7 +78,7 @@ def funnel_log_density(params): @jit def objective(params, t): - rng = random.PRNGKey(t) + rng = random.key(t) return -batch_elbo(funnel_log_density, rng, params, num_samples) # Set up figure. @@ -107,7 +107,7 @@ def callback(params, t): # Plot random samples from variational distribution. # Here we clone the rng used in computing the objective # so that we can show exactly the same samples. - rngs = random.split(random.PRNGKey(t), num_samples) + rngs = random.split(random.key(t), num_samples) samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params) ax.plot(samples[:, 0], samples[:, 1], 'b.') diff --git a/examples/differentially_private_sgd.py b/examples/differentially_private_sgd.py index 4777554b1127..ca368098d243 100644 --- a/examples/differentially_private_sgd.py +++ b/examples/differentially_private_sgd.py @@ -182,7 +182,7 @@ def main(_): num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value) num_batches = num_complete_batches + bool(leftover) - key = random.PRNGKey(_SEED.value) + key = random.key(_SEED.value) def data_stream(): rng = npr.RandomState(_SEED.value) diff --git a/examples/examples_test.py b/examples/examples_test.py index e2ca51d78155..b8b4d11e273d 100644 --- a/examples/examples_test.py +++ b/examples/examples_test.py @@ -35,7 +35,7 @@ def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape): - jax_rng = random.PRNGKey(0) + jax_rng = random.key(0) result_shape, params = init_fun(jax_rng, input_shape) result = apply_fun(params, test_case.rng.normal(size=input_shape).astype("float32")) test_case.assertEqual(result.shape, result_shape) diff --git a/examples/gaussian_process_regression.py b/examples/gaussian_process_regression.py index 070943b72413..c42a024d42aa 100644 --- a/examples/gaussian_process_regression.py +++ b/examples/gaussian_process_regression.py @@ -30,7 +30,7 @@ def main(unused_argv): numpts = 7 - key = random.PRNGKey(0) + key = random.key(0) eye = jnp.eye(numpts) def cov_map(cov_func, xs, xs2=None): diff --git a/examples/mnist_classifier.py b/examples/mnist_classifier.py index a0fe8b996c98..a7730ab2b6aa 100644 --- a/examples/mnist_classifier.py +++ b/examples/mnist_classifier.py @@ -50,7 +50,7 @@ def accuracy(params, batch): Dense(10), LogSoftmax) if __name__ == "__main__": - rng = random.PRNGKey(0) + rng = random.key(0) step_size = 0.001 num_epochs = 10 diff --git a/examples/mnist_vae.py b/examples/mnist_vae.py index 141be978f635..df207afd8749 100644 --- a/examples/mnist_vae.py +++ b/examples/mnist_vae.py @@ -87,14 +87,14 @@ def image_grid(nrow, ncol, imagevecs, imshape): batch_size = 32 nrow, ncol = 10, 10 # sampled image grid size - test_rng = random.PRNGKey(1) # fixed prng key for evaluation + test_rng = random.key(1) # fixed prng key for evaluation imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png") train_images, _, test_images, _ = datasets.mnist(permute_train=True) num_complete_batches, leftover = divmod(train_images.shape[0], batch_size) num_batches = num_complete_batches + bool(leftover) - enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2)) + enc_init_rng, dec_init_rng = random.split(random.key(2)) _, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28)) _, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10)) init_params = init_encoder_params, init_decoder_params @@ -131,7 +131,7 @@ def evaluate(opt_state, images): opt_state = opt_init(init_params) for epoch in range(num_epochs): tic = time.time() - opt_state = run_epoch(random.PRNGKey(epoch), opt_state, train_images) + opt_state = run_epoch(random.key(epoch), opt_state, train_images) test_elbo, sampled_images = evaluate(opt_state, test_images) print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)") plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray) From 0fb408571cb716583e0cfed3a5ca11fb4410b275 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 17 Aug 2023 17:33:43 -0700 Subject: [PATCH 10/35] update jax-101 to new-style typed PRNG keys --- docs/jax-101/05-random-numbers.ipynb | 10 +++++----- docs/jax-101/05-random-numbers.md | 10 +++++----- docs/jax-101/06-parallelism.ipynb | 2 +- docs/jax-101/06-parallelism.md | 2 +- docs/jax-101/07-state.ipynb | 4 ++-- docs/jax-101/07-state.md | 4 ++-- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/jax-101/05-random-numbers.ipynb b/docs/jax-101/05-random-numbers.ipynb index 65977674a062..ba246095aa7f 100644 --- a/docs/jax-101/05-random-numbers.ipynb +++ b/docs/jax-101/05-random-numbers.ipynb @@ -282,7 +282,7 @@ "source": [ "from jax import random\n", "\n", - "key = random.PRNGKey(42)\n", + "key = random.key(42)\n", "\n", "print(key)" ] @@ -293,7 +293,7 @@ "id": "XhFpKnW9F2nF" }, "source": [ - "A key is just an array of shape `(2,)`.\n", + "A single key is an array of scalar shape `()` and key element type.\n", "\n", "'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:" ] @@ -381,7 +381,7 @@ "source": [ "`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.\n", "\n", - "If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n", + "If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n", "\n", "It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.\n", "\n", @@ -460,12 +460,12 @@ } ], "source": [ - "key = random.PRNGKey(42)\n", + "key = random.key(42)\n", "subkeys = random.split(key, 3)\n", "sequence = np.stack([random.normal(subkey) for subkey in subkeys])\n", "print(\"individually:\", sequence)\n", "\n", - "key = random.PRNGKey(42)\n", + "key = random.key(42)\n", "print(\"all at once: \", random.normal(key, shape=(3,)))" ] }, diff --git a/docs/jax-101/05-random-numbers.md b/docs/jax-101/05-random-numbers.md index f9f3ae178efe..c8fc02a81db8 100644 --- a/docs/jax-101/05-random-numbers.md +++ b/docs/jax-101/05-random-numbers.md @@ -150,14 +150,14 @@ To avoid this issue, JAX does not use a global state. Instead, random functions from jax import random -key = random.PRNGKey(42) +key = random.key(42) print(key) ``` +++ {"id": "XhFpKnW9F2nF"} -A key is just an array of shape `(2,)`. +A single key is an array of scalar shape `()` and key element type. 'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated: @@ -201,7 +201,7 @@ key = new_key # If we wanted to do this again, we would use new_key as the key. `split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever. -If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it. +If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it. It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later. @@ -240,12 +240,12 @@ In the example below, sampling 3 values out of a normal distribution individuall :id: 4nB_TA54D-HT :outputId: 2f259f63-3c45-46c8-f597-4e53dc63cb56 -key = random.PRNGKey(42) +key = random.key(42) subkeys = random.split(key, 3) sequence = np.stack([random.normal(subkey) for subkey in subkeys]) print("individually:", sequence) -key = random.PRNGKey(42) +key = random.key(42) print("all at once: ", random.normal(key, shape=(3,))) ``` diff --git a/docs/jax-101/06-parallelism.ipynb b/docs/jax-101/06-parallelism.ipynb index 86aa7bd4260c..9f952ced7599 100644 --- a/docs/jax-101/06-parallelism.ipynb +++ b/docs/jax-101/06-parallelism.ipynb @@ -623,7 +623,7 @@ "ys = xs * true_w + true_b + noise\n", "\n", "# Initialise parameters and replicate across devices.\n", - "params = init(jax.random.PRNGKey(123))\n", + "params = init(jax.random.key(123))\n", "n_devices = jax.local_device_count()\n", "replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)" ] diff --git a/docs/jax-101/06-parallelism.md b/docs/jax-101/06-parallelism.md index f69301bcb74c..7c4b38b38a18 100644 --- a/docs/jax-101/06-parallelism.md +++ b/docs/jax-101/06-parallelism.md @@ -291,7 +291,7 @@ noise = 0.5 * np.random.normal(size=(128, 1)) ys = xs * true_w + true_b + noise # Initialise parameters and replicate across devices. -params = init(jax.random.PRNGKey(123)) +params = init(jax.random.key(123)) n_devices = jax.local_device_count() replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params) ``` diff --git a/docs/jax-101/07-state.ipynb b/docs/jax-101/07-state.ipynb index c7d75abf66b3..3393563136e9 100644 --- a/docs/jax-101/07-state.ipynb +++ b/docs/jax-101/07-state.ipynb @@ -249,7 +249,7 @@ "\n", "In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?\n", "\n", - "Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey." + "Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key." ] }, { @@ -351,7 +351,7 @@ "source": [ "import matplotlib.pyplot as plt\n", "\n", - "rng = jax.random.PRNGKey(42)\n", + "rng = jax.random.key(42)\n", "\n", "# Generate true data from y = w*x + b + noise\n", "true_w, true_b = 2, -1\n", diff --git a/docs/jax-101/07-state.md b/docs/jax-101/07-state.md index bd2d2aa390d9..527856e8e612 100644 --- a/docs/jax-101/07-state.md +++ b/docs/jax-101/07-state.md @@ -166,7 +166,7 @@ Notice that the need for a class becomes less clear once we have rewritten it th In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class? -Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey. +Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key. +++ {"id": "I2SqRx14_z98"} @@ -233,7 +233,7 @@ Notice that we manually pipe the params in and out of the update function. import matplotlib.pyplot as plt -rng = jax.random.PRNGKey(42) +rng = jax.random.key(42) # Generate true data from y = w*x + b + noise true_w, true_b = 2, -1 From fe3e798f82977759ff53fa32794f0ed7cf8787f3 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 17 Aug 2023 17:33:43 -0700 Subject: [PATCH 11/35] update pallas quickstart to new-style typed PRNG keys --- docs/pallas/quickstart.ipynb | 6 +++--- docs/pallas/quickstart.md | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index ee5fd44ed6f6..60fdf6314d1e 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -338,7 +338,7 @@ " lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n", " )\n", " )(x, y)\n", - "k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n", + "k1, k2 = jax.random.split(jax.random.key(0))\n", "x = jax.random.normal(k1, (1024, 1024))\n", "y = jax.random.normal(k2, (1024, 1024))\n", "z = matmul(x, y)\n", @@ -376,7 +376,7 @@ " lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n", " ),\n", " )(x, y)\n", - "k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n", + "k1, k2 = jax.random.split(jax.random.key(0))\n", "x = jax.random.normal(k1, (1024, 1024))\n", "y = jax.random.normal(k2, (1024, 1024))\n", "z = matmul(x, y, activation=jax.nn.relu)\n", @@ -397,7 +397,7 @@ "metadata": {}, "outputs": [], "source": [ - "k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n", + "k1, k2 = jax.random.split(jax.random.key(0))\n", "x = jax.random.normal(k1, (4, 1024, 1024))\n", "y = jax.random.normal(k2, (4, 1024, 1024))\n", "z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)\n", diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index 61e68ef1ea9a..e3d24b155acf 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -226,7 +226,7 @@ def matmul(x: jax.Array, y: jax.Array): lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2) ) )(x, y) -k1, k2 = jax.random.split(jax.random.PRNGKey(0)) +k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.normal(k1, (1024, 1024)) y = jax.random.normal(k2, (1024, 1024)) z = matmul(x, y) @@ -253,7 +253,7 @@ def matmul(x: jax.Array, y: jax.Array, *, activation): lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2) ), )(x, y) -k1, k2 = jax.random.split(jax.random.PRNGKey(0)) +k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.normal(k1, (1024, 1024)) y = jax.random.normal(k2, (1024, 1024)) z = matmul(x, y, activation=jax.nn.relu) @@ -263,7 +263,7 @@ np.testing.assert_allclose(z, jax.nn.relu(x @ y)) To conclude, let's highlight a cool feature of Pallas: it composes with `jax.vmap`! To turn this matrix multiplication into a batched version, we just need to `vmap` it. ```{code-cell} ipython3 -k1, k2 = jax.random.split(jax.random.PRNGKey(0)) +k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.normal(k1, (4, 1024, 1024)) y = jax.random.normal(k2, (4, 1024, 1024)) z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y) From 721ca3f7145d36fb14b5efec76dd22c7528d0ecf Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Fri, 22 Sep 2023 11:31:36 -0700 Subject: [PATCH 12/35] add key array upgrade note to `jax.random` module doc --- jax/random.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/jax/random.py b/jax/random.py index f65ec58580a4..91cdeb7ac656 100644 --- a/jax/random.py +++ b/jax/random.py @@ -58,6 +58,42 @@ >>> random.uniform(subkey) Array(0.10536897, dtype=float32) +.. note:: + + Typed key arrays, with element types such as ``key`` above, + were introduced in JAX v0.4.16. Before then, keys were + conventionally represented in ``uint32`` arrays, whose final + dimension represented the key's bit-level representation. + + Both forms of key array can still be created and used with the + :mod:`jax.random` module. New-style typed key arrays are made with + :py:func:`jax.random.key`. Legacy ``uint32`` key arrays are made + with :py:func:`jax.random.PRNGKey`. + + To convert between the two, use :py:func:`jax.random.key_data` and + :py:func:`jax.random.wrap_key_data`. The legacy key format may be + needed when interfacing with systems outside of JAX (e.g. exporting + arrays to a serializable format), or when passing keys to JAX-based + libraries that assume the legacy format. + + Otherwise, typed keys are recommended. Caveats of legacy keys + relative to typed ones include: + + * They have an extra trailing dimension. + + * They have a numeric dtype (``uint32``), allowing for operations + that are typically not meant to be carried out over keys, such as + integer arithmetic. + + * They do not carry information about the RNG implementation. When + legacy keys are passed to :mod:`jax.random` functions, a global + configuration setting determines the RNG implementation (see + "Advanced RNG configuration" below). + + To learn more about this upgrade, and the design of key types, see + `JEP 9263 + `_. + Advanced -------- From 00c231cbb2e586375c4772a0ae9a87160288d924 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 7 Mar 2024 13:33:13 -0800 Subject: [PATCH 13/35] Reuse the utility `_gspmd_to_named_sharding_via_mesh` in other places PiperOrigin-RevId: 613686995 --- jax/_src/interpreters/pxla.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9fc30fb6530c..8458b9e35bf2 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2483,11 +2483,7 @@ def _gspmd_to_named_sharding_via_mesh( def _gspmd_to_named_sharding( out_s: sharding_impls.GSPMDSharding, orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding: - parsed_pspec = sharding_impls.parse_flatten_op_sharding( - out_s._hlo_sharding, orig_in_s.mesh)[0] - return create_mesh_pspec_sharding( - orig_in_s.mesh, parsed_pspec.get_partition_spec(), parsed_pspec, - out_s.memory_kind) + return _gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) _register_out_sharding_handler( sharding_impls.NamedSharding, _gspmd_to_named_sharding) From c986cbc80371c9c2b63e88e557680115dd608118 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 7 Mar 2024 13:58:34 -0800 Subject: [PATCH 14/35] Disable the input sharding propagation temporarily PiperOrigin-RevId: 613694719 --- jax/_src/interpreters/pxla.py | 4 ++-- tests/memories_test.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 8458b9e35bf2..b777994fab89 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2049,8 +2049,8 @@ def lower_sharding_computation( any(not is_unspecified(o) for o in out_shardings)) gs = GSPMDSharding.get_replicated(device_assignment) - if xla_extension_version < 241 or hasattr(backend, "compile_replicated"): - in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) + # if xla_extension_version < 241 or hasattr(backend, "compile_replicated"): + in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) da_object = _create_da_object(tuple(device_assignment)) diff --git a/tests/memories_test.py b/tests/memories_test.py index 6903cdbef789..9b3c8ac076d1 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -73,11 +73,9 @@ def setUp(self): super().setUp() self.orig_memories_flag = config.enable_memories.value jax.config.update('jax_enable_memories', True) - FLAGS.xla_tpu_enable_host_aware_passes = True def tearDown(self): jax.config.update('jax_enable_memories', self.orig_memories_flag) - FLAGS.xla_tpu_enable_host_aware_passes = False super().tearDown() @parameterized.named_parameters( @@ -1104,6 +1102,12 @@ def setUp(self): if not jtu.test_device_matches(["tpu"]): self.skipTest("Memories do not work on CPU and GPU backends yet.") super().setUp() + self.orig_memories_flag = config.enable_memories.value + jax.config.update('jax_enable_memories', True) + + def tearDown(self): + jax.config.update('jax_enable_memories', self.orig_memories_flag) + super().tearDown() def test_remat_jaxpr_offloadable(self): mesh = jtu.create_global_mesh((2,), ("x",)) From 8ac29132963a80421c5ea5ec9997e99d02a1fb0d Mon Sep 17 00:00:00 2001 From: Selam Waktola Date: Wed, 6 Mar 2024 15:43:42 -0800 Subject: [PATCH 15/35] minor modification for silu and swish func description Update 'aka' only inside functions.py modify SiLU (a.k.a. swish) activation function. to SiLU (aka swish) activation function. --- jax/_src/nn/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 8e232fcc6d8e..239d1b8d2d27 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -149,7 +149,7 @@ def sigmoid(x: ArrayLike) -> Array: @jax.jit def silu(x: ArrayLike) -> Array: - r"""SiLU (a.k.a. swish) activation function. + r"""SiLU (aka swish) activation function. Computes the element-wise function: From 0e79f95fdb240acfc96775ba461093559c7b146d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 7 Mar 2024 18:12:42 -0800 Subject: [PATCH 16/35] lint: fix unused import --- jax/tools/pgo_nsys_converter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 4d3dfd3db23c..7180f9ee8aa8 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -16,7 +16,6 @@ import re import sys import argparse -import psutil import os import shutil import subprocess From 6d67aa2242a09d88bc0e968de28f1fbf39aa6109 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 7 Mar 2024 18:18:04 -0800 Subject: [PATCH 17/35] Fix mypy errors --- jax/tools/pgo_nsys_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 7180f9ee8aa8..5623d81ed2bc 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -35,6 +35,7 @@ pgle_folder = os.path.join(os.path.split(args.pgle_output_path)[0], '') profile_folder = os.path.join(os.path.split(args.profile_path)[0], '') +assert isinstance(nsys_path, str) stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", "nvtxkernsum", f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] print(f""" @@ -45,7 +46,6 @@ proc.wait() thunk_re = re.compile("hlo_op=(.*)#") -cost_dictionary = dict() with open(f"{args.pgle_output_path}", 'w', newline='') as protofile: with open(f"{pgle_folder}{pgle_filename}.pbtxt_nvtxkernsum.csv", newline='') as csvfile: reader = csv.DictReader(csvfile) From 2eff1f0f3f87b5494d248d1cf2e6b0ddf8767ba7 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 7 Mar 2024 18:32:18 -0800 Subject: [PATCH 18/35] Guard script against execution on import --- jax/tools/pgo_nsys_converter.py | 60 +++++++++++++++++---------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 5623d81ed2bc..5e87220be606 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -20,41 +20,43 @@ import shutil import subprocess -print("Script to convert NVIDIA Nsys Profiles to the .pbtxt format. This format is readable by XLA's Profile Guided Latency Estimator. Usage: pgo_nsys_converter.py --profile_path --pgle_output_path ") +if __name__ == '__main__': -nsys_path = shutil.which("nsys") + print("Script to convert NVIDIA Nsys Profiles to the .pbtxt format. This format is readable by XLA's Profile Guided Latency Estimator. Usage: pgo_nsys_converter.py --profile_path --pgle_output_path ") -parser = argparse.ArgumentParser(description='Tool to convert NVIDIA Nsys Profiles to the .pbtxt format') -parser.add_argument("--profile_path", type=str, help="path to nsys profile") -parser.add_argument("--post_process", help="post process pbtxt to get minimum cost value for each instruction", action="store_true") -parser.add_argument("--pgle_output_path", type=str, help="output directory", default="/opt/paxml/workspace/lhs_pbtxt/temp.pbtxt") + nsys_path = shutil.which("nsys") -args = parser.parse_args() + parser = argparse.ArgumentParser(description='Tool to convert NVIDIA Nsys Profiles to the .pbtxt format') + parser.add_argument("--profile_path", type=str, help="path to nsys profile") + parser.add_argument("--post_process", help="post process pbtxt to get minimum cost value for each instruction", action="store_true") + parser.add_argument("--pgle_output_path", type=str, help="output directory", default="/opt/paxml/workspace/lhs_pbtxt/temp.pbtxt") -pgle_filename = os.path.basename(args.pgle_output_path).partition('.')[0] -pgle_folder = os.path.join(os.path.split(args.pgle_output_path)[0], '') -profile_folder = os.path.join(os.path.split(args.profile_path)[0], '') + args = parser.parse_args() -assert isinstance(nsys_path, str) -stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", "nvtxkernsum", f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] + pgle_filename = os.path.basename(args.pgle_output_path).partition('.')[0] + pgle_folder = os.path.join(os.path.split(args.pgle_output_path)[0], '') + profile_folder = os.path.join(os.path.split(args.profile_path)[0], '') -print(f""" - ******Starting stats command****** - {stats_command}.""") + assert isinstance(nsys_path, str) + stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", "nvtxkernsum", f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] -proc = subprocess.Popen(stats_command, stdout=sys.stdout, stderr=sys.stderr) -proc.wait() + print(f""" + ******Starting stats command****** + {stats_command}.""") -thunk_re = re.compile("hlo_op=(.*)#") -with open(f"{args.pgle_output_path}", 'w', newline='') as protofile: - with open(f"{pgle_folder}{pgle_filename}.pbtxt_nvtxkernsum.csv", newline='') as csvfile: - reader = csv.DictReader(csvfile) - for row in reader: - name = row['NVTX Range'] - time_ns = float(row['Avg (ns)']) - m = thunk_re.search(name) - if m is not None: - protofile.write(f'costs {{ name: "{m.group(1)}" cost_us: {time_ns / 1000.0} }}\n') + proc = subprocess.Popen(stats_command, stdout=sys.stdout, stderr=sys.stderr) + proc.wait() -clean_command = f"rm {profile_folder}/*.sqlite; rm {pgle_folder}/*.csv" -subprocess.call(clean_command, shell=True) + thunk_re = re.compile("hlo_op=(.*)#") + with open(f"{args.pgle_output_path}", 'w', newline='') as protofile: + with open(f"{pgle_folder}{pgle_filename}.pbtxt_nvtxkernsum.csv", newline='') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + name = row['NVTX Range'] + time_ns = float(row['Avg (ns)']) + m = thunk_re.search(name) + if m is not None: + protofile.write(f'costs {{ name: "{m.group(1)}" cost_us: {time_ns / 1000.0} }}\n') + + clean_command = f"rm {profile_folder}/*.sqlite; rm {pgle_folder}/*.csv" + subprocess.call(clean_command, shell=True) From 4b0382ec7629cb87e9d67ffd1a2c4f970b32ff3e Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 7 Mar 2024 23:42:50 -0800 Subject: [PATCH 19/35] Update XLA dependency to use revision http://github.com/openxla/xla/commit/541962e88f52237bc6050e4c8d7270e7c7e12b4e. PiperOrigin-RevId: 613831465 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a6ac82fe100f..930afe0c3617 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -20,8 +20,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "421b738e400a15c053b02924712a0e915b73cf7b" -XLA_SHA256 = "86beb00e75e235a3c7c481840304a54bb8fac233b5e9f8cdcd2947a2b924cdc0" +XLA_COMMIT = "541962e88f52237bc6050e4c8d7270e7c7e12b4e" +XLA_SHA256 = "ca67f68edad0d898241b65cbd85869de2560e2fdba700d964c707b427158583d" def repo(): tf_http_archive( From 4244b218ca44dc7893eba3ee78396cd03ced4d37 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 8 Mar 2024 07:54:11 -0800 Subject: [PATCH 20/35] [XLA:Python] Port sharding and device lists to nanobind. PiperOrigin-RevId: 613933518 --- jax/_src/sharding_impls.py | 42 +++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index a3ff34e8dd01..8b62e19b4d52 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -284,24 +284,15 @@ def __init__( self.mesh = mesh self.spec = spec self._memory_kind = memory_kind - self._parsed_pspec = _parsed_pspec self._manual_axes = _manual_axes - self._preprocess() - - def _preprocess(self): - # This split exists because you can pass `_parsed_pspec` that has been - # modified from the original. For example: Adding extra dimension to - # axis_resources for vmap handlers. In such cases you need to preserve the - # `sync` attribute of parsed pspecs. - # PartitionSpec is inferred from the parsed pspec in this case. - # TODO(yaskatariya): Remove this and replace this with a normalized - # representation of Parsed Pspec - if self._parsed_pspec is None: - self._parsed_pspec, _, _ = prepare_axis_resources( - PartitionSpec() if self.spec is None else self.spec, - "NamedSharding spec", allow_unconstrained_dims=True) - - _check_mesh_resource_axis(self.mesh, self._parsed_pspec) + self._parsed_pspec = preprocess(self.mesh, self.spec, _parsed_pspec) + + # TODO(phawkins): remove this method when jaxlib 0.4.26 or newer is the + # minimum. This method is called by the C++ sharding implementation in earlier + # versions. + if xla_extension_version < 243: + def _preprocess(self): + self._parsed_pspec = preprocess(self.mesh, self.spec, self._parsed_pspec) def __repr__(self): mesh_repr = ", ".join(f"'{k}': {v}" for k, v in self.mesh.shape.items()) @@ -1115,6 +1106,23 @@ def __repr__(self): f"sync={self.sync})") +def preprocess(mesh, spec, parsed_pspec): + # This split exists because you can pass `_parsed_pspec` that has been + # modified from the original. For example: Adding extra dimension to + # axis_resources for vmap handlers. In such cases you need to preserve the + # `sync` attribute of parsed pspecs. + # PartitionSpec is inferred from the parsed pspec in this case. + # TODO(yaskatariya): Remove this and replace this with a normalized + # representation of Parsed Pspec + if parsed_pspec is None: + parsed_pspec, _, _ = prepare_axis_resources( + PartitionSpec() if spec is None else spec, + "NamedSharding spec", allow_unconstrained_dims=True) + + _check_mesh_resource_axis(mesh, parsed_pspec) + return parsed_pspec + + def prepare_axis_resources(axis_resources, arg_name, allow_unconstrained_dims=False): From 6ada248b3c6c5f10a5b7751a19b95b58e4c177f5 Mon Sep 17 00:00:00 2001 From: yixiaoer Date: Sat, 9 Mar 2024 00:38:28 +0800 Subject: [PATCH 21/35] update --- jax/random.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/random.py b/jax/random.py index c06f48b3583b..d3649a6cdb8f 100644 --- a/jax/random.py +++ b/jax/random.py @@ -116,8 +116,9 @@ identical across JAX/XLA versions ✅ ✅ ================================= ======== ========= === ========== ===== ============ -(*): with jax_threefry_partitionable=1 set -(**): with XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1 set +(*): with ``jax_threefry_partitionable=1`` set + +(**): with ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1`` set The difference between "rbg" and "unsafe_rbg" is that while "rbg" uses a less robust/studied hash function for random value generation (but not for From 6771a59181b8b723464822cc3c07a9895ef18d69 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 8 Mar 2024 09:06:00 -0800 Subject: [PATCH 22/35] [key reuse] add jax.random.clone --- docs/jax.experimental.key_reuse.rst | 1 - docs/jax.random.rst | 1 + jax/_src/lax/control_flow/loops.py | 4 +--- jax/_src/prng.py | 23 ----------------------- jax/_src/random.py | 25 +++++++++++++++++++++++++ jax/experimental/jax2tf/jax2tf.py | 2 +- jax/experimental/key_reuse/__init__.py | 4 ---- jax/experimental/key_reuse/_core.py | 2 +- jax/random.py | 1 + tests/key_reuse_test.py | 7 ++++--- tests/random_test.py | 8 ++++++++ tests/state_test.py | 7 +++---- 12 files changed, 45 insertions(+), 40 deletions(-) diff --git a/docs/jax.experimental.key_reuse.rst b/docs/jax.experimental.key_reuse.rst index 5c7caf80f0ce..c78f23866d0d 100644 --- a/docs/jax.experimental.key_reuse.rst +++ b/docs/jax.experimental.key_reuse.rst @@ -9,5 +9,4 @@ API .. autosummary:: :toctree: _autosummary - reuse_key KeyReuseError diff --git a/docs/jax.random.rst b/docs/jax.random.rst index ea7845f8ce0e..9d6369d2d2b1 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -18,6 +18,7 @@ Key Creation & Manipulation wrap_key_data fold_in split + clone Random Samplers ~~~~~~~~~~~~~~~ diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index e939f2170f8a..3cbb73ed6982 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -394,9 +394,7 @@ def body_fun(vals): i_ = length - i - 1 if reverse else i # TODO(jakevdp)[key-reuse]: this key reuse logic is not quite right, # because the scan body may consume any keys within it. - # Import here to avoid circular imports - from jax.experimental import key_reuse - xs_unconsumed = _map(key_reuse.reuse_key, xs) + xs_unconsumed = _map(jax.random.clone, xs) x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed) out_flat = f_impl(*consts, *carry, *x) carry_out, y_updates = split_list(out_flat, [num_carry]) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 7f125bd44741..db6d174687a3 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -1377,26 +1377,3 @@ def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: tag='urbg') register_prng(unsafe_rbg_prng_impl) - - -# Primitives related to key reuse -reuse_key_p = core.Primitive("reuse_key") -reuse_key_p.def_impl(lambda x: x) -reuse_key_p.def_abstract_eval(lambda x: x) -batching.defvectorized(reuse_key_p) -mlir.register_lowering(reuse_key_p, lambda _, k: [k]) - -def reuse_key(key): - """Explicitly mark a key as unconsumed. - - Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`) - this function operates as an identity. - - Example: - - >>> import jax - >>> key = jax.random.key(0) - >>> data = jax.random.uniform(key) - >>> same_data = jax.random.uniform(reuse_key(key)) - """ - return reuse_key_p.bind(key) diff --git a/jax/_src/random.py b/jax/_src/random.py index f9045ebf3135..85f516ca6868 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -2611,3 +2611,28 @@ def binomial( if shape is not None: shape = core.canonicalize_shape(shape) return _binomial(key, n, p, shape, dtype) + + +# Functions related to key reuse checking +random_clone_p = core.Primitive("random_clone") +random_clone_p.def_impl(lambda x: x) +random_clone_p.def_abstract_eval(lambda x: x) +batching.defvectorized(random_clone_p) +mlir.register_lowering(random_clone_p, lambda _, k: [k]) + +def clone(key): + """Clone a key for reuse + + Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`) + this function operates as an identity. + + Example: + + >>> import jax + >>> key = jax.random.key(0) + >>> data = jax.random.uniform(key) + >>> cloned_key = jax.random.clone(key) + >>> same_data = jax.random.uniform(cloned_key) + >>> assert data == same_data + """ + return random_clone_p.bind(key) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 343a9ac9e323..5b0f9929565c 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1528,7 +1528,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "consume", ] -tf_impl[prng.reuse_key_p] = lambda x: x +tf_impl[random_internal.random_clone_p] = lambda x: x tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient diff --git a/jax/experimental/key_reuse/__init__.py b/jax/experimental/key_reuse/__init__.py index 72d9a861eacf..75b231ed23b9 100644 --- a/jax/experimental/key_reuse/__init__.py +++ b/jax/experimental/key_reuse/__init__.py @@ -39,10 +39,6 @@ ... print(jax.random.normal(key)) -0.20584226 """ -from jax._src.prng import ( - reuse_key as reuse_key, -) - from jax.experimental.key_reuse._core import ( KeyReuseError as KeyReuseError, ) diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 489fcc14e8fa..0d832c3dd20f 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -149,7 +149,7 @@ def _check_consumed_value(eqn, consumed): key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [], [Forward(0, 0)]) key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignature([], [Source(0)]) +key_reuse_signatures[random.random_clone_p] = KeyReuseSignature([], [Source(0)]) key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], []) # TODO(jakevdp): should fold_in sink its input key? # key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)]) diff --git a/jax/random.py b/jax/random.py index c06f48b3583b..0262f06bb798 100644 --- a/jax/random.py +++ b/jax/random.py @@ -144,6 +144,7 @@ cauchy as cauchy, chisquare as chisquare, choice as choice, + clone as clone, dirichlet as dirichlet, double_sided_maxwell as double_sided_maxwell, exponential as exponential, diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index d164290ec48c..d1c390663f98 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -21,6 +21,7 @@ from jax import core import jax.numpy as jnp from jax._src import prng +from jax._src import random from jax._src import test_util as jtu from jax.experimental.key_reuse._core import ( assert_consumed, assert_unconsumed, consume, consume_p) @@ -36,7 +37,7 @@ primitives_with_static_signatures = { consume_p: (consume, key), - prng.reuse_key_p: (prng.reuse_key, key), + random.random_clone_p: (random.clone, key), prng.random_bits_p: (jax.random.bits, key), # prng.random_fold_in_p: (jax.random.fold_in, key, 2), prng.random_seed_p: (jax.random.key, 0), @@ -91,12 +92,12 @@ def f(key): assert_consumed(key2) self.check_key_reuse(f, jax.random.key(0)) - def test_reuse_key(self): + def test_random_clone(self): def f(key): assert_unconsumed(key) consume(key) assert_consumed(key) - key2 = prng.reuse_key(key) + key2 = jax.random.clone(key) assert_unconsumed(key2) self.check_key_reuse(f, jax.random.key(0)) diff --git a/tests/random_test.py b/tests/random_test.py index d8212106e211..170076989cde 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -588,6 +588,14 @@ def test_construction(self): key = random.key(42) self.assertIsInstance(key, prng_internal.PRNGKeyArray) + def test_random_clone(self): + # Here we test value semantics and compatibility with jit/vmap + # key reuse semantics are tested in key_reuse_test.py + keys = jax.random.split(jax.random.key(0), 5) + self.assertKeysEqual(keys, jax.random.clone(keys)) + self.assertKeysEqual(keys, jax.jit(jax.random.clone)(keys)) + self.assertKeysEqual(keys, jax.vmap(jax.random.clone)(keys)) + def test_issubdtype(self): key = random.key(42) diff --git a/tests/state_test.py b/tests/state_test.py index 1f109536fc16..31e97fd4bb5b 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -29,7 +29,6 @@ from jax._src import config from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe -from jax._src import prng from jax._src import test_util as jtu from jax._src.util import tuple_insert import jax.numpy as jnp @@ -1735,8 +1734,8 @@ def ref(x): y, impl_vjp = jax.vjp(impl, x) y_ref, ref_vjp = jax.vjp(ref, x) self.assertAllClose(y, y_ref) - t = random.normal(prng.reuse_key(k2), x.shape) - y2 = random.normal(prng.reuse_key(k1), y.shape) + t = random.normal(jax.random.clone(k2), x.shape) + y2 = random.normal(jax.random.clone(k1), y.shape) self.assertAllClose(impl_vjp(t), ref_vjp(t)) # Second order @@ -1752,7 +1751,7 @@ def ref(x): (x,), impl_vjp2 = jax.vjp(impl_vjp, t2) (x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2) self.assertAllClose(x, x_ref) - y2 = random.normal(prng.reuse_key(k1), y.shape) + y2 = random.normal(jax.random.clone(k1), y.shape) self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,))) if __name__ == '__main__': From 895ad60d60cb77ea9c1c85c1f1bdae4e7e218013 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 8 Mar 2024 18:00:36 +0000 Subject: [PATCH 23/35] Removed dlpack extraction leading to forced legacy path --- jax/_src/numpy/lax_numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3d621b6e3cb0..6fa887f1b376 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2438,7 +2438,7 @@ def fromiter(*args, **kwargs): """) def from_dlpack(x: Any) -> Array: from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top - return from_dlpack(x.__dlpack__()) + return from_dlpack(x) @util.implements(np.fromfunction) def fromfunction(function: Callable[..., Array], shape: Any, From 7634708743b6891d5ee9e4e5cec30078476e41a1 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 8 Mar 2024 10:40:33 -0800 Subject: [PATCH 24/35] [key reuse] define KeyReuseError in jax.errors --- docs/errors.rst | 1 + docs/jax.experimental.key_reuse.rst | 8 -------- jax/_src/errors.py | 26 ++++++++++++++++++++++++++ jax/errors.py | 1 + jax/experimental/key_reuse/__init__.py | 3 --- jax/experimental/key_reuse/_core.py | 4 +--- tests/core_test.py | 6 +----- tests/key_reuse_test.py | 3 ++- 8 files changed, 32 insertions(+), 20 deletions(-) diff --git a/docs/errors.rst b/docs/errors.rst index 4c76f5dcf5a1..23dbaf29c46f 100644 --- a/docs/errors.rst +++ b/docs/errors.rst @@ -7,6 +7,7 @@ along with representative examples of how one might fix them. .. currentmodule:: jax.errors .. autoclass:: ConcretizationTypeError +.. autoclass:: KeyReuseError .. autoclass:: NonConcreteBooleanIndexError .. autoclass:: TracerArrayConversionError .. autoclass:: TracerBoolConversionError diff --git a/docs/jax.experimental.key_reuse.rst b/docs/jax.experimental.key_reuse.rst index c78f23866d0d..7255afabfc10 100644 --- a/docs/jax.experimental.key_reuse.rst +++ b/docs/jax.experimental.key_reuse.rst @@ -2,11 +2,3 @@ ===================================== .. automodule:: jax.experimental.key_reuse - -API ---- - -.. autosummary:: - :toctree: _autosummary - - KeyReuseError diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 5594b261abcd..dd5d83d51ec1 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -655,3 +655,29 @@ class UnexpectedTracerError(JAXTypeError): def __init__(self, msg: str): super().__init__(msg) + + +@export +class KeyReuseError(JAXTypeError): + """ + This error occurs when a PRNG key is reused in an unsafe manner. + Key reuse is checked only when `jax_enable_key_reuse_checks` is + set to `True`. + + Here is a simple example of code that would lead to such an error:: + + >>> with jax.enable_key_reuse_checks(True): # doctest: +SKIP + ... key = jax.random.key(0) + ... value = jax.random.uniform(key) + ... new_value = jax.random.uniform(key) + ... + --------------------------------------------------------------------------- + KeyReuseError Traceback (most recent call last) + ... + KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0 + + This sort of key reuse is problematic because the JAX PRNG is stateless, and keys + must be manually split; For more information on this see `Sharp Bits: Random Numbers + `_. + """ + pass diff --git a/jax/errors.py b/jax/errors.py index 4b8a0cf7547e..15a6654fa32d 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -24,5 +24,6 @@ TracerBoolConversionError as TracerBoolConversionError, TracerIntegerConversionError as TracerIntegerConversionError, UnexpectedTracerError as UnexpectedTracerError, + KeyReuseError as KeyReuseError, ) from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback diff --git a/jax/experimental/key_reuse/__init__.py b/jax/experimental/key_reuse/__init__.py index 75b231ed23b9..f33020009f9b 100644 --- a/jax/experimental/key_reuse/__init__.py +++ b/jax/experimental/key_reuse/__init__.py @@ -39,6 +39,3 @@ ... print(jax.random.normal(key)) -0.20584226 """ -from jax.experimental.key_reuse._core import ( - KeyReuseError as KeyReuseError, -) diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 0d832c3dd20f..275c0e5f2adc 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -21,6 +21,7 @@ import jax from jax import lax from jax import tree_util +from jax.errors import KeyReuseError from jax.interpreters import batching, mlir from jax._src import api_util from jax._src import config @@ -99,9 +100,6 @@ def update_consumption(self, args_in, args_out): arg_out._consumed = arg_in._consumed -class KeyReuseError(RuntimeError): - pass - consume_p = core.Primitive("consume") consume_p.def_impl(lambda x: x) consume_p.def_abstract_eval(lambda x: x) diff --git a/tests/core_test.py b/tests/core_test.py index 788f61db943d..1831c57742ac 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -751,15 +751,11 @@ def g(x): return x def test_check_jaxpr_key_reuse(self): with config.enable_key_reuse_checks(True): - try: - from jax.experimental.key_reuse import KeyReuseError - except ImportError: - self.skipTest("Test requires jax.experimental.key_reuse") def f(seed): key = jax.random.key(seed) return jax.random.uniform(key) + jax.random.normal(key) with jax.enable_checks(True): - with self.assertRaises(KeyReuseError): + with self.assertRaises(jax.errors.KeyReuseError): jax.jit(f)(0) diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index d1c390663f98..725c99958ac4 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -23,9 +23,10 @@ from jax._src import prng from jax._src import random from jax._src import test_util as jtu +from jax.errors import KeyReuseError from jax.experimental.key_reuse._core import ( assert_consumed, assert_unconsumed, consume, consume_p) -from jax.experimental.key_reuse import _core, KeyReuseError +from jax.experimental.key_reuse import _core from jax import config config.parse_flags_with_absl() From 0644f192f2efbe13d5bc098e66d6ad45bf417500 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 8 Mar 2024 12:28:00 -0800 Subject: [PATCH 25/35] [key reuse] improve KeyReuseSignature semantics --- jax/experimental/key_reuse/_core.py | 168 ++++++++++++++++++---------- tests/key_reuse_test.py | 74 +++++++++++- 2 files changed, 181 insertions(+), 61 deletions(-) diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 275c0e5f2adc..ab96f48f3fb0 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -16,7 +16,7 @@ from collections import defaultdict from functools import partial, reduce, wraps -from typing import Any, Callable, NamedTuple +from typing import Any, Callable, Iterator, NamedTuple import jax from jax import lax @@ -41,36 +41,91 @@ import numpy as np -class Sink(NamedTuple): +# Create Source() and Sink() objects which validate inputs, have +# correct equality semantics, and are hashable & immutable. +class _SourceSinkBase: idx: int - mask: bool | np.ndarray = True + mask: bool | np.ndarray + + def __init__(self, idx: int, mask: bool | np.bool_ | np.ndarray = True): + assert isinstance(idx, int) + if isinstance(mask, np.ndarray): + assert mask.dtype == np.dtype('bool') + if np.all(mask): + mask = True + elif not np.any(mask): + mask = False + elif mask.flags.writeable: + mask = np.array(mask, copy=True) + mask.flags.writeable = False + elif isinstance(mask, np.bool_): + mask = bool(mask) + else: + assert isinstance(mask, bool) + super().__setattr__("idx", idx) + super().__setattr__("mask", mask) + + def __setattr__(self, *args, **kwargs): + raise ValueError(f"{self.__class__.__name__} is immutable") + + def __eq__(self, other): + return (self.__class__ == other.__class__ + and self.idx == other.idx + and np.shape(self.mask) == np.shape(other.mask) + and np.all(self.mask == other.mask)) + + def __hash__(self): + if isinstance(self.mask, bool): + return hash((self.__class__, self.idx, self.mask)) + else: + mask = np.asarray(self.mask) + return hash((self.__class__, self.idx, mask.shape, + tuple(mask.flatten().tolist()))) def __repr__(self): - if isinstance(self.mask, bool) and self.mask: - return f"Sink({self.idx})" - else: - return f"Sink({self.idx}, mask={self.mask})" + if self.mask is True: + return f"{self.__class__.__name__}({self.idx})" + return f"{self.__class__.__name__}({self.idx}, {self.mask})" -class Source(NamedTuple): - idx: int - mask: bool | np.ndarray = True +class Sink(_SourceSinkBase): + pass + + +class Source(_SourceSinkBase): + pass - def __repr__(self): - if isinstance(self.mask, bool) and self.mask: - return f"Source({self.idx})" - else: - return f"Source({self.idx}, mask={self.mask})" class Forward(NamedTuple): in_idx: int out_idx: int -class KeyReuseSignature(NamedTuple): - sinks: list[Sink] - sources: list[Source] - forwards: list[Forward] = [] +# KeyReuseSignature is essentially a frozen set of Source/Sink/Forward +# objects, with a few convenience methods related to key reuse checking. +class KeyReuseSignature: + _args: frozenset[Source | Sink | Forward] + + def __init__(self, *args): + self._args = frozenset(args) + + def __eq__(self, other): + return isinstance(other, KeyReuseSignature) and self._args == other._args + + def __hash__(self): + return hash(self._args) + + @property + def sinks(self) -> Iterator[Sink]: + yield from (s for s in self._args if isinstance(s, Sink)) + + @property + def sources(self) -> Iterator[Source]: + yield from (s for s in self._args if isinstance(s, Source)) + + @property + def forwards(self) -> Iterator[Forward]: + yield from (s for s in self._args if isinstance(s, Forward)) def check_signature(self, *args, funcname="function", context=None): for sink in self.sinks: @@ -145,34 +200,33 @@ def _check_consumed_value(eqn, consumed): # The behavior of most primitives can be described via simple signatures. key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {} -key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [], [Forward(0, 0)]) -key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[random.random_clone_p] = KeyReuseSignature([], [Source(0)]) -key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], []) +key_reuse_signatures[consume_p] = KeyReuseSignature(Sink(0), Forward(0, 0)) +key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[random.random_clone_p] = KeyReuseSignature(Source(0)) +key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature(Sink(0)) # TODO(jakevdp): should fold_in sink its input key? -# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)]) -key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([], [Source(0)]) -key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature([], [Source(0)]) -key_reuse_signatures[prng.random_split_p] = KeyReuseSignature([Sink(0)], [Source(0)]) -key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature([Sink(0)], []) +key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature(Source(0)) +key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature(Source(0)) +key_reuse_signatures[prng.random_split_p] = KeyReuseSignature(Sink(0), Source(0)) +key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature(Sink(0)) # TODO(jakevdp): broadcast should probably consume the input to avoid implicit duplication -key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.copy_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.device_put_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.reshape_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature([], [Source(0)], []) +key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.copy_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.device_put_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.reshape_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature(Source(0)) # TODO(jakevdp): should unwrap sink its input key? -key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignature([], [], []) -key_reuse_signatures[debug_callback_p] = KeyReuseSignature([], []) -key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([Sink(1)], [], [Forward(0, 0)]) -key_reuse_signatures[lax.gather_p] = KeyReuseSignature([], [], [Forward(0, 0)]) -key_reuse_signatures[lax.scatter_p] = KeyReuseSignature([Sink(2)], [], [Forward(0, 0)]) +key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignature() +key_reuse_signatures[debug_callback_p] = KeyReuseSignature() +key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature(Sink(1), Forward(0, 0)) +key_reuse_signatures[lax.gather_p] = KeyReuseSignature(Forward(0, 0)) +key_reuse_signatures[lax.scatter_p] = KeyReuseSignature(Sink(2), Forward(0, 0)) # Equality checks don't consume -key_reuse_signatures[lax.eq_p] = KeyReuseSignature([], [], []) -key_reuse_signatures[lax.ne_p] = KeyReuseSignature([], [], []) +key_reuse_signatures[lax.eq_p] = KeyReuseSignature() +key_reuse_signatures[lax.ne_p] = KeyReuseSignature() # Rules which require more dynamic logic. key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignature]] = {} @@ -182,8 +236,7 @@ def unknown_signature(eqn): def is_key(var: core.Atom): return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key) return KeyReuseSignature( - sinks=[Sink(idx, True) for idx, var in enumerate(eqn.invars) if is_key(var)], - sources=[], + *(Sink(idx) for idx, var in enumerate(eqn.invars) if is_key(var)) ) @weakref_lru_cache @@ -216,7 +269,6 @@ def sink(var: core.Atom, mask=True): return True consumed[var] = np.logical_or(consumed.get(var, False), mask) - def source(var: core.Atom, mask=False): if not is_key(var): return @@ -262,13 +314,13 @@ def is_consumed(var: core.Atom): source(eqn.outvars[src.idx]) return KeyReuseSignature( - sinks=[Sink(i, consumed[v]) for i, v in enumerate(jaxpr.invars) - if is_key(v) and np.any(consumed.get(v, False))], - sources=[Source(i) for i, v in enumerate(jaxpr.outvars) - if is_key(v) and resolve_forwards(v) not in jaxpr.invars and not consumed.get(v, False)], - forwards=[Forward(jaxpr.invars.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type] - for idx_out, outvar in enumerate(jaxpr.outvars) - if is_key(outvar) and resolve_forwards(outvar) in jaxpr.invars] + *(Sink(i, consumed[v]) for i, v in enumerate(jaxpr.invars) + if is_key(v) and np.any(consumed.get(v, False))), + *(Source(i) for i, v in enumerate(jaxpr.outvars) + if is_key(v) and resolve_forwards(v) not in jaxpr.invars and not consumed.get(v, False)), + *(Forward(jaxpr.invars.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type] + for idx_out, outvar in enumerate(jaxpr.outvars) + if is_key(outvar) and resolve_forwards(outvar) in jaxpr.invars) ) @@ -292,16 +344,16 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None: def _slice_signature(eqn): in_aval = eqn.invars[0].aval if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key): - return KeyReuseSignature([], [], [Forward(0, 0)]) + return KeyReuseSignature(Forward(0, 0)) if any(core.is_symbolic_dim(s) for s in in_aval.shape): - return KeyReuseSignature([], [], [Forward(0, 0)]) + return KeyReuseSignature(Forward(0, 0)) start_indices = eqn.params['start_indices'] limit_indices = eqn.params['limit_indices'] strides = eqn.params['strides'] or (1,) * len(start_indices) idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides)) sink = np.zeros(in_aval.shape, dtype=bool) sink[idx] = True - return KeyReuseSignature([Sink(0, sink)], [Source(0)]) + return KeyReuseSignature(Sink(0, sink), Source(0)) key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature @@ -329,7 +381,7 @@ def _cond_key_type_signature(eqn): combined_sources = [Source(i, reduce(np.logical_and, m)) for i, m in sources.items()] combined_forwards = [Forward(f.in_idx + 1, f.out_idx) for f in set.intersection(*(set(sig.forwards) for sig in signatures))] - return KeyReuseSignature(combined_sinks, combined_sources, combined_forwards) + return KeyReuseSignature(*combined_sinks, *combined_sources, *combined_forwards) key_reuse_signatures_dynamic[lax.cond_p] = _cond_key_type_signature @@ -410,7 +462,7 @@ def _remat_key_type_signature(eqn): # 2) will never create keys # Therefore, the differentiated pass is a no-op. if eqn.params['differentiated']: - return KeyReuseSignature([], []) + return KeyReuseSignature() return get_jaxpr_type_signature(eqn.params['jaxpr']) key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 725c99958ac4..697b1e318f98 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -25,7 +25,8 @@ from jax._src import test_util as jtu from jax.errors import KeyReuseError from jax.experimental.key_reuse._core import ( - assert_consumed, assert_unconsumed, consume, consume_p) + assert_consumed, assert_unconsumed, consume, consume_p, + Source, Sink, Forward, KeyReuseSignature) from jax.experimental.key_reuse import _core from jax import config @@ -589,7 +590,7 @@ def f_good(x, key): @jtu.with_config(jax_enable_key_reuse_checks=True) -class KeyReuseEager(jtu.JaxTestCase): +class KeyReuseEagerTest(jtu.JaxTestCase): jit_msg = "Previously-consumed key passed to jit-compiled function at index 0" eager_bits_msg = "Previously-consumed key passed to random_bits at index 0" traced_bits_msg = "In random_bits, argument 0 is already consumed." @@ -616,9 +617,76 @@ def f(): f() +class KeyReuseImplementationTest(jtu.JaxTestCase): + + def assertEquivalent(self, a, b): + self.assertEqual(a, b) + self.assertEqual(hash(a), hash(b)) + + def assertNotEquivalent(self, a, b): + self.assertNotEqual(a, b) + self.assertNotEqual(hash(a), hash(b)) + + def test_source_sink_immutability(self): + mask = np.array([True, False]) + orig_mask_writeable = mask.flags.writeable + + sink = Sink(0, mask) + source = Source(0, mask) + + self.assertFalse(sink.mask.flags.writeable) + self.assertFalse(source.mask.flags.writeable) + self.assertEqual(mask.flags.writeable, orig_mask_writeable) + + with self.assertRaises(ValueError): + sink.idx = 1 + with self.assertRaises(ValueError): + sink.mask = True + with self.assertRaises(ValueError): + source.idx = 1 + with self.assertRaises(ValueError): + source.mask = True + + def test_source_sink_forward_equivalence_semantics(self): + + true_mask = np.array([True, True]) + false_mask = np.array([False, False]) + mixed_mask = np.array([True, False]) + + self.assertEquivalent(Source(0), Source(0, True)) + self.assertEquivalent(Source(0, True), Source(0, true_mask)) + self.assertEquivalent(Source(0, False), Source(0, false_mask)) + self.assertEquivalent(Source(0, mixed_mask), Source(0, mixed_mask)) + self.assertNotEquivalent(Source(0), Source(1)) + self.assertNotEquivalent(Source(0), Source(0, False)) + self.assertNotEquivalent(Source(0), Source(0, mixed_mask)) + + self.assertEquivalent(Sink(0), Sink(0, True)) + self.assertEquivalent(Sink(0, True), Sink(0, true_mask)) + self.assertEquivalent(Sink(0, False), Sink(0, false_mask)) + self.assertEquivalent(Sink(0, mixed_mask), Sink(0, mixed_mask)) + self.assertNotEquivalent(Sink(0), Sink(1)) + self.assertNotEquivalent(Sink(0), Sink(0, False)) + self.assertNotEquivalent(Sink(0), Sink(0, mixed_mask)) + + self.assertNotEquivalent(Source(0), Sink(0)) + + self.assertEquivalent(Forward(0, 1), Forward(0, 1)) + self.assertNotEquivalent(Forward(0, 1), Forward(1, 0)) + + def test_signature_equality_semantics(self): + self.assertEquivalent( + KeyReuseSignature(Sink(0), Source(1), Forward(1, 0)), + KeyReuseSignature(Forward(1, 0), Source(1), Sink(0))) + self.assertEquivalent( + KeyReuseSignature(), KeyReuseSignature()) + self.assertNotEquivalent( + KeyReuseSignature(Source(0)), KeyReuseSignature(Sink(0))) + + @jtu.with_config(jax_enable_checks=False) -class KeyReuseGlobalFlags(jtu.JaxTestCase): +class KeyReuseGlobalFlagsTest(jtu.JaxTestCase): def test_key_reuse_flag(self): @jax.jit From 29edfd89251c98a647ca6af33b813f50fb80c373 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 5 Mar 2024 20:09:14 -0800 Subject: [PATCH 26/35] define a loop-free untrue batching rule for `rng_bit_generator` --- jax/_src/lax/control_flow/loops.py | 16 ++++---- jax/_src/random.py | 2 +- tests/BUILD | 3 ++ tests/lax_test.py | 18 +++++++++ tests/random_lax_test.py | 59 ++++++++++++++++++++++++++---- 5 files changed, 82 insertions(+), 16 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index e939f2170f8a..99fbb010130a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2046,16 +2046,18 @@ def map(f, xs): return ys def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm): - """Calls RBG in a loop and stacks the results.""" - key, = batched_args + keys, = batched_args bd, = batch_dims if bd is batching.not_mapped: - return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype, + return lax.rng_bit_generator_p.bind(keys, shape=shape, dtype=dtype, algorithm=algorithm), (None, None) - key = batching.moveaxis(key, bd, 0) - map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm) - stacked_keys, stacked_bits = map(map_body, key) - return (stacked_keys, stacked_bits), (0, 0) + keys = batching.moveaxis(keys, bd, 0) + batch_size = keys.shape[0] + key = keys[0] + new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape), + dtype=dtype, algorithm=algorithm) + new_keys = jax.lax.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) + return (new_keys, bits), (0, 0) batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore diff --git a/jax/_src/random.py b/jax/_src/random.py index f9045ebf3135..31d10db01c05 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -1233,7 +1233,7 @@ def _gamma_impl(key, a, *, log_space, use_vmap=False): keys = keys.flatten() alphas = a.flatten() - if use_vmap: + if use_vmap and _key_impl(key) is prng.threefry_prng_impl: samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas) else: samples = lax.map( diff --git a/tests/BUILD b/tests/BUILD index 9c8ca93103b7..6c72e9da98f0 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -784,6 +784,9 @@ jax_test( "notsan", # Times out ], }, + backend_variant_args = { + "gpu": ["--jax_num_generated_cases=40"], + }, shard_count = { "cpu": 40, "gpu": 30, diff --git a/tests/lax_test.py b/tests/lax_test.py index aadac1d64566..613164650bd8 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2652,6 +2652,24 @@ def testRngBitGeneratorReturnedKey(self): new_key, _ = lax.rng_bit_generator(key, (0,)) self.assertAllClose(key, new_key) + def test_rng_bit_generator_vmap(self): + def f(key): + return lax.rng_bit_generator(key, shape=(5, 7)) + + keys = np.arange(3 * 4).reshape((3, 4)).astype(np.uint32) + out_keys, bits = jax.vmap(f)(keys) + self.assertEqual(out_keys.shape, (3, 4)) + self.assertEqual(bits.shape, (3, 5, 7)) + + def test_rng_bit_generator_vmap_vmap(self): + def f(key): + return lax.rng_bit_generator(key, shape=(5, 7)) + + keys = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.uint32) + out_keys, bits = jax.vmap(jax.vmap(f))(keys) + self.assertEqual(out_keys.shape, (2, 3, 4)) + self.assertEqual(bits.shape, (2, 3, 5, 7)) + @jtu.sample_product( dtype=lax_test_util.all_dtypes + lax_test_util.python_scalar_types, weak_type=[True, False], diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 52e8cbc7b262..f6280a5e02ec 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -1348,6 +1348,7 @@ def test_vmap_fold_in_shape(self): out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys(), msgs.T) self.assertEqual(out.shape, (3, 2)) + @jax.enable_key_reuse_checks(False) def test_vmap_split_mapped_key(self): key = self.make_key(73) mapped_keys = random.split(key, num=3) @@ -1408,24 +1409,57 @@ def test_vmap_split_not_mapped_key(self): self.assertArraysEqual(random.key_data(vk), random.key_data(single_split_key)) - def test_vmap_split_mapped_key(self): + @jax.enable_key_reuse_checks(False) + def test_vmap_split_mapped_key_shape(self): key = self.make_key(73) mapped_keys = random.split(key, num=3) - forloop_keys = [random.split(k) for k in mapped_keys] vmapped_keys = vmap(random.split)(mapped_keys) self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape)) - for fk, vk in zip(forloop_keys, vmapped_keys): - self.assertArraysEqual(random.key_data(fk), + + @jax.enable_key_reuse_checks(False) + def test_vmap_split_mapped_key_values(self): + key = self.make_key(73) + mapped_keys = random.split(key, num=3) + vmapped_keys = vmap(random.split)(mapped_keys) + ref_keys = [random.split(k) for k in mapped_keys] + for rk, vk in zip(ref_keys, vmapped_keys): + self.assertArraysEqual(random.key_data(rk), random.key_data(vk)) - def test_vmap_random_bits(self): - rand_fun = lambda key: random.randint(key, (), 0, 100) + @jax.enable_key_reuse_checks(False) + def test_vmap_random_bits_shape(self): + rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100) key = self.make_key(73) mapped_keys = random.split(key, num=3) - forloop_rand_nums = [rand_fun(k) for k in mapped_keys] rand_nums = vmap(rand_fun)(mapped_keys) self.assertEqual(rand_nums.shape, (3,)) - self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums)) + + @jtu.skip_on_devices("tpu") + @jax.enable_key_reuse_checks(False) + def test_vmap_random_bits_value(self): + rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100) + key = self.make_key(73) + mapped_keys = random.split(key, num=3) + rand_nums = vmap(rand_fun)(mapped_keys) + ref_nums = rand_fun(mapped_keys[0], shape=(3,)) + self.assertArraysEqual(rand_nums, ref_nums) + + def test_vmap_random_bits_distribution(self): + dtype = jnp.float32 + keys = lambda: jax.random.split(self.make_key(0), 10) + + def rand(key): + nums = jax.vmap(lambda key: random.uniform(key, (1000,), dtype))(key) + return nums.flatten() + + crand = jax.jit(rand) + + uncompiled_samples = rand(keys()) + compiled_samples = crand(keys()) + + for samples in [uncompiled_samples, compiled_samples]: + self._CheckCollisions(samples, jnp.finfo(dtype).nmant) + self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf) def test_cannot_add(self): key = self.make_key(73) @@ -1455,6 +1489,15 @@ class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest): def make_key(self, seed): return random.PRNGKey(seed, impl="unsafe_rbg") + @jtu.skip_on_devices("tpu") + @jax.enable_key_reuse_checks(False) + def test_vmap_split_mapped_key_values(self): + key = self.make_key(73) + mapped_keys = random.split(key, num=3) + vmapped_keys = vmap(random.split)(mapped_keys) + ref_keys = random.split(mapped_keys[0], (3, 2)) + self.assertArraysEqual(random.key_data(vmapped_keys), + random.key_data(ref_keys)) def _sampler_unimplemented_with_custom_prng(*args, **kwargs): raise SkipTest('sampler only implemented for default RNG') From d1e49f9c89369f0a3190081a8834d3cceb21be5e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 8 Mar 2024 15:16:39 -0800 Subject: [PATCH 27/35] [key reuse] fix random_clone impl rule --- jax/_src/random.py | 3 ++- tests/key_reuse_test.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 69bd8d4b9a4b..cc1a4a38f944 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -30,6 +30,7 @@ from jax._src import config from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src import prng from jax._src import xla_bridge @@ -2615,7 +2616,7 @@ def binomial( # Functions related to key reuse checking random_clone_p = core.Primitive("random_clone") -random_clone_p.def_impl(lambda x: x) +dispatch.simple_impl(random_clone_p) random_clone_p.def_abstract_eval(lambda x: x) batching.defvectorized(random_clone_p) mlir.register_lowering(random_clone_p, lambda _, k: [k]) diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 697b1e318f98..5d5d6e12e849 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -595,6 +595,15 @@ class KeyReuseEagerTest(jtu.JaxTestCase): eager_bits_msg = "Previously-consumed key passed to random_bits at index 0" traced_bits_msg = "In random_bits, argument 0 is already consumed." + def test_clone_eager(self): + key = jax.random.key(0) + key2 = jax.random.clone(key) + self.assertIsNot(key, key2) + + _ = jax.random.uniform(key) + self.assertTrue(key._consumed) + self.assertFalse(key2._consumed) + def test_simple_reuse_nojit(self): key = jax.random.key(0) _ = jax.random.bits(key) From 930aaa5e47d05af8f30d12365689c7aaac29bb7f Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 8 Mar 2024 16:50:02 -0800 Subject: [PATCH 28/35] Deprecated the jax.experimental.maps submodule PiperOrigin-RevId: 614082251 --- CHANGELOG.md | 3 +++ jax/experimental/maps.py | 26 ++++++++++++++++++++++++-- pyproject.toml | 1 + 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a4fe3af345c8..e2426b107a44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ Remember to align the itemized text with the first line of an item within a list * Deprecations * {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`. + * The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are + deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the + `spmd_axis_name` argument for expressing SPMD device-parallel computations. * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` that cannot be converted to a JAX array now results in an exception. diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index afdd11dbb3f7..6e398df0d5a5 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -12,18 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + +from jax._src import deprecations from jax._src.maps import ( AxisName as AxisName, ResourceSet as ResourceSet, SerialLoop as SerialLoop, + _prepare_axes as _prepare_axes, make_xmap_callable as make_xmap_callable, serial_loop as serial_loop, - xmap as xmap, xmap_p as xmap_p, - _prepare_axes as _prepare_axes, + xmap as xmap, ) from jax._src.mesh import ( EMPTY_ENV as EMPTY_ENV, ResourceEnv as ResourceEnv, thread_resources as thread_resources, ) + +# Added March 7, 2024. +_msg = ( + "jax.experimental.maps and jax.experimental.maps.xmap are deprecated and" + " will be removed in a future release. Use jax.experimental.shard_map or" + " jax.vmap with the spmd_axis_name argument for expressing SPMD" + " device-parallel computations. Please file an issue on" + " https://github.com/google/jax/issues if neither" + " jax.experimental.shard_map nor jax.vmap are suitable for your use case." +) + +deprecations.register("jax.experimental.maps", "maps-module") + +if deprecations.is_accelerated("jax.experimental.maps", "maps-module"): + raise ImportError(_msg) +else: + warnings.warn(_msg, DeprecationWarning, stacklevel=2) + +del deprecations, warnings, _msg diff --git a/pyproject.toml b/pyproject.toml index 0a5873d89e16..fdbcf155556f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ filterwarnings = [ "ignore:Special cases found for .* but none were parsed.*:UserWarning", # end array_api_tests-related warnings "ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", + "ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning", ] doctest_optionflags = [ "NUMBER", From 632d095690732302f2f1f5fa1b0c6295e6ce7590 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 8 Mar 2024 23:11:22 -0800 Subject: [PATCH 29/35] Update XLA dependency to use revision http://github.com/openxla/xla/commit/a4204a1e81f50e375778397bcdde3515b2674f43. PiperOrigin-RevId: 614150920 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 930afe0c3617..32614261e062 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -20,8 +20,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "541962e88f52237bc6050e4c8d7270e7c7e12b4e" -XLA_SHA256 = "ca67f68edad0d898241b65cbd85869de2560e2fdba700d964c707b427158583d" +XLA_COMMIT = "a4204a1e81f50e375778397bcdde3515b2674f43" +XLA_SHA256 = "6de2c29774063dd43c557985144e537da799f74b0014beed2a3545900e7bc0b0" def repo(): tf_http_archive( From 0f1cb74ad53f34ae66fec96744e1923325818a9a Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Sat, 9 Mar 2024 21:16:07 -0800 Subject: [PATCH 30/35] Prevent the XLA compiler from sharding the custom call in favour of Mosaic sharding based on user annotations. PiperOrigin-RevId: 614336455 --- jax/_src/tpu_custom_call.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index bc5a8c0f2c10..9846a5e27969 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -141,6 +141,9 @@ def to_json(self) -> bytes: if i + 1 != len(self.flags): config.write(b",") config.write(b"]") + # Prevent the compiler from sharding the custom call beyond what Mosaic does + # based on user annotations + config.write(b', "implicit_sharding": {"type": "MANUAL"}') config.write(b"}") return config.getvalue() From 63ceb5f539c45fe00766634ce7b01ea5176e0cc4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 9 Mar 2024 22:19:18 -0800 Subject: [PATCH 31/35] Update XLA dependency to use revision http://github.com/openxla/xla/commit/3684be8fa3a5e0356fcd9197ec33c26dee8029ff. PiperOrigin-RevId: 614344476 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 32614261e062..8fdac4f9bc24 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -20,8 +20,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "a4204a1e81f50e375778397bcdde3515b2674f43" -XLA_SHA256 = "6de2c29774063dd43c557985144e537da799f74b0014beed2a3545900e7bc0b0" +XLA_COMMIT = "3684be8fa3a5e0356fcd9197ec33c26dee8029ff" +XLA_SHA256 = "0599cce382ccada773dc8721891a3069581df0329b4dcdd25edbeb30e149a822" def repo(): tf_http_archive( From 477a5aa148367dd12a78ce2a5cf2ad8f278504a9 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Sun, 10 Mar 2024 10:14:21 -0700 Subject: [PATCH 32/35] set a lower pval for vmap-of-rbg based uniform statistical test PiperOrigin-RevId: 614434781 --- tests/random_lax_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index f6280a5e02ec..4dc57e475a07 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -1459,7 +1459,8 @@ def rand(key): for samples in [uncompiled_samples, compiled_samples]: self._CheckCollisions(samples, jnp.finfo(dtype).nmant) - self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf) + self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf, + pval=0.005) def test_cannot_add(self): key = self.make_key(73) From 7863508184691d5177398d6334d8f9685dc466d2 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 11 Mar 2024 06:34:24 -0700 Subject: [PATCH 33/35] Include source info as ir.Locations when lowering Pallas kernels on GPU I decided to leave out the name stacks for now for simplicity, but we might want to add them in the future. PiperOrigin-RevId: 614644216 --- jax/_src/pallas/triton/lowering.py | 17 +++++++++++++---- .../pallas/triton/pallas_call_registration.py | 4 +++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 976e51a71e82..d3721f83b06b 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -32,6 +32,7 @@ from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src import pjit +from jax._src import source_info_util from jax._src import state from jax._src import util from jax._src.interpreters import mlir @@ -44,7 +45,6 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils -from jax._src.state import AbstractRef from jax._src.state import discharge from jax._src.state import indexing from jax._src.state import primitives as sp @@ -73,6 +73,7 @@ class ModuleContext: name: str grid_mapping: GridMapping program_ids: Sequence[ir.Value] + traceback_caches: mlir.TracebackCaches = dataclasses.field(repr=False) @dataclasses.dataclass @@ -269,7 +270,9 @@ def lower_jaxpr_to_triton_module( for i, pid in enumerate(program_ids) if i not in grid_mapping.mapped_dims ] - ctx = ModuleContext(name, grid_mapping, local_program_ids) + ctx = ModuleContext( + name, grid_mapping, local_program_ids, mlir.TracebackCaches() + ) if grid_mapping.num_index_operands: raise NotImplementedError( "Scalar prefetch not supported in Triton lowering." @@ -336,9 +339,13 @@ def write_env(var: jax_core.Var, val): avals_in = [v.aval for v in eqn.invars] avals_out = [v.aval for v in eqn.outvars] eqn_block_infos = map(read_block_info_env, eqn.invars) + loc = mlir._source_info_to_location( + ctx, eqn.primitive, eqn.params, eqn.source_info + ) rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos) try: - outvals = rule(rule_ctx, *invals, **eqn.params) + with source_info_util.user_context(eqn.source_info.traceback), loc: + outvals = rule(rule_ctx, *invals, **eqn.params) except LoweringError: raise # We only add the extra info to the innermost exception. except Exception as e: @@ -2039,7 +2046,9 @@ def _for_lowering_rule( step = _i32_constant(1) init_args = map(_ensure_ir_value, args, ctx.avals_in) # Partially discharge state from jaxpr for non-pointers - should_discharge = [not isinstance(a, AbstractRef) for a in ctx.avals_in] + should_discharge = [ + not isinstance(a, state.AbstractRef) for a in ctx.avals_in + ] discharged_jaxpr, () = discharge.discharge_state( jaxpr, (), should_discharge=[True, *should_discharge] ) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index aa7b90b1bc15..cb8250d95693 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -205,7 +205,9 @@ def _pallas_call_ttir_lowering( lowering_result = lowering.lower_jaxpr_to_triton_module( jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, cuda_options ) + module_op = lowering_result.module.operation if debug: + print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True)) lowering_result.module.dump() grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid) @@ -214,7 +216,7 @@ def _pallas_call_ttir_lowering( for shape in out_shapes ] buf = io.BytesIO() - lowering_result.module.operation.write_bytecode(buf) + module_op.write_bytecode(buf) backend_config = dict( name=ir.StringAttr.get(name), ir=ir.StringAttr.get(buf.getvalue()), From de455e70030bd9452deabcc57a83c2594a85c03c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 11 Mar 2024 07:04:14 -0700 Subject: [PATCH 34/35] Fix small bug in random_test. unsafe_buffer_pointer() and on_device_size_in_bytes() are methods, not properties, so presumably the test intended to call them rather than test equality of the bound methods. PiperOrigin-RevId: 614651090 --- tests/random_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/random_test.py b/tests/random_test.py index 170076989cde..257acee544ac 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1023,8 +1023,10 @@ def test_array_impl_attributes(self): with jtu.ignore_warning(category=DeprecationWarning, message="arr.device"): self.assertEqual(key.device(), key._base_array.device()) self.assertEqual(key.devices(), key._base_array.devices()) - self.assertEqual(key.on_device_size_in_bytes, key._base_array.on_device_size_in_bytes) - self.assertEqual(key.unsafe_buffer_pointer, key._base_array.unsafe_buffer_pointer) + self.assertEqual(key.on_device_size_in_bytes(), + key._base_array.on_device_size_in_bytes()) + self.assertEqual(key.unsafe_buffer_pointer(), + key._base_array.unsafe_buffer_pointer()) self.assertArraysEqual(key.addressable_data(0)._base_array, key._base_array.addressable_data(0)) self.assertLen(key.addressable_shards, len(key._base_array.addressable_shards)) From 71ec6e33ca21535c0ebad37682d54dbac224e86d Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 11 Mar 2024 08:40:34 -0700 Subject: [PATCH 35/35] Make pl.num_programs lowering take the vmapped axes into account Otherwise the size of the wrong axis is returned. PiperOrigin-RevId: 614677218 --- jax/_src/pallas/mosaic/lowering.py | 28 ++++++++++++++++++++++------ tests/pallas/pallas_call_tpu_test.py | 21 +++++++++++++++++++++ 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 27ab8082a8fb..2a0caaa95253 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -86,7 +86,9 @@ class MeshContext: @dataclasses.dataclass class LoweringContext: ir_context: ir.Context - grid_indices: Sequence[ir.Value] | None + grid_rank: int # Includes both user and vmap axes. + mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions. + user_grid_indices: Sequence[ir.Value] | None block_shapes: list[tuple[int | pl_core.Mapped, ...]] name_stack: source_info_util.NameStack mesh_context: MeshContext | None @@ -475,6 +477,8 @@ def body_func(*args): mesh_context = None lowering_context = LoweringContext( ctx, + len(mosaic_grid_mapping.grid), + mosaic_grid_mapping.mapped_dims, None, arg_block_shapes, source_info_util.NameStack(), @@ -531,6 +535,8 @@ def body_func(*args): mesh_context = None lowering_context = LoweringContext( ctx, + len(mosaic_grid_mapping.grid), + mosaic_grid_mapping.mapped_dims, jaxpr_indices, arg_block_shapes, source_info_util.NameStack(), @@ -1846,22 +1852,32 @@ def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): - if ctx.lowering_context.grid_indices is None: + if ctx.lowering_context.user_grid_indices is None: raise ValueError( f"program id: {axis} was passed, but user did not provide a grid." ) - length = len(ctx.lowering_context.grid_indices) + length = len(ctx.lowering_context.user_grid_indices) if not (0 <= axis < length): raise ValueError( f"user passed in program id with axis: {axis}, but grid only has" f" length: {length}" ) - return ctx.lowering_context.grid_indices[axis] + return ctx.lowering_context.user_grid_indices[axis] lowering_rules[primitives.program_id_p] = _program_id_lowering_rule def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int): - del ctx - return tpu.iteration_bound(axis) + mapped_axes = set(ctx.lowering_context.mapped_dims) + seen_user_axes = 0 + for i in range(ctx.lowering_context.grid_rank): + seen_user_axes += int(i not in mapped_axes) + if seen_user_axes == axis + 1: + break + else: + raise ValueError( + f"user passed in program id with axis: {axis}, but grid only has" + f" length: {len(ctx.lowering_context.grid_rank)}" + ) + return tpu.iteration_bound(i) lowering_rules[primitives.num_programs_p] = _num_programs_lowering_rule diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index c27e6468bb4b..48aa84b3056d 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -450,6 +450,27 @@ def dynamic_kernel(steps): self.assertEqual(dynamic_kernel(4), 8) + @parameterized.parameters(range(1, 4)) + def test_vmap_num_programs(self, num_vmaps): + result_ty = jax.ShapeDtypeStruct((8, 128), jnp.int32) + + def kernel(y_ref): + y_ref[...] = jnp.full_like(y_ref, pl.num_programs(0)) + + kernel_call = self.pallas_call( + kernel, + grid=(8,), + out_specs=pl.BlockSpec(lambda i: (0, 0), result_ty.shape), + out_shape=result_ty, + ) + + out_shape = (*(2 for _ in range(num_vmaps)), *result_ty.shape) + f = kernel_call + for _ in range(num_vmaps): + f = lambda impl=f: jax.vmap(impl, axis_size=2)() + out = jax.jit(f)() + np.testing.assert_array_equal(out, np.full(out_shape, 8.0)) + def test_num_programs_block_spec(self): def kernel(x_ref, y_ref): y_ref[...] = x_ref[...]