Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Jun 20, 2024
1 parent de8fd3b commit 57fd15c
Showing 1 changed file with 51 additions and 38 deletions.
89 changes: 51 additions & 38 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
from collections import Counter
from functools import partial
import math
import os
import platform
import tempfile
import unittest
from unittest import mock
from unittest import SkipTest
Expand All @@ -34,8 +32,10 @@
from jax._src import config
from jax._src import distributed
from jax._src import monitoring
from jax._src import path as pathlib
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.compilation_cache_interface import CacheInterface
from jax._src.lib import xla_client
from jax._src.maps import xmap
from jax.experimental.pjit import pjit
Expand All @@ -60,20 +60,47 @@ def tearDownModule():
def increment_event_count(event):
_counts[event] += 1


class InMemoryCache(CacheInterface):
'''An in-memory cache for testing purposes.'''

# not used, but required by `CacheInterface`
_path = pathlib.Path()

def __init__(self):
self._cache: dict[str, bytes] = {}

def get(self, key: str) -> bytes | None:
return self._cache.get(key)

def put(self, key: str, value: bytes) -> None:
self._cache[key] = value

def clear(self) -> None:
self._cache.clear()

def __len__(self) -> int:
return len(self._cache)


def count_cache_items() -> int:
return 0 if cc._cache is None else len(cc._cache)


def clear_cache() -> None:
if cc._cache is not None:
cc._cache.clear()


class CompilationCacheTestCase(jtu.JaxTestCase):
tmpdir: str

def setUp(self):
super().setUp()
cc.reset_cache()
tmpdir = tempfile.TemporaryDirectory()
self.enter_context(tmpdir)
self.enter_context(config.compilation_cache_dir(tmpdir.name))
self.tmpdir = tmpdir.name
cc._cache = InMemoryCache()

def tearDown(self):
cc.reset_cache()
self.tmpdir = ""
super().tearDown()


Expand Down Expand Up @@ -158,38 +185,32 @@ def test_pmap(self):
f = pmap(lambda x: x - lax.psum(x, "i"), axis_name="i")
x = np.arange(jax.device_count(), dtype=np.int64)
f(x)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)
x = np.arange(jax.device_count(), dtype=np.float32)
f(x)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 2)
self.assertEqual(count_cache_items(), 2)
# TODO: create a test for calling pmap with the same input more than once

def test_jit(self):
f = jit(lambda x: x * x)
f(1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)
f(1.0)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 2)
self.assertEqual(count_cache_items(), 2)

def test_xla_autofdo_profile_version(self):
original_profile_version = config.jax_xla_profile_version.value
with config.jax_xla_profile_version(original_profile_version + 1):
f = jit(lambda x: x * x)
f(1)
files_in_cache_directory = os.listdir(self.tmpdir)
self.assertLen(files_in_cache_directory, 1)
self.assertEqual(count_cache_items(), 1)
# Clear the cache directory, then update the profile version and execute
# again. The in-memory caches should be invalidated and a new persistent
# cache entry created.
os.unlink(os.path.join(self.tmpdir, files_in_cache_directory[0]))
clear_cache()
with config.jax_xla_profile_version(original_profile_version + 2):
f(1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)

@jtu.with_mesh([("x", 2)])
def test_pjit(self):
Expand All @@ -200,12 +221,10 @@ def f(x, y):
shape = (8, 8)
x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape)
f(x, x + 1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f(x, x + 1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 2)
self.assertEqual(count_cache_items(), 2)

@jtu.with_mesh([("x", 2)])
def test_xmap(self):
Expand All @@ -219,14 +238,12 @@ def f(x):
xmap(
f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"}
)(x)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)
x = np.arange(8, dtype=np.float32).reshape((2, 2, 2))
xmap(
f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"}
)(x)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 2)
self.assertEqual(count_cache_items(), 2)

def test_cache_write_warning(self):
f = jit(lambda x: x * x)
Expand Down Expand Up @@ -280,8 +297,7 @@ def test_min_entry_size(self):
config.persistent_cache_min_entry_size_bytes(1048576), # 1MiB
):
jit(lambda x: x + 1)(1)
files_in_cache = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_cache, 0)
self.assertEqual(count_cache_items(), 0)

def test_min_compile_time(self):
with (
Expand All @@ -291,14 +307,12 @@ def test_min_compile_time(self):
# Mock time to progress in small intervals so compilation time is small.
with mock.patch("time.monotonic", side_effect=np.arange(0, 10, 0.1)):
jit(lambda x: x + 1)(1)
files_in_cache = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_cache, 0)
self.assertEqual(count_cache_items(), 0)

# Mock time to progress in large intervals so compilation time is large.
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
jit(lambda x: x + 2)(1)
files_in_cache = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_cache, 1)
self.assertEqual(count_cache_items(), 1)

# This is perhaps related to mocking time.monotonic?
@unittest.skipIf(platform.system() == "Windows", "Test fails on Windows")
Expand Down Expand Up @@ -408,7 +422,7 @@ def test_cache_write_with_process_restriction(self, process_id):
):
jit(lambda x: x + 1)(1)

files_in_directory = len(os.listdir(self.tmpdir))
files_in_directory = count_cache_items()
if process_id == 0:
self.assertEqual(files_in_directory, 1)
elif process_id == 1:
Expand Down Expand Up @@ -446,8 +460,7 @@ def test_jit(self):
with config.enable_compilation_cache(False):
f = jit(lambda x: x * x)
f(1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 0)
self.assertEqual(count_cache_items(), 0)


if __name__ == "__main__":
Expand Down

0 comments on commit 57fd15c

Please sign in to comment.