diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index eb2f00da4..990acdf30 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -1,39 +1,101 @@ -import os import datetime import functools +import logging import numbers +import os +from pathlib import Path +import shutil +import subprocess import time +from typing import List, Dict, Literal + import confuse import numpy as np +import numpy.typing as npt import pandas as pd import pyarrow as pa import scipy.stats import sympy.parsing.sympy_parser -import subprocess -import shutil -import logging + from gempyor import file_paths -from typing import List, Dict + logger = logging.getLogger(__name__) config = confuse.Configuration("flepiMoP", read=False) -def write_df(fname: str, df: pd.DataFrame, extension: str = ""): - """write without index, so assume the index has been put a column""" - # cast to str to use .split in case fname is a PosixPath - fname = str(fname) - if extension: # Empty strings are falsy in python - fname = f"{fname}.{extension}" - extension = fname.split(".")[-1] - if extension == "csv": - df.to_csv(fname, index=False) - elif extension == "parquet": - df = pa.Table.from_pandas(df, preserve_index=False) - pa.parquet.write_table(df, fname) - else: - raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'") +def write_df( + fname: str | bytes | os.PathLike, + df: pd.DataFrame, + extension: Literal[None, "", "csv", "parquet"] = "", +) -> None: + """Writes a pandas DataFrame without its index to a file. + + Writes a pandas DataFrame to either a CSV or Parquet file without its index and can + infer which format to use based on the extension given in `fname` or based on + explicit `extension`. + + Args: + fname: The name of the file to write to. + df: A pandas DataFrame whose contents to write, but without its index. + extension: A user specified extension to use for the file if not contained in + `fname` already. + + Returns: + None + + Raises: + NotImplementedError: The given output extension is not supported yet. + """ + # Decipher the path given + fname = fname.decode() if isinstance(fname, bytes) else fname + path = Path(f"{fname}.{extension}") if extension else Path(fname) + # Write df to either a csv or parquet or raise if an invalid extension + if path.suffix == ".csv": + return df.to_csv(path, index=False) + elif path.suffix == ".parquet": + return df.to_parquet(path, index=False, engine="pyarrow") + raise NotImplementedError( + f"Invalid extension {path.suffix[1:]}. Must be 'csv' or 'parquet'." + ) + + +def read_df( + fname: str | bytes | os.PathLike, + extension: Literal[None, "", "csv", "parquet"] = "", +) -> pd.DataFrame: + """Reads a pandas DataFrame from either a CSV or Parquet file. + + Reads a pandas DataFrame to either a CSV or Parquet file and can infer which format + to use based on the extension given in `fname` or based on explicit `extension`. If + the file being read is a csv with a column called 'subpop' then that column will be + cast as a string. + + Args: + fname: The name of the file to read from. + extension: A user specified extension to use for the file if not contained in + `fname` already. + + Returns: + A pandas DataFrame parsed from the file given. + + Raises: + NotImplementedError: The given output extension is not supported yet. + """ + # Decipher the path given + fname = fname.decode() if isinstance(fname, bytes) else fname + path = Path(f"{fname}.{extension}") if extension else Path(fname) + # Read df from either a csv or parquet or raise if an invalid extension + if path.suffix == ".csv": + return pd.read_csv( + path, converters={"subpop": lambda x: str(x)}, skipinitialspace=True + ) + elif path.suffix == ".parquet": + return pd.read_parquet(path, engine="pyarrow") + raise NotImplementedError( + f"Invalid extension {path.suffix[1:]}. Must be 'csv' or 'parquet'." + ) def command_safe_run(command, command_name="mycommand", fail_on_fail=True): @@ -60,23 +122,6 @@ def command_safe_run(command, command_name="mycommand", fail_on_fail=True): return sr.returncode, stdout, stderr -def read_df(fname: str, extension: str = "") -> pd.DataFrame: - """Load a dataframe from a file, agnostic to whether it is a parquet or a csv. The extension - can be provided as an argument or it is infered""" - fname = str(fname) - if extension: # Empty strings are falsy in python - fname = f"{fname}.{extension}" - extension = fname.split(".")[-1] - if extension == "csv": - # The converter prevents e.g leading geoid (0600) to be converted as int; and works when the column is absent - df = pd.read_csv(fname, converters={"subpop": lambda x: str(x)}, skipinitialspace=True) - elif extension == "parquet": - df = pa.parquet.read_table(fname).to_pandas() - else: - raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'") - return df - - def add_method(cls): "Decorator to add a method to a class" @@ -206,14 +251,50 @@ def as_evaled_expression(self): raise ValueError(f"expected numeric or string expression [got: {value}]") -def get_truncated_normal(*, mean=0, sd=1, a=0, b=10): - "Returns the truncated normal distribution" +def get_truncated_normal( + mean: float | int = 0, + sd: float | int = 1, + a: float | int = 0, + b: float | int = 10, +) -> scipy.stats._distn_infrastructure.rv_frozen: + """Returns a truncated normal distribution. + + This function constructs a truncated normal distribution with the specified + mean, standard deviation, and bounds. The truncated normal distribution is + a normal distribution bounded within the interval [a, b]. + + Args: + mean: The mean of the truncated normal distribution. Defaults to 0. + sd: The standard deviation of the truncated normal distribution. Defaults to 1. + a: The lower bound of the truncated normal distribution. Defaults to 0. + b: The upper bound of the truncated normal distribution. Defaults to 10. + + Returns: + rv_frozen: A frozen instance of the truncated normal distribution with the + specified parameters. + """ + lower = (a - mean) / sd + upper = (b - mean) / sd + return scipy.stats.truncnorm(lower, upper, loc=mean, scale=sd) - return scipy.stats.truncnorm((a - mean) / sd, (b - mean) / sd, loc=mean, scale=sd) +def get_log_normal( + meanlog: float | int, + sdlog: float | int, +) -> scipy.stats._distn_infrastructure.rv_frozen: + """Returns a log normal distribution. -def get_log_normal(meanlog, sdlog): - "Returns the log normal distribution" + This function constructs a log normal distribution with the specified + log mean and log standard deviation. + + Args: + meanlog: The log of the mean of the log normal distribution. + sdlog: The log of the standard deviation of the log normal distribution. + + Returns: + rv_frozen: A frozen instance of the log normal distribution with the + specified parameters. + """ return scipy.stats.lognorm(s=sdlog, scale=np.exp(meanlog), loc=0) @@ -263,45 +344,91 @@ def as_random_distribution(self): return functools.partial(np.random.uniform, self.as_evaled_expression(), self.as_evaled_expression(),) -def list_filenames(folder: str = ".", filters: list = []) -> list: - """ - return the list of all filename and path in the provided folders. - If filters [list] is provided, then only the files that contains each of the - substrings in filter will be returned. Example to get all hosp file: - ``` - gempyor.utils.list_filenames(folder="model_output/", filters=["hosp"]) - ``` - and be sure we only get parquet: - ``` - gempyor.utils.list_filenames(folder="model_output/", filters=["hosp" , ".parquet"]) - ``` - """ - from pathlib import Path - - fn_list = [] - for f in Path(str(folder)).rglob(f"*"): - if f.is_file(): # not a folder - f = str(f) - if not filters: - fn_list.append(f) - else: - if all(c in f for c in filters): - fn_list.append(str(f)) - else: - pass - return fn_list +def list_filenames( + folder: str | bytes | os.PathLike = ".", + filters: str | list[str] = [], +) -> list[str]: + """Return the list of all filenames and paths in the provided folder. + + This function lists all files in the specified folder and its subdirectories. + If filters are provided, only the files containing each of the substrings + in the filters will be returned. + + Example: + To get all files containing "hosp": + ``` + gempyor.utils.list_filenames( + folder="model_output/", + filters=["hosp"], + ) + ``` + + To get only "hosp" files with a ".parquet" extension: + ``` + gempyor.utils.list_filenames( + folder="model_output/", + filters=["hosp", ".parquet"], + ) + ``` + Args: + folder: The directory to search for files. Defaults to the current directory. + filters: A string or a list of strings to filter filenames. Only files + containing all the provided substrings will be returned. Defaults to an + empty list. -def rolling_mean_pad(data, window): + Returns: + A list of strings representing the paths to the files that match the filters. """ - Calculates rolling mean with centered window and pads the edges. + filters = [filters] if not isinstance(filters, list) else filters + filters = filters if len(filters) else [""] + folder = Path(folder.decode() if isinstance(folder, bytes) else folder) + files = [ + str(file) + for file in folder.rglob("*") + if file.is_file() and all(f in str(file) for f in filters) + ] + return files + + +def rolling_mean_pad( + data: npt.NDArray[np.number], + window: int, +) -> npt.NDArray[np.number]: + """ + Calculates the column-wise rolling mean with centered window. Args: - data: A NumPy array !!! shape must be (n_days, nsubpops). + data: A two dimensional numpy array, typically the row dimension is time and the + column dimension is subpop. window: The window size for the rolling mean. Returns: - A NumPy array with the padded rolling mean (n_days, nsubpops). + A two dimensional numpy array that is the same shape as `data`. + + Examples: + Below is a brief set of examples showcasing how to smooth a metric, like + hospitalizations, using this function. + + >>> import numpy as np + >>> from gempyor.utils import rolling_mean_pad + >>> hospitalizations = np.arange(1., 29.).reshape((7, 4)) + >>> hospitalizations + array([[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.], + [13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.], + [25., 26., 27., 28.]]) + >>> rolling_mean_pad(hospitalizations, 5) + array([[ 3.4, 4.4, 5.4, 6.4], + [ 5.8, 6.8, 7.8, 8.8], + [ 9. , 10. , 11. , 12. ], + [13. , 14. , 15. , 16. ], + [17. , 18. , 19. , 20. ], + [20.2, 21.2, 22.2, 23.2], + [22.6, 23.6, 24.6, 25.6]]) """ padding_size = (window - 1) // 2 padded_data = np.pad(data, ((padding_size, padding_size), (0, 0)), mode="edge") diff --git a/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py b/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py new file mode 100644 index 000000000..367a7f550 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest +import scipy.stats + +from gempyor.utils import get_log_normal + + +class TestGetLogNormal: + """Unit tests for the `gempyor.utils.get_log_normal` function.""" + + @pytest.mark.parametrize( + "meanlog,sdlog", + [ + (1.0, 1.0), + (0.0, 2.0), + (10.0, 30.0), + (0.33, 4.56), + (9.87, 4.21), + (1, 1), + (0, 2), + (10, 30), + ], + ) + def test_construct_distribution( + self, + meanlog: float | int, + sdlog: float | int, + ) -> None: + """Test the construction of a log normal distribution. + + This test checks whether the `get_log_normal` function correctly constructs + a log normal distribution with the specified parameters. It verifies that + the returned object is an instance of `rv_frozen`, and that its support and + parameters (log mean and log standard deviation) are correctly set. + + Args: + mean: The mean of the truncated normal distribution. + sd: The standard deviation of the truncated normal distribution. + a: The lower bound of the truncated normal distribution. + b: The upper bound of the truncated normal distribution. + """ + dist = get_log_normal(meanlog=meanlog, sdlog=sdlog) + assert isinstance(dist, scipy.stats._distn_infrastructure.rv_frozen) + lower, upper = dist.support() + assert np.isclose(lower, 0.0) + assert np.isclose(upper, np.inf) + assert np.isclose(dist.kwds.get("s"), sdlog) + assert np.isclose(dist.kwds.get("scale"), np.exp(meanlog)) + assert np.isclose(dist.kwds.get("loc"), 0.0) diff --git a/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py b/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py new file mode 100644 index 000000000..23e4fad58 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest +import scipy.stats + +from gempyor.utils import get_truncated_normal + + +class TestGetTruncatedNormal: + """Unit tests for the `gempyor.utils.get_truncated_normal` function.""" + + @pytest.mark.parametrize( + "mean,sd,a,b", + [ + (0.0, 1.0, 0.0, 10.0), + (0.0, 2.0, -4.0, 4.0), + (-5.0, 3.0, -5.0, 10.0), + (-3.25, 1.4, -8.74, 4.89), + (0, 1, 0, 10), + (0, 2, -4, 4), + (-5, 3, -5, 10), + ], + ) + def test_construct_distribution( + self, + mean: float | int, + sd: float | int, + a: float | int, + b: float | int, + ) -> None: + """Test the construction of a truncated normal distribution. + + This test checks whether the `get_truncated_normal` function correctly + constructs a truncated normal distribution with the specified parameters. + It verifies that the returned object is an instance of `rv_frozen`, and that + its support and parameters (mean and standard deviation) are correctly set. + + Args: + mean: The mean of the truncated normal distribution. + sd: The standard deviation of the truncated normal distribution. + a: The lower bound of the truncated normal distribution. + b: The upper bound of the truncated normal distribution. + """ + dist = get_truncated_normal(mean=mean, sd=sd, a=a, b=b) + assert isinstance(dist, scipy.stats._distn_infrastructure.rv_frozen) + lower, upper = dist.support() + assert np.isclose(lower, a) + assert np.isclose(upper, b) + assert np.isclose(dist.kwds.get("loc"), mean) + assert np.isclose(dist.kwds.get("scale"), sd) diff --git a/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py b/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py new file mode 100644 index 000000000..f24c6aaa6 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_list_filenames.py @@ -0,0 +1,227 @@ +"""Unit tests for the `gempyor.utils.list_filenames` function. + +These tests cover scenarios for finding files in both flat and nested directories. +""" + +from collections.abc import Generator +import os +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest + +from gempyor.utils import list_filenames + + +@pytest.fixture(scope="class") +def create_directories_with_files( + request: pytest.FixtureRequest, +) -> Generator[tuple[TemporaryDirectory, TemporaryDirectory], None, None]: + """Fixture to create temporary directories with files for testing. + + This fixture creates two temporary directories: + - A flat directory with files. + - A nested directory with files organized in subdirectories. + + The directories and files are cleaned up after the tests are run. + + Args: + request: The pytest fixture request object. + + Yields: + tuple: A tuple containing the flat and nested TemporaryDirectory objects. + """ + # Create a flat and nested directories + flat_temp_dir = TemporaryDirectory() + nested_temp_dir = TemporaryDirectory() + # Setup flat directory + for file in ["hosp.csv", "hosp.parquet", "spar.csv", "spar.parquet"]: + Path(f"{flat_temp_dir.name}/{file}").touch() + # Setup nested directory structure + for file in [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + "seed/001.csv", + "seed/002.csv", + ]: + path = Path(f"{nested_temp_dir.name}/{file}") + path.parent.mkdir(parents=True, exist_ok=True) + path.touch() + # Yield + request.cls.flat_temp_dir = flat_temp_dir + request.cls.nested_temp_dir = nested_temp_dir + yield (flat_temp_dir, nested_temp_dir) + # Clean up directories on test end + flat_temp_dir.cleanup() + nested_temp_dir.cleanup() + + +@pytest.mark.usefixtures("create_directories_with_files") +class TestListFilenames: + """Unit tests for the `gempyor.utils.list_filenames` function.""" + + @pytest.mark.parametrize( + "filters,expected_basenames", + [ + ("", ["hosp.csv", "hosp.parquet", "spar.csv", "spar.parquet"]), + ([], ["hosp.csv", "hosp.parquet", "spar.csv", "spar.parquet"]), + ("hosp", ["hosp.csv", "hosp.parquet"]), + (["hosp"], ["hosp.csv", "hosp.parquet"]), + ("spar", ["spar.csv", "spar.parquet"]), + (["spar"], ["spar.csv", "spar.parquet"]), + (".parquet", ["hosp.parquet", "spar.parquet"]), + ([".parquet"], ["hosp.parquet", "spar.parquet"]), + (".csv", ["hosp.csv", "spar.csv"]), + ([".csv"], ["hosp.csv", "spar.csv"]), + (".tsv", []), + ([".tsv"], []), + (["hosp", ".csv"], ["hosp.csv"]), + (["spar", ".parquet"], ["spar.parquet"]), + (["hosp", "spar"], []), + ([".csv", ".parquet"], []), + ], + ) + def test_finds_files_in_flat_directory( + self, + filters: str | list[str], + expected_basenames: list[str], + ) -> None: + """Test `list_filenames` in a flat directory. + + Args: + filters: List of filters to apply to filenames. + expected_basenames: List of expected filenames that match the filters. + """ + self._test_list_filenames( + folder=self.flat_temp_dir.name, + filters=filters, + expected_basenames=expected_basenames, + ) + self._test_list_filenames( + folder=self.flat_temp_dir.name.encode(), + filters=filters, + expected_basenames=expected_basenames, + ) + self._test_list_filenames( + folder=Path(self.flat_temp_dir.name), + filters=filters, + expected_basenames=expected_basenames, + ) + + @pytest.mark.parametrize( + "filters,expected_basenames", + [ + ( + "", + [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + "seed/001.csv", + "seed/002.csv", + ], + ), + ( + [], + [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + "seed/001.csv", + "seed/002.csv", + ], + ), + ( + "hpar", + [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + ], + ), + ( + ["hpar"], + [ + "hpar/chimeric/001.parquet", + "hpar/chimeric/002.parquet", + "hpar/global/001.parquet", + "hpar/global/002.parquet", + ], + ), + ("seed", ["seed/001.csv", "seed/002.csv"]), + (["seed"], ["seed/001.csv", "seed/002.csv"]), + ("global", ["hpar/global/001.parquet", "hpar/global/002.parquet"]), + (["global"], ["hpar/global/001.parquet", "hpar/global/002.parquet"]), + ( + "001", + [ + "hpar/chimeric/001.parquet", + "hpar/global/001.parquet", + "seed/001.csv", + ], + ), + ( + ["001"], + [ + "hpar/chimeric/001.parquet", + "hpar/global/001.parquet", + "seed/001.csv", + ], + ), + (["hpar", "001"], ["hpar/chimeric/001.parquet", "hpar/global/001.parquet"]), + (["seed", "002"], ["seed/002.csv"]), + (["hpar", "001", "global"], ["hpar/global/001.parquet"]), + (".tsv", []), + ([".tsv"], []), + ], + ) + def test_find_files_in_nested_directory( + self, + filters: str | list[str], + expected_basenames: list[str], + ) -> None: + """Test `list_filenames` in a nested directory. + + Args: + filters: List of filters to apply to filenames. + expected_basenames: List of expected filenames that match the filters. + """ + self._test_list_filenames( + folder=self.nested_temp_dir.name, + filters=filters, + expected_basenames=expected_basenames, + ) + self._test_list_filenames( + folder=self.nested_temp_dir.name.encode(), + filters=filters, + expected_basenames=expected_basenames, + ) + self._test_list_filenames( + folder=Path(self.nested_temp_dir.name), + filters=filters, + expected_basenames=expected_basenames, + ) + + def _test_list_filenames( + self, + folder: str | bytes | os.PathLike, + filters: str | list[str], + expected_basenames: list[str], + ) -> None: + """Helper method to test `list_filenames`. + + Args: + folder: The directory to search for files. + filters: List of filters to apply to filenames. + expected_basenames: List of expected filenames that match the filters. + """ + files = list_filenames(folder=folder, filters=filters) + assert len(files) == len(expected_basenames) + folder = folder.decode() if isinstance(folder, bytes) else str(folder) + basenames = [f.removeprefix(f"{folder}/") for f in files] + assert sorted(basenames) == sorted(expected_basenames) diff --git a/flepimop/gempyor_pkg/tests/utils/test_read_df.py b/flepimop/gempyor_pkg/tests/utils/test_read_df.py new file mode 100644 index 000000000..7a0a0c581 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_read_df.py @@ -0,0 +1,172 @@ +import os +from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import Callable, Any, Literal + +import pytest +import pandas as pd +from pandas.api.types import is_object_dtype, is_numeric_dtype + +from gempyor.utils import read_df + + +class TestReadDf: + """ + Unit tests for the `gempyor.utils.read_df` function. + """ + + sample_df: pd.DataFrame = pd.DataFrame( + { + "abc": [1, 2, 3, 4, 5], + "def": ["v", "w", "x", "y", "z"], + "ghi": [True, False, False, None, True], + "jkl": [1.2, 3.4, 5.6, 7.8, 9.0], + } + ) + + subpop_df: pd.DataFrame = pd.DataFrame( + { + "subpop": [1, 2, 3, 4], + "value": [5, 6, 7, 8], + } + ) + + def test_raises_not_implemented_error(self) -> None: + """ + Tests that read_df raises a NotImplementedError for unsupported file + extensions. + """ + with pytest.raises( + expected_exception=NotImplementedError, + match="Invalid extension txt. Must be 'csv' or 'parquet'.", + ) as _: + with NamedTemporaryFile(suffix=".txt") as temp_file: + read_df(fname=temp_file.name) + with pytest.raises( + expected_exception=NotImplementedError, + match="Invalid extension txt. Must be 'csv' or 'parquet'.", + ) as _: + with NamedTemporaryFile(suffix=".txt") as temp_file: + fname = temp_file.name[:-4] + read_df(fname=fname, extension="txt") + + @pytest.mark.parametrize( + "fname_transformer,extension", + [ + (lambda x: str(x), ""), + (lambda x: x, ""), + (lambda x: str(x), None), + (lambda x: x, None), + (lambda x: f"{x.parent}/{x.stem}", "csv"), + (lambda x: Path(f"{x.parent}/{x.stem}"), "csv"), + ], + ) + def test_read_csv_dataframe( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: Literal[None, "", "csv", "parquet"], + ) -> None: + """ + Tests reading a DataFrame from a CSV file. + + Args: + fname_transformer: A function that transforms the file name to create the + `fname` arg. + extension: The file extension to use, provided directly to + `gempyor.utils.read_df`. + """ + self._test_read_df( + fname_transformer=fname_transformer, + extension=extension, + suffix=".csv", + path_writer=lambda p, df: df.to_csv(p, index=False), + ) + + @pytest.mark.parametrize( + "fname_transformer,extension", + [ + (lambda x: str(x), ""), + (lambda x: x, ""), + (lambda x: str(x), None), + (lambda x: x, None), + (lambda x: f"{x.parent}/{x.stem}", "parquet"), + (lambda x: Path(f"{x.parent}/{x.stem}"), "parquet"), + ], + ) + def test_read_parquet_dataframe( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: Literal[None, "", "csv", "parquet"], + ) -> None: + """ + Tests reading a DataFrame from a Parquet file. + + Args: + fname_transformer: A function that transforms the file name to create the + `fname` arg. + extension: The file extension to use, provided directly to + `gempyor.utils.read_df`. + """ + self._test_read_df( + fname_transformer=fname_transformer, + extension=extension, + suffix=".parquet", + path_writer=lambda p, df: df.to_parquet(p, engine="pyarrow", index=False), + ) + + def test_subpop_is_cast_as_str(self) -> None: + """ + Tests that read_df returns an object dtype for the column 'subpop' when reading + a csv file, but not when reading a parquet file. + """ + # First confirm the dtypes of our test DataFrame + assert is_numeric_dtype(self.subpop_df["subpop"]) + assert is_numeric_dtype(self.subpop_df["value"]) + # Test that the subpop column is converted to a string for a csv file + with NamedTemporaryFile(suffix=".csv") as temp_file: + temp_path = Path(temp_file.name) + assert temp_path.stat().st_size == 0 + assert self.subpop_df.to_csv(temp_path, index=False) is None + assert temp_path.stat().st_size > 0 + test_df = read_df(fname=temp_path) + assert isinstance(test_df, pd.DataFrame) + assert is_object_dtype(test_df["subpop"]) + assert is_numeric_dtype(test_df["value"]) + # Test that the subpop column remains unaltered for a parquet file + with NamedTemporaryFile(suffix=".parquet") as temp_file: + temp_path = Path(temp_file.name) + assert temp_path.stat().st_size == 0 + assert ( + self.subpop_df.to_parquet(temp_path, engine="pyarrow", index=False) + is None + ) + assert temp_path.stat().st_size > 0 + test_df = read_df(fname=temp_path) + assert isinstance(test_df, pd.DataFrame) + assert is_numeric_dtype(test_df["subpop"]) + assert is_numeric_dtype(test_df["value"]) + + def _test_read_df( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: Literal[None, "", "csv", "parquet"], + suffix: str | None, + path_writer: Callable[[os.PathLike, pd.DataFrame], None], + ) -> None: + """ + Helper method to test writing a DataFrame to a file. + + Args: + fname_transformer: A function that transforms the file name. + extension: The file extension to use. + suffix: The suffix to use for the temporary file. + path_writer: A function to write the DataFrame to the file. + """ + with NamedTemporaryFile(suffix=suffix) as temp_file: + temp_path = Path(temp_file.name) + assert temp_path.stat().st_size == 0 + path_writer(temp_path, self.sample_df) + test_df = read_df(fname=fname_transformer(temp_path), extension=extension) + assert isinstance(test_df, pd.DataFrame) + assert temp_path.stat().st_size > 0 + assert test_df.equals(self.sample_df) diff --git a/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py b/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py new file mode 100644 index 000000000..94be3394a --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_rolling_mean_pad.py @@ -0,0 +1,149 @@ +import numpy as np +import numpy.typing as npt +import pytest + +from gempyor.utils import rolling_mean_pad + + +class TestRollingMeanPad: + """Unit tests for the `gempyor.utils.rolling_mean_pad` function.""" + + # Test data for various matrix configurations + test_data = { + # 1x1 matrices + "one_by_one_const": np.array([[1.0]]), + "one_by_one_nan": np.array([[np.nan]]), + "one_by_one_rand": np.random.uniform(size=(1, 1)), + # 1xN matrices + "one_by_many_const": np.arange(start=1.0, stop=6.0).reshape((1, 5)), + "one_by_many_nan": np.repeat(np.nan, 5).reshape((1, 5)), + "one_by_many_rand": np.random.uniform(size=(1, 5)), + # Mx1 matrices + "many_by_one_const": np.arange(start=3.0, stop=9.0).reshape((6, 1)), + "many_by_one_nan": np.repeat(np.nan, 6).reshape((6, 1)), + "many_by_one_rand": np.random.uniform(size=(6, 1)), + # MxN matrices + "many_by_many_const": np.arange(start=1.0, stop=49.0).reshape((12, 4)), + "many_by_many_nan": np.repeat(np.nan, 48).reshape((12, 4)), + "many_by_many_rand": np.random.uniform(size=(12, 4)), + } + + @pytest.mark.parametrize( + "test_data_name,expected_shape,window,put_nans", + [ + # 1x1 matrices + ("one_by_one_const", (1, 1), 3, []), + ("one_by_one_const", (1, 1), 4, []), + ("one_by_one_nan", (1, 1), 3, []), + ("one_by_one_nan", (1, 1), 4, []), + ("one_by_one_rand", (1, 1), 3, []), + ("one_by_one_rand", (1, 1), 4, []), + ("one_by_one_rand", (1, 1), 5, []), + ("one_by_one_rand", (1, 1), 6, []), + # 1xN matrices + ("one_by_many_const", (1, 5), 3, []), + ("one_by_many_const", (1, 5), 4, []), + ("one_by_many_nan", (1, 5), 3, []), + ("one_by_many_nan", (1, 5), 4, []), + ("one_by_many_rand", (1, 5), 3, []), + ("one_by_many_rand", (1, 5), 4, []), + ("one_by_many_rand", (1, 5), 5, []), + ("one_by_many_rand", (1, 5), 6, []), + # Mx1 matrices + ("many_by_one_const", (6, 1), 3, []), + ("many_by_one_const", (6, 1), 4, []), + ("many_by_one_nan", (6, 1), 3, []), + ("many_by_one_nan", (6, 1), 4, []), + ("many_by_one_rand", (6, 1), 3, []), + ("many_by_one_rand", (6, 1), 4, []), + ("many_by_one_rand", (6, 1), 5, []), + ("many_by_one_rand", (6, 1), 6, []), + # MxN matrices + ("many_by_many_const", (12, 4), 3, []), + ("many_by_many_const", (12, 4), 4, []), + ("many_by_many_const", (12, 4), 5, []), + ("many_by_many_const", (12, 4), 6, []), + ("many_by_many_nan", (12, 4), 3, []), + ("many_by_many_nan", (12, 4), 4, []), + ("many_by_many_nan", (12, 4), 5, []), + ("many_by_many_nan", (12, 4), 6, []), + ("many_by_many_rand", (12, 4), 3, []), + ("many_by_many_rand", (12, 4), 4, []), + ("many_by_many_rand", (12, 4), 5, []), + ("many_by_many_rand", (12, 4), 6, []), + ("many_by_many_rand", (12, 4), 7, []), + ("many_by_many_rand", (12, 4), 8, []), + ("many_by_many_rand", (12, 4), 9, []), + ("many_by_many_rand", (12, 4), 10, []), + ("many_by_many_rand", (12, 4), 11, []), + ("many_by_many_rand", (12, 4), 12, []), + ("many_by_many_rand", (12, 4), 13, []), + ("many_by_many_rand", (12, 4), 14, []), + ("many_by_many_rand", (12, 4), 3, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 4, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 5, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 6, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 7, [(2, 2), (4, 4)]), + ("many_by_many_rand", (12, 4), 8, [(2, 2), (4, 4)]), + ], + ) + def test_rolling_mean_pad( + self, + test_data_name: str, + expected_shape: tuple[int, int], + window: int, + put_nans: list[tuple[int, int]], + ) -> None: + """Tests `rolling_mean_pad` function with various configurations of input data. + + Args: + test_data_name: The name of the test data set to use. + expected_shape: The expected shape of the output array. + window: The size of the rolling window. + put_nans: A list of indices to insert NaNs into the input data. + + Raises: + AssertionError: If the shape or contents of the output do not match the + expected values. + """ + test_data = self.test_data.get(test_data_name).copy() + if put_nans: + np.put(test_data, put_nans, np.nan) + rolling_mean_data = rolling_mean_pad(test_data, window) + rolling_mean_reference = self._rolling_mean_pad_reference(test_data, window) + assert rolling_mean_data.shape == expected_shape + assert np.isclose( + rolling_mean_data, rolling_mean_reference, equal_nan=True + ).all() + + def _rolling_mean_pad_reference( + self, data: npt.NDArray[np.number], window: int + ) -> npt.NDArray[np.number]: + """Generates a reference rolling mean with padding. + + This implementation should match the `gempyor.utils.rolling_mean_pad` + implementation, but is written for readability. As a result this + reference implementation is extremely slow. + + Args: + data: The input array for which to compute the rolling mean. + window: The size of the rolling window. + + Returns: + An array of the same shape as `data` containing the rolling mean values. + """ + # Setup + rows, cols = data.shape + output = np.zeros((rows, cols), dtype=data.dtype) + # Slow but intuitive triple loop + for i in range(rows): + for j in range(cols): + # If the last row on an even window, change the window to be one less, + # so 4 -> 3, but 5 -> 5. + sub_window = window - 1 if window % 2 == 0 and i == rows - 1 else window + weight = 1.0 / sub_window + for l in range(-((sub_window - 1) // 2), 1 + (sub_window // 2)): + i_star = min(max(i + l, 0), rows - 1) + output[i, j] += weight * data[i_star, j] + # Done + return output diff --git a/flepimop/gempyor_pkg/tests/utils/test_write_df.py b/flepimop/gempyor_pkg/tests/utils/test_write_df.py new file mode 100644 index 000000000..b13e0b948 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/utils/test_write_df.py @@ -0,0 +1,138 @@ +import os +from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import Callable, Any, Literal + +import pytest +import pandas as pd + +from gempyor.utils import write_df + + +class TestWriteDf: + """ + Unit tests for the `gempyor.utils.write_df` function. + """ + + sample_df: pd.DataFrame = pd.DataFrame( + { + "abc": [1, 2, 3, 4, 5], + "def": ["v", "w", "x", "y", "z"], + "ghi": [True, False, False, None, True], + "jkl": [1.2, 3.4, 5.6, 7.8, 9.0], + } + ) + + def test_raises_not_implemented_error(self) -> None: + """ + Tests that write_df raises a NotImplementedError for unsupported file + extensions. + """ + with pytest.raises( + expected_exception=NotImplementedError, + match="Invalid extension txt. Must be 'csv' or 'parquet'.", + ) as _: + with NamedTemporaryFile(suffix=".txt") as temp_file: + write_df(fname=temp_file.name, df=self.sample_df) + with pytest.raises( + expected_exception=NotImplementedError, + match="Invalid extension txt. Must be 'csv' or 'parquet'.", + ) as _: + with NamedTemporaryFile(suffix=".txt") as temp_file: + fname = temp_file.name[:-4] + write_df(fname=fname, df=self.sample_df, extension="txt") + + @pytest.mark.parametrize( + "fname_transformer,extension", + [ + (lambda x: str(x), ""), + (lambda x: x, ""), + (lambda x: str(x), None), + (lambda x: x, None), + (lambda x: f"{x.parent}/{x.stem}", "csv"), + (lambda x: Path(f"{x.parent}/{x.stem}"), "csv"), + ], + ) + def test_write_csv_dataframe( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: Literal[None, "", "csv", "parquet"], + ) -> None: + """ + Tests writing a DataFrame to a CSV file. + + Args: + fname_transformer: A function that transforms the file name to create the + `fname` arg. + extension: The file extension to use, provided directly to + `gempyor.utils.write_df`. + """ + self._test_write_df( + fname_transformer=fname_transformer, + df=self.sample_df, + extension=extension, + suffix=".csv", + path_reader=lambda x: pd.read_csv(x, index_col=False), + ) + + @pytest.mark.parametrize( + "fname_transformer,extension", + [ + (lambda x: str(x), ""), + (lambda x: x, ""), + (lambda x: str(x), None), + (lambda x: x, None), + (lambda x: f"{x.parent}/{x.stem}", "parquet"), + (lambda x: Path(f"{x.parent}/{x.stem}"), "parquet"), + ], + ) + def test_write_parquet_dataframe( + self, + fname_transformer: Callable[[os.PathLike], Any], + extension: Literal[None, "", "csv", "parquet"], + ) -> None: + """ + Tests writing a DataFrame to a Parquet file. + + Args: + fname_transformer: A function that transforms the file name to create the + `fname` arg. + extension: The file extension to use, provided directly to + `gempyor.utils.write_df`. + """ + self._test_write_df( + fname_transformer=fname_transformer, + df=self.sample_df, + extension=extension, + suffix=".parquet", + path_reader=lambda x: pd.read_parquet(x, engine="pyarrow"), + ) + + def _test_write_df( + self, + fname_transformer: Callable[[os.PathLike], Any], + df: pd.DataFrame, + extension: Literal[None, "", "csv", "parquet"], + suffix: str | None, + path_reader: Callable[[os.PathLike], pd.DataFrame], + ) -> None: + """ + Helper method to test writing a DataFrame to a file. + + Args: + fname_transformer: A function that transforms the file name. + df: The DataFrame to write. + extension: The file extension to use. + suffix: The suffix to use for the temporary file. + path_reader: A function to read the DataFrame from the file. + """ + with NamedTemporaryFile(suffix=suffix) as temp_file: + temp_path = Path(temp_file.name) + assert temp_path.stat().st_size == 0 + assert ( + write_df(fname=fname_transformer(temp_path), df=df, extension=extension) + is None + ) + assert temp_path.stat().st_size > 0 + test_df = path_reader(temp_path) + assert test_df.equals(df)