diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index e60a29e274e9..367515928e7a 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -31,7 +31,7 @@ from jax._src.compilation_cache_interface import CacheInterface from jax._src import config from jax._src import monitoring -from jax._src.gfile_cache import GFileCache +from jax._src.lru_cache import LRUCache from jax._src.lib import xla_client from jax._src.lib.mlir import ir @@ -65,7 +65,7 @@ def set_once_cache_used(f) -> None: def get_file_cache(path: str) -> tuple[CacheInterface, str] | None: """Returns the file cache and the path to the cache.""" - return GFileCache(path), path + return LRUCache(path), path def set_cache_dir(path) -> None: diff --git a/jax/_src/compilation_cache_interface.py b/jax/_src/compilation_cache_interface.py index 95d557c5531e..b3c042055814 100644 --- a/jax/_src/compilation_cache_interface.py +++ b/jax/_src/compilation_cache_interface.py @@ -14,12 +14,10 @@ from abc import abstractmethod -from jax._src import path as pathlib from jax._src import util class CacheInterface(util.StrictABC): - _path: pathlib.Path @abstractmethod def get(self, key: str): diff --git a/jax/_src/gfile_cache.py b/jax/_src/gfile_cache.py index 301f61cc6bdb..98c524189c4b 100644 --- a/jax/_src/gfile_cache.py +++ b/jax/_src/gfile_cache.py @@ -15,9 +15,12 @@ import os from jax._src import path as pathlib -from jax._src.compilation_cache_interface import CacheInterface -class GFileCache(CacheInterface): + +# TODO (ayx): This class will be ultimately removed after `lru_cache.py` is +# finished. It exists because the current `lru_cache.py` does not support +# `gs://`. +class GFileCacheImpl: def __init__(self, path: str): """Sets up a cache at 'path'. Cached values may already be present.""" diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py new file mode 100644 index 000000000000..1e92d4348d05 --- /dev/null +++ b/jax/_src/lru_cache.py @@ -0,0 +1,135 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import tensorflow as tf +from typing import Any, Optional + +from jax._src.compilation_cache_interface import CacheInterface +from jax._src.gfile_cache import GFileCacheImpl + + +filelock: Any = None +try: + import filelock +except ImportError: + pass + + +logger = logging.getLogger(__name__) + + +def is_path_gs(path: str) -> bool: + return path.startswith("gs://") + + +def get_size(path: str) -> int: + return tf.io.gfile.stat(path).length + + +def get_mtime(path: str) -> int: + return tf.io.gfile.stat(path).mtime_nsec + + +def update_mtime(path: str) -> None: + if not is_path_gs(path): + os.utime(path) # set mtime (and also atime) to current time + else: + tmp_path = f"{path}.tmp" + tf.io.gfile.rename(path, tmp_path) + tf.io.gfile.rename(tmp_path, path) + + +class LRUCacheImpl: + + def __init__(self, cache_dir: str, max_cache_size: int, timeout=10): + if filelock is None: + raise RuntimeError("Please install filelock to use the LRUCache") + + self.cache_dir = cache_dir + self.max_cache_size = max_cache_size + os.makedirs(cache_dir, exist_ok=True) + self.timeout = timeout + + self.lock_file = os.path.join(self.cache_dir, ".lockfile") + self.lock = filelock.FileLock(self.lock_file) + + def get(self, key: str) -> Optional[bytes]: + with self.lock.acquire(timeout=self.timeout): + file_path = os.path.join(self.cache_dir, key) + + if not os.path.exists(file_path): + logger.debug("Cache miss") + return None + + logger.debug("Cache hit") + update_mtime(file_path) + with open(file_path, "rb") as f: + return f.read() + + def put(self, key: str, val: bytes) -> None: + with self.lock.acquire(timeout=self.timeout): + file_path = os.path.join(self.cache_dir, key) + if os.path.exists(file_path): + return + + self._evict_if_needed() + with open(file_path, "wb") as f: + f.write(val) + update_mtime(file_path) + + def _evict_if_needed(self) -> None: + if self.max_cache_size == -1: + return # max_cache_size == -1: no limit on cache size + + files = os.listdir(self.cache_dir) + + mtime_path_sizes = [] + cache_dir_size = 0 + for filename in files: + file_path = os.path.join(self.cache_dir, filename) + + file_size = get_size(file_path) + file_mtime = get_mtime(file_path) + + cache_dir_size += file_size + mtime_path_sizes.append((file_mtime, file_path, file_size)) + + # sort by mtime, descending + mtime_path_sizes.sort(key=lambda xyz: xyz[0], reverse=True) + + while cache_dir_size >= self.max_cache_size: + file_mtime, file_path, file_size = mtime_path_sizes[-1] + os.remove(file_path) + cache_dir_size -= file_size + mtime_path_sizes.pop() + +class LRUCache(CacheInterface): + + def __init__(self, path: str): + """Sets up a cache at `path`. Cached values may already be present.""" + if is_path_gs(path): + # gs:// does not support cache eviction yet + self.cache = GFileCacheImpl(path=path) + else: + self.cache = LRUCacheImpl(cache_dir=path, max_cache_size=1000000) + + def get(self, key: str) -> Optional[bytes]: + """Returns None if `key` isn't present.""" + return self.cache.get(key) + + def put(self, key: str, value: bytes) -> None: + """Adds new cache entry.""" + self.cache.put(key, value) diff --git a/tests/lru_cache_test.py b/tests/lru_cache_test.py new file mode 100644 index 000000000000..ce4761a9f615 --- /dev/null +++ b/tests/lru_cache_test.py @@ -0,0 +1,50 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import string +import tempfile + +from absl.testing import absltest + +from jax._src.lru_cache import LRUCacheImpl +import jax._src.test_util as jtu + + +class LRUCacheTest(jtu.JaxTestCase): + + def test_cache_eviction(self): + def generate_random_k(): + # simulate keys with many collisions + return random.choice(string.ascii_lowercase[:12]) + + def generate_random_v(k): + # simulate large values while ensures that one k corresponds to one v + return bytes(f"{k}abcdefghijklmnopqrstuvwxyz" * 8192, encoding="utf-8") + + with tempfile.TemporaryDirectory() as tmpdirname: + cache = LRUCacheImpl(tmpdirname, 884700) + + for _ in range(100000): + k = generate_random_k() + + if cache.get(k) is not None: + pass # cache hit + else: + v = generate_random_v(k) + cache.put(k, v) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())