Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation cache doc #21819

Merged
merged 6 commits into from
Jul 4, 2024
Merged

Conversation

keshavb96
Copy link
Contributor

@keshavb96 keshavb96 commented Jun 12, 2024

More details on persistent compilation caching. Still a WIP.

@ayaka14732 ayaka14732 self-assigned this Jun 12, 2024
@ayaka14732 ayaka14732 self-requested a review June 12, 2024 13:37
@ayaka14732
Copy link
Member

I will review this PR because I am recently working on this topic.

```python
import jax

jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it useful to tell that the compilation time isn't constant.
So to have a full cache, if the value isn't set to 0, you may need to run multiple time?

@keshavb96 keshavb96 marked this pull request as ready for review June 12, 2024 19:06
@ayaka14732
Copy link
Member

Besides, I am recently working on adding LRU cache eviction support for the JAX persistent compilation cache, and here is the design doc: https://docs.google.com/document/d/111YibwGXOFb_hMm-lua1u63QooAzIBEH-xfRPGmibis/edit?usp=sharing. Linked here because I think it might be useful.

@ayaka14732
Copy link
Member

@keshavb96 Can you confirm that this is ready to merge? I will review it if so.

@keshavb96
Copy link
Contributor Author

@nouiz @jaro-sevcik It looks ready to me, what do you think?

docs/persistent_compilation_cache.md Outdated Show resolved Hide resolved
docs/persistent_compilation_cache.md Outdated Show resolved Hide resolved
docs/persistent_compilation_cache.md Outdated Show resolved Hide resolved
docs/persistent_compilation_cache.md Show resolved Hide resolved
docs/persistent_compilation_cache.md Outdated Show resolved Hide resolved

### Multiple process and multiple devices (on either single or multiple nodes)

In this runtime the first time a program is run (the persistent cache is cold / empty) all processes will compile, but only the process with rank 0 in the global communication group will write to the persistent cache. In subsequent runs, all processes will attempt to read from the persistent cache, so it is important for the persistent cache to be in a shared file system (eg: NFS) or remote storage (eg: GFS). If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile in subsequent runs as a result of a compilation cache miss.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also "Multi-Host with separate disks". For example, TPU v4-32 has 4 nodes, and there is no shared storage by default. You can choose to store the compilation cache on all of the 4 nodes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include 2 sections then? One for TPUs and one for GPUs? Can't hurt to mention both options?

docs/persistent_compilation_cache.md Outdated Show resolved Hide resolved

### Examining cache misses

To merely examine and understand why there are cache misses JAX includes a configuration flag that enables the logging of all cache misses (including persistent compilation cache misses) with their explanations. This can be enabled by setting the following configuration.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flag is for all cache misses in jax, but only tracing cache miss is implemented right now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it quite helpful in debugging an issue I had, maybe it's useful to mention exactly what you mentioned, i.e. that it is indeed a flag for all cache misses wherever JAX is using a cache and that if the issue they're trying to debug is a tracing cache issue that this flag is still useful?

@ayaka14732
Copy link
Member

You should split long lines of sentences in to multiple lines, similar to other docs.

@ayaka14732
Copy link
Member

ping @keshavb96

@keshavb96
Copy link
Contributor Author

@ayaka14732 sorry! I got busy working on something else and some medical emergencies, I will push the changes by tomorrow!

@keshavb96
Copy link
Contributor Author

ping @ayaka14732

@@ -11,7 +11,7 @@ The compilation cache is enabled when the
is set. This should be done prior to the first compilation. Set the location as
follows:

```
```python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should mention that there are 3 ways to set the cache dirctory:

(1) Using environment variable

In shell, before running the script:

export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"

Or on the top of the Python script:

import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"

(2) Using jax.config.update()

jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")

(3) Using set_cache_dir()

from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("/tmp/jax_cache")

The original doc mentions the existance of the function set_cache_dir(), but does not give an example of it. This is bad.

@@ -67,8 +67,135 @@ Cloud Storage (GCS) bucket. We recommend the following configuration:
Assuming that `gs://jax-cache` is the GCS bucket, set `cache-location` as
follows:

```
```python
import jax

jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should mention that gs:// does not work automatically. It requires etils[epath] to be installed.

Should mention that etils[epath] supports many cloud providers, while GCS is one common choice.

@@ -27,7 +27,7 @@ is an alternate way of setting `cache-location`.

`cache-location` can be a directory on the local filesystem. For example:

```
```python
import jax

jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"the cache does not have an eviction mechanism implemented" is no longer true. There is indeed an LRU cache eviction mechanism for local filesystems, after #21394

Comment on lines +93 to +95
When the signature for a function created using the parameters above matches
that of a compiled function in the persistent cache the function will not be compiled,
but will just be read and deserialized from the persistent cache.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pararaph is redundant, because this is just how a cache works. Can just mention that the cache key consists of these items.


## Different Runtimes

Below we outline some observed behavior of the persistent compilation cache
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these called "observed behavior"? Aren't the behaviors defined in the code?

Comment on lines +117 to +125
### Single node with single process and single device

This is the simplest runtime and the only process that compiles and writes to the compilation cache is the singular process.
The number of devices does not matter to the cache key in this setup, only the type of device does.

### Single node with single process and multiple devices

Once again the only process that compiles and writes to the compilation cache is the singular proess.
The difference between this setup and the previous is that now the number of devices matters in addition to the type of device.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only focus on the number of nodes here. Since there is only one node, only one binary will be produced. Therefore, these two points are essentially the same. The number of devices does not matter here. It is a part of the cache key, which has already been mentioned above.

Once again the only process that compiles and writes to the compilation cache is the singular proess.
The difference between this setup and the previous is that now the number of devices matters in addition to the type of device.

### Multiple process and multiple devices (on either single or multiple nodes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a jax flag jax_share_binary_between_hosts and an issue #18819 which might be related to this topic.

It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.
While there is no singular canonical way of debugging and examining what's happening in the compilation cache,
here are a few suggestions on how to begin.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should mention that users can enable the logging of related source files by placing

import os
os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache"

on the top of the script.

## Logging cache activity

It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.
While there is no singular canonical way of debugging and examining what's happening in the compilation cache,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence sounds passive. I think that there is no "singular canonical way" of debugging anything. Should just remove this sentence.

- The HLO of the function contains a pointer to the custom_partitioning callback, and leads to different cache keys for the same computation across runs.
- In this situation, caching still proceeds, but a different key is produced every time, making the cache ineffective.

### Working around custom_partitioning
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: quote `custom_partitioning`

Copy link
Member

@ayaka14732 ayaka14732 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get this merged first!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jul 4, 2024
@copybara-service copybara-service bot merged commit 1e14157 into google:main Jul 4, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants