Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement LRU cache eviction for persistent compilation cache #21394

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ pytype_strict_library(
":compilation_cache_interface",
":config",
":gfile_cache",
":lru_cache",
":monitoring",
":path",
"//jax/_src/lib",
Expand All @@ -415,6 +416,15 @@ pytype_strict_library(
],
)

pytype_strict_library(
name = "lru_cache",
srcs = ["_src/lru_cache.py"],
deps = [
":config",
":compilation_cache_interface",
] + py_deps("filelock"),
)

pytype_strict_library(
name = "config",
srcs = ["_src/config.py"],
Expand Down
16 changes: 15 additions & 1 deletion jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from jax._src.gfile_cache import GFileCache
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lru_cache import LRUCache


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -66,7 +67,20 @@ def set_once_cache_used(f) -> None:

def get_file_cache(path: str) -> tuple[CacheInterface, str] | None:
"""Returns the file cache and the path to the cache."""
return GFileCache(path), path

def is_local_filesystem(path: str) -> bool:
return path.startswith("file://") or "://" not in path

# `LRUCache` currently only supports local filesystem. Therefore, if `path`
# is not on a local filesystem, instead of using `LRUCache`, we
# fallback to the old `GFileCache`, which does not support LRU eviction.
# TODO(ayx): Add cloud storage support for `LRUCache`, so that all these code
# can be removed.
if not is_local_filesystem(path):
return GFileCache(path), path

max_size = config.compilation_cache_max_size.value
return LRUCache(path, max_size=max_size), path


def set_cache_dir(path) -> None:
Expand Down
12 changes: 12 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,18 @@ def _update_jax_memories_thread_local(val):
'2. The value of this flag set in the command line or by default.'),
)

compilation_cache_max_size = define_int_state(
name='jax_compilation_cache_max_size',
default=-1,
help=('The maximum size (in bytes) allowed for the persistent compilation '
'cache. When set, the least recently accessed cache entry(s) '
'will be deleted once the total cache directory size '
'exceeds the specified limit. '
'Caching will be disabled if this value is set to 0. A '
'special value of -1 indicates no limit, allowing the cache '
'size to grow indefinitely.'),
)

default_dtype_bits = define_enum_state(
name='jax_default_dtype_bits',
enum_values=['32', '64'],
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/gfile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from jax._src import path as pathlib
from jax._src.compilation_cache_interface import CacheInterface


# TODO (ayx): This class will be ultimately removed after `lru_cache.py` is
# finished. It exists because the current `lru_cache.py` does not support
# `gs://`.
class GFileCache(CacheInterface):

def __init__(self, path: str):
Expand Down
184 changes: 184 additions & 0 deletions jax/_src/lru_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import heapq
import logging
import pathlib
import warnings

from jax._src.compilation_cache_interface import CacheInterface


try:
import filelock
except ImportError:
filelock = None


logger = logging.getLogger(__name__)


class LRUCache(CacheInterface):
"""Bounded cache with least-recently-used (LRU) eviction policy.

This implementation includes cache reading, writing and eviction
based on the LRU policy.

Notably, when ``max_size`` is set to -1, the cache eviction
is disabled, and the LRU cache functions as a normal cache
without any size limitations.
"""

def __init__(self, path: str, *, max_size: int, lock_timeout_secs: float | None = 10):
"""Args:

path: The path to the cache directory.
max_size: The maximum size of the cache in bytes. Caching will be
disabled if this value is set to ``0``. A special value of ``-1``
indicates no limit, allowing the cache size to grow indefinitely.
lock_timeout_secs: (optional) The timeout for acquiring a file lock.
"""
# TODO(ayx): add support for cloud other filesystems such as GCS
if not self._is_local_filesystem(path):
raise NotImplementedError("LRUCache only supports local filesystem at this time.")

self.path = pathlib.Path(path)
self.path.mkdir(parents=True, exist_ok=True)

# TODO(ayx): having a `self._path` is required by the base class
# `CacheInterface`, but the base class can be removed after `LRUCache`
# and the original `GFileCache` are unified
self._path = self.path

self.eviction_enabled = max_size != -1 # no eviction if `max_size` is set to -1

if self.eviction_enabled:
if filelock is None:
raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`")

self.max_size = max_size
self.lock_timeout_secs = lock_timeout_secs

lock_path = self.path / ".lockfile"
self.lock = filelock.FileLock(lock_path)

def get(self, key: str) -> bytes | None:
"""Retrieves the cached value for the given key.

Args:
key: The key for which the cache value is retrieved.

Returns:
The cached data as bytes if available; ``None`` otherwise.
"""
if not key:
raise ValueError("key cannot be empty")

file = self.path / key
ayaka14732 marked this conversation as resolved.
Show resolved Hide resolved

if self.eviction_enabled:
self.lock.acquire(timeout=self.lock_timeout_secs)

try:
if not file.exists():
logger.debug(f"Cache miss for key: {key!r}")
return None

logger.debug(f"Cache hit for key: {key!r}")
file.touch() # update mtime
return file.read_bytes()

finally:
if self.eviction_enabled:
self.lock.release()

def put(self, key: str, val: bytes) -> None:
"""Adds a new entry to the cache.

If a cache item with the same key already exists, no action
will be taken, even if the value is different.

Args:
key: The key under which the data will be stored.
val: The data to be stored.
"""
if not key:
raise ValueError("key cannot be empty")

# prevent adding entries that exceed the maximum size limit of the cache
if self.eviction_enabled and len(val) > self.max_size:
ayaka14732 marked this conversation as resolved.
Show resolved Hide resolved
msg = (f"Cache value for key {key!r} of size {len(val)} bytes exceeds "
f"the maximum cache size of {self.max_size} bytes")
warnings.warn(msg)
return

file = self.path / key

if self.eviction_enabled:
self.lock.acquire(timeout=self.lock_timeout_secs)

try:
if file.exists():
return
ayaka14732 marked this conversation as resolved.
Show resolved Hide resolved

self._evict_if_needed(additional_size=len(val))
file.write_bytes(val)

finally:
if self.eviction_enabled:
self.lock.release()

def _evict_if_needed(self, *, additional_size: int = 0) -> None:
"""Evicts the least recently used items from the cache if necessary
to ensure the cache does not exceed its maximum size.

Args:
additional_size: The size of the new entry being added to the cache.
This is included to account for the new entry when checking if
eviction is needed.
"""
if not self.eviction_enabled:
return

# a priority queue, each element is a tuple `(file_mtime, file, file_size)`
h: list[tuple[int, pathlib.Path, int]] = []
dir_size = 0
for file in self.path.iterdir():
if file.is_file():
file_size = file.stat().st_size
file_mtime = file.stat().st_mtime_ns

dir_size += file_size
heapq.heappush(h, (file_mtime, file, file_size))

target_size = self.max_size - additional_size
# evict files until the directory size is less than or equal
# to `target_size`
while dir_size > target_size:
file_mtime, file, file_size = heapq.heappop(h)
msg = (f"Evicting cache file {file.name}: file size {file_size} bytes, "
f"target cache size {target_size} bytes")
logger.debug(msg)
file.unlink()
ayaka14732 marked this conversation as resolved.
Show resolved Hide resolved
dir_size -= file_size

# See comments in `jax.src.compilation_cache.get_file_cache()` for details.
# TODO(ayx): This function has a duplicate in that place, and there is
# redundancy here. However, this code is temporary, and once the issue
# is fixed, this code can be removed.
@staticmethod
def _is_local_filesystem(path: str) -> bool:
return path.startswith("file://") or "://" not in path
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ warn_unused_ignores = true
module = [
"absl.*",
"colorama.*",
"filelock.*",
"importlib_metadata.*",
"IPython.*",
"numpy.*",
Expand Down
10 changes: 10 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,16 @@ py_test(
],
)

py_test(
name = "lru_cache_test",
srcs = ["lru_cache_test.py"],
deps = [
"//jax",
"//jax:lru_cache",
"//jax:test_util",
] + py_deps("filelock"),
)

jax_test(
name = "compilation_cache_test",
srcs = ["compilation_cache_test.py"],
Expand Down
Loading