From 934142dff4716cffa46d0a72b3fefdf0e51ded89 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 17 Jun 2024 23:58:05 +0400 Subject: [PATCH] Storing the last access time of a cache entry in a separate file --- jax/_src/lru_cache.py | 64 ++++++++++++++++++--------- tests/lru_cache_test.py | 95 +++++++++++++++++++++-------------------- 2 files changed, 93 insertions(+), 66 deletions(-) diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py index 3b1f9df07210..0b9a178e15ac 100644 --- a/jax/_src/lru_cache.py +++ b/jax/_src/lru_cache.py @@ -17,6 +17,7 @@ import heapq import logging import pathlib +import time import warnings from jax._src.compilation_cache_interface import CacheInterface @@ -31,6 +32,10 @@ logger = logging.getLogger(__name__) +_CACHE_SUFFIX = "-cache" +_ATIME_SUFFIX = "-atime" + + class LRUCache(CacheInterface): """Bounded cache with least-recently-used (LRU) eviction policy. @@ -87,19 +92,25 @@ def get(self, key: str) -> bytes | None: if not key: raise ValueError("key cannot be empty") - file = self.path / key + cache_path = self.path / f"{key}{_CACHE_SUFFIX}" + atime_path = self.path / f"{key}{_ATIME_SUFFIX}" if self.eviction_enabled: self.lock.acquire(timeout=self.lock_timeout_secs) try: - if not file.exists(): + if not cache_path.exists(): logger.debug(f"Cache miss for key: {key!r}") return None logger.debug(f"Cache hit for key: {key!r}") - file.touch() # update mtime - return file.read_bytes() + + val = cache_path.read_bytes() + + timestamp = time.time_ns().to_bytes(8, "little") + atime_path.write_bytes(timestamp) + + return val finally: if self.eviction_enabled: @@ -125,17 +136,22 @@ def put(self, key: str, val: bytes) -> None: warnings.warn(msg) return - file = self.path / key + cache_path = self.path / f"{key}{_CACHE_SUFFIX}" + atime_path = self.path / f"{key}{_ATIME_SUFFIX}" if self.eviction_enabled: self.lock.acquire(timeout=self.lock_timeout_secs) try: - if file.exists(): + if cache_path.exists(): return self._evict_if_needed(additional_size=len(val)) - file.write_bytes(val) + + cache_path.write_bytes(val) + + timestamp = time.time_ns().to_bytes(8, "little") + atime_path.write_bytes(timestamp) finally: if self.eviction_enabled: @@ -153,26 +169,34 @@ def _evict_if_needed(self, *, additional_size: int = 0) -> None: if not self.eviction_enabled: return - # a priority queue, each element is a tuple `(file_mtime, file, file_size)` - h: list[tuple[int, pathlib.Path, int]] = [] + # a priority queue, each element is a tuple `(file_atime, key, file_size)` + h: list[tuple[int, str, int]] = [] dir_size = 0 - for file in self.path.iterdir(): - if file.is_file() and file != self.lock_path: - file_size = file.stat().st_size - file_mtime = file.stat().st_mtime_ns + for cache_path in self.path.glob(f"*{_CACHE_SUFFIX}"): + file_size = cache_path.stat().st_size + + key = cache_path.name.removesuffix(_CACHE_SUFFIX) + atime_path = self.path / f"{key}{_ATIME_SUFFIX}" + file_atime = int.from_bytes(atime_path.read_bytes(), "little") - dir_size += file_size - heapq.heappush(h, (file_mtime, file, file_size)) + dir_size += file_size + heapq.heappush(h, (file_atime, key, file_size)) target_size = self.max_size - additional_size # evict files until the directory size is less than or equal # to `target_size` while dir_size > target_size: - file_mtime, file, file_size = heapq.heappop(h) - msg = (f"Evicting cache file {file.name}: file size {file_size} bytes, " - f"target cache size {target_size} bytes") - logger.debug(msg) - file.unlink() + file_atime, key, file_size = heapq.heappop(h) + + logger.debug("Evicting cache entry %r: file size %d bytes, " + "target cache size %d bytes", key, file_size, target_size) + + cache_path = self.path / f"{key}{_CACHE_SUFFIX}" + atime_path = self.path / f"{key}{_ATIME_SUFFIX}" + + cache_path.unlink() + atime_path.unlink() + dir_size -= file_size # See comments in `jax.src.compilation_cache.get_file_cache()` for details. diff --git a/tests/lru_cache_test.py b/tests/lru_cache_test.py index fb999cbef0cf..588eccddbc0f 100644 --- a/tests/lru_cache_test.py +++ b/tests/lru_cache_test.py @@ -21,7 +21,7 @@ from absl.testing import absltest from jax._src import path as pathlib -from jax._src.lru_cache import LRUCache +from jax._src.lru_cache import _CACHE_SUFFIX, LRUCache import jax._src.test_util as jtu @@ -44,30 +44,33 @@ def tearDown(self): self.name = None super().tearDown() + def assertCacheKeys(self, keys): + self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), {self.path / f"{key}{_CACHE_SUFFIX}" for key in keys}) + class LRUCacheTest(LRUCacheTestCase): def test_get_nonexistent_key(self): cache = LRUCache(self.name, max_size=-1) - self.assertIsNone(cache.get("cache-a")) + self.assertIsNone(cache.get("a")) def test_put_and_get_key(self): cache = LRUCache(self.name, max_size=-1) - cache.put("cache-a", b"a") - self.assertEqual(cache.get("cache-a"), b"a") - self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a"}) + cache.put("a", b"a") + self.assertEqual(cache.get("a"), b"a") + self.assertCacheKeys(("a",)) - cache.put("cache-b", b"b") - self.assertEqual(cache.get("cache-a"), b"a") - self.assertEqual(cache.get("cache-b"), b"b") - self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"}) + cache.put("b", b"b") + self.assertEqual(cache.get("a"), b"a") + self.assertEqual(cache.get("b"), b"b") + self.assertCacheKeys(("a", "b")) def test_put_empty_value(self): cache = LRUCache(self.name, max_size=-1) - cache.put("cache-a", b"") - self.assertEqual(cache.get("cache-a"), b"") + cache.put("a", b"") + self.assertEqual(cache.get("a"), b"") def test_put_empty_key(self): cache = LRUCache(self.name, max_size=-1) @@ -78,67 +81,67 @@ def test_put_empty_key(self): def test_eviction(self): cache = LRUCache(self.name, max_size=2) - cache.put("cache-a", b"a") - cache.put("cache-b", b"b") + cache.put("a", b"a") + cache.put("b", b"b") - # `sleep()` is necessary to guarantee that `cache-b`"s timestamp is strictly greater than `cache-a`"s + # `sleep()` is necessary to guarantee that `b`'s timestamp is strictly greater than `a`'s time.sleep(1) - cache.get("cache-b") + cache.get("b") - # write `cache-c`, evict `cache-a` - cache.put("cache-c", b"c") - self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-c"}) + # write `c`. `a` should be evicted + cache.put("c", b"c") + self.assertCacheKeys(("b", "c")) - # calling `get()` on `cache-b` makes `cache-c` least recently used + # calling `get()` on `b` makes `c` least recently used time.sleep(1) - cache.get("cache-b") + cache.get("b") - # write `cache-d`, evict `cache-c` - cache.put("cache-d", b"d") - self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-d"}) + # write `d`. `c` should be evicted + cache.put("d", b"d") + self.assertCacheKeys(("b", "d")) def test_eviction_with_empty_value(self): cache = LRUCache(self.name, max_size=1) - cache.put("cache-a", b"a") + cache.put("a", b"a") - # write `cache-b` with length 0 + # write `b` with length 0 # eviction should not happen even though the cache is full - cache.put("cache-b", b"") - self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"}) + cache.put("b", b"") + self.assertCacheKeys(("a", "b")) - # calling `get()` on `cache-a` makes `cache-b` least recently used + # calling `get()` on `a` makes `b` least recently used time.sleep(1) - cache.get("cache-a") + cache.get("a") - # writing `cache-c` should result in evicting the - # least recent used file (`cache-b`) first, - # but this is not sufficient to make room for `cache-c`, - # so `cache-a` should be evicted as well - cache.put("cache-c", b"c") - self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-c"}) + # writing `c` should result in evicting the + # least recent used file (`b`) first, + # but this is not sufficient to make room for `c`, + # so `a` should be evicted as well + cache.put("c", b"c") + self.assertCacheKeys(("c",)) def test_existing_cache_dir(self): cache = LRUCache(self.name, max_size=2) - cache.put("cache-a", b"a") + cache.put("a", b"a") # simulates reinitializing the cache in another process del cache cache = LRUCache(self.name, max_size=2) - self.assertEqual(cache.get("cache-a"), b"a") + self.assertEqual(cache.get("a"), b"a") # ensure that the LRU policy survives cache reinitialization - cache.put("cache-b", b"b") + cache.put("b", b"b") - # calling `get()` on `cache-a` makes `cache-b` least recently used + # calling `get()` on `a` makes `b` least recently used time.sleep(1) - cache.get("cache-a") + cache.get("a") - # write `cache-c`, evict `cache-b` - cache.put("cache-c", b"c") - self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-c"}) + # write `c`. `b` should be evicted + cache.put("c", b"c") + self.assertCacheKeys(("a", "c")) def test_max_size(self): cache = LRUCache(self.name, max_size=1) @@ -146,9 +149,9 @@ def test_max_size(self): msg = (r"Cache value for key .+? of size \d+ bytes exceeds the maximum " r"cache size of \d+ bytes") with self.assertWarnsRegex(UserWarning, msg): - cache.put("cache-a", b"aaaa") - self.assertIsNone(cache.get("cache-a")) - self.assertEqual(set(self.path.glob("cache-*")), set()) + cache.put("a", b"aaaa") + self.assertIsNone(cache.get("a")) + self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), set()) if __name__ == "__main__":