Skip to content

Commit

Permalink
Merge pull request #21394 from ayaka14732:lru-cache
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642333998
  • Loading branch information
jax authors committed Jun 11, 2024
2 parents 5cf52b8 + 1a3a15c commit ce4a56a
Show file tree
Hide file tree
Showing 8 changed files with 387 additions and 1 deletion.
10 changes: 10 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ pytype_strict_library(
":compilation_cache_interface",
":config",
":gfile_cache",
":lru_cache",
":monitoring",
":path",
"//jax/_src/lib",
Expand All @@ -415,6 +416,15 @@ pytype_strict_library(
],
)

pytype_strict_library(
name = "lru_cache",
srcs = ["_src/lru_cache.py"],
deps = [
":compilation_cache_interface",
":config",
] + py_deps("filelock"),
)

pytype_strict_library(
name = "config",
srcs = ["_src/config.py"],
Expand Down
16 changes: 15 additions & 1 deletion jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from jax._src.gfile_cache import GFileCache
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lru_cache import LRUCache


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -66,7 +67,20 @@ def set_once_cache_used(f) -> None:

def get_file_cache(path: str) -> tuple[CacheInterface, str] | None:
"""Returns the file cache and the path to the cache."""
return GFileCache(path), path

def is_local_filesystem(path: str) -> bool:
return path.startswith("file://") or "://" not in path

# `LRUCache` currently only supports local filesystem. Therefore, if `path`
# is not on a local filesystem, instead of using `LRUCache`, we
# fallback to the old `GFileCache`, which does not support LRU eviction.
# TODO(ayx): Add cloud storage support for `LRUCache`, so that all these code
# can be removed.
if not is_local_filesystem(path):
return GFileCache(path), path

max_size = config.compilation_cache_max_size.value
return LRUCache(path, max_size=max_size), path


def set_cache_dir(path) -> None:
Expand Down
12 changes: 12 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,18 @@ def _update_jax_memories_thread_local(val):
'2. The value of this flag set in the command line or by default.'),
)

compilation_cache_max_size = define_int_state(
name='jax_compilation_cache_max_size',
default=-1,
help=('The maximum size (in bytes) allowed for the persistent compilation '
'cache. When set, the least recently accessed cache entry(s) '
'will be deleted once the total cache directory size '
'exceeds the specified limit. '
'Caching will be disabled if this value is set to 0. A '
'special value of -1 indicates no limit, allowing the cache '
'size to grow indefinitely.'),
)

default_dtype_bits = define_enum_state(
name='jax_default_dtype_bits',
enum_values=['32', '64'],
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/gfile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from jax._src import path as pathlib
from jax._src.compilation_cache_interface import CacheInterface


# TODO (ayx): This class will be ultimately removed after `lru_cache.py` is
# finished. It exists because the current `lru_cache.py` does not support
# `gs://`.
class GFileCache(CacheInterface):

def __init__(self, path: str):
Expand Down
184 changes: 184 additions & 0 deletions jax/_src/lru_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import heapq
import logging
import pathlib
import warnings

from jax._src.compilation_cache_interface import CacheInterface


try:
import filelock
except ImportError:
filelock = None


logger = logging.getLogger(__name__)


class LRUCache(CacheInterface):
"""Bounded cache with least-recently-used (LRU) eviction policy.
This implementation includes cache reading, writing and eviction
based on the LRU policy.
Notably, when ``max_size`` is set to -1, the cache eviction
is disabled, and the LRU cache functions as a normal cache
without any size limitations.
"""

def __init__(self, path: str, *, max_size: int, lock_timeout_secs: float | None = 10):
"""Args:
path: The path to the cache directory.
max_size: The maximum size of the cache in bytes. Caching will be
disabled if this value is set to ``0``. A special value of ``-1``
indicates no limit, allowing the cache size to grow indefinitely.
lock_timeout_secs: (optional) The timeout for acquiring a file lock.
"""
# TODO(ayx): add support for cloud other filesystems such as GCS
if not self._is_local_filesystem(path):
raise NotImplementedError("LRUCache only supports local filesystem at this time.")

self.path = pathlib.Path(path)
self.path.mkdir(parents=True, exist_ok=True)

# TODO(ayx): having a `self._path` is required by the base class
# `CacheInterface`, but the base class can be removed after `LRUCache`
# and the original `GFileCache` are unified
self._path = self.path

self.eviction_enabled = max_size != -1 # no eviction if `max_size` is set to -1

if self.eviction_enabled:
if filelock is None:
raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`")

self.max_size = max_size
self.lock_timeout_secs = lock_timeout_secs

lock_path = self.path / ".lockfile"
self.lock = filelock.FileLock(lock_path)

def get(self, key: str) -> bytes | None:
"""Retrieves the cached value for the given key.
Args:
key: The key for which the cache value is retrieved.
Returns:
The cached data as bytes if available; ``None`` otherwise.
"""
if not key:
raise ValueError("key cannot be empty")

file = self.path / key

if self.eviction_enabled:
self.lock.acquire(timeout=self.lock_timeout_secs)

try:
if not file.exists():
logger.debug(f"Cache miss for key: {key!r}")
return None

logger.debug(f"Cache hit for key: {key!r}")
file.touch() # update mtime
return file.read_bytes()

finally:
if self.eviction_enabled:
self.lock.release()

def put(self, key: str, val: bytes) -> None:
"""Adds a new entry to the cache.
If a cache item with the same key already exists, no action
will be taken, even if the value is different.
Args:
key: The key under which the data will be stored.
val: The data to be stored.
"""
if not key:
raise ValueError("key cannot be empty")

# prevent adding entries that exceed the maximum size limit of the cache
if self.eviction_enabled and len(val) > self.max_size:
msg = (f"Cache value for key {key!r} of size {len(val)} bytes exceeds "
f"the maximum cache size of {self.max_size} bytes")
warnings.warn(msg)
return

file = self.path / key

if self.eviction_enabled:
self.lock.acquire(timeout=self.lock_timeout_secs)

try:
if file.exists():
return

self._evict_if_needed(additional_size=len(val))
file.write_bytes(val)

finally:
if self.eviction_enabled:
self.lock.release()

def _evict_if_needed(self, *, additional_size: int = 0) -> None:
"""Evicts the least recently used items from the cache if necessary
to ensure the cache does not exceed its maximum size.
Args:
additional_size: The size of the new entry being added to the cache.
This is included to account for the new entry when checking if
eviction is needed.
"""
if not self.eviction_enabled:
return

# a priority queue, each element is a tuple `(file_mtime, file, file_size)`
h: list[tuple[int, pathlib.Path, int]] = []
dir_size = 0
for file in self.path.iterdir():
if file.is_file():
file_size = file.stat().st_size
file_mtime = file.stat().st_mtime_ns

dir_size += file_size
heapq.heappush(h, (file_mtime, file, file_size))

target_size = self.max_size - additional_size
# evict files until the directory size is less than or equal
# to `target_size`
while dir_size > target_size:
file_mtime, file, file_size = heapq.heappop(h)
msg = (f"Evicting cache file {file.name}: file size {file_size} bytes, "
f"target cache size {target_size} bytes")
logger.debug(msg)
file.unlink()
dir_size -= file_size

# See comments in `jax.src.compilation_cache.get_file_cache()` for details.
# TODO(ayx): This function has a duplicate in that place, and there is
# redundancy here. However, this code is temporary, and once the issue
# is fixed, this code can be removed.
@staticmethod
def _is_local_filesystem(path: str) -> bool:
return path.startswith("file://") or "://" not in path
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ warn_unused_ignores = true
module = [
"absl.*",
"colorama.*",
"filelock.*",
"importlib_metadata.*",
"IPython.*",
"numpy.*",
Expand Down
10 changes: 10 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,16 @@ py_test(
],
)

py_test(
name = "lru_cache_test",
srcs = ["lru_cache_test.py"],
deps = [
"//jax",
"//jax:lru_cache",
"//jax:test_util",
] + py_deps("filelock"),
)

jax_test(
name = "compilation_cache_test",
srcs = ["compilation_cache_test.py"],
Expand Down
Loading

0 comments on commit ce4a56a

Please sign in to comment.