Skip to content

Commit

Permalink
Merge pull request #247 from HopkinsIDD/enhancement/GH-246-document-t…
Browse files Browse the repository at this point in the history
…est-write_df-read_df

Document/Test `gempyor.utils.read_df/write_df/get_truncated_normal/get_log_normal/list_filenames/rolling_mean_pad`
  • Loading branch information
TimothyWillard committed Jul 11, 2024
2 parents 530afbc + a91fc6d commit 4485458
Show file tree
Hide file tree
Showing 7 changed files with 985 additions and 74 deletions.
275 changes: 201 additions & 74 deletions flepimop/gempyor_pkg/src/gempyor/utils.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,103 @@
import os
import datetime
import functools
import logging
import numbers
import os
from pathlib import Path
import shutil
import subprocess
import time
from typing import List, Dict, Literal

import boto3
from botocore.exceptions import ClientError
import confuse
import numpy as np
import numpy.typing as npt
import pandas as pd
import pyarrow as pa
import scipy.stats
import sympy.parsing.sympy_parser
import subprocess
import shutil
import logging
import boto3

from gempyor import file_paths
from typing import List, Dict
from botocore.exceptions import ClientError


logger = logging.getLogger(__name__)

config = confuse.Configuration("flepiMoP", read=False)


def write_df(fname: str, df: pd.DataFrame, extension: str = ""):
"""write without index, so assume the index has been put a column"""
# cast to str to use .split in case fname is a PosixPath
fname = str(fname)
if extension: # Empty strings are falsy in python
fname = f"{fname}.{extension}"
extension = fname.split(".")[-1]
if extension == "csv":
df.to_csv(fname, index=False)
elif extension == "parquet":
df = pa.Table.from_pandas(df, preserve_index=False)
pa.parquet.write_table(df, fname)
else:
raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'")
def write_df(
fname: str | bytes | os.PathLike,
df: pd.DataFrame,
extension: Literal[None, "", "csv", "parquet"] = "",
) -> None:
"""Writes a pandas DataFrame without its index to a file.
Writes a pandas DataFrame to either a CSV or Parquet file without its index and can
infer which format to use based on the extension given in `fname` or based on
explicit `extension`.
Args:
fname: The name of the file to write to.
df: A pandas DataFrame whose contents to write, but without its index.
extension: A user specified extension to use for the file if not contained in
`fname` already.
Returns:
None
Raises:
NotImplementedError: The given output extension is not supported yet.
"""
# Decipher the path given
fname = fname.decode() if isinstance(fname, bytes) else fname
path = Path(f"{fname}.{extension}") if extension else Path(fname)
# Write df to either a csv or parquet or raise if an invalid extension
if path.suffix == ".csv":
return df.to_csv(path, index=False)
elif path.suffix == ".parquet":
return df.to_parquet(path, index=False, engine="pyarrow")
raise NotImplementedError(
f"Invalid extension {path.suffix[1:]}. Must be 'csv' or 'parquet'."
)


def read_df(
fname: str | bytes | os.PathLike,
extension: Literal[None, "", "csv", "parquet"] = "",
) -> pd.DataFrame:
"""Reads a pandas DataFrame from either a CSV or Parquet file.
Reads a pandas DataFrame to either a CSV or Parquet file and can infer which format
to use based on the extension given in `fname` or based on explicit `extension`. If
the file being read is a csv with a column called 'subpop' then that column will be
cast as a string.
Args:
fname: The name of the file to read from.
extension: A user specified extension to use for the file if not contained in
`fname` already.
Returns:
A pandas DataFrame parsed from the file given.
Raises:
NotImplementedError: The given output extension is not supported yet.
"""
# Decipher the path given
fname = fname.decode() if isinstance(fname, bytes) else fname
path = Path(f"{fname}.{extension}") if extension else Path(fname)
# Read df from either a csv or parquet or raise if an invalid extension
if path.suffix == ".csv":
return pd.read_csv(
path, converters={"subpop": lambda x: str(x)}, skipinitialspace=True
)
elif path.suffix == ".parquet":
return pd.read_parquet(path, engine="pyarrow")
raise NotImplementedError(
f"Invalid extension {path.suffix[1:]}. Must be 'csv' or 'parquet'."
)


def command_safe_run(command, command_name="mycommand", fail_on_fail=True):
Expand All @@ -62,23 +124,6 @@ def command_safe_run(command, command_name="mycommand", fail_on_fail=True):
return sr.returncode, stdout, stderr


def read_df(fname: str, extension: str = "") -> pd.DataFrame:
"""Load a dataframe from a file, agnostic to whether it is a parquet or a csv. The extension
can be provided as an argument or it is infered"""
fname = str(fname)
if extension: # Empty strings are falsy in python
fname = f"{fname}.{extension}"
extension = fname.split(".")[-1]
if extension == "csv":
# The converter prevents e.g leading geoid (0600) to be converted as int; and works when the column is absent
df = pd.read_csv(fname, converters={"subpop": lambda x: str(x)}, skipinitialspace=True)
elif extension == "parquet":
df = pa.parquet.read_table(fname).to_pandas()
else:
raise NotImplementedError(f"Invalid extension {extension}. Must be 'csv' or 'parquet'")
return df


def add_method(cls):
"Decorator to add a method to a class"

Expand Down Expand Up @@ -208,14 +253,50 @@ def as_evaled_expression(self):
raise ValueError(f"expected numeric or string expression [got: {value}]")


def get_truncated_normal(*, mean=0, sd=1, a=0, b=10):
"Returns the truncated normal distribution"
def get_truncated_normal(
mean: float | int = 0,
sd: float | int = 1,
a: float | int = 0,
b: float | int = 10,
) -> scipy.stats._distn_infrastructure.rv_frozen:
"""Returns a truncated normal distribution.
return scipy.stats.truncnorm((a - mean) / sd, (b - mean) / sd, loc=mean, scale=sd)
This function constructs a truncated normal distribution with the specified
mean, standard deviation, and bounds. The truncated normal distribution is
a normal distribution bounded within the interval [a, b].
Args:
mean: The mean of the truncated normal distribution. Defaults to 0.
sd: The standard deviation of the truncated normal distribution. Defaults to 1.
a: The lower bound of the truncated normal distribution. Defaults to 0.
b: The upper bound of the truncated normal distribution. Defaults to 10.
Returns:
rv_frozen: A frozen instance of the truncated normal distribution with the
specified parameters.
"""
lower = (a - mean) / sd
upper = (b - mean) / sd
return scipy.stats.truncnorm(lower, upper, loc=mean, scale=sd)


def get_log_normal(meanlog, sdlog):
"Returns the log normal distribution"
def get_log_normal(
meanlog: float | int,
sdlog: float | int,
) -> scipy.stats._distn_infrastructure.rv_frozen:
"""Returns a log normal distribution.
This function constructs a log normal distribution with the specified
log mean and log standard deviation.
Args:
meanlog: The log of the mean of the log normal distribution.
sdlog: The log of the standard deviation of the log normal distribution.
Returns:
rv_frozen: A frozen instance of the log normal distribution with the
specified parameters.
"""
return scipy.stats.lognorm(s=sdlog, scale=np.exp(meanlog), loc=0)


Expand Down Expand Up @@ -265,45 +346,91 @@ def as_random_distribution(self):
return functools.partial(np.random.uniform, self.as_evaled_expression(), self.as_evaled_expression(),)


def list_filenames(folder: str = ".", filters: list = []) -> list:
"""
return the list of all filename and path in the provided folders.
If filters [list] is provided, then only the files that contains each of the
substrings in filter will be returned. Example to get all hosp file:
```
gempyor.utils.list_filenames(folder="model_output/", filters=["hosp"])
```
and be sure we only get parquet:
```
gempyor.utils.list_filenames(folder="model_output/", filters=["hosp" , ".parquet"])
```
"""
from pathlib import Path

fn_list = []
for f in Path(str(folder)).rglob(f"*"):
if f.is_file(): # not a folder
f = str(f)
if not filters:
fn_list.append(f)
else:
if all(c in f for c in filters):
fn_list.append(str(f))
else:
pass
return fn_list
def list_filenames(
folder: str | bytes | os.PathLike = ".",
filters: str | list[str] = [],
) -> list[str]:
"""Return the list of all filenames and paths in the provided folder.
This function lists all files in the specified folder and its subdirectories.
If filters are provided, only the files containing each of the substrings
in the filters will be returned.
Example:
To get all files containing "hosp":
```
gempyor.utils.list_filenames(
folder="model_output/",
filters=["hosp"],
)
```
To get only "hosp" files with a ".parquet" extension:
```
gempyor.utils.list_filenames(
folder="model_output/",
filters=["hosp", ".parquet"],
)
```
Args:
folder: The directory to search for files. Defaults to the current directory.
filters: A string or a list of strings to filter filenames. Only files
containing all the provided substrings will be returned. Defaults to an
empty list.
def rolling_mean_pad(data, window):
Returns:
A list of strings representing the paths to the files that match the filters.
"""
Calculates rolling mean with centered window and pads the edges.
filters = [filters] if not isinstance(filters, list) else filters
filters = filters if len(filters) else [""]
folder = Path(folder.decode() if isinstance(folder, bytes) else folder)
files = [
str(file)
for file in folder.rglob("*")
if file.is_file() and all(f in str(file) for f in filters)
]
return files


def rolling_mean_pad(
data: npt.NDArray[np.number],
window: int,
) -> npt.NDArray[np.number]:
"""
Calculates the column-wise rolling mean with centered window.
Args:
data: A NumPy array !!! shape must be (n_days, nsubpops).
data: A two dimensional numpy array, typically the row dimension is time and the
column dimension is subpop.
window: The window size for the rolling mean.
Returns:
A NumPy array with the padded rolling mean (n_days, nsubpops).
A two dimensional numpy array that is the same shape as `data`.
Examples:
Below is a brief set of examples showcasing how to smooth a metric, like
hospitalizations, using this function.
>>> import numpy as np
>>> from gempyor.utils import rolling_mean_pad
>>> hospitalizations = np.arange(1., 29.).reshape((7, 4))
>>> hospitalizations
array([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.],
[25., 26., 27., 28.]])
>>> rolling_mean_pad(hospitalizations, 5)
array([[ 3.4, 4.4, 5.4, 6.4],
[ 5.8, 6.8, 7.8, 8.8],
[ 9. , 10. , 11. , 12. ],
[13. , 14. , 15. , 16. ],
[17. , 18. , 19. , 20. ],
[20.2, 21.2, 22.2, 23.2],
[22.6, 23.6, 24.6, 25.6]])
"""
padding_size = (window - 1) // 2
padded_data = np.pad(data, ((padding_size, padding_size), (0, 0)), mode="edge")
Expand Down
49 changes: 49 additions & 0 deletions flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import pytest
import scipy.stats

from gempyor.utils import get_log_normal


class TestGetLogNormal:
"""Unit tests for the `gempyor.utils.get_log_normal` function."""

@pytest.mark.parametrize(
"meanlog,sdlog",
[
(1.0, 1.0),
(0.0, 2.0),
(10.0, 30.0),
(0.33, 4.56),
(9.87, 4.21),
(1, 1),
(0, 2),
(10, 30),
],
)
def test_construct_distribution(
self,
meanlog: float | int,
sdlog: float | int,
) -> None:
"""Test the construction of a log normal distribution.
This test checks whether the `get_log_normal` function correctly constructs
a log normal distribution with the specified parameters. It verifies that
the returned object is an instance of `rv_frozen`, and that its support and
parameters (log mean and log standard deviation) are correctly set.
Args:
mean: The mean of the truncated normal distribution.
sd: The standard deviation of the truncated normal distribution.
a: The lower bound of the truncated normal distribution.
b: The upper bound of the truncated normal distribution.
"""
dist = get_log_normal(meanlog=meanlog, sdlog=sdlog)
assert isinstance(dist, scipy.stats._distn_infrastructure.rv_frozen)
lower, upper = dist.support()
assert np.isclose(lower, 0.0)
assert np.isclose(upper, np.inf)
assert np.isclose(dist.kwds.get("s"), sdlog)
assert np.isclose(dist.kwds.get("scale"), np.exp(meanlog))
assert np.isclose(dist.kwds.get("loc"), 0.0)
Loading

0 comments on commit 4485458

Please sign in to comment.