Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve compilation cache tests #21982

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 53 additions & 38 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections import Counter
from functools import partial
import math
import os
import platform
import tempfile
import unittest
from unittest import mock
from unittest import SkipTest
Expand All @@ -34,8 +34,10 @@
from jax._src import config
from jax._src import distributed
from jax._src import monitoring
from jax._src import path as pathlib
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.compilation_cache_interface import CacheInterface
from jax._src.lib import xla_client
from jax._src.maps import xmap
from jax.experimental.pjit import pjit
Expand All @@ -60,20 +62,47 @@ def tearDownModule():
def increment_event_count(event):
_counts[event] += 1


class InMemoryCache(CacheInterface):
"""An in-memory cache for testing purposes."""

# not used, but required by `CacheInterface`
_path = pathlib.Path()

def __init__(self):
self._cache: dict[str, bytes] = {}

def get(self, key: str) -> bytes | None:
return self._cache.get(key)

def put(self, key: str, value: bytes) -> None:
self._cache[key] = value

def clear(self) -> None:
self._cache.clear()

def __len__(self) -> int:
return len(self._cache)


def count_cache_items() -> int:
return 0 if cc._cache is None else len(cc._cache)


def clear_cache() -> None:
if cc._cache is not None:
cc._cache.clear()


class CompilationCacheTestCase(jtu.JaxTestCase):
tmpdir: str

def setUp(self):
super().setUp()
cc.reset_cache()
tmpdir = tempfile.TemporaryDirectory()
self.enter_context(tmpdir)
self.enter_context(config.compilation_cache_dir(tmpdir.name))
self.tmpdir = tmpdir.name
cc._cache = InMemoryCache()

def tearDown(self):
cc.reset_cache()
self.tmpdir = ""
super().tearDown()


Expand Down Expand Up @@ -158,38 +187,32 @@ def test_pmap(self):
f = pmap(lambda x: x - lax.psum(x, "i"), axis_name="i")
x = np.arange(jax.device_count(), dtype=np.int64)
f(x)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)
x = np.arange(jax.device_count(), dtype=np.float32)
f(x)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 2)
self.assertEqual(count_cache_items(), 2)
# TODO: create a test for calling pmap with the same input more than once

def test_jit(self):
f = jit(lambda x: x * x)
f(1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)
f(1.0)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 2)
self.assertEqual(count_cache_items(), 2)

def test_xla_autofdo_profile_version(self):
original_profile_version = config.jax_xla_profile_version.value
with config.jax_xla_profile_version(original_profile_version + 1):
f = jit(lambda x: x * x)
f(1)
files_in_cache_directory = os.listdir(self.tmpdir)
self.assertLen(files_in_cache_directory, 1)
self.assertEqual(count_cache_items(), 1)
# Clear the cache directory, then update the profile version and execute
# again. The in-memory caches should be invalidated and a new persistent
# cache entry created.
os.unlink(os.path.join(self.tmpdir, files_in_cache_directory[0]))
clear_cache()
with config.jax_xla_profile_version(original_profile_version + 2):
f(1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)

@jtu.with_mesh([("x", 2)])
def test_pjit(self):
Expand All @@ -200,12 +223,10 @@ def f(x, y):
shape = (8, 8)
x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape)
f(x, x + 1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f(x, x + 1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 2)
self.assertEqual(count_cache_items(), 2)

@jtu.with_mesh([("x", 2)])
def test_xmap(self):
Expand All @@ -219,14 +240,12 @@ def f(x):
xmap(
f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"}
)(x)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 1)
self.assertEqual(count_cache_items(), 1)
x = np.arange(8, dtype=np.float32).reshape((2, 2, 2))
xmap(
f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"}
)(x)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 2)
self.assertEqual(count_cache_items(), 2)

def test_cache_write_warning(self):
f = jit(lambda x: x * x)
Expand Down Expand Up @@ -280,8 +299,7 @@ def test_min_entry_size(self):
config.persistent_cache_min_entry_size_bytes(1048576), # 1MiB
):
jit(lambda x: x + 1)(1)
files_in_cache = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_cache, 0)
self.assertEqual(count_cache_items(), 0)

def test_min_compile_time(self):
with (
Expand All @@ -291,14 +309,12 @@ def test_min_compile_time(self):
# Mock time to progress in small intervals so compilation time is small.
with mock.patch("time.monotonic", side_effect=np.arange(0, 10, 0.1)):
jit(lambda x: x + 1)(1)
files_in_cache = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_cache, 0)
self.assertEqual(count_cache_items(), 0)

# Mock time to progress in large intervals so compilation time is large.
with mock.patch("time.monotonic", side_effect=np.arange(0, 100, 10)):
jit(lambda x: x + 2)(1)
files_in_cache = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_cache, 1)
self.assertEqual(count_cache_items(), 1)

# This is perhaps related to mocking time.monotonic?
@unittest.skipIf(platform.system() == "Windows", "Test fails on Windows")
Expand Down Expand Up @@ -408,7 +424,7 @@ def test_cache_write_with_process_restriction(self, process_id):
):
jit(lambda x: x + 1)(1)

files_in_directory = len(os.listdir(self.tmpdir))
files_in_directory = count_cache_items()
if process_id == 0:
self.assertEqual(files_in_directory, 1)
elif process_id == 1:
Expand Down Expand Up @@ -446,8 +462,7 @@ def test_jit(self):
with config.enable_compilation_cache(False):
f = jit(lambda x: x * x)
f(1)
files_in_directory = len(os.listdir(self.tmpdir))
self.assertEqual(files_in_directory, 0)
self.assertEqual(count_cache_items(), 0)


if __name__ == "__main__":
Expand Down