Skip to content

Commit

Permalink
fix: do not skip tests, use cache consistently
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Jul 12, 2024
1 parent ef7778c commit 9255c56
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 20 deletions.
4 changes: 2 additions & 2 deletions aviary/cgcnn/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import functools
import itertools
import json
from functools import cache
from typing import TYPE_CHECKING, Any

import numpy as np
Expand Down Expand Up @@ -125,7 +125,7 @@ def __repr__(self) -> str:
return f"{type(self).__name__}({df_repr}, task_dict={self.task_dict})"

# Cache loaded structures
@functools.cache # noqa: B019
@cache # noqa: B019
def __getitem__(self, idx: int):
"""Get an entry out of the Dataset.
Expand Down
4 changes: 2 additions & 2 deletions aviary/roost/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import functools
import json
from functools import cache
from typing import TYPE_CHECKING, Any

import numpy as np
Expand Down Expand Up @@ -76,7 +76,7 @@ def __repr__(self) -> str:
return f"{type(self).__name__}({df_repr}, task_dict={self.task_dict})"

# Cache data for faster training
@functools.cache # noqa: B019
@cache # noqa: B019
def __getitem__(self, idx: int):
"""Get an entry out of the Dataset.
Expand Down
4 changes: 2 additions & 2 deletions aviary/wren/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import functools
import json
import re
from functools import cache
from itertools import groupby
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -90,7 +90,7 @@ def __repr__(self) -> str:
df_repr = f"cols=[{', '.join(self.df.columns)}], len={len(self.df)}"
return f"{type(self).__name__}({df_repr}, task_dict={self.task_dict})"

@functools.cache # noqa: B019
@cache # noqa: B019
def __getitem__(self, idx: int):
"""Get an entry out of the Dataset.
Expand Down
4 changes: 2 additions & 2 deletions aviary/wrenformer/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from functools import lru_cache
from functools import cache
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
Expand Down Expand Up @@ -65,7 +65,7 @@ def collate_batch(
elem_features = json.load(file)


@lru_cache(None)
@cache
def get_wyckoff_features(
equivalent_wyckoff_set: list[tuple], spg_num: int
) -> np.ndarray:
Expand Down
12 changes: 0 additions & 12 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -28,15 +26,6 @@ def test_np_softmax():
assert np.allclose(out.sum(axis=axis), 1)


reason = """
our use of torch.where() requires torch pre-release 1.12.0 which handles type promotion
(pytorch#76691) and avoids:
RuntimeError: expected scalar type long int but found double
skipif CI can be removed once 1.12 is released as stable
"""


@pytest.mark.skipif("CI" in os.environ, reason=reason)
def test_masked_mean():
# test 1d tensor
x1 = torch.arange(5).float()
Expand All @@ -53,7 +42,6 @@ def test_masked_mean():
assert masked_mean(x2, mask2, dim=1) == pytest.approx([1.5, 6])


@pytest.mark.skipif("CI" in os.environ, reason=reason)
def test_masked_std():
# test 1d tensor
x1 = torch.arange(5).float()
Expand Down

0 comments on commit 9255c56

Please sign in to comment.