Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Jul 3, 2024
1 parent 467c62c commit 3f823c0
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 221 deletions.
24 changes: 5 additions & 19 deletions jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,16 @@
zstandard = None

from jax._src import cache_key
from jax._src.compilation_cache_interface import CacheInterface
from jax._src import config
from jax._src import monitoring
from jax._src.gfile_cache import GFileCache
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lru_cache import LRUCache


logger = logging.getLogger(__name__)

_cache: CacheInterface | None = None
_cache: LRUCache | None = None

_cache_initialized: bool = False

Expand All @@ -65,20 +63,8 @@ def set_once_cache_used(f) -> None:
f()


def get_file_cache(path: str) -> tuple[CacheInterface, str] | None:
def get_file_cache(path: str) -> tuple[LRUCache, str] | None:
"""Returns the file cache and the path to the cache."""

def is_local_filesystem(path: str) -> bool:
return path.startswith("file://") or "://" not in path

# `LRUCache` currently only supports local filesystem. Therefore, if `path`
# is not on a local filesystem, instead of using `LRUCache`, we
# fallback to the old `GFileCache`, which does not support LRU eviction.
# TODO(ayx): Add cloud storage support for `LRUCache`, so that all these code
# can be removed.
if not is_local_filesystem(path):
return GFileCache(path), path

max_size = config.compilation_cache_max_size.value
return LRUCache(path, max_size=max_size), path

Expand Down Expand Up @@ -149,7 +135,7 @@ def _initialize_cache() -> None:
logger.debug("Initialized persistent compilation cache at %s", path)


def _get_cache(backend) -> CacheInterface | None:
def _get_cache(backend) -> LRUCache | None:
# TODO(b/289098047): consider making this an API and changing the callers of
# get_executable_and_time() and put_executable_and_time() to call get_cache()
# and passing the result to them.
Expand Down Expand Up @@ -198,7 +184,7 @@ def get_executable_and_time(
logger.debug("get_executable_and_time: cache is disabled/not initialized")
return None, None
executable_and_time = cache.get(cache_key)
if not executable_and_time:
if executable_and_time is not None:
return None, None

executable_and_time = decompress_executable(executable_and_time)
Expand Down Expand Up @@ -274,7 +260,7 @@ def reset_cache() -> None:
global _cache_initialized
global _cache_used
logger.info("Resetting cache at %s.",
_cache._path if _cache is not None else "<empty>")
_cache.path if _cache is not None else "<empty>")
_cache = None
with _cache_initialized_mutex:
_cache_initialized = False
Expand Down
30 changes: 0 additions & 30 deletions jax/_src/compilation_cache_interface.py

This file was deleted.

59 changes: 0 additions & 59 deletions jax/_src/gfile_cache.py

This file was deleted.

31 changes: 12 additions & 19 deletions jax/_src/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,15 @@

import heapq
import logging
import pathlib
import time
import warnings

from jax._src.compilation_cache_interface import CacheInterface


try:
import filelock
except ImportError:
filelock = None

from jax._src import path as pathlib

logger = logging.getLogger(__name__)

Expand All @@ -36,7 +33,7 @@
_ATIME_SUFFIX = "-atime"


class LRUCache(CacheInterface):
class LRUCache:
"""Bounded cache with least-recently-used (LRU) eviction policy.
This implementation includes cache reading, writing and eviction
Expand All @@ -56,29 +53,29 @@ def __init__(self, path: str, *, max_size: int, lock_timeout_secs: float | None
indicates no limit, allowing the cache size to grow indefinitely.
lock_timeout_secs: (optional) The timeout for acquiring a file lock.
"""
# TODO(ayx): add support for cloud other filesystems such as GCS
if not self._is_local_filesystem(path):
raise NotImplementedError("LRUCache only supports local filesystem at this time.")
if not self._is_local_filesystem(path) and not pathlib.epath_installed:
raise RuntimeError("Please install the `etils[epath]` package to specify a cache directory on a non-local filesystem")

self.path = pathlib.Path(path)
self.path.mkdir(parents=True, exist_ok=True)

# TODO(ayx): having a `self._path` is required by the base class
# `CacheInterface`, but the base class can be removed after `LRUCache`
# and the original `GFileCache` are unified
self._path = self.path

self.eviction_enabled = max_size != -1 # no eviction if `max_size` is set to -1

if self.eviction_enabled:
if filelock is None:
raise RuntimeError("Please install filelock package to set `jax_compilation_cache_max_size`")
raise RuntimeError("Please install the `filelock` package to set `jax_compilation_cache_max_size`")

self.max_size = max_size
self.lock_timeout_secs = lock_timeout_secs

self.lock_path = self.path / ".lockfile"
self.lock = filelock.FileLock(self.lock_path)

if self._is_local_filesystem(path):
FileLock = filelock.FileLock
else:
FileLock = filelock.SoftFileLock

self.lock = FileLock(self.lock_path)

def get(self, key: str) -> bytes | None:
"""Retrieves the cached value for the given key.
Expand Down Expand Up @@ -199,10 +196,6 @@ def _evict_if_needed(self, *, additional_size: int = 0) -> None:

dir_size -= file_size

# See comments in `jax.src.compilation_cache.get_file_cache()` for details.
# TODO(ayx): This function has a duplicate in that place, and there is
# redundancy here. However, this code is temporary, and once the issue
# is fixed, this code can be removed.
@staticmethod
def _is_local_filesystem(path: str) -> bool:
return path.startswith("file://") or "://" not in path
2 changes: 2 additions & 0 deletions jax/_src/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

try:
import etils.epath as epath
epath_installed = True
except:
epath = None
epath_installed = False

# If etils.epath (aka etils[epath] to pip) is present, we prefer it because it
# can read and write to, e.g., GCS buckets. Otherwise we use the builtin
Expand Down
7 changes: 3 additions & 4 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from jax._src import path as pathlib
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.compilation_cache_interface import CacheInterface
from jax._src.lib import xla_client
from jax._src.maps import xmap
from jax.experimental.pjit import pjit
Expand All @@ -63,11 +62,11 @@ def increment_event_count(event):
_counts[event] += 1


class InMemoryCache(CacheInterface):
class InMemoryCache:
"""An in-memory cache for testing purposes."""

# not used, but required by `CacheInterface`
_path = pathlib.Path()
# not used, but needed to simulate an `LRUCache` object
path = pathlib.Path()

def __init__(self):
self._cache: dict[str, bytes] = {}
Expand Down
90 changes: 0 additions & 90 deletions tests/gfile_cache_test.py

This file was deleted.

0 comments on commit 3f823c0

Please sign in to comment.