Skip to content

Commit

Permalink
Ensure that repo and features are filtered if needed (#34)
Browse files Browse the repository at this point in the history
It can happen:
- when they are loaded from the cache using a more restrictive filter
- when apply_filter is executed
  • Loading branch information
GianlucaFicarelli committed Apr 12, 2024
1 parent 370ba6d commit 6c64821
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 42 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ Bug Fixes
Improvements
~~~~~~~~~~~~

- Filter features only when ``apply_filter`` is called to save some time.
- Improve logging in ``utils.timed()``.
- Filter features only when ``apply_filter`` is called to save some time, but ensure that repo and features are filtered when they are loaded from the cache using a more restrictive filter.
- Improve logging in ``blueetl.utils.timed()``.
- Improve tests coverage.

Version 0.8.2
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Bio-Informatics",
]
dependencies = [
"blueetl-core>=0.1.0",
"blueetl-core>=0.2.3",
"bluepysnap>=1.0.7",
"click>=8",
"jsonschema>=4.0",
Expand Down
58 changes: 48 additions & 10 deletions src/blueetl/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,7 @@ def _check_config_cache(self) -> bool:
return False

# check the criteria used to filter the simulations
actual_filter = self._analysis_configs.actual.simulations_filter
cached_filter = self._analysis_configs.cached.simulations_filter
if not is_subfilter(actual_filter, cached_filter):
if not self._is_subfilter(strict=False):
# the filter is less specific, so the cache cannot be used
self._invalidate_cached_checksums()
return False
Expand Down Expand Up @@ -394,6 +392,15 @@ def _initialize_cache(self) -> None:
self._dump_simulations_config()
self._dump_cached_checksums()

@_raise_if(locked=False)
def is_repo_cached(self, name: str) -> bool:
"""Return whether a specific repo dataframe is present in the cache."""
# the checksums have been checked in _initialize_cache/_delete_cached_repo_files,
# so they are not calculate again here
return bool(
self._cached_checksums["repo"].get(name) and self._repo_store.path(name).is_file()
)

@_raise_if(locked=False)
def load_repo(self, name: str) -> Optional[pd.DataFrame]:
"""Load a specific repo dataframe from the cache.
Expand All @@ -404,10 +411,7 @@ def load_repo(self, name: str) -> Optional[pd.DataFrame]:
Returns:
The loaded dataframe, or None if it's not cached.
"""
is_cached = bool(self._cached_checksums["repo"].get(name))
L.debug("The repository %s is cached: %s", name, is_cached)
# the checksums have been checked in _initialize_cache/_delete_cached_repo_files,
# so they are not calculate again here
is_cached = self.is_repo_cached(name)
return self._repo_store.load(name) if is_cached else None

@_raise_if(readonly=True)
Expand All @@ -419,7 +423,6 @@ def dump_repo(self, df: pd.DataFrame, name: str) -> None:
df: dataframe to be saved.
name: name of the repo dataframe.
"""
L.info("Writing cached %s", name)
self._repo_store.dump(df, name)
self._cached_checksums["repo"][name] = self._repo_store.checksum(name)
self._dump_cached_checksums()
Expand All @@ -431,7 +434,6 @@ def get_cached_features_checksums(
"""Return the cached features checksums, or an empty dict if the cache doesn't exist."""
config_checksum = features_config.checksum()
cached = self._cached_checksums["features"].get(config_checksum, {})
L.debug("The features %s are cached: %s", config_checksum[:8], bool(cached))
return cached

@_raise_if(locked=False)
Expand Down Expand Up @@ -469,7 +471,6 @@ def dump_features(
features_dict: dict of features to be written.
features_config: configuration dict of the features to be written.
"""
L.info("Writing cached features")
config_checksum = features_config.checksum()
old_checksums = self._cached_checksums["features"].pop(config_checksum, None)
new_checksums = self._cached_checksums["features"][config_checksum] = {}
Expand All @@ -484,3 +485,40 @@ def dump_features(
len(set(new_checksums).difference(old_checksums)) == 0
), "Some features have been found only in the new cached data"
self._dump_cached_checksums()

def _is_subfilter(self, strict: bool) -> bool:
"""Check whether the actual filter is more or less specific than the cached filter.
Args:
strict: affects the result only when the two filters Actual and Cached are equal.
If True, the filter Actual isn't considered a subfilter of Cached.
If False, the filter Actual is considered a subfilter of Cached.
Returns:
True if the actual filter is more specific than the cached filter.
False if the actual filter is less specific than the cached filter.
"""
if not self._analysis_configs.cached:
return False
actual_filter = self._analysis_configs.actual.simulations_filter
cached_filter = self._analysis_configs.cached.simulations_filter
return is_subfilter(actual_filter, cached_filter, strict=strict)

@_raise_if(locked=False)
def repo_cache_needs_filter(self, name: str) -> bool:
"""Return True if the cached repo needs to be filtered.
This happens when the cache is used, but the actual filter
is more specific than the cached filter.
"""
return self.is_repo_cached(name) and self._is_subfilter(strict=True)

@_raise_if(locked=False)
def features_cache_needs_filter(self, features_config: FeaturesConfig) -> bool:
"""Return True if the cached features need to be filtered.
This happens when the cache is used, but the actual filter
is more specific than the cached filter.
"""
cached_checksums = self.get_cached_features_checksums(features_config)
return len(cached_checksums) > 0 and self._is_subfilter(strict=True)
3 changes: 2 additions & 1 deletion src/blueetl/config/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def _resolve_trial_steps(global_config: MultiAnalysisConfig, base_path: Path):
"""
for config in global_config.analysis.values():
for trial_steps_config in config.extraction.trial_steps.values():
trial_steps_config.base_path = str(global_config.output)
if not trial_steps_config.base_path:
trial_steps_config.base_path = global_config.output
if path := trial_steps_config.node_sets_file:
path = base_path / path
trial_steps_config.node_sets_file = path
Expand Down
3 changes: 2 additions & 1 deletion src/blueetl/config/analysis_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def dict(self, *args, by_alias=True, **kwargs):
def json(self, *args, sort_keys=False, **kwargs):
"""Generate a JSON representation of the model, using by_alias=True by default."""
# use json.dumps because model_dump_json in pydantic v2 doesn't support sort_keys
return json.dumps(self.dict(*args, **kwargs), sort_keys=sort_keys)
return json.dumps(self.dict(*args, **kwargs), sort_keys=sort_keys, default=str)

def dump(self, path: Path) -> None:
"""Dump the model to file in yaml format."""
Expand Down Expand Up @@ -92,6 +92,7 @@ class TrialStepsConfig(BaseModel):
node_sets_file: Optional[Path] = None
node_sets_checksum: Optional[str] = None # to invalidate the cache when the file changes
limit: Optional[int] = None
base_path: Optional[Path] = None # can be used in the function calculating the trial steps

@model_validator(mode="after")
def forbid_fields(self):
Expand Down
2 changes: 1 addition & 1 deletion src/blueetl/external/bnac/calculate_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def calculate_features_multi(repo, key, df, params):
"smoothed_3ms_spike_times_max_normalised_hist_1ms_bin"
],
}
).rename_axis(BIN)
).rename_axis(index=BIN)

return {
"by_gid": by_gid,
Expand Down
11 changes: 8 additions & 3 deletions src/blueetl/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,11 @@ def _group_features_by_attributes() -> tuple[

def _process_cached_features(cached: list[FeaturesConfig]) -> None:
for n, features_config in enumerate(cached, 1):
query = None
if self._repo.cache_manager.features_cache_needs_filter(features_config):
query = {SIMULATION_ID: self._repo.simulation_ids}
df_dict = self.cache_manager.load_features(features_config=features_config)
features = _calculate_cached(features_config, df_dict)
features = _calculate_cached(features_config, df_dict, query=query)
_process_features(features_config, features)
_log_features(features, n, len(cached), features_config.id)

Expand Down Expand Up @@ -347,10 +350,12 @@ def _dataframes_to_features(


def _calculate_cached(
features_config: FeaturesConfig, df_dict: dict[str, pd.DataFrame]
features_config: FeaturesConfig,
df_dict: dict[str, pd.DataFrame],
query: Optional[dict],
) -> dict[str, Feature]:
"""Load cached features from a dict of DataFrames."""
return _dataframes_to_features(df_dict, config=features_config, cached=True, query=None)
return _dataframes_to_features(df_dict, config=features_config, cached=True, query=query)


def _calculate_new(
Expand Down
57 changes: 34 additions & 23 deletions src/blueetl/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def extract_new(self) -> ExtractorT:
"""Instantiate an object from the configuration."""

@abstractmethod
def extract_cached(self, df: pd.DataFrame) -> ExtractorT:
def extract_cached(self, df: pd.DataFrame, name: str) -> ExtractorT:
"""Instantiate an object from a cached DataFrame."""

def extract(self, name: str) -> ExtractorT:
Expand All @@ -51,7 +51,7 @@ def extract(self, name: str) -> ExtractorT:
with timed(L.debug, f"Extracting {name}") as messages:
df = self._repo.cache_manager.load_repo(name)
if df is not None:
instance = self.extract_cached(df)
instance = self.extract_cached(df, name)
else:
instance = self.extract_new()
assert instance is not None, "The extraction didn't return a valid instance."
Expand All @@ -73,9 +73,12 @@ def extract_new(self) -> Simulations:
query=self._repo.simulations_filter,
)

def extract_cached(self, df: pd.DataFrame) -> Simulations:
def extract_cached(self, df: pd.DataFrame, name: str) -> Simulations:
"""Instantiate an object from a cached DataFrame."""
return Simulations.from_pandas(df, query=self._repo.simulations_filter, cached=True)
query = None
if self._repo.needs_filter(name):
query = self._repo.simulations_filter
return Simulations.from_pandas(df, query=query, cached=True)


class NeuronsExtractor(BaseExtractor[Neurons]):
Expand All @@ -88,10 +91,10 @@ def extract_new(self) -> Neurons:
neuron_classes=self._repo.extraction_config.neuron_classes,
)

def extract_cached(self, df: pd.DataFrame) -> Neurons:
def extract_cached(self, df: pd.DataFrame, name: str) -> Neurons:
"""Instantiate an object from a cached DataFrame."""
query = {}
if self._repo.simulations_filter:
query = None
if self._repo.needs_filter(name):
selected_sims = self._repo.simulations.df.etl.q(simulation_id=self._repo.simulation_ids)
query = {CIRCUIT_ID: sorted(set(selected_sims[CIRCUIT_ID]))}
return Neurons.from_pandas(df, query=query, cached=True)
Expand All @@ -106,10 +109,10 @@ def extract_new(self) -> NeuronClasses:
neurons=self._repo.neurons, neuron_classes=self._repo.extraction_config.neuron_classes
)

def extract_cached(self, df: pd.DataFrame) -> NeuronClasses:
def extract_cached(self, df: pd.DataFrame, name: str) -> NeuronClasses:
"""Instantiate an object from a cached DataFrame."""
query = {}
if self._repo.simulations_filter:
query = None
if self._repo.needs_filter(name):
selected_sims = self._repo.simulations.df.etl.q(simulation_id=self._repo.simulation_ids)
query = {CIRCUIT_ID: sorted(set(selected_sims[CIRCUIT_ID]))}
return NeuronClasses.from_pandas(df, query=query, cached=True)
Expand All @@ -128,10 +131,10 @@ def extract_new(self) -> Windows:
resolver=self._repo.resolver,
)

def extract_cached(self, df: pd.DataFrame) -> Windows:
def extract_cached(self, df: pd.DataFrame, name: str) -> Windows:
"""Instantiate an object from a cached DataFrame."""
query = {}
if self._repo.simulations_filter:
query = None
if self._repo.needs_filter(name):
query = {SIMULATION_ID: self._repo.simulation_ids}
return Windows.from_pandas(df, query=query, cached=True)

Expand All @@ -149,10 +152,10 @@ def extract_new(self) -> Spikes:
name=self._repo.extraction_config.report.name,
)

def extract_cached(self, df: pd.DataFrame) -> Spikes:
def extract_cached(self, df: pd.DataFrame, name: str) -> Spikes:
"""Instantiate an object from a cached DataFrame."""
query = {}
if self._repo.simulations_filter:
query = None
if self._repo.needs_filter(name):
query = {SIMULATION_ID: self._repo.simulation_ids}
return Spikes.from_pandas(df, query=query, cached=True)

Expand All @@ -170,10 +173,10 @@ def extract_new(self) -> SomaReport:
name=self._repo.extraction_config.report.name,
)

def extract_cached(self, df: pd.DataFrame) -> SomaReport:
def extract_cached(self, df: pd.DataFrame, name: str) -> SomaReport:
"""Instantiate an object from a cached DataFrame."""
query = {}
if self._repo.simulations_filter:
query = None
if self._repo.needs_filter(name):
query = {SIMULATION_ID: self._repo.simulation_ids}
return SomaReport.from_pandas(df, query=query, cached=True)

Expand All @@ -191,10 +194,10 @@ def extract_new(self) -> CompartmentReport:
name=self._repo.extraction_config.report.name,
)

def extract_cached(self, df: pd.DataFrame) -> CompartmentReport:
def extract_cached(self, df: pd.DataFrame, name: str) -> CompartmentReport:
"""Instantiate an object from a cached DataFrame."""
query = {}
if self._repo.simulations_filter:
query = None
if self._repo.needs_filter(name):
query = {SIMULATION_ID: self._repo.simulation_ids}
return CompartmentReport.from_pandas(df, query=query, cached=True)

Expand Down Expand Up @@ -374,6 +377,10 @@ def apply_filter(self, simulations_filter: dict[str, Any]) -> "Repository":
"""Apply the given filter and return a new object."""
return FilteredRepository(parent=self, simulations_filter=simulations_filter)

def needs_filter(self, name: str) -> bool:
"""Return True if the repository needs to be filtered during the extraction."""
return bool(self.simulations_filter) and self.cache_manager.repo_cache_needs_filter(name)


class FilteredRepository(Repository):
"""FilteredRepository class."""
Expand All @@ -396,5 +403,9 @@ def _assign_from_dataframes(self, dicts: dict[str, pd.DataFrame]) -> None:
"""Assign the repository properties from the given dict of DataFrames."""
for name, df in dicts.items():
assert name not in self.__dict__
value = self._mapping[name](self).extract_cached(df)
value = self._mapping[name](self).extract_cached(df, name)
setattr(self, name, value)

def needs_filter(self, name: str) -> bool:
"""Return True if the repository needs to be filtered during the extraction."""
return bool(self.simulations_filter)
1 change: 1 addition & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def repo(global_config):
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
extraction_config = global_config.analysis["spikes"].extraction
cache_manager = PicklableMock(
is_repo_cached=PicklableMock(return_value=False),
load_repo=PicklableMock(return_value=None),
load_features=PicklableMock(return_value=None),
get_cached_features_checksums=PicklableMock(return_value={}),
Expand Down

0 comments on commit 6c64821

Please sign in to comment.