Skip to content

Commit

Permalink
v0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrilJl committed May 13, 2024
1 parent 212c2e4 commit 5cd8ff9
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 8 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,30 @@ np.allclose(a, b)
>>> 306 ms ± 5.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

## NaN handling possibility

While the previous `Batch*` classes exclude every sample containing at least one NaN from the computations, the `BatchNan*` classes adopt a more flexible approach to handling NaN values, similar to `np.nansum`, `np.nanmean`, etc. Consequently, the outputted statistics can be computed from various numbers of samples for each feature:

```python
import numpy as np
from batchstats import BatchNanSum

m, n = 1_000_000, 50
nan_ratio = 0.05
n_batches = 17

data = np.random.randn(m, n)
num_nans = int(m * n * nan_ratio)
nan_indices = np.random.choice(range(m * n), num_nans, replace=False)
data.ravel()[nan_indices] = np.nan

batchsum = BatchNanSum()
for batch_data in np.array_split(data, n_batches):
batchsum.update_batch(batch=batch_data)
np.allclose(np.nansum(data, axis=0), batchsum())
>>> True
```

## Documentation

The documentation is available [here](https://batchstats.readthedocs.io/en/latest/).
Expand Down
2 changes: 1 addition & 1 deletion batchstats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
__all__ = ['BatchCov', 'BatchMax', 'BatchMean', 'BatchMin', 'BatchStat', 'BatchSum', 'BatchVar',
'BatchNanMean', 'BatchNanStat', 'BatchNanSum']

__version__ = '0.2'
__version__ = '0.3'
76 changes: 75 additions & 1 deletion batchstats/nanstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,31 @@


class BatchNanStat:
"""
Base class for calculating statistics over batches of data that can contain NaN values.
Attributes:
n_samples (numpy.ndarray): Total number of samples processed, accounting for NaN values.
"""

def __init__(self):
"""
Initialize the BatchNanStat object.
"""
self.n_samples = None

def _process_batch(self, batch):
"""
Process the input batch, counting NaN values.
Args:
batch (numpy.ndarray): Input batch.
Returns:
numpy.ndarray: Processed batch.
"""
batch = np.atleast_2d(np.asarray(batch))
axis = tuple(range(1, batch.ndim))
if self.n_samples is None:
Expand All @@ -18,11 +39,29 @@ def _process_batch(self, batch):


class BatchNanSum(BatchNanStat):
"""
Class for calculating the sum of batches of data that can contain NaN values.
"""

def __init__(self):
"""
Initialize the BatchNanSum object.
"""
super().__init__()
self.sum = None

def update_batch(self, batch):
"""
Update the sum with a new batch of data that can contain NaN values.
Args:
batch (numpy.ndarray): Input batch.
Returns:
BatchNanSum: Updated BatchNanSum object.
"""
batch = self._process_batch(batch)
axis = tuple(range(1, batch.ndim))
if self.sum is None:
Expand All @@ -32,20 +71,55 @@ def update_batch(self, batch):
return self

def __call__(self):
"""
Calculate the sum of the batches that can contain NaN values.
Returns:
numpy.ndarray: Sum of the batches.
Raises:
NoValidSamplesError: If no valid samples are available.
"""
if self.sum is None:
raise NoValidSamplesError()
else:
return np.where(self.n_samples > 0, self.sum, np.nan)


class BatchNanMean(BatchNanStat):
"""
Class for calculating the mean of batches of data that can contain NaN values.
"""

def __init__(self):
"""
Initialize the BatchNanMean object.
"""
super().__init__()
self.sum = BatchNanSum()

def update_batch(self, batch):
"""
Update the mean with a new batch of data that can contain NaN values.
Args:
batch (numpy.ndarray): Input batch.
Returns:
BatchNanMean: Updated BatchNanMean object.
"""
self.sum.update_batch(batch)
return self

def __call__(self):
return self.sum()/self.sum.n_samples
"""
Calculate the mean of the batches that can contain NaN values.
Returns:
numpy.ndarray: Mean of the batches.
"""
return self.sum() / self.sum.n_samples
10 changes: 7 additions & 3 deletions batchstats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ class BatchVar(BatchMean):
"""
Class for calculating the variance of batches of data.
Args:
ddof (int, optional): Means Delta Degrees of Freedom. The divisor used in calculations is N - ddof, where N represents the number of elements. By default ddof is zero.
"""

def __init__(self, ddof=0):
Expand Down Expand Up @@ -272,9 +274,9 @@ def init_var(cls, v, vm):
def compute_incremental_variance(v, p, u):
"""
Compute incremental variance.
For v 2D and p/u 1D, equivalent to ((v-p).T@(v-u)).sum(axis=0) or
np.einsum('ji,ji->i', v - p, v - u). faster and less memory consumer because
no intermediate 2d array are created.
For v 2D and p/u 1D, equivalent to ``((v-p).T@(v-u)).sum(axis=0)`` or
``np.einsum('ji,ji->i', v - p, v - u)``. Faster and less memory consumer because
no intermediate 2D array are created.
Args:
v (numpy.ndarray): Input data.
Expand Down Expand Up @@ -341,6 +343,8 @@ class BatchCov(BatchStat):
"""
Class for calculating the covariance of batches of data.
Args:
ddof (int, optional): Means Delta Degrees of Freedom. The divisor used in calculations is N - ddof, where N represents the number of elements. By default ddof is zero.
"""

def __init__(self, ddof=0):
Expand Down
16 changes: 14 additions & 2 deletions docs/source/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,19 @@ The class ``BatchStat`` is the parent class from which other classes inherit. It
The following classes inherit from ``BatchStat``, and enable the user to compute various statistics over batch-accessed data:

.. automodule:: batchstats
:members:
:members: BatchCov, BatchMax, BatchMean, BatchMin, BatchSum, BatchVar
:undoc-members:
:show-inheritance:
:exclude-members: BatchStat


The class ``BatchNanStat`` is the parent class from which other classes that can treat NaNs inherit. It allows for the factorization of the ``_process_batch`` method, which keeps track of the number of NaNs per feature.

.. autoclass:: batchstats.BatchNanStat

The following classes inherit from ``BatchNanStat``:

.. automodule:: batchstats
:members: BatchNanMean, BatchNanSum
:undoc-members:
:show-inheritance:
:no-index:
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
templates_path = ['_templates']
exclude_patterns = []


# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output

Expand Down

0 comments on commit 5cd8ff9

Please sign in to comment.