Skip to content

Commit

Permalink
Merge pull request #21926 from ayaka14732:lru-cache-3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 649019092
  • Loading branch information
jax authors committed Jul 3, 2024
2 parents 8844877 + 934142d commit 467c62c
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 66 deletions.
64 changes: 44 additions & 20 deletions jax/_src/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import heapq
import logging
import pathlib
import time
import warnings

from jax._src.compilation_cache_interface import CacheInterface
Expand All @@ -31,6 +32,10 @@
logger = logging.getLogger(__name__)


_CACHE_SUFFIX = "-cache"
_ATIME_SUFFIX = "-atime"


class LRUCache(CacheInterface):
"""Bounded cache with least-recently-used (LRU) eviction policy.
Expand Down Expand Up @@ -87,19 +92,25 @@ def get(self, key: str) -> bytes | None:
if not key:
raise ValueError("key cannot be empty")

file = self.path / key
cache_path = self.path / f"{key}{_CACHE_SUFFIX}"
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"

if self.eviction_enabled:
self.lock.acquire(timeout=self.lock_timeout_secs)

try:
if not file.exists():
if not cache_path.exists():
logger.debug(f"Cache miss for key: {key!r}")
return None

logger.debug(f"Cache hit for key: {key!r}")
file.touch() # update mtime
return file.read_bytes()

val = cache_path.read_bytes()

timestamp = time.time_ns().to_bytes(8, "little")
atime_path.write_bytes(timestamp)

return val

finally:
if self.eviction_enabled:
Expand All @@ -125,17 +136,22 @@ def put(self, key: str, val: bytes) -> None:
warnings.warn(msg)
return

file = self.path / key
cache_path = self.path / f"{key}{_CACHE_SUFFIX}"
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"

if self.eviction_enabled:
self.lock.acquire(timeout=self.lock_timeout_secs)

try:
if file.exists():
if cache_path.exists():
return

self._evict_if_needed(additional_size=len(val))
file.write_bytes(val)

cache_path.write_bytes(val)

timestamp = time.time_ns().to_bytes(8, "little")
atime_path.write_bytes(timestamp)

finally:
if self.eviction_enabled:
Expand All @@ -153,26 +169,34 @@ def _evict_if_needed(self, *, additional_size: int = 0) -> None:
if not self.eviction_enabled:
return

# a priority queue, each element is a tuple `(file_mtime, file, file_size)`
h: list[tuple[int, pathlib.Path, int]] = []
# a priority queue, each element is a tuple `(file_atime, key, file_size)`
h: list[tuple[int, str, int]] = []
dir_size = 0
for file in self.path.iterdir():
if file.is_file() and file != self.lock_path:
file_size = file.stat().st_size
file_mtime = file.stat().st_mtime_ns
for cache_path in self.path.glob(f"*{_CACHE_SUFFIX}"):
file_size = cache_path.stat().st_size

key = cache_path.name.removesuffix(_CACHE_SUFFIX)
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"
file_atime = int.from_bytes(atime_path.read_bytes(), "little")

dir_size += file_size
heapq.heappush(h, (file_mtime, file, file_size))
dir_size += file_size
heapq.heappush(h, (file_atime, key, file_size))

target_size = self.max_size - additional_size
# evict files until the directory size is less than or equal
# to `target_size`
while dir_size > target_size:
file_mtime, file, file_size = heapq.heappop(h)
msg = (f"Evicting cache file {file.name}: file size {file_size} bytes, "
f"target cache size {target_size} bytes")
logger.debug(msg)
file.unlink()
file_atime, key, file_size = heapq.heappop(h)

logger.debug("Evicting cache entry %r: file size %d bytes, "
"target cache size %d bytes", key, file_size, target_size)

cache_path = self.path / f"{key}{_CACHE_SUFFIX}"
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"

cache_path.unlink()
atime_path.unlink()

dir_size -= file_size

# See comments in `jax.src.compilation_cache.get_file_cache()` for details.
Expand Down
95 changes: 49 additions & 46 deletions tests/lru_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import absltest

from jax._src import path as pathlib
from jax._src.lru_cache import LRUCache
from jax._src.lru_cache import _CACHE_SUFFIX, LRUCache
import jax._src.test_util as jtu


Expand All @@ -44,30 +44,33 @@ def tearDown(self):
self.name = None
super().tearDown()

def assertCacheKeys(self, keys):
self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), {self.path / f"{key}{_CACHE_SUFFIX}" for key in keys})


class LRUCacheTest(LRUCacheTestCase):

def test_get_nonexistent_key(self):
cache = LRUCache(self.name, max_size=-1)
self.assertIsNone(cache.get("cache-a"))
self.assertIsNone(cache.get("a"))

def test_put_and_get_key(self):
cache = LRUCache(self.name, max_size=-1)

cache.put("cache-a", b"a")
self.assertEqual(cache.get("cache-a"), b"a")
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a"})
cache.put("a", b"a")
self.assertEqual(cache.get("a"), b"a")
self.assertCacheKeys(("a",))

cache.put("cache-b", b"b")
self.assertEqual(cache.get("cache-a"), b"a")
self.assertEqual(cache.get("cache-b"), b"b")
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"})
cache.put("b", b"b")
self.assertEqual(cache.get("a"), b"a")
self.assertEqual(cache.get("b"), b"b")
self.assertCacheKeys(("a", "b"))

def test_put_empty_value(self):
cache = LRUCache(self.name, max_size=-1)

cache.put("cache-a", b"")
self.assertEqual(cache.get("cache-a"), b"")
cache.put("a", b"")
self.assertEqual(cache.get("a"), b"")

def test_put_empty_key(self):
cache = LRUCache(self.name, max_size=-1)
Expand All @@ -78,77 +81,77 @@ def test_put_empty_key(self):
def test_eviction(self):
cache = LRUCache(self.name, max_size=2)

cache.put("cache-a", b"a")
cache.put("cache-b", b"b")
cache.put("a", b"a")
cache.put("b", b"b")

# `sleep()` is necessary to guarantee that `cache-b`"s timestamp is strictly greater than `cache-a`"s
# `sleep()` is necessary to guarantee that `b`'s timestamp is strictly greater than `a`'s
time.sleep(1)
cache.get("cache-b")
cache.get("b")

# write `cache-c`, evict `cache-a`
cache.put("cache-c", b"c")
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-c"})
# write `c`. `a` should be evicted
cache.put("c", b"c")
self.assertCacheKeys(("b", "c"))

# calling `get()` on `cache-b` makes `cache-c` least recently used
# calling `get()` on `b` makes `c` least recently used
time.sleep(1)
cache.get("cache-b")
cache.get("b")

# write `cache-d`, evict `cache-c`
cache.put("cache-d", b"d")
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-d"})
# write `d`. `c` should be evicted
cache.put("d", b"d")
self.assertCacheKeys(("b", "d"))

def test_eviction_with_empty_value(self):
cache = LRUCache(self.name, max_size=1)

cache.put("cache-a", b"a")
cache.put("a", b"a")

# write `cache-b` with length 0
# write `b` with length 0
# eviction should not happen even though the cache is full
cache.put("cache-b", b"")
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"})
cache.put("b", b"")
self.assertCacheKeys(("a", "b"))

# calling `get()` on `cache-a` makes `cache-b` least recently used
# calling `get()` on `a` makes `b` least recently used
time.sleep(1)
cache.get("cache-a")
cache.get("a")

# writing `cache-c` should result in evicting the
# least recent used file (`cache-b`) first,
# but this is not sufficient to make room for `cache-c`,
# so `cache-a` should be evicted as well
cache.put("cache-c", b"c")
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-c"})
# writing `c` should result in evicting the
# least recent used file (`b`) first,
# but this is not sufficient to make room for `c`,
# so `a` should be evicted as well
cache.put("c", b"c")
self.assertCacheKeys(("c",))

def test_existing_cache_dir(self):
cache = LRUCache(self.name, max_size=2)

cache.put("cache-a", b"a")
cache.put("a", b"a")

# simulates reinitializing the cache in another process
del cache
cache = LRUCache(self.name, max_size=2)

self.assertEqual(cache.get("cache-a"), b"a")
self.assertEqual(cache.get("a"), b"a")

# ensure that the LRU policy survives cache reinitialization
cache.put("cache-b", b"b")
cache.put("b", b"b")

# calling `get()` on `cache-a` makes `cache-b` least recently used
# calling `get()` on `a` makes `b` least recently used
time.sleep(1)
cache.get("cache-a")
cache.get("a")

# write `cache-c`, evict `cache-b`
cache.put("cache-c", b"c")
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-c"})
# write `c`. `b` should be evicted
cache.put("c", b"c")
self.assertCacheKeys(("a", "c"))

def test_max_size(self):
cache = LRUCache(self.name, max_size=1)

msg = (r"Cache value for key .+? of size \d+ bytes exceeds the maximum "
r"cache size of \d+ bytes")
with self.assertWarnsRegex(UserWarning, msg):
cache.put("cache-a", b"aaaa")
self.assertIsNone(cache.get("cache-a"))
self.assertEqual(set(self.path.glob("cache-*")), set())
cache.put("a", b"aaaa")
self.assertIsNone(cache.get("a"))
self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), set())


if __name__ == "__main__":
Expand Down

0 comments on commit 467c62c

Please sign in to comment.