Skip to content

Commit

Permalink
Merge pull request #22253 from ayaka14732:lru-cache-5
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 649412314
  • Loading branch information
jax authors committed Jul 4, 2024
2 parents 061ccd4 + 6c05aa2 commit 9214ace
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 213 deletions.
12 changes: 1 addition & 11 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ pytype_strict_library(
":cache_key",
":compilation_cache_interface",
":config",
":gfile_cache",
":lru_cache",
":monitoring",
":path",
Expand Down Expand Up @@ -421,7 +420,7 @@ pytype_strict_library(
srcs = ["_src/lru_cache.py"],
deps = [
":compilation_cache_interface",
":config",
":path",
] + py_deps("filelock"),
)

Expand Down Expand Up @@ -515,15 +514,6 @@ pytype_strict_library(
] + py_deps("numpy"),
)

pytype_strict_library(
name = "gfile_cache",
srcs = ["_src/gfile_cache.py"],
deps = [
":compilation_cache_interface",
":path",
],
)

pytype_strict_library(
name = "hardware_utils",
srcs = ["_src/hardware_utils.py"],
Expand Down
21 changes: 4 additions & 17 deletions jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
zstandard = None

from jax._src import cache_key
from jax._src.compilation_cache_interface import CacheInterface
from jax._src import config
from jax._src import monitoring
from jax._src.gfile_cache import GFileCache
from jax._src.compilation_cache_interface import CacheInterface
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lru_cache import LRUCache
Expand Down Expand Up @@ -67,18 +66,6 @@ def set_once_cache_used(f) -> None:

def get_file_cache(path: str) -> tuple[CacheInterface, str] | None:
"""Returns the file cache and the path to the cache."""

def is_local_filesystem(path: str) -> bool:
return path.startswith("file://") or "://" not in path

# `LRUCache` currently only supports local filesystem. Therefore, if `path`
# is not on a local filesystem, instead of using `LRUCache`, we
# fallback to the old `GFileCache`, which does not support LRU eviction.
# TODO(ayx): Add cloud storage support for `LRUCache`, so that all these code
# can be removed.
if not is_local_filesystem(path):
return GFileCache(path), path

max_size = config.compilation_cache_max_size.value
return LRUCache(path, max_size=max_size), path

Expand Down Expand Up @@ -161,14 +148,14 @@ def _get_cache(backend) -> CacheInterface | None:
return _cache


def compress_executable(executable):
def compress_executable(executable: bytes) -> bytes:
if zstandard:
compressor = zstandard.ZstdCompressor()
return compressor.compress(executable)
else:
return zlib.compress(executable)

def decompress_executable(executable):
def decompress_executable(executable: bytes) -> bytes:
if zstandard:
decompressor = zstandard.ZstdDecompressor()
return decompressor.decompress(executable)
Expand Down Expand Up @@ -198,7 +185,7 @@ def get_executable_and_time(
logger.debug("get_executable_and_time: cache is disabled/not initialized")
return None, None
executable_and_time = cache.get(cache_key)
if not executable_and_time:
if executable_and_time is None:
return None, None

executable_and_time = decompress_executable(executable_and_time)
Expand Down
2 changes: 0 additions & 2 deletions jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,6 @@ def _compile_and_share_module(
) -> xc.LoadedExecutable:
share_timeout = config.share_binary_between_hosts_timeout_ms.value

# TODO: We need a proper eviction protocol here, otherwise all compiled
# modules will pile in memory.
if cache_key in _compile_and_share_module.modules_cache:
return _compile_and_share_module.modules_cache[cache_key]

Expand Down
59 changes: 0 additions & 59 deletions jax/_src/gfile_cache.py

This file was deleted.

44 changes: 20 additions & 24 deletions jax/_src/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,16 @@

import heapq
import logging
import pathlib
import time
import warnings

from jax._src.compilation_cache_interface import CacheInterface


try:
import filelock
except ImportError:
filelock = None

from jax._src import path as pathlib
from jax._src.compilation_cache_interface import CacheInterface

logger = logging.getLogger(__name__)

Expand All @@ -36,6 +34,10 @@
_ATIME_SUFFIX = "-atime"


def _is_local_filesystem(path: str) -> bool:
return path.startswith("file://") or "://" not in path


class LRUCache(CacheInterface):
"""Bounded cache with least-recently-used (LRU) eviction policy.
Expand All @@ -56,29 +58,26 @@ def __init__(self, path: str, *, max_size: int, lock_timeout_secs: float | None
indicates no limit, allowing the cache size to grow indefinitely.
lock_timeout_secs: (optional) The timeout for acquiring a file lock.
"""
# TODO(ayx): add support for cloud other filesystems such as GCS
if not self._is_local_filesystem(path):
raise NotImplementedError("LRUCache only supports local filesystem at this time.")
if not _is_local_filesystem(path) and not pathlib.epath_installed:
raise RuntimeError("Please install the `etils[epath]` package to specify a cache directory on a non-local filesystem")

self.path = pathlib.Path(path)
self.path = self._path = pathlib.Path(path)
self.path.mkdir(parents=True, exist_ok=True)

# TODO(ayx): having a `self._path` is required by the base class
# `CacheInterface`, but the base class can be removed after `LRUCache`
# and the original `GFileCache` are unified
self._path = self.path

self.eviction_enabled = max_size != -1 # no eviction if `max_size` is set to -1

if self.eviction_enabled:
if filelock is None:
raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`")
raise RuntimeError("Please install the `filelock` package to set `jax_compilation_cache_max_size`")

self.max_size = max_size
self.lock_timeout_secs = lock_timeout_secs

self.lock_path = self.path / ".lockfile"
self.lock = filelock.FileLock(self.lock_path)
if _is_local_filesystem(path):
self.lock = filelock.FileLock(self.lock_path)
else:
self.lock = filelock.SoftFileLock(self.lock_path)

def get(self, key: str) -> bytes | None:
"""Retrieves the cached value for the given key.
Expand Down Expand Up @@ -173,7 +172,12 @@ def _evict_if_needed(self, *, additional_size: int = 0) -> None:
h: list[tuple[int, str, int]] = []
dir_size = 0
for cache_path in self.path.glob(f"*{_CACHE_SUFFIX}"):
file_size = cache_path.stat().st_size
file_stat = cache_path.stat()

# `pathlib` and `etils[epath]` have different API for obtaining the size
# of a file, and we need to support them both.
# See also https://github.com/google/etils/issues/630
file_size = file_stat.st_size if not pathlib.epath_installed else file_stat.length # pytype: disable=attribute-error

key = cache_path.name.removesuffix(_CACHE_SUFFIX)
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"
Expand All @@ -198,11 +202,3 @@ def _evict_if_needed(self, *, additional_size: int = 0) -> None:
atime_path.unlink()

dir_size -= file_size

# See comments in `jax.src.compilation_cache.get_file_cache()` for details.
# TODO(ayx): This function has a duplicate in that place, and there is
# redundancy here. However, this code is temporary, and once the issue
# is fixed, this code can be removed.
@staticmethod
def _is_local_filesystem(path: str) -> bool:
return path.startswith("file://") or "://" not in path
2 changes: 2 additions & 0 deletions jax/_src/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

try:
import etils.epath as epath
epath_installed = True
except:
epath = None
epath_installed = False

# If etils.epath (aka etils[epath] to pip) is present, we prefer it because it
# can read and write to, e.g., GCS buckets. Otherwise we use the builtin
Expand Down
10 changes: 0 additions & 10 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1145,16 +1145,6 @@ py_test(
] + py_deps("absl/logging"),
)

py_test(
name = "gfile_cache_test",
srcs = ["gfile_cache_test.py"],
deps = [
"//jax",
"//jax:gfile_cache",
"//jax:test_util",
],
)

py_test(
name = "lru_cache_test",
srcs = ["lru_cache_test.py"],
Expand Down
90 changes: 0 additions & 90 deletions tests/gfile_cache_test.py

This file was deleted.

0 comments on commit 9214ace

Please sign in to comment.