Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement LRU cache eviction for persistent compilation cache #21394

Merged
merged 1 commit into from
Jun 11, 2024

Conversation

ayaka14732
Copy link
Member

@ayaka14732 ayaka14732 commented May 23, 2024

This PR is part of the implementation of LRU cache eviction using the mtime attribute provided by the filesystem. The current PR does not support GCS, but this problem will be solved in a subsequent PR.

More details in the design doc: https://docs.google.com/document/d/111YibwGXOFb_hMm-lua1u63QooAzIBEH-xfRPGmibis/edit?usp=sharing

@ayaka14732 ayaka14732 self-assigned this May 23, 2024
@ayaka14732 ayaka14732 force-pushed the lru-cache branch 5 times, most recently from 906ee17 to 0de3c9c Compare May 30, 2024 11:24
@ayaka14732 ayaka14732 marked this pull request as ready for review May 30, 2024 11:25
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
tests/lru_cache_test.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
tests/lru_cache_test.py Outdated Show resolved Hide resolved
tests/lru_cache_test.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
@ayaka14732 ayaka14732 force-pushed the lru-cache branch 2 times, most recently from e03c0d5 to c7999c8 Compare May 30, 2024 13:37
@ayaka14732 ayaka14732 requested review from hawkinsp and skye May 30, 2024 13:38
@ayaka14732 ayaka14732 force-pushed the lru-cache branch 2 times, most recently from 758c287 to f28de67 Compare May 30, 2024 13:59
Copy link
Collaborator

@nouiz nouiz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you link to the design doc?
Also, would be good to have it documented somewhere?
Like in the https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html file?

Note, I saw some NFS server being configured to not update mtime to speed up the server. Maybe document that this can happen and in that case, this will revert to creation time?
The first time I saw the behavior without knowing the reason, it took times to understand what was going on.

jax/_src/config.py Outdated Show resolved Hide resolved
@hawkinsp
Copy link
Member

Note, I saw some NFS server being configured to not update mtime to speed up the server. Maybe document that this can happen and in that case, this will revert to creation time? The first time I saw the behavior without knowing the reason, it took times to understand what was going on.

Yes, I suspect there's a chance you might see stale mtime values if you stick the cache on NFS and you're accessing it concurrently from multiple clients (see lookupcache in the NFS docs). I'm not sure there's a lot we can do about that, though.

Copy link
Collaborator

@skye skye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall nice work!

jax/_src/compilation_cache_interface.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
tests/lru_cache_test.py Outdated Show resolved Hide resolved
tests/lru_cache_test.py Show resolved Hide resolved
tests/lru_cache_test.py Outdated Show resolved Hide resolved
@skye
Copy link
Collaborator

skye commented May 31, 2024

This is a first cut at the LRU eviction implementation, so it isn't expected to work well with network file systems yet (notably GCS, which many Cloud TPU users use for their cache storage). We'll iterate from here. I don't think we should publicly document this until it works well across filesystems, but absolutely agree this should eventually be in https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html.

jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
Copy link
Member

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, but please address the comment in tests.

jax/_src/compilation_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
@ayaka14732
Copy link
Member Author

Test fails because the test utilises filelock, which is not installed.

=================================== FAILURES ===================================
_______________________ LRUCacheTest.test_cache_eviction _______________________
[gw5] linux -- Python 3.12.3 /opt/hostedtoolcache/Python/3.12.3/x64/bin/python
tests/lru_cache_test.py:39: in test_cache_eviction
    cache = Impl(path, max_size=884700)
jax/_src/lru_cache.py:57: in __init__
    raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`")
E   RuntimeError: Please install filelock package to set `jax_compilation_cache_max_size`

@superbobry
Copy link
Member

superbobry commented Jun 7, 2024

Add filelock to build/test-requirements.txt and to the deps in tests/BUILD.

Or skip the test for now if filelock is not importable.

@ayaka14732
Copy link
Member Author

Just realised that JAX had a FileSystemCache that supports LRU cache eviction introduced in #6869, but was subsequently removed in #10771 to support GCS. This is exactly one of the challenges that I faced in this PR. Fortunately, I've devised potential solutions to simultaneously support LRU cache eviction and GCS compatibility. This is going to be completely solved in a subsequent PR.

@ayaka14732
Copy link
Member Author

All comments resolved

jax/BUILD Outdated
":monitoring",
":path",
"//jax/_src/lib",
"//third_party/py/filelock",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add that to lru_cache deps?

You will want to use the py_deps macros for this: py_deps("filelock"). We have a list of deps in jax.bzl.

Copy link
Member Author

@ayaka14732 ayaka14732 Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where can I find the lru_cache deps?

copybara-service bot pushed a commit that referenced this pull request Jun 7, 2024
This should unblock #21394, which uses filelock in the compilation cache.

PiperOrigin-RevId: 641310140
copybara-service bot pushed a commit that referenced this pull request Jun 7, 2024
This should unblock #21394, which uses filelock in the compilation cache.

PiperOrigin-RevId: 641310140
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jun 11, 2024
@ayaka14732 ayaka14732 force-pushed the lru-cache branch 2 times, most recently from 079b4a6 to f46d41f Compare June 11, 2024 12:21
@gnecula
Copy link
Collaborator

gnecula commented Jun 11, 2024

Thank you for preparing this. Please squash the long chain of commits, or at least most of them.

Copy link
Member

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please squash the commits.

tests/lru_cache_test.py Outdated Show resolved Hide resolved
tests/lru_cache_test.py Outdated Show resolved Hide resolved
tests/lru_cache_test.py Outdated Show resolved Hide resolved
jax/_src/lru_cache.py Outdated Show resolved Hide resolved
jax/_src/compilation_cache.py Outdated Show resolved Hide resolved
@copybara-service copybara-service bot merged commit ce4a56a into google:main Jun 11, 2024
13 of 14 checks passed
@ayaka14732 ayaka14732 mentioned this pull request Jun 12, 2024
3 tasks
@ayaka14732 ayaka14732 deleted the lru-cache branch June 12, 2024 12:01
@ayaka14732
Copy link
Member Author

ayaka14732 commented Jun 12, 2024

Can you link to the design doc?

@nouiz I've just added the link to the first comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants