Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation cache doc #21819

Merged
merged 6 commits into from
Jul 4, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions docs/persistent_compilation_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The compilation cache is enabled when the
is set. This should be done prior to the first compilation. Set the location as
follows:

```
```python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should mention that there are 3 ways to set the cache dirctory:

(1) Using environment variable

In shell, before running the script:

export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"

Or on the top of the Python script:

import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"

(2) Using jax.config.update()

jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")

(3) Using set_cache_dir()

from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("/tmp/jax_cache")

The original doc mentions the existance of the function set_cache_dir(), but does not give an example of it. This is bad.

import jax

# Make sure this is called before jax runs any operations!
Expand All @@ -27,7 +27,7 @@ is an alternate way of setting `cache-location`.

`cache-location` can be a directory on the local filesystem. For example:

```
```python
import jax

jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"the cache does not have an eviction mechanism implemented" is no longer true. There is indeed an LRU cache eviction mechanism for local filesystems, after #21394

Expand Down Expand Up @@ -67,8 +67,74 @@ Cloud Storage (GCS) bucket. We recommend the following configuration:
Assuming that `gs://jax-cache` is the GCS bucket, set `cache-location` as
follows:

```
```python
import jax

jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should mention that gs:// does not work automatically. It requires etils[epath] to be installed.

Should mention that etils[epath] supports many cloud providers, while GCS is one common choice.

```

## How it works

The JAX compilation cache works by hashing a number of parameters to create a signature for a compiled function these are:
keshavb96 marked this conversation as resolved.
Show resolved Hide resolved

* The computation performed by the function captured by the non-optimized HLO of the JAX function being hashed
* The Jaxlib version
keshavb96 marked this conversation as resolved.
Show resolved Hide resolved
* Relevant XLA compilation flags
* Device configuration captured in general, by the number of devices and the topology of the devices. Currently for GPUs, the topology only contains a string representation of the GPU name
* Compression algorithm used to compress the compiled executable
* Any custom hooks

When the signature for a function created using the parameters above matches that of a compiled function in the persistent cache, the function will not be compiled but just read and deserialized from the persistent cache.
keshavb96 marked this conversation as resolved.
Show resolved Hide resolved

### Compile time dependent caching
keshavb96 marked this conversation as resolved.
Show resolved Hide resolved

JAX only caches executables that take a certain amount of time to compile. This threshold is controlled by the `jax_persistent_cache_min_compile_time_secs` configuration option. To cache every executable that is compiled regardless of compile time, set this value to zero:

```python
import jax

jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it useful to tell that the compilation time isn't constant.
So to have a full cache, if the value isn't set to 0, you may need to run multiple time?

```

## Different Runtimes

Below we outline some observed behavior of the persistent compilation cache in a variety of different runtimes as it relates to which processes write to the cache.

### Single node with single process and single device

This is the simplest runtime and the only process that compiles and writes to the compilation cache is the singular process. The number of devices does not matter to the cache key in this setup, only the type of device does.

### Single node with single process and multiple devices

Once again the only process that compiles and writes to the compilation cache is the singular proess. The difference between this setup and the previous is that now the number of devices matters in addition to the type of device.

### Multiple process and multiple devices (on either single or multiple nodes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a jax flag jax_share_binary_between_hosts and an issue #18819 which might be related to this topic.


In this runtime the first time a program is run (the persistent cache is cold / empty) all processes will compile, but only the process with rank 0 in the global communication group will write to the persistent cache. In subsequent runs, all processes will attempt to read from the persistent cache, so it is important for the persistent cache to be in a shared file system (eg: NFS) or remote storage (eg: GFS). If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile in subsequent runs as a result of a compilation cache miss.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also "Multi-Host with separate disks". For example, TPU v4-32 has 4 nodes, and there is no shared storage by default. You can choose to store the compilation cache on all of the 4 nodes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include 2 sections then? One for TPUs and one for GPUs? Can't hurt to mention both options?


## Logging cache activity

It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging. While there is no singular canonical way of debugging and examining what's happening in the compilation cache, here are a few suggestions on how to begin.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should mention that users can enable the logging of related source files by placing

import os
os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache"

on the top of the script.

### Examining cache misses

To merely examine and understand why there are cache misses JAX includes a configuration flag that enables the logging of all cache misses (including persistent compilation cache misses) with their explanations. This can be enabled by setting the following configuration.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flag is for all cache misses in jax, but only tracing cache miss is implemented right now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it quite helpful in debugging an issue I had, maybe it's useful to mention exactly what you mentioned, i.e. that it is indeed a flag for all cache misses wherever JAX is using a cache and that if the issue they're trying to debug is a tracing cache issue that this flag is still useful?


```python
import jax

jax.config.update("jax_explain_cache_misses", True)
```

## Pitfalls

There are a couple of pitfalls that have currently been discovered:

* Currently the persistent cache doesn't work with function that have host callbacks. In this situation, caching in completely avoided.
- This is because the HLO contains a pointer to the callback and changes from run to run even if the computation and compute infrastructure is exactly the same.

* Currently the persistent cache doesn't work with a function that uses primitives that implement their own custom_partitioning.
- The HLO of the function contains a pointer to the custom_partitioning callback, and leads to different cache keys for the same computation across runs.
- In this situation, caching still proceeds, but a different key is produced every time, making the cache ineffective.


keshavb96 marked this conversation as resolved.
Show resolved Hide resolved