Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation cache doc #21819

Merged
merged 6 commits into from
Jul 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions docs/persistent_compilation_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,23 @@ The JAX compilation cache works by hashing a number of parameters to create a si
* Relevant XLA compilation flags
* Device configuration captured in general, by the number of devices and the topology of the devices. Currently for GPUs, the topology only contains a string representation of the GPU name
* Compression algorithm used to compress the compiled executable
* Any custom hooks
* Any custom hooks

When the signature for a function created using the parameters above matches that of a compiled function in the persistent cache, the function will not be compiled but just read and deserialized from the persistent cache. Below we outline some observed behavior of the persistent compilation cache in a variety of different runtimes as it relates to which processes write to the cache.
When the signature for a function created using the parameters above matches that of a compiled function in the persistent cache, the function will not be compiled but just read and deserialized from the persistent cache.
keshavb96 marked this conversation as resolved.
Show resolved Hide resolved

### Compile time dependent caching
keshavb96 marked this conversation as resolved.
Show resolved Hide resolved

JAX only caches executables that take a certain amount of time to compile. This threshold is controlled by the `jax_persistent_cache_min_compile_time_secs` configuration option. To cache every executable that is compiled regardless of compile time, set this value to zero:

```python
import jax

jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it useful to tell that the compilation time isn't constant.
So to have a full cache, if the value isn't set to 0, you may need to run multiple time?

```

## Different Runtimes

Below we outline some observed behavior of the persistent compilation cache in a variety of different runtimes as it relates to which processes write to the cache.

### Single node with single process and single device

Expand Down