Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using GFile for persistent compilation caching in JAX #10771

Merged
merged 1 commit into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.

## jax 0.3.14 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...main).
* Breaking changes
* {func}`jax.experimental.compilation_cache.initialize_cache` does not support `max_cache_size_ bytes` anymore and will not get that as an input.
* Changes
* {func}`jax.numpy.linalg.slogdet` now accepts an optional `method` argument
that allows selection between an LU-decomposition based implementation and
Expand Down Expand Up @@ -45,6 +47,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
traces as an alternative to the Tensorboard UI.
* Added a `jax.named_scope` context manager that adds profiler metadata to
Python programs (similar to `jax.named_call`).
* {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs
bucket path as input.

## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
Expand Down
8 changes: 3 additions & 5 deletions jax/experimental/compilation_cache/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,19 @@
from typing import List, Optional

import jax
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
from jax.experimental.compilation_cache.gfile_cache import GFileCache
import jax._src.lib
from jax._src.lib import xla_client
from absl import logging

_cache = None

def initialize_cache(path, max_cache_size_bytes=32 * 2**30):
def initialize_cache(path):
"""Creates a global cache object. Should only be called once per process.

max_cache_sixe defaults to 32GiB.
"""
sshahrokhi marked this conversation as resolved.
Show resolved Hide resolved
global _cache
assert _cache == None, f"The cache path has already been initialized to {_cache._path}"
_cache = FileSystemCache(path, max_cache_size_bytes)
_cache = GFileCache(path)
logging.warning("Initialized persistent compilation cache at %s", path)

def get_executable(xla_computation, compile_options, backend) -> Optional[xla_client.Executable]:
Expand Down
82 changes: 0 additions & 82 deletions jax/experimental/compilation_cache/file_system_cache.py

This file was deleted.

57 changes: 57 additions & 0 deletions jax/experimental/compilation_cache/gfile_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pathlib

from jax.experimental.compilation_cache.cache_interface import CacheInterface
from etils import epath
from absl import logging

class GFileCache(CacheInterface):

def __init__(self, path: str):
"""Sets up a cache at 'path'. Cached values may already be present."""
self._path = epath.Path(path)
self._path.mkdir(parents=True, exist_ok=True)

def get(self, key: str):
"""Returns None if 'key' isn't present."""
if not key:
raise ValueError("key cannot be empty")
path_to_key = self._path / key
if path_to_key.exists():
return path_to_key.read_bytes()
else:
return None

def put(self, key: str, value: bytes):
"""Adds new cache entry."""
if not key:
raise ValueError("key cannot be empty")
path_to_new_file = self._path / key
if str(path_to_new_file).startswith('gs://'):
# Writes to gcs are atomic.
path_to_new_file.write_bytes(value)
sshahrokhi marked this conversation as resolved.
Show resolved Hide resolved
elif str(path_to_new_file).startswith('file://') or '://' not in str(path_to_new_file):
tmp_path = self._path / f"_temp_{key}"
with open(str(tmp_path), "wb") as f:
f.write(value)
f.flush()
os.fsync(f.fileno())
os.rename(tmp_path, path_to_new_file)
else:
tmp_path = self._path / f"_temp_{key}"
tmp_path.write_bytes(value)
sshahrokhi marked this conversation as resolved.
Show resolved Hide resolved
tmp_path.rename(str(path_to_new_file))
150 changes: 0 additions & 150 deletions tests/file_system_cache_test.py

This file was deleted.

Loading