From 6c05aa2f3251a91e73d704bca2fe511e6f59b2e2 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 3 Jul 2024 13:20:55 +0000 Subject: [PATCH] Clean up --- jax/BUILD | 12 +---- jax/_src/compilation_cache.py | 21 ++------ jax/_src/compiler.py | 2 - jax/_src/gfile_cache.py | 59 ----------------------- jax/_src/lru_cache.py | 44 ++++++++--------- jax/_src/path.py | 2 + tests/BUILD | 10 ---- tests/gfile_cache_test.py | 90 ----------------------------------- 8 files changed, 27 insertions(+), 213 deletions(-) delete mode 100644 jax/_src/gfile_cache.py delete mode 100644 tests/gfile_cache_test.py diff --git a/jax/BUILD b/jax/BUILD index a0de4f19ce5c..7fcf5cfa9884 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -389,7 +389,6 @@ pytype_strict_library( ":cache_key", ":compilation_cache_interface", ":config", - ":gfile_cache", ":lru_cache", ":monitoring", ":path", @@ -421,7 +420,7 @@ pytype_strict_library( srcs = ["_src/lru_cache.py"], deps = [ ":compilation_cache_interface", - ":config", + ":path", ] + py_deps("filelock"), ) @@ -515,15 +514,6 @@ pytype_strict_library( ] + py_deps("numpy"), ) -pytype_strict_library( - name = "gfile_cache", - srcs = ["_src/gfile_cache.py"], - deps = [ - ":compilation_cache_interface", - ":path", - ], -) - pytype_strict_library( name = "hardware_utils", srcs = ["_src/hardware_utils.py"], diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 9c276151741d..03b462b33806 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -28,10 +28,9 @@ zstandard = None from jax._src import cache_key -from jax._src.compilation_cache_interface import CacheInterface from jax._src import config from jax._src import monitoring -from jax._src.gfile_cache import GFileCache +from jax._src.compilation_cache_interface import CacheInterface from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lru_cache import LRUCache @@ -67,18 +66,6 @@ def set_once_cache_used(f) -> None: def get_file_cache(path: str) -> tuple[CacheInterface, str] | None: """Returns the file cache and the path to the cache.""" - - def is_local_filesystem(path: str) -> bool: - return path.startswith("file://") or "://" not in path - - # `LRUCache` currently only supports local filesystem. Therefore, if `path` - # is not on a local filesystem, instead of using `LRUCache`, we - # fallback to the old `GFileCache`, which does not support LRU eviction. - # TODO(ayx): Add cloud storage support for `LRUCache`, so that all these code - # can be removed. - if not is_local_filesystem(path): - return GFileCache(path), path - max_size = config.compilation_cache_max_size.value return LRUCache(path, max_size=max_size), path @@ -161,14 +148,14 @@ def _get_cache(backend) -> CacheInterface | None: return _cache -def compress_executable(executable): +def compress_executable(executable: bytes) -> bytes: if zstandard: compressor = zstandard.ZstdCompressor() return compressor.compress(executable) else: return zlib.compress(executable) -def decompress_executable(executable): +def decompress_executable(executable: bytes) -> bytes: if zstandard: decompressor = zstandard.ZstdDecompressor() return decompressor.decompress(executable) @@ -198,7 +185,7 @@ def get_executable_and_time( logger.debug("get_executable_and_time: cache is disabled/not initialized") return None, None executable_and_time = cache.get(cache_key) - if not executable_and_time: + if executable_and_time is None: return None, None executable_and_time = decompress_executable(executable_and_time) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 438f1f9e5183..b9906ad3c0fa 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -557,8 +557,6 @@ def _compile_and_share_module( ) -> xc.LoadedExecutable: share_timeout = config.share_binary_between_hosts_timeout_ms.value - # TODO: We need a proper eviction protocol here, otherwise all compiled - # modules will pile in memory. if cache_key in _compile_and_share_module.modules_cache: return _compile_and_share_module.modules_cache[cache_key] diff --git a/jax/_src/gfile_cache.py b/jax/_src/gfile_cache.py deleted file mode 100644 index 989844b10ddb..000000000000 --- a/jax/_src/gfile_cache.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2022 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -from jax._src import path as pathlib -from jax._src.compilation_cache_interface import CacheInterface - - -# TODO (ayx): This class will be ultimately removed after `lru_cache.py` is -# finished. It exists because the current `lru_cache.py` does not support -# `gs://`. -class GFileCache(CacheInterface): - - def __init__(self, path: str): - """Sets up a cache at 'path'. Cached values may already be present.""" - self._path = pathlib.Path(path) - self._path.mkdir(parents=True, exist_ok=True) - - def get(self, key: str): - """Returns None if 'key' isn't present.""" - if not key: - raise ValueError("key cannot be empty") - path_to_key = self._path / key - if path_to_key.exists(): - return path_to_key.read_bytes() - else: - return None - - def put(self, key: str, value: bytes): - """Adds new cache entry.""" - if not key: - raise ValueError("key cannot be empty") - path_to_new_file = self._path / key - if str(path_to_new_file).startswith('gs://'): - # Writes to gcs are atomic. - path_to_new_file.write_bytes(value) - elif str(path_to_new_file).startswith('file://') or '://' not in str(path_to_new_file): - tmp_path = self._path / f"_temp_{key}" - with open(str(tmp_path), "wb") as f: - f.write(value) - f.flush() - os.fsync(f.fileno()) - os.replace(tmp_path, path_to_new_file) - else: - tmp_path = self._path / f"_temp_{key}" - tmp_path.write_bytes(value) - tmp_path.replace(str(path_to_new_file)) diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py index 0b9a178e15ac..d476fe172dbf 100644 --- a/jax/_src/lru_cache.py +++ b/jax/_src/lru_cache.py @@ -16,18 +16,16 @@ import heapq import logging -import pathlib import time import warnings -from jax._src.compilation_cache_interface import CacheInterface - - try: import filelock except ImportError: filelock = None +from jax._src import path as pathlib +from jax._src.compilation_cache_interface import CacheInterface logger = logging.getLogger(__name__) @@ -36,6 +34,10 @@ _ATIME_SUFFIX = "-atime" +def _is_local_filesystem(path: str) -> bool: + return path.startswith("file://") or "://" not in path + + class LRUCache(CacheInterface): """Bounded cache with least-recently-used (LRU) eviction policy. @@ -56,29 +58,26 @@ def __init__(self, path: str, *, max_size: int, lock_timeout_secs: float | None indicates no limit, allowing the cache size to grow indefinitely. lock_timeout_secs: (optional) The timeout for acquiring a file lock. """ - # TODO(ayx): add support for cloud other filesystems such as GCS - if not self._is_local_filesystem(path): - raise NotImplementedError("LRUCache only supports local filesystem at this time.") + if not _is_local_filesystem(path) and not pathlib.epath_installed: + raise RuntimeError("Please install the `etils[epath]` package to specify a cache directory on a non-local filesystem") - self.path = pathlib.Path(path) + self.path = self._path = pathlib.Path(path) self.path.mkdir(parents=True, exist_ok=True) - # TODO(ayx): having a `self._path` is required by the base class - # `CacheInterface`, but the base class can be removed after `LRUCache` - # and the original `GFileCache` are unified - self._path = self.path - self.eviction_enabled = max_size != -1 # no eviction if `max_size` is set to -1 if self.eviction_enabled: if filelock is None: - raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`") + raise RuntimeError("Please install the `filelock` package to set `jax_compilation_cache_max_size`") self.max_size = max_size self.lock_timeout_secs = lock_timeout_secs self.lock_path = self.path / ".lockfile" - self.lock = filelock.FileLock(self.lock_path) + if _is_local_filesystem(path): + self.lock = filelock.FileLock(self.lock_path) + else: + self.lock = filelock.SoftFileLock(self.lock_path) def get(self, key: str) -> bytes | None: """Retrieves the cached value for the given key. @@ -173,7 +172,12 @@ def _evict_if_needed(self, *, additional_size: int = 0) -> None: h: list[tuple[int, str, int]] = [] dir_size = 0 for cache_path in self.path.glob(f"*{_CACHE_SUFFIX}"): - file_size = cache_path.stat().st_size + file_stat = cache_path.stat() + + # `pathlib` and `etils[epath]` have different API for obtaining the size + # of a file, and we need to support them both. + # See also https://github.com/google/etils/issues/630 + file_size = file_stat.st_size if not pathlib.epath_installed else file_stat.length # pytype: disable=attribute-error key = cache_path.name.removesuffix(_CACHE_SUFFIX) atime_path = self.path / f"{key}{_ATIME_SUFFIX}" @@ -198,11 +202,3 @@ def _evict_if_needed(self, *, additional_size: int = 0) -> None: atime_path.unlink() dir_size -= file_size - - # See comments in `jax.src.compilation_cache.get_file_cache()` for details. - # TODO(ayx): This function has a duplicate in that place, and there is - # redundancy here. However, this code is temporary, and once the issue - # is fixed, this code can be removed. - @staticmethod - def _is_local_filesystem(path: str) -> bool: - return path.startswith("file://") or "://" not in path diff --git a/jax/_src/path.py b/jax/_src/path.py index 1f92dd504282..1dd523249692 100644 --- a/jax/_src/path.py +++ b/jax/_src/path.py @@ -19,8 +19,10 @@ try: import etils.epath as epath + epath_installed = True except: epath = None + epath_installed = False # If etils.epath (aka etils[epath] to pip) is present, we prefer it because it # can read and write to, e.g., GCS buckets. Otherwise we use the builtin diff --git a/tests/BUILD b/tests/BUILD index b5a9b916d134..cb36d7d10055 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1145,16 +1145,6 @@ py_test( ] + py_deps("absl/logging"), ) -py_test( - name = "gfile_cache_test", - srcs = ["gfile_cache_test.py"], - deps = [ - "//jax", - "//jax:gfile_cache", - "//jax:test_util", - ], -) - py_test( name = "lru_cache_test", srcs = ["lru_cache_test.py"], diff --git a/tests/gfile_cache_test.py b/tests/gfile_cache_test.py deleted file mode 100644 index 1ccaaf8ba50f..000000000000 --- a/tests/gfile_cache_test.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2021 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tempfile -import threading - -from absl.testing import absltest - -from jax._src.gfile_cache import GFileCache -import jax._src.test_util as jtu - - -class FileSystemCacheTest(jtu.JaxTestCase): - - def test_get_nonexistent_key(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = GFileCache(tmpdir) - self.assertEqual(cache.get("nonExistentKey"), None) - - def test_put_and_get_key(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = GFileCache(tmpdir) - cache.put("foo", b"bar") - self.assertEqual(cache.get("foo"), b"bar") - - def test_existing_cache_path(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache1 = GFileCache(tmpdir) - cache1.put("foo", b"bar") - del cache1 - cache2 = GFileCache(tmpdir) - self.assertEqual(cache2.get("foo"), b"bar") - - def test_empty_value_put(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = GFileCache(tmpdir) - cache.put("foo", b"") - self.assertEqual(cache.get("foo"), b"") - - def test_empty_key_put(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = GFileCache(tmpdir) - with self.assertRaisesRegex(ValueError, r"key cannot be empty"): - cache.put("", b"bar") - - def test_empty_key_get(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = GFileCache(tmpdir) - with self.assertRaisesRegex(ValueError, r"key cannot be empty"): - cache.get("") - - def test_threads(self): - file_contents1 = "1" * (65536 + 1) - file_contents2 = "2" * (65536 + 1) - - def call_multiple_puts_and_gets(cache): - for _ in range(50): - cache.put("foo", file_contents1.encode("utf-8").strip()) - cache.put("foo", file_contents2.encode("utf-8").strip()) - cache.get("foo") - self.assertEqual( - cache.get("foo"), file_contents2.encode("utf-8").strip() - ) - - with tempfile.TemporaryDirectory() as tmpdir: - cache = GFileCache(tmpdir) - threads = [] - for _ in range(50): - t = threading.Thread(target=call_multiple_puts_and_gets(cache)) - t.start() - threads.append(t) - for t in threads: - t.join() - - self.assertEqual(cache.get("foo"), file_contents2.encode("utf-8").strip()) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader())