From 1eb11f919ac6a593c305c969c2ba23c6fedcfa0c Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 30 May 2024 17:59:05 +0400 Subject: [PATCH] Initial commit Use pathlib API rename timeout -> timeout_secs Build file Update docstring Use heapq style style style style Move GFileCache/LRUCache logit outside the LRUCache class fix Guarantee that dir_size never exceeds max_size Add docstring Improve docstring Fix bug Fix bug Fix bugs Don't instantiate FileLock if cache eviction is not enabled Try solving import-not-found error from linter Fix test error Update docstring and comments Update docstring Improve docstring Try add filelock to test dependency Avoid random tests More tests Add more tests and fix bug Improve tests Improve tests Try fix pytype error Specify the behaviour of max_size=0 Improve code Added ":lru_cache" to the BUILD file Apply review comments Apply review comments Apply review comments Apply review comments Apply review comments Apply review comments Utilize `setUp()` and `tearDown()` for test cases Merge `LRUCache` and `Impl` into one class Minor changes Apply review comments Apply review comments Try fix test error by reverting adding type annotations Apply review comments Co-authored-by: Sergei Lebedev --- build/test-requirements.txt | 1 + jax/BUILD | 10 ++ jax/_src/compilation_cache.py | 16 ++- jax/_src/config.py | 12 +++ jax/_src/gfile_cache.py | 4 + jax/_src/lru_cache.py | 184 ++++++++++++++++++++++++++++++++++ tests/BUILD | 10 ++ tests/lru_cache_test.py | 151 ++++++++++++++++++++++++++++ 8 files changed, 387 insertions(+), 1 deletion(-) create mode 100644 jax/_src/lru_cache.py create mode 100644 tests/lru_cache_test.py diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 800bc735daf7..4f9d19e76ba2 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -2,6 +2,7 @@ absl-py build cloudpickle colorama>=0.4.4 +filelock flatbuffers hypothesis mpmath>=1.3 diff --git a/jax/BUILD b/jax/BUILD index 70a9eda86b32..81050790bc38 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -390,6 +390,7 @@ pytype_strict_library( ":compilation_cache_interface", ":config", ":gfile_cache", + ":lru_cache", ":monitoring", ":path", "//jax/_src/lib", @@ -415,6 +416,15 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "lru_cache", + srcs = ["_src/lru_cache.py"], + deps = [ + ":config", + ":compilation_cache_interface", + ] + py_deps("filelock"), +) + pytype_strict_library( name = "config", srcs = ["_src/config.py"], diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 7884cda7c134..a4c31f547a9f 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -32,6 +32,7 @@ from jax._src import config from jax._src import monitoring from jax._src.gfile_cache import GFileCache +from jax._src.lru_cache import LRUCache from jax._src.lib import xla_client from jax._src.lib.mlir import ir @@ -66,7 +67,20 @@ def set_once_cache_used(f) -> None: def get_file_cache(path: str) -> tuple[CacheInterface, str] | None: """Returns the file cache and the path to the cache.""" - return GFileCache(path), path + + def is_local_filesystem(path: str) -> bool: + return path.startswith("file://") or "://" not in path + + # `LRUCache` currently only supports local filesystem. Therefore, if `path` + # is not on a local filesystem, instead of using `LRUCache`, we + # fallback to the old `GFileCache`, which does not support LRU eviction. + # TODO(ayx): Add cloud storage support for `LRUCache`, so that all these code + # can be removed. + if not is_local_filesystem(path): + return GFileCache(path), path + + max_size = config.compilation_cache_max_size.value + return LRUCache(path, max_size=max_size), path def set_cache_dir(path) -> None: diff --git a/jax/_src/config.py b/jax/_src/config.py index ba4c18fa0c71..949da0358aca 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1281,6 +1281,18 @@ def _update_jax_memories_thread_local(val): '2. The value of this flag set in the command line or by default.'), ) +compilation_cache_max_size = define_int_state( + name='jax_compilation_cache_max_size', + default=-1, + help=('The maximum size (in bytes) allowed for the persistent compilation ' + 'cache. When set, the least recently accessed cache entry(s) ' + 'will be deleted once the total cache directory size ' + 'exceeds the specified limit. ' + 'Caching will be disabled if this value is set to 0. A ' + 'special value of -1 indicates no limit, allowing the cache ' + 'size to grow indefinitely.'), +) + default_dtype_bits = define_enum_state( name='jax_default_dtype_bits', enum_values=['32', '64'], diff --git a/jax/_src/gfile_cache.py b/jax/_src/gfile_cache.py index 301f61cc6bdb..989844b10ddb 100644 --- a/jax/_src/gfile_cache.py +++ b/jax/_src/gfile_cache.py @@ -17,6 +17,10 @@ from jax._src import path as pathlib from jax._src.compilation_cache_interface import CacheInterface + +# TODO (ayx): This class will be ultimately removed after `lru_cache.py` is +# finished. It exists because the current `lru_cache.py` does not support +# `gs://`. class GFileCache(CacheInterface): def __init__(self, path: str): diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py new file mode 100644 index 000000000000..480d4915eceb --- /dev/null +++ b/jax/_src/lru_cache.py @@ -0,0 +1,184 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import heapq +import logging +import pathlib +import warnings + +from jax._src.compilation_cache_interface import CacheInterface + + +try: + import filelock +except ImportError: + filelock = None + + +logger = logging.getLogger(__name__) + + +class LRUCache(CacheInterface): + """Bounded cache with least-recently-used (LRU) eviction policy. + + This implementation includes cache reading, writing and eviction + based on the LRU policy. + + Notably, when ``max_size`` is set to -1, the cache eviction + is disabled, and the LRU cache functions as a normal cache + without any size limitations. + """ + + def __init__(self, path: str, *, max_size: int, lock_timeout_secs: float | None = 10): + """Args: + + path: The path to the cache directory. + max_size: The maximum size of the cache in bytes. Caching will be + disabled if this value is set to ``0``. A special value of ``-1`` + indicates no limit, allowing the cache size to grow indefinitely. + lock_timeout_secs: (optional) The timeout for acquiring a file lock. + """ + # TODO(ayx): add support for cloud other filesystems such as GCS + if not self._is_local_filesystem(path): + raise NotImplementedError("LRUCache only supports local filesystem at this time.") + + self.path = pathlib.Path(path) + self.path.mkdir(parents=True, exist_ok=True) + + # TODO(ayx): having a `self._path` is required by the base class + # `CacheInterface`, but the base class can be removed after `LRUCache` + # and the original `GFileCache` are unified + self._path = self.path + + self.eviction_enabled = max_size != -1 # no eviction if `max_size` is set to -1 + + if self.eviction_enabled: + if filelock is None: + raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`") + + self.max_size = max_size + self.lock_timeout_secs = lock_timeout_secs + + lock_path = self.path / ".lockfile" + self.lock = filelock.FileLock(lock_path) + + def get(self, key: str) -> bytes | None: + """Retrieves the cached value for the given key. + + Args: + key: The key for which the cache value is retrieved. + + Returns: + The cached data as bytes if available; ``None`` otherwise. + """ + if not key: + raise ValueError("key cannot be empty") + + file = self.path / key + + if self.eviction_enabled: + self.lock.acquire(timeout=self.lock_timeout_secs) + + try: + if not file.exists(): + logger.debug(f"Cache miss for key: {key!r}") + return None + + logger.debug(f"Cache hit for key: {key!r}") + file.touch() # update mtime + return file.read_bytes() + + finally: + if self.eviction_enabled: + self.lock.release() + + def put(self, key: str, val: bytes) -> None: + """Adds a new entry to the cache. + + If a cache item with the same key already exists, no action + will be taken, even if the value is different. + + Args: + key: The key under which the data will be stored. + val: The data to be stored. + """ + if not key: + raise ValueError("key cannot be empty") + + # prevent adding entries that exceed the maximum size limit of the cache + if self.eviction_enabled and len(val) > self.max_size: + msg = (f"Cache value for key {key!r} of size {len(val)} bytes exceeds " + f"the maximum cache size of {self.max_size} bytes") + warnings.warn(msg) + return + + file = self.path / key + + if self.eviction_enabled: + self.lock.acquire(timeout=self.lock_timeout_secs) + + try: + if file.exists(): + return + + self._evict_if_needed(additional_size=len(val)) + file.write_bytes(val) + + finally: + if self.eviction_enabled: + self.lock.release() + + def _evict_if_needed(self, *, additional_size: int = 0) -> None: + """Evicts the least recently used items from the cache if necessary + to ensure the cache does not exceed its maximum size. + + Args: + additional_size: The size of the new entry being added to the cache. + This is included to account for the new entry when checking if + eviction is needed. + """ + if not self.eviction_enabled: + return + + # a priority queue, each element is a tuple `(file_mtime, file, file_size)` + h: list[tuple[int, pathlib.Path, int]] = [] + dir_size = 0 + for file in self.path.iterdir(): + if file.is_file(): + file_size = file.stat().st_size + file_mtime = file.stat().st_mtime_ns + + dir_size += file_size + heapq.heappush(h, (file_mtime, file, file_size)) + + target_size = self.max_size - additional_size + # evict files until the directory size is less than or equal + # to `target_size` + while dir_size > target_size: + file_mtime, file, file_size = heapq.heappop(h) + msg = (f"Evicting cache file {file.name}: file size {file_size} bytes, " + f"target cache size {target_size} bytes") + logger.debug(msg) + file.unlink() + dir_size -= file_size + + # See comments in `jax.src.compilation_cache.get_file_cache()` for details. + # TODO(ayx): This function has a duplicate in that place, and there is + # redundancy here. However, this code is temporary, and once the issue + # is fixed, this code can be removed. + @staticmethod + def _is_local_filesystem(path: str) -> bool: + return path.startswith("file://") or "://" not in path diff --git a/tests/BUILD b/tests/BUILD index 69d6aff38009..21fe476e3ff9 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1135,6 +1135,16 @@ py_test( ], ) +py_test( + name = "lru_cache_test", + srcs = ["lru_cache_test.py"], + deps = [ + "//jax", + "//jax:lru_cache", + "//jax:test_util", + ] + py_deps("filelock"), +) + jax_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.py"], diff --git a/tests/lru_cache_test.py b/tests/lru_cache_test.py new file mode 100644 index 000000000000..483f65975845 --- /dev/null +++ b/tests/lru_cache_test.py @@ -0,0 +1,151 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import tempfile +import time + +from absl.testing import absltest + +from jax._src import path as pathlib +from jax._src.lru_cache import LRUCache +import jax._src.test_util as jtu + + +class LRUCacheTestCase(jtu.JaxTestCase): + name: str | None + path: pathlib.Path | None + + def setUp(self): + super().setUp() + tmpdir = tempfile.TemporaryDirectory() + self.enter_context(tmpdir) + self.name = tmpdir.name + self.path = pathlib.Path(self.name) + + def tearDown(self): + self.path = None + self.name = None + super().tearDown() + + +class LRUCacheTest(LRUCacheTestCase): + + def test_get_nonexistent_key(self): + cache = LRUCache(self.name, max_size=-1) + self.assertIsNone(cache.get("cache-a")) + + def test_put_and_get_key(self): + cache = LRUCache(self.name, max_size=-1) + + cache.put("cache-a", b"a") + self.assertEqual(cache.get("cache-a"), b"a") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a"}) + + cache.put("cache-b", b"b") + self.assertEqual(cache.get("cache-a"), b"a") + self.assertEqual(cache.get("cache-b"), b"b") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"}) + + def test_put_empty_value(self): + cache = LRUCache(self.name, max_size=-1) + + cache.put("cache-a", b"") + self.assertEqual(cache.get("cache-a"), b"") + + def test_put_empty_key(self): + cache = LRUCache(self.name, max_size=-1) + + with self.assertRaisesRegex(ValueError, r"key cannot be empty"): + cache.put("", b"a") + + def test_eviction(self): + cache = LRUCache(self.name, max_size=2) + + cache.put("cache-a", b"a") + cache.put("cache-b", b"b") + + # `sleep()` is necessary to guarantee that `cache-b`"s timestamp is strictly greater than `cache-a`"s + time.sleep(1) + cache.get("cache-b") + + # write `cache-c`, evict `cache-a` + cache.put("cache-c", b"c") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-c"}) + + # calling `get()` on `cache-b` makes `cache-c` least recently used + time.sleep(1) + cache.get("cache-b") + + # write `cache-d`, evict `cache-c` + cache.put("cache-d", b"d") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-d"}) + + def test_eviction_with_empty_value(self): + cache = LRUCache(self.name, max_size=1) + + cache.put("cache-a", b"a") + + # write `cache-b` with length 0 + # eviction should not happen even though the cache is full + cache.put("cache-b", b"") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"}) + + # calling `get()` on `cache-a` makes `cache-b` least recently used + time.sleep(1) + cache.get("cache-a") + + # writing `cache-c` should result in evicting the + # least recent used file (`cache-b`) first, + # but this is not sufficient to make room for `cache-c`, + # so `cache-a` should be evicted as well + cache.put("cache-c", b"c") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-c"}) + + def test_existing_cache_dir(self): + cache = LRUCache(self.name, max_size=2) + + cache.put("cache-a", b"a") + + # simulates reinitializing the cache in another process + del cache + cache = LRUCache(self.name, max_size=2) + + self.assertEqual(cache.get("cache-a"), b"a") + + # ensure that the LRU policy survives cache reinitialization + cache.put("cache-b", b"b") + + # calling `get()` on `cache-a` makes `cache-b` least recently used + time.sleep(1) + cache.get("cache-a") + + # write `cache-c`, evict `cache-b` + cache.put("cache-c", b"c") + self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-c"}) + + def test_max_size(self): + cache = LRUCache(self.name, max_size=1) + + msg = (r"Cache value for key .+? of size \d+ bytes exceeds the maximum " + r"cache size of \d+ bytes") + with self.assertWarnsRegex(UserWarning, msg): + cache.put("cache-a", b"aaaa") + self.assertIsNone(cache.get("cache-a")) + self.assertEqual(set(self.path.glob("cache-*")), set()) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())