-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compilation cache doc #21819
Compilation cache doc #21819
Changes from all commits
095d946
6f9da97
476cd1f
b765547
f2759d8
cf0b8fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ The compilation cache is enabled when the | |
is set. This should be done prior to the first compilation. Set the location as | ||
follows: | ||
|
||
``` | ||
```python | ||
import jax | ||
|
||
# Make sure this is called before jax runs any operations! | ||
|
@@ -27,7 +27,7 @@ is an alternate way of setting `cache-location`. | |
|
||
`cache-location` can be a directory on the local filesystem. For example: | ||
|
||
``` | ||
```python | ||
import jax | ||
|
||
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "the cache does not have an eviction mechanism implemented" is no longer true. There is indeed an LRU cache eviction mechanism for local filesystems, after #21394 |
||
|
@@ -67,8 +67,135 @@ Cloud Storage (GCS) bucket. We recommend the following configuration: | |
Assuming that `gs://jax-cache` is the GCS bucket, set `cache-location` as | ||
follows: | ||
|
||
``` | ||
```python | ||
import jax | ||
|
||
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should mention that Should mention that |
||
``` | ||
|
||
## How it works | ||
|
||
The JAX compilation cache works by hashing a number of parameters to create a signature for a compiled function. These are: | ||
|
||
* The computation performed by the function captured by the non-optimized HLO of the JAX function being hashed | ||
|
||
* The jaxlib version | ||
|
||
* Relevant XLA compilation flags | ||
|
||
* Device configuration captured in general, by the number of devices and the topology of the devices. | ||
Currently for GPUs, the topology only contains a string representation of the GPU name | ||
|
||
* Compression algorithm used to compress the compiled executable | ||
|
||
* Any custom hooks | ||
|
||
When the signature for a function created using the parameters above matches | ||
that of a compiled function in the persistent cache the function will not be compiled, | ||
but will just be read and deserialized from the persistent cache. | ||
Comment on lines
+93
to
+95
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This pararaph is redundant, because this is just how a cache works. Can just mention that the cache key consists of these items. |
||
|
||
### Caching thresholds | ||
|
||
There are two thresholds that control whether JAX caches an executable. | ||
These are `jax_persistent_cache_min_entry_size_bytes` and `jax_persistent_cache_min_compile_time_secs`. | ||
Only are at least `jax_persistent_cache_min_entry_size_bytes` large and take `jax_persistent_cache_min_compile_time_secs` | ||
long to compile will be cached. To cache every executable that is compiled, | ||
you can set the former to -1 and the latter to 0 as follows: | ||
|
||
```python | ||
import jax | ||
|
||
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) | ||
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it useful to tell that the compilation time isn't constant. |
||
``` | ||
|
||
## Different Runtimes | ||
|
||
Below we outline some observed behavior of the persistent compilation cache | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are these called "observed behavior"? Aren't the behaviors defined in the code? |
||
in a variety of different runtimes as it relates to which processes write to the cache. | ||
|
||
### Single node with single process and single device | ||
|
||
This is the simplest runtime and the only process that compiles and writes to the compilation cache is the singular process. | ||
The number of devices does not matter to the cache key in this setup, only the type of device does. | ||
|
||
### Single node with single process and multiple devices | ||
|
||
Once again the only process that compiles and writes to the compilation cache is the singular proess. | ||
The difference between this setup and the previous is that now the number of devices matters in addition to the type of device. | ||
Comment on lines
+117
to
+125
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should only focus on the number of nodes here. Since there is only one node, only one binary will be produced. Therefore, these two points are essentially the same. The number of devices does not matter here. It is a part of the cache key, which has already been mentioned above. |
||
|
||
### Multiple process and multiple devices (on either single or multiple nodes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a jax flag |
||
|
||
In this runtime the first time a program is run (the persistent cache is cold / empty) all processes will compile, | ||
but only the process with rank 0 in the global communication group will write to the persistent cache. | ||
In subsequent runs, all processes will attempt to read from the persistent cache, | ||
so it is important for the persistent cache to be in a shared file system (eg: NFS) or remote storage (eg: GFS). | ||
If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile | ||
in subsequent runs as a result of a compilation cache miss. | ||
|
||
## Logging cache activity | ||
|
||
It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging. | ||
While there is no singular canonical way of debugging and examining what's happening in the compilation cache, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sentence sounds passive. I think that there is no "singular canonical way" of debugging anything. Should just remove this sentence. |
||
here are a few suggestions on how to begin. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should mention that users can enable the logging of related source files by placing import os
os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache" on the top of the script. |
||
### Examining cache misses | ||
|
||
To merely examine and understand why there are cache misses JAX includes a configuration flag that | ||
enables the logging of all cache misses (including persistent compilation cache misses) with their explanations. | ||
Although currently, this is only implemented for tracing cache misses, the eventual goal is to | ||
explain all cache misses. This can be enabled by setting the following configuration. | ||
|
||
```python | ||
import jax | ||
|
||
jax.config.update("jax_explain_cache_misses", True) | ||
``` | ||
|
||
## Pitfalls | ||
|
||
There are a couple of pitfalls that have currently been discovered: | ||
|
||
* Currently the persistent cache doesn't work with function that have host callbacks. In this situation, caching in completely avoided. | ||
- This is because the HLO contains a pointer to the callback and changes from run to run even if the computation and compute infrastructure is exactly the same. | ||
|
||
* Currently the persistent cache doesn't work with a function that uses primitives that implement their own custom_partitioning. | ||
- The HLO of the function contains a pointer to the custom_partitioning callback, and leads to different cache keys for the same computation across runs. | ||
- In this situation, caching still proceeds, but a different key is produced every time, making the cache ineffective. | ||
|
||
### Working around custom_partitioning | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: quote `custom_partitioning` |
||
|
||
As mentioned, the compilation cache doesn't work with a function that is composed of primitives that implement custom_partitioning. However, it is possible to use shard_map to circumvent custom_partitioning for those primitives that do implement it and make the compilation cache work as expected: | ||
|
||
Let's pretend we have a function `F` that implements a layernorm followed by a matrix multiplication using a primitive `LayerNorm` that implements custom_partitioning: | ||
|
||
```python | ||
import jax | ||
|
||
def F(x1, x2, gamma, beta): | ||
ln_out = LayerNorm(x1, gamma, beta) | ||
return ln_out @ x2 | ||
``` | ||
If we were to merely compile this function without shard_map, the cache key for `layernorm_matmul_without_shard_map` would be different everytime we ran the same code: | ||
|
||
```python | ||
layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta) | ||
``` | ||
|
||
However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key for `layernorm_matmul_with_shard_map` will be the same everytime despite `LayerNorm` being implementing custom_partitioning: | ||
|
||
```python | ||
import jax | ||
from jax.experimental.shard_map import shard_map | ||
|
||
def G(x1, x2, gamma, beta, mesh, ispecs, ospecs): | ||
ln_out = shard_map(LayerNorm, mesh, in_specs=ispecs, out_specs=ospecs, check_rep=False)(x1, x2, gamma, beta) | ||
return ln_out @ x2 | ||
|
||
ispecs = jax.sharding.PartitionSpec(...) | ||
ospecs = jax.sharding.PartitionSpec(...) | ||
mesh = jax.sharding.Mesh(...) | ||
layernorm_matmul_with_shard_map = jax.jit(G, static_argnames=['mesh', 'ispecs', 'ospecs'])(x1, x2, gamma, beta, mesh, ispecs, ospecs) | ||
``` | ||
Note that the primitive that implements custom_partitioning must be wrapped in shard_map for this work around. It is insufficient to wrap the outer function `F` in shard_map. | ||
|
||
keshavb96 marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should mention that there are 3 ways to set the cache dirctory:
(1) Using environment variable
In shell, before running the script:
Or on the top of the Python script:
(2) Using
jax.config.update()
(3) Using
set_cache_dir()
The original doc mentions the existance of the function
set_cache_dir()
, but does not give an example of it. This is bad.