From 1a3a15c9e3b347a925925446361a5ea8c7aadc05 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 30 May 2024 17:59:05 +0400 Subject: [PATCH] Implement LRU cache eviction for persistent compilation cache Co-authored-by: Sergei Lebedev --- jax/BUILD | 10 ++ jax/_src/compilation_cache.py | 16 ++- jax/_src/config.py | 12 +++ jax/_src/gfile_cache.py | 4 + jax/_src/lru_cache.py | 184 ++++++++++++++++++++++++++++++++++ pyproject.toml | 1 + tests/BUILD | 10 ++ tests/lru_cache_test.py | 151 ++++++++++++++++++++++++++++ 8 files changed, 387 insertions(+), 1 deletion(-) create mode 100644 jax/_src/lru_cache.py create mode 100644 tests/lru_cache_test.py diff --git a/jax/BUILD b/jax/BUILD index 70a9eda86b32..81050790bc38 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -390,6 +390,7 @@ pytype_strict_library( ":compilation_cache_interface", ":config", ":gfile_cache", + ":lru_cache", ":monitoring", ":path", "//jax/_src/lib", @@ -415,6 +416,15 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "lru_cache", + srcs = ["_src/lru_cache.py"], + deps = [ + ":config", + ":compilation_cache_interface", + ] + py_deps("filelock"), +) + pytype_strict_library( name = "config", srcs = ["_src/config.py"], diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 7884cda7c134..9c276151741d 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -34,6 +34,7 @@ from jax._src.gfile_cache import GFileCache from jax._src.lib import xla_client from jax._src.lib.mlir import ir +from jax._src.lru_cache import LRUCache logger = logging.getLogger(__name__) @@ -66,7 +67,20 @@ def set_once_cache_used(f) -> None: def get_file_cache(path: str) -> tuple[CacheInterface, str] | None: """Returns the file cache and the path to the cache.""" - return GFileCache(path), path + + def is_local_filesystem(path: str) -> bool: + return path.startswith("file://") or "://" not in path + + # `LRUCache` currently only supports local filesystem. Therefore, if `path` + # is not on a local filesystem, instead of using `LRUCache`, we + # fallback to the old `GFileCache`, which does not support LRU eviction. + # TODO(ayx): Add cloud storage support for `LRUCache`, so that all these code + # can be removed. + if not is_local_filesystem(path): + return GFileCache(path), path + + max_size = config.compilation_cache_max_size.value + return LRUCache(path, max_size=max_size), path def set_cache_dir(path) -> None: diff --git a/jax/_src/config.py b/jax/_src/config.py index ba4c18fa0c71..949da0358aca 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1281,6 +1281,18 @@ def _update_jax_memories_thread_local(val): '2. The value of this flag set in the command line or by default.'), ) +compilation_cache_max_size = define_int_state( + name='jax_compilation_cache_max_size', + default=-1, + help=('The maximum size (in bytes) allowed for the persistent compilation ' + 'cache. When set, the least recently accessed cache entry(s) ' + 'will be deleted once the total cache directory size ' + 'exceeds the specified limit. ' + 'Caching will be disabled if this value is set to 0. A ' + 'special value of -1 indicates no limit, allowing the cache ' + 'size to grow indefinitely.'), +) + default_dtype_bits = define_enum_state( name='jax_default_dtype_bits', enum_values=['32', '64'], diff --git a/jax/_src/gfile_cache.py b/jax/_src/gfile_cache.py index 301f61cc6bdb..989844b10ddb 100644 --- a/jax/_src/gfile_cache.py +++ b/jax/_src/gfile_cache.py @@ -17,6 +17,10 @@ from jax._src import path as pathlib from jax._src.compilation_cache_interface import CacheInterface + +# TODO (ayx): This class will be ultimately removed after `lru_cache.py` is +# finished. It exists because the current `lru_cache.py` does not support +# `gs://`. class GFileCache(CacheInterface): def __init__(self, path: str): diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py new file mode 100644 index 000000000000..480d4915eceb --- /dev/null +++ b/jax/_src/lru_cache.py @@ -0,0 +1,184 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import heapq +import logging +import pathlib +import warnings + +from jax._src.compilation_cache_interface import CacheInterface + + +try: + import filelock +except ImportError: + filelock = None + + +logger = logging.getLogger(__name__) + + +class LRUCache(CacheInterface): + """Bounded cache with least-recently-used (LRU) eviction policy. + + This implementation includes cache reading, writing and eviction + based on the LRU policy. + + Notably, when ``max_size`` is set to -1, the cache eviction + is disabled, and the LRU cache functions as a normal cache + without any size limitations. + """ + + def __init__(self, path: str, *, max_size: int, lock_timeout_secs: float | None = 10): + """Args: + + path: The path to the cache directory. + max_size: The maximum size of the cache in bytes. Caching will be + disabled if this value is set to ``0``. A special value of ``-1`` + indicates no limit, allowing the cache size to grow indefinitely. + lock_timeout_secs: (optional) The timeout for acquiring a file lock. + """ + # TODO(ayx): add support for cloud other filesystems such as GCS + if not self._is_local_filesystem(path): + raise NotImplementedError("LRUCache only supports local filesystem at this time.") + + self.path = pathlib.Path(path) + self.path.mkdir(parents=True, exist_ok=True) + + # TODO(ayx): having a `self._path` is required by the base class + # `CacheInterface`, but the base class can be removed after `LRUCache` + # and the original `GFileCache` are unified + self._path = self.path + + self.eviction_enabled = max_size != -1 # no eviction if `max_size` is set to -1 + + if self.eviction_enabled: + if filelock is None: + raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`") + + self.max_size = max_size + self.lock_timeout_secs = lock_timeout_secs + + lock_path = self.path / ".lockfile" + self.lock = filelock.FileLock(lock_path) + + def get(self, key: str) -> bytes | None: + """Retrieves the cached value for the given key. + + Args: + key: The key for which the cache value is retrieved. + + Returns: + The cached data as bytes if available; ``None`` otherwise. + """ + if not key: + raise ValueError("key cannot be empty") + + file = self.path / key + + if self.eviction_enabled: + self.lock.acquire(timeout=self.lock_timeout_secs) + + try: + if not file.exists(): + logger.debug(f"Cache miss for key: {key!r}") + return None + + logger.debug(f"Cache hit for key: {key!r}") + file.touch() # update mtime + return file.read_bytes() + + finally: + if self.eviction_enabled: + self.lock.release() + + def put(self, key: str, val: bytes) -> None: + """Adds a new entry to the cache. + + If a cache item with the same key already exists, no action + will be taken, even if the value is different. + + Args: + key: The key under which the data will be stored. + val: The data to be stored. + """ + if not key: + raise ValueError("key cannot be empty") + + # prevent adding entries that exceed the maximum size limit of the cache + if self.eviction_enabled and len(val) > self.max_size: + msg = (f"Cache value for key {key!r} of size {len(val)} bytes exceeds " + f"the maximum cache size of {self.max_size} bytes") + warnings.warn(msg) + return + + file = self.path / key + + if self.eviction_enabled: + self.lock.acquire(timeout=self.lock_timeout_secs) + + try: + if file.exists(): + return + + self._evict_if_needed(additional_size=len(val)) + file.write_bytes(val) + + finally: + if self.eviction_enabled: + self.lock.release() + + def _evict_if_needed(self, *, additional_size: int = 0) -> None: + """Evicts the least recently used items from the cache if necessary + to ensure the cache does not exceed its maximum size. + + Args: + additional_size: The size of the new entry being added to the cache. + This is included to account for the new entry when checking if + eviction is needed. + """ + if not self.eviction_enabled: + return + + # a priority queue, each element is a tuple `(file_mtime, file, file_size)` + h: list[tuple[int, pathlib.Path, int]] = [] + dir_size = 0 + for file in self.path.iterdir(): + if file.is_file(): + file_size = file.stat().st_size + file_mtime = file.stat().st_mtime_ns + + dir_size += file_size + heapq.heappush(h, (file_mtime, file, file_size)) + + target_size = self.max_size - additional_size + # evict files until the directory size is less than or equal + # to `target_size` + while dir_size > target_size: + file_mtime, file, file_size = heapq.heappop(h) + msg = (f"Evicting cache file {file.name}: file size {file_size} bytes, " + f"target cache size {target_size} bytes") + logger.debug(msg) + file.unlink() + dir_size -= file_size + + # See comments in `jax.src.compilation_cache.get_file_cache()` for details. + # TODO(ayx): This function has a duplicate in that place, and there is + # redundancy here. However, this code is temporary, and once the issue + # is fixed, this code can be removed. + @staticmethod + def _is_local_filesystem(path: str) -> bool: + return path.startswith("file://") or "://" not in path diff --git a/pyproject.toml b/pyproject.toml index 630f25835a0d..8148308aa26a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ warn_unused_ignores = true module = [ "absl.*", "colorama.*", + "filelock.*", "importlib_metadata.*", "IPython.*", "numpy.*", diff --git a/tests/BUILD b/tests/BUILD index 69d6aff38009..21fe476e3ff9 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1135,6 +1135,16 @@ py_test( ], ) +py_test( + name = "lru_cache_test", + srcs = ["lru_cache_test.py"], + deps = [ + "//jax", + "//jax:lru_cache", + "//jax:test_util", + ] + py_deps("filelock"), +) + jax_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.py"], diff --git a/tests/lru_cache_test.py b/tests/lru_cache_test.py new file mode 100644 index 000000000000..483f65975845 --- /dev/null +++ b/tests/lru_cache_test.py @@ -0,0 +1,151 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import tempfile +import time + +from absl.testing import absltest + +from jax._src import path as pathlib +from jax._src.lru_cache import LRUCache +import jax._src.test_util as jtu + + +class LRUCacheTestCase(jtu.JaxTestCase): + name: str | None + path: pathlib.Path | None + + def setUp(self): + super().setUp() + tmpdir = tempfile.TemporaryDirectory() + self.enter_context(tmpdir) + self.name = tmpdir.name + self.path = pathlib.Path(self.name) + + def tearDown(self): + self.path = None + self.name = None + super().tearDown() + + +class LRUCacheTest(LRUCacheTestCase): + + def test_get_nonexistent_key(self): + cache = LRUCache(self.name, max_size=-1) + self.assertIsNone(cache.get("cache-a")) + + def test_put_and_get_key(self): + cache = LRUCache(self.name, max_size=-1) + + cache.put("cache-a", b"a") + self.assertEqual(cache.get("cache-a"), b"a") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a"}) + + cache.put("cache-b", b"b") + self.assertEqual(cache.get("cache-a"), b"a") + self.assertEqual(cache.get("cache-b"), b"b") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"}) + + def test_put_empty_value(self): + cache = LRUCache(self.name, max_size=-1) + + cache.put("cache-a", b"") + self.assertEqual(cache.get("cache-a"), b"") + + def test_put_empty_key(self): + cache = LRUCache(self.name, max_size=-1) + + with self.assertRaisesRegex(ValueError, r"key cannot be empty"): + cache.put("", b"a") + + def test_eviction(self): + cache = LRUCache(self.name, max_size=2) + + cache.put("cache-a", b"a") + cache.put("cache-b", b"b") + + # `sleep()` is necessary to guarantee that `cache-b`"s timestamp is strictly greater than `cache-a`"s + time.sleep(1) + cache.get("cache-b") + + # write `cache-c`, evict `cache-a` + cache.put("cache-c", b"c") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-c"}) + + # calling `get()` on `cache-b` makes `cache-c` least recently used + time.sleep(1) + cache.get("cache-b") + + # write `cache-d`, evict `cache-c` + cache.put("cache-d", b"d") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-d"}) + + def test_eviction_with_empty_value(self): + cache = LRUCache(self.name, max_size=1) + + cache.put("cache-a", b"a") + + # write `cache-b` with length 0 + # eviction should not happen even though the cache is full + cache.put("cache-b", b"") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"}) + + # calling `get()` on `cache-a` makes `cache-b` least recently used + time.sleep(1) + cache.get("cache-a") + + # writing `cache-c` should result in evicting the + # least recent used file (`cache-b`) first, + # but this is not sufficient to make room for `cache-c`, + # so `cache-a` should be evicted as well + cache.put("cache-c", b"c") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-c"}) + + def test_existing_cache_dir(self): + cache = LRUCache(self.name, max_size=2) + + cache.put("cache-a", b"a") + + # simulates reinitializing the cache in another process + del cache + cache = LRUCache(self.name, max_size=2) + + self.assertEqual(cache.get("cache-a"), b"a") + + # ensure that the LRU policy survives cache reinitialization + cache.put("cache-b", b"b") + + # calling `get()` on `cache-a` makes `cache-b` least recently used + time.sleep(1) + cache.get("cache-a") + + # write `cache-c`, evict `cache-b` + cache.put("cache-c", b"c") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-c"}) + + def test_max_size(self): + cache = LRUCache(self.name, max_size=1) + + msg = (r"Cache value for key .+? of size \d+ bytes exceeds the maximum " + r"cache size of \d+ bytes") + with self.assertWarnsRegex(UserWarning, msg): + cache.put("cache-a", b"aaaa") + self.assertIsNone(cache.get("cache-a")) + self.assertEqual(set(self.path.glob("cache-*")), set()) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())