Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
Use pathlib API

rename timeout -> timeout_secs

Build file

Update docstring

Use heapq

style

style

style

style

Move GFileCache/LRUCache logit outside the LRUCache class

fix

Guarantee that dir_size never exceeds max_size

Add docstring

Improve docstring

Fix bug

Fix bug

Fix bugs

Don't instantiate FileLock if cache eviction is not enabled

Try solving import-not-found error from linter

Fix test error

Update docstring and comments

Update docstring

Improve docstring

Try add filelock to test dependency

Avoid random tests

More tests

Add more tests and fix bug

Improve tests

Improve tests

Try fix pytype error

Specify the behaviour of max_size=0

Improve code

Added ":lru_cache" to the BUILD file

Apply review comments

Apply review comments

Apply review comments

Apply review comments

Apply review comments

Apply review comments

Utilize `setUp()` and `tearDown()` for test cases

Merge `LRUCache` and `Impl` into one class

Minor changes

Apply review comments

Apply review comments

Try fix test error by reverting adding type annotations

Apply review comments

Co-authored-by: Sergei Lebedev <[email protected]>
  • Loading branch information
ayaka14732 and superbobry committed Jun 11, 2024
1 parent 1256ceb commit 1eb11f9
Show file tree
Hide file tree
Showing 8 changed files with 387 additions and 1 deletion.
1 change: 1 addition & 0 deletions build/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ absl-py
build
cloudpickle
colorama>=0.4.4
filelock
flatbuffers
hypothesis
mpmath>=1.3
Expand Down
10 changes: 10 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ pytype_strict_library(
":compilation_cache_interface",
":config",
":gfile_cache",
":lru_cache",
":monitoring",
":path",
"//jax/_src/lib",
Expand All @@ -415,6 +416,15 @@ pytype_strict_library(
],
)

pytype_strict_library(
name = "lru_cache",
srcs = ["_src/lru_cache.py"],
deps = [
":config",
":compilation_cache_interface",
] + py_deps("filelock"),
)

pytype_strict_library(
name = "config",
srcs = ["_src/config.py"],
Expand Down
16 changes: 15 additions & 1 deletion jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from jax._src import config
from jax._src import monitoring
from jax._src.gfile_cache import GFileCache
from jax._src.lru_cache import LRUCache
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir

Expand Down Expand Up @@ -66,7 +67,20 @@ def set_once_cache_used(f) -> None:

def get_file_cache(path: str) -> tuple[CacheInterface, str] | None:
"""Returns the file cache and the path to the cache."""
return GFileCache(path), path

def is_local_filesystem(path: str) -> bool:
return path.startswith("file://") or "://" not in path

# `LRUCache` currently only supports local filesystem. Therefore, if `path`
# is not on a local filesystem, instead of using `LRUCache`, we
# fallback to the old `GFileCache`, which does not support LRU eviction.
# TODO(ayx): Add cloud storage support for `LRUCache`, so that all these code
# can be removed.
if not is_local_filesystem(path):
return GFileCache(path), path

max_size = config.compilation_cache_max_size.value
return LRUCache(path, max_size=max_size), path


def set_cache_dir(path) -> None:
Expand Down
12 changes: 12 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,18 @@ def _update_jax_memories_thread_local(val):
'2. The value of this flag set in the command line or by default.'),
)

compilation_cache_max_size = define_int_state(
name='jax_compilation_cache_max_size',
default=-1,
help=('The maximum size (in bytes) allowed for the persistent compilation '
'cache. When set, the least recently accessed cache entry(s) '
'will be deleted once the total cache directory size '
'exceeds the specified limit. '
'Caching will be disabled if this value is set to 0. A '
'special value of -1 indicates no limit, allowing the cache '
'size to grow indefinitely.'),
)

default_dtype_bits = define_enum_state(
name='jax_default_dtype_bits',
enum_values=['32', '64'],
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/gfile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from jax._src import path as pathlib
from jax._src.compilation_cache_interface import CacheInterface


# TODO (ayx): This class will be ultimately removed after `lru_cache.py` is
# finished. It exists because the current `lru_cache.py` does not support
# `gs://`.
class GFileCache(CacheInterface):

def __init__(self, path: str):
Expand Down
184 changes: 184 additions & 0 deletions jax/_src/lru_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import heapq
import logging
import pathlib
import warnings

from jax._src.compilation_cache_interface import CacheInterface


try:
import filelock
except ImportError:
filelock = None


logger = logging.getLogger(__name__)


class LRUCache(CacheInterface):
"""Bounded cache with least-recently-used (LRU) eviction policy.
This implementation includes cache reading, writing and eviction
based on the LRU policy.
Notably, when ``max_size`` is set to -1, the cache eviction
is disabled, and the LRU cache functions as a normal cache
without any size limitations.
"""

def __init__(self, path: str, *, max_size: int, lock_timeout_secs: float | None = 10):
"""Args:
path: The path to the cache directory.
max_size: The maximum size of the cache in bytes. Caching will be
disabled if this value is set to ``0``. A special value of ``-1``
indicates no limit, allowing the cache size to grow indefinitely.
lock_timeout_secs: (optional) The timeout for acquiring a file lock.
"""
# TODO(ayx): add support for cloud other filesystems such as GCS
if not self._is_local_filesystem(path):
raise NotImplementedError("LRUCache only supports local filesystem at this time.")

self.path = pathlib.Path(path)
self.path.mkdir(parents=True, exist_ok=True)

# TODO(ayx): having a `self._path` is required by the base class
# `CacheInterface`, but the base class can be removed after `LRUCache`
# and the original `GFileCache` are unified
self._path = self.path

self.eviction_enabled = max_size != -1 # no eviction if `max_size` is set to -1

if self.eviction_enabled:
if filelock is None:
raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`")

self.max_size = max_size
self.lock_timeout_secs = lock_timeout_secs

lock_path = self.path / ".lockfile"
self.lock = filelock.FileLock(lock_path)

def get(self, key: str) -> bytes | None:
"""Retrieves the cached value for the given key.
Args:
key: The key for which the cache value is retrieved.
Returns:
The cached data as bytes if available; ``None`` otherwise.
"""
if not key:
raise ValueError("key cannot be empty")

file = self.path / key

if self.eviction_enabled:
self.lock.acquire(timeout=self.lock_timeout_secs)

try:
if not file.exists():
logger.debug(f"Cache miss for key: {key!r}")
return None

logger.debug(f"Cache hit for key: {key!r}")
file.touch() # update mtime
return file.read_bytes()

finally:
if self.eviction_enabled:
self.lock.release()

def put(self, key: str, val: bytes) -> None:
"""Adds a new entry to the cache.
If a cache item with the same key already exists, no action
will be taken, even if the value is different.
Args:
key: The key under which the data will be stored.
val: The data to be stored.
"""
if not key:
raise ValueError("key cannot be empty")

# prevent adding entries that exceed the maximum size limit of the cache
if self.eviction_enabled and len(val) > self.max_size:
msg = (f"Cache value for key {key!r} of size {len(val)} bytes exceeds "
f"the maximum cache size of {self.max_size} bytes")
warnings.warn(msg)
return

file = self.path / key

if self.eviction_enabled:
self.lock.acquire(timeout=self.lock_timeout_secs)

try:
if file.exists():
return

self._evict_if_needed(additional_size=len(val))
file.write_bytes(val)

finally:
if self.eviction_enabled:
self.lock.release()

def _evict_if_needed(self, *, additional_size: int = 0) -> None:
"""Evicts the least recently used items from the cache if necessary
to ensure the cache does not exceed its maximum size.
Args:
additional_size: The size of the new entry being added to the cache.
This is included to account for the new entry when checking if
eviction is needed.
"""
if not self.eviction_enabled:
return

# a priority queue, each element is a tuple `(file_mtime, file, file_size)`
h: list[tuple[int, pathlib.Path, int]] = []
dir_size = 0
for file in self.path.iterdir():
if file.is_file():
file_size = file.stat().st_size
file_mtime = file.stat().st_mtime_ns

dir_size += file_size
heapq.heappush(h, (file_mtime, file, file_size))

target_size = self.max_size - additional_size
# evict files until the directory size is less than or equal
# to `target_size`
while dir_size > target_size:
file_mtime, file, file_size = heapq.heappop(h)
msg = (f"Evicting cache file {file.name}: file size {file_size} bytes, "
f"target cache size {target_size} bytes")
logger.debug(msg)
file.unlink()
dir_size -= file_size

# See comments in `jax.src.compilation_cache.get_file_cache()` for details.
# TODO(ayx): This function has a duplicate in that place, and there is
# redundancy here. However, this code is temporary, and once the issue
# is fixed, this code can be removed.
@staticmethod
def _is_local_filesystem(path: str) -> bool:
return path.startswith("file://") or "://" not in path
10 changes: 10 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,16 @@ py_test(
],
)

py_test(
name = "lru_cache_test",
srcs = ["lru_cache_test.py"],
deps = [
"//jax",
"//jax:lru_cache",
"//jax:test_util",
] + py_deps("filelock"),
)

jax_test(
name = "compilation_cache_test",
srcs = ["compilation_cache_test.py"],
Expand Down
Loading

0 comments on commit 1eb11f9

Please sign in to comment.