Skip to content

Commit

Permalink
Merge branch 'google:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
coreyjadams committed Mar 11, 2024
2 parents 900a037 + 71ec6e3 commit 4f22d86
Show file tree
Hide file tree
Showing 78 changed files with 715 additions and 311 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
* The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are
deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array now results in an exception.

Expand Down
4 changes: 2 additions & 2 deletions docs/Custom_Operation_for_GPUs.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
jax.random.PRNGKey(0),
jax.random.key(0),
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
dtype=jnp.bfloat16,
)
Expand Down Expand Up @@ -1049,7 +1049,7 @@ per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
jax.random.PRNGKey(0),
jax.random.key(0),
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
dtype=jnp.bfloat16,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/async_dispatch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ program:
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.PRNGKey(0), (1000, 1000))
>>> x = random.uniform(random.key(0), (1000, 1000))
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
>>> # will block until the value is ready.
>>> jnp.dot(x, x) + 3. # doctest: +SKIP
Expand Down
6 changes: 3 additions & 3 deletions docs/device_memory_profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def func2(x):
y = func1(x)
return y, jnp.tile(x, 10) + 1

x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)

z.block_until_ready()
Expand Down Expand Up @@ -107,14 +107,14 @@ import jax.numpy as jnp
import jax.profiler

def afunction():
return jax.random.normal(jax.random.PRNGKey(77), (1000000,))
return jax.random.normal(jax.random.key(77), (1000000,))

z = afunction()

def anotherfunc():
arrays = []
for i in range(1, 10):
x = jax.random.normal(jax.random.PRNGKey(42), (i, 10000))
x = jax.random.normal(jax.random.key(42), (i, 10000))
arrays.append(x)
x.block_until_ready()
jax.profiler.save_device_memory_profile(f"memory{i}.prof")
Expand Down
1 change: 1 addition & 0 deletions docs/errors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ along with representative examples of how one might fix them.

.. currentmodule:: jax.errors
.. autoclass:: ConcretizationTypeError
.. autoclass:: KeyReuseError
.. autoclass:: NonConcreteBooleanIndexError
.. autoclass:: TracerArrayConversionError
.. autoclass:: TracerBoolConversionError
Expand Down
10 changes: 5 additions & 5 deletions docs/jax-101/05-random-numbers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@
"source": [
"from jax import random\n",
"\n",
"key = random.PRNGKey(42)\n",
"key = random.key(42)\n",
"\n",
"print(key)"
]
Expand All @@ -293,7 +293,7 @@
"id": "XhFpKnW9F2nF"
},
"source": [
"A key is just an array of shape `(2,)`.\n",
"A single key is an array of scalar shape `()` and key element type.\n",
"\n",
"'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:"
]
Expand Down Expand Up @@ -381,7 +381,7 @@
"source": [
"`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.\n",
"\n",
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
"\n",
"It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.\n",
"\n",
Expand Down Expand Up @@ -460,12 +460,12 @@
}
],
"source": [
"key = random.PRNGKey(42)\n",
"key = random.key(42)\n",
"subkeys = random.split(key, 3)\n",
"sequence = np.stack([random.normal(subkey) for subkey in subkeys])\n",
"print(\"individually:\", sequence)\n",
"\n",
"key = random.PRNGKey(42)\n",
"key = random.key(42)\n",
"print(\"all at once: \", random.normal(key, shape=(3,)))"
]
},
Expand Down
10 changes: 5 additions & 5 deletions docs/jax-101/05-random-numbers.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ To avoid this issue, JAX does not use a global state. Instead, random functions
from jax import random
key = random.PRNGKey(42)
key = random.key(42)
print(key)
```

+++ {"id": "XhFpKnW9F2nF"}

A key is just an array of shape `(2,)`.
A single key is an array of scalar shape `()` and key element type.

'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:

Expand Down Expand Up @@ -201,7 +201,7 @@ key = new_key # If we wanted to do this again, we would use new_key as the key.

`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.

If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.
If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.

It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.

Expand Down Expand Up @@ -240,12 +240,12 @@ In the example below, sampling 3 values out of a normal distribution individuall
:id: 4nB_TA54D-HT
:outputId: 2f259f63-3c45-46c8-f597-4e53dc63cb56
key = random.PRNGKey(42)
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.PRNGKey(42)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
```

Expand Down
2 changes: 1 addition & 1 deletion docs/jax-101/06-parallelism.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@
"ys = xs * true_w + true_b + noise\n",
"\n",
"# Initialise parameters and replicate across devices.\n",
"params = init(jax.random.PRNGKey(123))\n",
"params = init(jax.random.key(123))\n",
"n_devices = jax.local_device_count()\n",
"replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)"
]
Expand Down
2 changes: 1 addition & 1 deletion docs/jax-101/06-parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise
# Initialise parameters and replicate across devices.
params = init(jax.random.PRNGKey(123))
params = init(jax.random.key(123))
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)
```
Expand Down
4 changes: 2 additions & 2 deletions docs/jax-101/07-state.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@
"\n",
"In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?\n",
"\n",
"Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey."
"Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key."
]
},
{
Expand Down Expand Up @@ -351,7 +351,7 @@
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"rng = jax.random.PRNGKey(42)\n",
"rng = jax.random.key(42)\n",
"\n",
"# Generate true data from y = w*x + b + noise\n",
"true_w, true_b = 2, -1\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/jax-101/07-state.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ Notice that the need for a class becomes less clear once we have rewritten it th

In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?

Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey.
Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key.

+++ {"id": "I2SqRx14_z98"}

Expand Down Expand Up @@ -233,7 +233,7 @@ Notice that we manually pipe the params in and out of the update function.
import matplotlib.pyplot as plt
rng = jax.random.PRNGKey(42)
rng = jax.random.key(42)
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
Expand Down
9 changes: 0 additions & 9 deletions docs/jax.experimental.key_reuse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,3 @@
=====================================

.. automodule:: jax.experimental.key_reuse

API
---

.. autosummary::
:toctree: _autosummary

reuse_key
KeyReuseError
4 changes: 2 additions & 2 deletions docs/jax.nn.initializers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ consistent with definitions used in Keras and Sonnet.

An initializer is a function that takes three arguments:
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random
key used when generating random numbers to initialize the array.
data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from
:func:`jax.random.key`), used to generate random numbers to initialize the array.

.. autosummary::
:toctree: _autosummary
Expand Down
1 change: 1 addition & 0 deletions docs/jax.random.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Key Creation & Manipulation
wrap_key_data
fold_in
split
clone

Random Samplers
~~~~~~~~~~~~~~~
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/Common_Gotchas_in_JAX.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@
"source": [
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"\n",
"The random state is described by two unsigned-int32s that we call a __key__:"
"The random state is described by a special array element that we call a __key__:"
]
},
{
Expand All @@ -1030,7 +1030,7 @@
],
"source": [
"from jax import random\n",
"key = random.PRNGKey(0)\n",
"key = random.key(0)\n",
"key"
]
},
Expand Down Expand Up @@ -2121,7 +2121,7 @@
}
],
"source": [
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
"x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n",
"x.dtype"
]
},
Expand Down Expand Up @@ -2188,7 +2188,7 @@
"source": [
"import jax.numpy as jnp\n",
"from jax import random\n",
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
"x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n",
"x.dtype # --> dtype('float64')"
]
},
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/Common_Gotchas_in_JAX.md
Original file line number Diff line number Diff line change
Expand Up @@ -463,14 +463,14 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha

JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.

The random state is described by two unsigned-int32s that we call a __key__:
The random state is described by a special array element that we call a __key__:

```{code-cell} ipython3
:id: yPHE7KTWgAWs
:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
from jax import random
key = random.PRNGKey(0)
key = random.key(0)
key
```

Expand Down Expand Up @@ -1071,7 +1071,7 @@ At the moment, JAX by default enforces single-precision numbers to mitigate the
:id: CNNGtzM3NDkO
:outputId: b422bb23-a784-44dc-f8c9-57f3b6c861b8
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype
```

Expand Down Expand Up @@ -1117,7 +1117,7 @@ We can then confirm that `x64` mode is enabled:
import jax.numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
],
"source": [
"# Create an array of random values:\n",
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n",
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
"# and use jax.device_put to distribute it across devices:\n",
"y = jax.device_put(x, sharding.reshape(4, 2))\n",
"jax.debug.visualize_array_sharding(y)"
Expand Down Expand Up @@ -272,7 +272,7 @@
"outputs": [],
"source": [
"import jax\n",
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))"
"x = jax.random.normal(jax.random.key(0), (8192, 8192))"
]
},
{
Expand Down Expand Up @@ -1513,7 +1513,7 @@
},
"outputs": [],
"source": [
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n",
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
"x = jax.device_put(x, sharding.reshape(4, 2))"
]
},
Expand Down Expand Up @@ -1738,7 +1738,7 @@
"layer_sizes = [784, 8192, 8192, 8192, 10]\n",
"batch_size = 8192\n",
"\n",
"params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)"
"params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)"
]
},
{
Expand Down Expand Up @@ -2184,7 +2184,7 @@
" numbers = jax.random.uniform(key, x.shape)\n",
" return x + numbers\n",
"\n",
"key = jax.random.PRNGKey(42)\n",
"key = jax.random.key(42)\n",
"x_sharding = jax.sharding.PositionalSharding(jax.devices())\n",
"x = jax.device_put(jnp.arange(24), x_sharding)"
]
Expand Down
Loading

0 comments on commit 4f22d86

Please sign in to comment.