Skip to content

Commit

Permalink
[NSETM-2312] Implement cache.skip_features, to skip writing the featu…
Browse files Browse the repository at this point in the history
…res DataFrames (#38)
  • Loading branch information
GianlucaFicarelli committed Apr 24, 2024
1 parent 53bce0b commit 5602fd3
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 12 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ New Features

- Add ``cache.readonly``, to be able to use an existing cache without exclusive locking [NSETM-2310].
- Add ``cache.store_type``, to change the file format (experimental).
- Add ``cache.skip_features``, to skip writing the features DataFrames (not implemented yet).
- Add ``cache.skip_features``, to skip writing the features DataFrames [NSETM-2312].

Deprecations
~~~~~~~~~~~~
Expand Down
12 changes: 10 additions & 2 deletions src/blueetl/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def from_config(
cls,
global_config: dict,
base_path: StrOrPath,
extra_params: dict[str, Any],
extra_params: Optional[dict[str, Any]] = None,
) -> "MultiAnalyzer":
"""Initialize the MultiAnalyzer from the given configuration.
Expand Down Expand Up @@ -223,7 +223,9 @@ def _init_analyzers(self) -> dict[str, Analyzer]:
}

@classmethod
def from_file(cls, path: StrOrPath, extra_params: dict[str, Any]) -> "MultiAnalyzer":
def from_file(
cls, path: StrOrPath, extra_params: Optional[dict[str, Any]] = None
) -> "MultiAnalyzer":
"""Return a new instance loaded using the given configuration file."""
return cls.from_config(
global_config=load_yaml(path),
Expand Down Expand Up @@ -337,6 +339,7 @@ def run_from_file(
show: bool = False,
clear_cache: Optional[bool] = None,
readonly_cache: Optional[bool] = None,
skip_features_cache: Optional[bool] = None,
loglevel: Optional[int] = None,
) -> MultiAnalyzer:
"""Initialize and return the MultiAnalyzer.
Expand All @@ -353,11 +356,15 @@ def run_from_file(
readonly_cache: if None, use the value from the configuration file. Otherwise:
if True, use the existing cache if possible, or raise an error;
if False, use the existing cache if possible, or update it.
skip_features_cache: if None, use the value from the configuration file. Otherwise:
if True, do not write the features to the cache;
if False, write the features to the cache after calculating them.
loglevel: if specified, used to set up logging.
Returns:
a new MultiAnalyzer instance.
"""
# pylint: disable=too-many-arguments
if loglevel is not None:
setup_logging(loglevel=loglevel, force=True)
if seed is not None:
Expand All @@ -368,6 +375,7 @@ def run_from_file(
extra_params={
"clear_cache": clear_cache,
"readonly_cache": readonly_cache,
"skip_features_cache": skip_features_cache,
},
)
if extract:
Expand Down
9 changes: 8 additions & 1 deletion src/blueetl/apps/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
help="If True, use the existing cache if possible, or raise an error if not.",
default=None,
)
@click.option(
"--skip-features-cache/--no-skip-features-cache",
help="If True, do not write the features to the cache.",
default=None,
)
@click.option("-i", "--interactive/--no-interactive", help="Start an interactive IPython shell.")
@click.option("-v", "--verbose", count=True, help="-v for INFO, -vv for DEBUG")
def run(
Expand All @@ -35,11 +40,12 @@ def run(
show,
clear_cache,
readonly_cache,
skip_features_cache,
interactive,
verbose,
):
"""Run the analysis."""
# pylint: disable=unused-variable,unused-import,import-outside-toplevel,too-many-arguments
# pylint: disable=unused-variable,unused-import,import-outside-toplevel,too-many-arguments,too-many-locals
loglevel = (logging.WARNING, logging.INFO, logging.DEBUG)[min(verbose, 2)]
# assign the result to a local variable to make it available in the interactive shell
ma = run_from_file( # noqa
Expand All @@ -50,6 +56,7 @@ def run(
show=show,
clear_cache=clear_cache,
readonly_cache=readonly_cache,
skip_features_cache=skip_features_cache,
loglevel=loglevel,
)
if interactive:
Expand Down
16 changes: 13 additions & 3 deletions src/blueetl/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,15 @@ def __init__(
}
store_class = store_classes[cache_config.store_type]

self.readonly = cache_config.readonly
self._readonly = cache_config.readonly
self._skip_features = cache_config.skip_features

self._version = 1
self._repo_store = store_class(repo_dir)
self._features_store = store_class(features_dir)

self._lock_manager: LockManagerProtocol = LockManager(self._output_dir)
self._lock_manager.lock(mode=LockManager.LOCK_SH if self.readonly else LockManager.LOCK_EX)
self._lock_manager.lock(mode=LockManager.LOCK_SH if self._readonly else LockManager.LOCK_EX)

self._cached_analysis_config_path = config_dir / "analysis_config.cached.yaml"
self._cached_simulations_config_path = config_dir / "simulations_config.cached.yaml"
Expand Down Expand Up @@ -205,7 +207,7 @@ def __setstate__(self, state: dict) -> None:
"""Set the object state when the object is unpickled."""
self.__dict__.update(state)
# The unpickled object must always be readonly, even when the pickled object isn't.
self.readonly = True
self._readonly = True
# A new lock isn't created in the subprocess b/c we want to be able to read the cache.
self._lock_manager = DummyLockManager()

Expand All @@ -221,6 +223,11 @@ def close(self) -> None:
"""
self._lock_manager.unlock()

@property
def readonly(self) -> bool:
"""Return True if the cache manager is set to read-only, False otherwise."""
return self._readonly

@_raise_if(locked=False)
def to_readonly(self) -> "CacheManager":
"""Return a read-only copy of the object.
Expand Down Expand Up @@ -539,6 +546,9 @@ def dump_features(
features_dict: dict of features to be written.
features_config: configuration dict of the features to be written.
"""
if self._skip_features:
L.info("Skipping writing features to cache")
return
config_checksum = features_config.checksum()
old_checksums = self._cached_checksums["features"].pop(config_checksum, None)
new_checksums = self._cached_checksums["features"][config_checksum] = {}
Expand Down
8 changes: 5 additions & 3 deletions src/blueetl/config/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy
from itertools import chain
from pathlib import Path
from typing import Any, NamedTuple, Union
from typing import Any, NamedTuple, Optional, Union

from blueetl.config.analysis_model import (
FeaturesConfig,
Expand All @@ -28,6 +28,8 @@ def _override_params(global_config: dict[str, Any], extra_params: dict[str, Any]
cache_config["clear"] = value
if (value := extra_params.get("readonly_cache")) is not None:
cache_config["readonly"] = value
if (value := extra_params.get("skip_features_cache")) is not None:
cache_config["skip_features"] = value
return global_config


Expand Down Expand Up @@ -176,10 +178,10 @@ def _resolve_analysis_configs(global_config: MultiAnalysisConfig) -> None:


def init_multi_analysis_configuration(
global_config: dict, base_path: Path, extra_params: dict[str, Any]
global_config: dict, base_path: Path, extra_params: Optional[dict[str, Any]]
) -> MultiAnalysisConfig:
"""Return a config object from a config dict."""
global_config = _override_params(global_config, extra_params)
global_config = _override_params(global_config, extra_params or {})
validate_config(global_config, schema=read_schema("analysis_config"))
config = MultiAnalysisConfig.model_validate(global_config)
_resolve_paths(config, base_path=base_path)
Expand Down
14 changes: 13 additions & 1 deletion tests/unit/apps/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,30 @@
from blueetl.apps import run as test_module


@pytest.mark.parametrize("skip_features_cache", [True, False, None])
@pytest.mark.parametrize("readonly_cache", [True, False, None])
@pytest.mark.parametrize("clear_cache", [True, False, None])
@pytest.mark.parametrize("show", [True, False])
@pytest.mark.parametrize("calculate", [True, False])
@pytest.mark.parametrize("extract", [True, False])
@patch(test_module.__name__ + ".run_from_file")
def test_run(mock_run_from_file, tmp_path, extract, calculate, show, clear_cache, readonly_cache):
def test_run(
mock_run_from_file,
tmp_path,
extract,
calculate,
show,
clear_cache,
readonly_cache,
skip_features_cache,
):
options_dict = {
"extract": extract,
"calculate": calculate,
"show": show,
"clear-cache": clear_cache,
"readonly-cache": readonly_cache,
"skip-features-cache": skip_features_cache,
}
options = [f"--{k}" if v else f"--no-{k}" for k, v in options_dict.items() if v is not None]
analysis_config_file = "config.yaml"
Expand All @@ -41,6 +52,7 @@ def test_run(mock_run_from_file, tmp_path, extract, calculate, show, clear_cache
show=show,
clear_cache=clear_cache,
readonly_cache=readonly_cache,
skip_features_cache=skip_features_cache,
loglevel=logging.DEBUG,
)

Expand Down
14 changes: 13 additions & 1 deletion tests/unit/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,23 @@ def _prepare_env(path):
)


@pytest.mark.parametrize("skip_features_cache", [True, False, None])
@pytest.mark.parametrize("readonly_cache", [True, False, None])
@pytest.mark.parametrize("clear_cache", [True, False, None])
@pytest.mark.parametrize("show", [True, False])
@pytest.mark.parametrize("calculate", [True, False])
@pytest.mark.parametrize("extract", [True, False])
@patch.object(test_module.MultiAnalyzer, "from_file")
def test_run_from_file(from_file, tmp_path, extract, calculate, show, clear_cache, readonly_cache):
def test_run_from_file(
from_file,
tmp_path,
extract,
calculate,
show,
clear_cache,
readonly_cache,
skip_features_cache,
):
analysis_config_file = tmp_path / "config.yaml"
analysis_config_file.write_text("---")

Expand All @@ -41,13 +51,15 @@ def test_run_from_file(from_file, tmp_path, extract, calculate, show, clear_cach
show=show,
clear_cache=clear_cache,
readonly_cache=readonly_cache,
skip_features_cache=skip_features_cache,
)

from_file.assert_called_once_with(
analysis_config_file,
extra_params={
"clear_cache": clear_cache,
"readonly_cache": readonly_cache,
"skip_features_cache": skip_features_cache,
},
)
assert instance.extract_repo.call_count == int(extract)
Expand Down

0 comments on commit 5602fd3

Please sign in to comment.