Skip to content

Commit

Permalink
Documented rolling_mean_pad
Browse files Browse the repository at this point in the history
* Added type hints for `gempyor.utils.rolling_mean_pad`.
* Expanded the existing docstring and included an example.
  • Loading branch information
TimothyWillard committed Jul 8, 2024
1 parent e073f88 commit a91fc6d
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions flepimop/gempyor_pkg/src/gempyor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from botocore.exceptions import ClientError
import confuse
import numpy as np
import numpy.typing as npt
import pandas as pd
import pyarrow as pa
import scipy.stats
Expand Down Expand Up @@ -392,16 +393,44 @@ def list_filenames(
return files


def rolling_mean_pad(data, window):
def rolling_mean_pad(
data: npt.NDArray[np.number],
window: int,
) -> npt.NDArray[np.number]:
"""
Calculates rolling mean with centered window and pads the edges.
Calculates the column-wise rolling mean with centered window.
Args:
data: A NumPy array !!! shape must be (n_days, nsubpops).
data: A two dimensional numpy array, typically the row dimension is time and the
column dimension is subpop.
window: The window size for the rolling mean.
Returns:
A NumPy array with the padded rolling mean (n_days, nsubpops).
A two dimensional numpy array that is the same shape as `data`.
Examples:
Below is a brief set of examples showcasing how to smooth a metric, like
hospitalizations, using this function.
>>> import numpy as np
>>> from gempyor.utils import rolling_mean_pad
>>> hospitalizations = np.arange(1., 29.).reshape((7, 4))
>>> hospitalizations
array([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.],
[25., 26., 27., 28.]])
>>> rolling_mean_pad(hospitalizations, 5)
array([[ 3.4, 4.4, 5.4, 6.4],
[ 5.8, 6.8, 7.8, 8.8],
[ 9. , 10. , 11. , 12. ],
[13. , 14. , 15. , 16. ],
[17. , 18. , 19. , 20. ],
[20.2, 21.2, 22.2, 23.2],
[22.6, 23.6, 24.6, 25.6]])
"""
padding_size = (window - 1) // 2
padded_data = np.pad(data, ((padding_size, padding_size), (0, 0)), mode="edge")
Expand Down

0 comments on commit a91fc6d

Please sign in to comment.