diff --git a/jax/_src/config.py b/jax/_src/config.py index 2afa9aea31c8..4a5209152b0f 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1236,11 +1236,11 @@ def _update_jax_memories_thread_local(val): compilation_cache_max_size = define_int_state( name='jax_compilation_cache_max_size', default=-1, - help=('The maximum size of the persistent compilation cache. ' - 'This value is used in the LRU cache eviction. The cache eviction ' - 'happens when the size of the cache directort would be larger than ' - 'the size specified here (in bytes). A special value of -1 ' - 'indicates that the size of the cache directory can be infinite.'), + help=('The maximum size (in bytes) allowed for the persistent compilation ' + 'cache. When set, this value triggers the LRU cache eviction once the ' + 'total cache directory size exceeds the specified limit. A special ' + 'value of -1 indicates no upper limit, allowing the cache size to ' + 'grow indefinitely.'), ) default_dtype_bits = define_enum_state( diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py index 327934eea742..f7ed07de862d 100644 --- a/jax/_src/lru_cache.py +++ b/jax/_src/lru_cache.py @@ -40,8 +40,8 @@ class Impl: def __init__(self, path: str, *, max_size: int, timeout_secs: float | None=10): path = pathlib.Path(path) - if filelock is None: - raise RuntimeError("Please install filelock to use the LRUCache") + if max_size != -1 and filelock is None: + raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`") self.path = path self.max_size = max_size