Skip to content

Commit

Permalink
Improve compilation cache tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Jun 27, 2024
1 parent de8fd3b commit bc7addf
Showing 1 changed file with 53 additions and 38 deletions.
91 changes: 53 additions & 38 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections import Counter
from functools import partial
import math
import os
import platform
import tempfile
import unittest
from unittest import mock
from unittest import SkipTest
Expand All @@ -34,8 +34,10 @@
from jax._src import config
from jax._src import distributed
from jax._src import monitoring
from jax._src import path as pathlib
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.compilation_cache_interface import CacheInterface
from jax._src.lib import xla_client
from jax._src.maps import xmap
from jax.experimental.pjit import pjit
Expand All @@ -60,20 +62,47 @@ def tearDownModule():
def increment_event_count(event):
_counts[event] += 1


class InMemoryCache(CacheInterface):
"""An in-memory cache for testing purposes."""

# not used, but required by `CacheInterface`
_path = pathlib.Path()

def __init__(self):
self._cache: dict[str, bytes] = {}

def get(self, key: str) -> bytes | None:
return self._cache.get(key)

def put(self, key: str, value: bytes) -> None:
self._cache[key] = value

def clear(self) -> None:
self._cache.clear()

def __len__(self) -> int:
return len(self._cache)


def count_cache_items() -> int:
return 0 if cc._cache is None else len(cc._cache)


def clear_cache() -> None:
if cc._cache is not None:
cc._cache.clear()


class CompilationCacheTestCase(jtu.JaxTestCase):
tmpdir: str

def setUp(self):
super().setUp()
cc.reset_cache()
tmpdir = tempfile.TemporaryDirectory()
self.enter_context(tmpdir)
self.enter_context(config.compilation_cache_dir(tmpdir.name))
self.tmpdir = tmpdir.name
cc._cache = InMemoryCache()

def tearDown(self):
cc.reset_cache()
self.tmpdir = ""
super().tearDown()


Expand Down Expand Up @@ -158,38 +187,32 @@ def test_pmap(self):
f = pmap(lambda x: x - lax.psum(x, "i"), axis_name="i")
x = np.arange(jax.device_count(), dtype=np.int64)
f(x)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)
x = np.arange(jax.device_count(), dtype=np.float32)
f(x)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 2)
self.assertEqual(count_cache_items(), 2)
# TODO: create a test for calling pmap with the same input more than once

def test_jit(self):
f = jit(lambda x: x * x)
f(1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)
f(1.0)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 2)
self.assertEqual(count_cache_items(), 2)

def test_xla_autofdo_profile_version(self):
original_profile_version = config.jax_xla_profile_version.value
with config.jax_xla_profile_version(original_profile_version + 1):
f = jit(lambda x: x * x)
f(1)
files_in_cache_directory = os.listdir(self.tmpdir)
self.assertLen(files_in_cache_directory, 1)
self.assertEqual(count_cache_items(), 1)
# Clear the cache directory, then update the profile version and execute
# again. The in-memory caches should be invalidated and a new persistent
# cache entry created.
os.unlink(os.path.join(self.tmpdir, files_in_cache_directory[0]))
clear_cache()
with config.jax_xla_profile_version(original_profile_version + 2):
f(1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)

@jtu.with_mesh([("x", 2)])
def test_pjit(self):
Expand All @@ -200,12 +223,10 @@ def f(x, y):
shape = (8, 8)
x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape)
f(x, x + 1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f(x, x + 1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 2)
self.assertEqual(count_cache_items(), 2)

@jtu.with_mesh([("x", 2)])
def test_xmap(self):
Expand All @@ -219,14 +240,12 @@ def f(x):
xmap(
f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"}
)(x)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)
x = np.arange(8, dtype=np.float32).reshape((2, 2, 2))
xmap(
f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"}
)(x)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 2)
self.assertEqual(count_cache_items(), 2)

def test_cache_write_warning(self):
f = jit(lambda x: x * x)
Expand Down Expand Up @@ -280,8 +299,7 @@ def test_min_entry_size(self):
config.persistent_cache_min_entry_size_bytes(1048576), # 1MiB
):
jit(lambda x: x + 1)(1)
files_in_cache = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_cache, 0)
self.assertEqual(count_cache_items(), 0)

def test_min_compile_time(self):
with (
Expand All @@ -291,14 +309,12 @@ def test_min_compile_time(self):
# Mock time to progress in small intervals so compilation time is small.
with mock.patch("time.monotonic", side_effect=np.arange(0, 10, 0.1)):
jit(lambda x: x + 1)(1)
files_in_cache = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_cache, 0)
self.assertEqual(count_cache_items(), 0)

# Mock time to progress in large intervals so compilation time is large.
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
jit(lambda x: x + 2)(1)
files_in_cache = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_cache, 1)
self.assertEqual(count_cache_items(), 1)

# This is perhaps related to mocking time.monotonic?
@unittest.skipIf(platform.system() == "Windows", "Test fails on Windows")
Expand Down Expand Up @@ -408,7 +424,7 @@ def test_cache_write_with_process_restriction(self, process_id):
):
jit(lambda x: x + 1)(1)

files_in_directory = len(os.listdir(self.tmpdir))
files_in_directory = count_cache_items()
if process_id == 0:
self.assertEqual(files_in_directory, 1)
elif process_id == 1:
Expand Down Expand Up @@ -446,8 +462,7 @@ def test_jit(self):
with config.enable_compilation_cache(False):
f = jit(lambda x: x * x)
f(1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 0)
self.assertEqual(count_cache_items(), 0)


if __name__ == "__main__":
Expand Down

0 comments on commit bc7addf

Please sign in to comment.