Skip to content

Commit

Permalink
Merge pull request #10771 from sshahrokhi:gfilecache
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 454692872
  • Loading branch information
jax authors committed Jun 13, 2022
2 parents 1089c79 + 498ee60 commit b174b77
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 237 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.

## jax 0.3.14 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...main).
* Breaking changes
* {func}`jax.experimental.compilation_cache.initialize_cache` does not support
`max_cache_size_ bytes` anymore and will not get that as an input.
* Changes
* {func}`jax.numpy.linalg.slogdet` now accepts an optional `method` argument
that allows selection between an LU-decomposition based implementation and
Expand Down Expand Up @@ -50,6 +53,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
In a future release, this will become an error. An example of an unsafe implicit
cast is `jnp.zeros(4, dtype=int).at[0].set(1.5)`, in which `1.5` previously was
silently truncated to `1`.
* {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs
bucket path as input.

## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
Expand Down
8 changes: 3 additions & 5 deletions jax/experimental/compilation_cache/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,19 @@
from typing import List, Optional

import jax
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
from jax.experimental.compilation_cache.gfile_cache import GFileCache
import jax._src.lib
from jax._src.lib import xla_client
from absl import logging

_cache = None

def initialize_cache(path, max_cache_size_bytes=32 * 2**30):
def initialize_cache(path):
"""Creates a global cache object. Should only be called once per process.
max_cache_sixe defaults to 32GiB.
"""
global _cache
assert _cache == None, f"The cache path has already been initialized to {_cache._path}"
_cache = FileSystemCache(path, max_cache_size_bytes)
_cache = GFileCache(path)
logging.warning("Initialized persistent compilation cache at %s", path)

def get_executable(xla_computation, compile_options, backend) -> Optional[xla_client.Executable]:
Expand Down
82 changes: 0 additions & 82 deletions jax/experimental/compilation_cache/file_system_cache.py

This file was deleted.

57 changes: 57 additions & 0 deletions jax/experimental/compilation_cache/gfile_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pathlib

from jax.experimental.compilation_cache.cache_interface import CacheInterface
from etils import epath
from absl import logging

class GFileCache(CacheInterface):

def __init__(self, path: str):
"""Sets up a cache at 'path'. Cached values may already be present."""
self._path = epath.Path(path)
self._path.mkdir(parents=True, exist_ok=True)

def get(self, key: str):
"""Returns None if 'key' isn't present."""
if not key:
raise ValueError("key cannot be empty")
path_to_key = self._path / key
if path_to_key.exists():
return path_to_key.read_bytes()
else:
return None

def put(self, key: str, value: bytes):
"""Adds new cache entry."""
if not key:
raise ValueError("key cannot be empty")
path_to_new_file = self._path / key
if str(path_to_new_file).startswith('gs://'):
# Writes to gcs are atomic.
path_to_new_file.write_bytes(value)
elif str(path_to_new_file).startswith('file://') or '://' not in str(path_to_new_file):
tmp_path = self._path / f"_temp_{key}"
with open(str(tmp_path), "wb") as f:
f.write(value)
f.flush()
os.fsync(f.fileno())
os.rename(tmp_path, path_to_new_file)
else:
tmp_path = self._path / f"_temp_{key}"
tmp_path.write_bytes(value)
tmp_path.rename(str(path_to_new_file))
150 changes: 0 additions & 150 deletions tests/file_system_cache_test.py

This file was deleted.

Loading

0 comments on commit b174b77

Please sign in to comment.