diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 9c276151741d..ba084704735c 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -28,10 +28,8 @@ zstandard = None from jax._src import cache_key -from jax._src.compilation_cache_interface import CacheInterface from jax._src import config from jax._src import monitoring -from jax._src.gfile_cache import GFileCache from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lru_cache import LRUCache @@ -39,7 +37,7 @@ logger = logging.getLogger(__name__) -_cache: CacheInterface | None = None +_cache: LRUCache | None = None _cache_initialized: bool = False @@ -65,20 +63,8 @@ def set_once_cache_used(f) -> None: f() -def get_file_cache(path: str) -> tuple[CacheInterface, str] | None: +def get_file_cache(path: str) -> tuple[LRUCache, str] | None: """Returns the file cache and the path to the cache.""" - - def is_local_filesystem(path: str) -> bool: - return path.startswith("file://") or "://" not in path - - # `LRUCache` currently only supports local filesystem. Therefore, if `path` - # is not on a local filesystem, instead of using `LRUCache`, we - # fallback to the old `GFileCache`, which does not support LRU eviction. - # TODO(ayx): Add cloud storage support for `LRUCache`, so that all these code - # can be removed. - if not is_local_filesystem(path): - return GFileCache(path), path - max_size = config.compilation_cache_max_size.value return LRUCache(path, max_size=max_size), path @@ -149,7 +135,7 @@ def _initialize_cache() -> None: logger.debug("Initialized persistent compilation cache at %s", path) -def _get_cache(backend) -> CacheInterface | None: +def _get_cache(backend) -> LRUCache | None: # TODO(b/289098047): consider making this an API and changing the callers of # get_executable_and_time() and put_executable_and_time() to call get_cache() # and passing the result to them. @@ -198,7 +184,7 @@ def get_executable_and_time( logger.debug("get_executable_and_time: cache is disabled/not initialized") return None, None executable_and_time = cache.get(cache_key) - if not executable_and_time: + if executable_and_time is None: return None, None executable_and_time = decompress_executable(executable_and_time) @@ -274,7 +260,7 @@ def reset_cache() -> None: global _cache_initialized global _cache_used logger.info("Resetting cache at %s.", - _cache._path if _cache is not None else "") + _cache.path if _cache is not None else "") _cache = None with _cache_initialized_mutex: _cache_initialized = False diff --git a/jax/_src/compilation_cache_interface.py b/jax/_src/compilation_cache_interface.py deleted file mode 100644 index 95d557c5531e..000000000000 --- a/jax/_src/compilation_cache_interface.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2021 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import abstractmethod - -from jax._src import path as pathlib -from jax._src import util - - -class CacheInterface(util.StrictABC): - _path: pathlib.Path - - @abstractmethod - def get(self, key: str): - pass - - @abstractmethod - def put(self, key: str, value: bytes): - pass diff --git a/jax/_src/gfile_cache.py b/jax/_src/gfile_cache.py deleted file mode 100644 index 989844b10ddb..000000000000 --- a/jax/_src/gfile_cache.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2022 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -from jax._src import path as pathlib -from jax._src.compilation_cache_interface import CacheInterface - - -# TODO (ayx): This class will be ultimately removed after `lru_cache.py` is -# finished. It exists because the current `lru_cache.py` does not support -# `gs://`. -class GFileCache(CacheInterface): - - def __init__(self, path: str): - """Sets up a cache at 'path'. Cached values may already be present.""" - self._path = pathlib.Path(path) - self._path.mkdir(parents=True, exist_ok=True) - - def get(self, key: str): - """Returns None if 'key' isn't present.""" - if not key: - raise ValueError("key cannot be empty") - path_to_key = self._path / key - if path_to_key.exists(): - return path_to_key.read_bytes() - else: - return None - - def put(self, key: str, value: bytes): - """Adds new cache entry.""" - if not key: - raise ValueError("key cannot be empty") - path_to_new_file = self._path / key - if str(path_to_new_file).startswith('gs://'): - # Writes to gcs are atomic. - path_to_new_file.write_bytes(value) - elif str(path_to_new_file).startswith('file://') or '://' not in str(path_to_new_file): - tmp_path = self._path / f"_temp_{key}" - with open(str(tmp_path), "wb") as f: - f.write(value) - f.flush() - os.fsync(f.fileno()) - os.replace(tmp_path, path_to_new_file) - else: - tmp_path = self._path / f"_temp_{key}" - tmp_path.write_bytes(value) - tmp_path.replace(str(path_to_new_file)) diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py index 0b9a178e15ac..6c2d9e1d5620 100644 --- a/jax/_src/lru_cache.py +++ b/jax/_src/lru_cache.py @@ -16,18 +16,15 @@ import heapq import logging -import pathlib import time import warnings -from jax._src.compilation_cache_interface import CacheInterface - - try: import filelock except ImportError: filelock = None +from jax._src import path as pathlib logger = logging.getLogger(__name__) @@ -36,7 +33,7 @@ _ATIME_SUFFIX = "-atime" -class LRUCache(CacheInterface): +class LRUCache: """Bounded cache with least-recently-used (LRU) eviction policy. This implementation includes cache reading, writing and eviction @@ -56,29 +53,29 @@ def __init__(self, path: str, *, max_size: int, lock_timeout_secs: float | None indicates no limit, allowing the cache size to grow indefinitely. lock_timeout_secs: (optional) The timeout for acquiring a file lock. """ - # TODO(ayx): add support for cloud other filesystems such as GCS - if not self._is_local_filesystem(path): - raise NotImplementedError("LRUCache only supports local filesystem at this time.") + if not self._is_local_filesystem(path) and not pathlib.epath_installed: + raise RuntimeError("Please install the `etils[epath]` package to specify a cache directory on a non-local filesystem") self.path = pathlib.Path(path) self.path.mkdir(parents=True, exist_ok=True) - # TODO(ayx): having a `self._path` is required by the base class - # `CacheInterface`, but the base class can be removed after `LRUCache` - # and the original `GFileCache` are unified - self._path = self.path - self.eviction_enabled = max_size != -1 # no eviction if `max_size` is set to -1 if self.eviction_enabled: if filelock is None: - raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`") + raise RuntimeError("Please install the `filelock` package to set `jax_compilation_cache_max_size`") self.max_size = max_size self.lock_timeout_secs = lock_timeout_secs self.lock_path = self.path / ".lockfile" - self.lock = filelock.FileLock(self.lock_path) + + if self._is_local_filesystem(path): + FileLock = filelock.FileLock + else: + FileLock = filelock.SoftFileLock + + self.lock = FileLock(self.lock_path) def get(self, key: str) -> bytes | None: """Retrieves the cached value for the given key. @@ -199,10 +196,6 @@ def _evict_if_needed(self, *, additional_size: int = 0) -> None: dir_size -= file_size - # See comments in `jax.src.compilation_cache.get_file_cache()` for details. - # TODO(ayx): This function has a duplicate in that place, and there is - # redundancy here. However, this code is temporary, and once the issue - # is fixed, this code can be removed. @staticmethod def _is_local_filesystem(path: str) -> bool: return path.startswith("file://") or "://" not in path diff --git a/jax/_src/path.py b/jax/_src/path.py index 1f92dd504282..1dd523249692 100644 --- a/jax/_src/path.py +++ b/jax/_src/path.py @@ -19,8 +19,10 @@ try: import etils.epath as epath + epath_installed = True except: epath = None + epath_installed = False # If etils.epath (aka etils[epath] to pip) is present, we prefer it because it # can read and write to, e.g., GCS buckets. Otherwise we use the builtin diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index a3e9c623db4c..9f059f8434d3 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -37,7 +37,6 @@ from jax._src import path as pathlib from jax._src import test_util as jtu from jax._src import xla_bridge -from jax._src.compilation_cache_interface import CacheInterface from jax._src.lib import xla_client from jax._src.maps import xmap from jax.experimental.pjit import pjit @@ -63,11 +62,11 @@ def increment_event_count(event): _counts[event] += 1 -class InMemoryCache(CacheInterface): +class InMemoryCache: """An in-memory cache for testing purposes.""" - # not used, but required by `CacheInterface` - _path = pathlib.Path() + # not used, but needed to simulate an `LRUCache` object + path = pathlib.Path() def __init__(self): self._cache: dict[str, bytes] = {} diff --git a/tests/gfile_cache_test.py b/tests/gfile_cache_test.py deleted file mode 100644 index 1ccaaf8ba50f..000000000000 --- a/tests/gfile_cache_test.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2021 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tempfile -import threading - -from absl.testing import absltest - -from jax._src.gfile_cache import GFileCache -import jax._src.test_util as jtu - - -class FileSystemCacheTest(jtu.JaxTestCase): - - def test_get_nonexistent_key(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = GFileCache(tmpdir) - self.assertEqual(cache.get("nonExistentKey"), None) - - def test_put_and_get_key(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = GFileCache(tmpdir) - cache.put("foo", b"bar") - self.assertEqual(cache.get("foo"), b"bar") - - def test_existing_cache_path(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache1 = GFileCache(tmpdir) - cache1.put("foo", b"bar") - del cache1 - cache2 = GFileCache(tmpdir) - self.assertEqual(cache2.get("foo"), b"bar") - - def test_empty_value_put(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = GFileCache(tmpdir) - cache.put("foo", b"") - self.assertEqual(cache.get("foo"), b"") - - def test_empty_key_put(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = GFileCache(tmpdir) - with self.assertRaisesRegex(ValueError, r"key cannot be empty"): - cache.put("", b"bar") - - def test_empty_key_get(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = GFileCache(tmpdir) - with self.assertRaisesRegex(ValueError, r"key cannot be empty"): - cache.get("") - - def test_threads(self): - file_contents1 = "1" * (65536 + 1) - file_contents2 = "2" * (65536 + 1) - - def call_multiple_puts_and_gets(cache): - for _ in range(50): - cache.put("foo", file_contents1.encode("utf-8").strip()) - cache.put("foo", file_contents2.encode("utf-8").strip()) - cache.get("foo") - self.assertEqual( - cache.get("foo"), file_contents2.encode("utf-8").strip() - ) - - with tempfile.TemporaryDirectory() as tmpdir: - cache = GFileCache(tmpdir) - threads = [] - for _ in range(50): - t = threading.Thread(target=call_multiple_puts_and_gets(cache)) - t.start() - threads.append(t) - for t in threads: - t.join() - - self.assertEqual(cache.get("foo"), file_contents2.encode("utf-8").strip()) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader())