Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed May 30, 2024
1 parent 47420a3 commit c7999c8
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 6 deletions.
4 changes: 2 additions & 2 deletions jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from jax._src.compilation_cache_interface import CacheInterface
from jax._src import config
from jax._src import monitoring
from jax._src.gfile_cache import GFileCache
from jax._src.lru_cache import LRUCache
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir

Expand Down Expand Up @@ -65,7 +65,7 @@ def set_once_cache_used(f) -> None:

def get_file_cache(path: str) -> tuple[CacheInterface, str] | None:
"""Returns the file cache and the path to the cache."""
return GFileCache(path), path
return LRUCache(path), path


def set_cache_dir(path) -> None:
Expand Down
2 changes: 0 additions & 2 deletions jax/_src/compilation_cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@

from abc import abstractmethod

from jax._src import path as pathlib
from jax._src import util


class CacheInterface(util.StrictABC):
_path: pathlib.Path

@abstractmethod
def get(self, key: str):
Expand Down
7 changes: 5 additions & 2 deletions jax/_src/gfile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
import os

from jax._src import path as pathlib
from jax._src.compilation_cache_interface import CacheInterface

class GFileCache(CacheInterface):

# TODO (ayx): This class will be ultimately removed after `lru_cache.py` is
# finished. It exists because the current `lru_cache.py` does not support
# `gs://`.
class GFileCacheImpl:

def __init__(self, path: str):
"""Sets up a cache at 'path'. Cached values may already be present."""
Expand Down
134 changes: 134 additions & 0 deletions jax/_src/lru_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
import os
import tensorflow as tf
from typing import Any

from jax._src.compilation_cache_interface import CacheInterface
from jax._src.gfile_cache import GFileCacheImpl


filelock: Any = None
try:
import filelock
except ImportError:
pass


logger = logging.getLogger(__name__)


def is_path_gs(path: str) -> bool:
return path.startswith("gs://")


def get_size(path: str) -> int:
return tf.io.gfile.stat(path).length


def get_mtime(path: str) -> int:
return tf.io.gfile.stat(path).mtime_nsec


def update_mtime(path: str) -> None:
if not is_path_gs(path):
os.utime(path) # set mtime (and also atime) to current time
else:
tmp_path = f"{path}.tmp"
tf.io.gfile.rename(path, tmp_path)
tf.io.gfile.rename(tmp_path, path)


class LRUCacheImpl:

def __init__(self, cache_dir: str, max_cache_size: int, timeout=10):
if filelock is None:
raise RuntimeError("Please install filelock to use the LRUCache")

self.cache_dir = cache_dir
self.max_cache_size = max_cache_size
os.makedirs(cache_dir, exist_ok=True)
self.timeout = timeout

self.lock_file = os.path.join(self.cache_dir, ".lockfile")
self.lock = filelock.FileLock(self.lock_file)

def get(self, key: str) -> bytes | None:
with self.lock.acquire(timeout=self.timeout):
file_path = os.path.join(self.cache_dir, key)

if not os.path.exists(file_path):
logger.debug("Cache miss")
return None

logger.debug("Cache hit")
update_mtime(file_path)
with open(file_path, "rb") as f:
return f.read()

def put(self, key: str, val: bytes) -> None:
with self.lock.acquire(timeout=self.timeout):
file_path = os.path.join(self.cache_dir, key)
if os.path.exists(file_path):
return

self._evict_if_needed()
with open(file_path, "wb") as f:
f.write(val)
update_mtime(file_path)

def _evict_if_needed(self) -> None:
if self.max_cache_size == -1:
return # max_cache_size == -1: no limit on cache size

files = os.listdir(self.cache_dir)

mtime_path_sizes = []
cache_dir_size = 0
for filename in files:
file_path = os.path.join(self.cache_dir, filename)

file_size = get_size(file_path)
file_mtime = get_mtime(file_path)

cache_dir_size += file_size
mtime_path_sizes.append((file_mtime, file_path, file_size))

# sort by mtime, descending
mtime_path_sizes.sort(key=lambda xyz: xyz[0], reverse=True)

while cache_dir_size >= self.max_cache_size:
file_mtime, file_path, file_size = mtime_path_sizes[-1]
os.remove(file_path)
cache_dir_size -= file_size
mtime_path_sizes.pop()

class LRUCache(CacheInterface):

def __init__(self, path: str):
if is_path_gs(path):
# gs:// does not support cache eviction yet
self.cache = GFileCacheImpl(path=path)
else:
self.cache = LRUCacheImpl(cache_dir=path, max_cache_size=1000000)

def get(self, key: str) -> bytes | None:
return self.cache.get(key)

def put(self, key: str, value: bytes) -> None:
self.cache.put(key, value)
50 changes: 50 additions & 0 deletions tests/lru_cache_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import string
import tempfile

from absl.testing import absltest

from jax._src.lru_cache import LRUCacheImpl
import jax._src.test_util as jtu


class LRUCacheTest(jtu.JaxTestCase):

def test_cache_eviction(self):
def generate_random_k():
# simulate keys with many collisions
return random.choice(string.ascii_lowercase[:12])

def generate_random_v(k):
# simulate large values while ensures that one k corresponds to one v
return bytes(f"{k}abcdefghijklmnopqrstuvwxyz" * 8192, encoding="utf-8")

with tempfile.TemporaryDirectory() as tmpdirname:
cache = LRUCacheImpl(tmpdirname, 884700)

for _ in range(100000):
k = generate_random_k()

if cache.get(k) is not None:
pass # cache hit
else:
v = generate_random_v(k)
cache.put(k, v)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit c7999c8

Please sign in to comment.