Skip to content

Commit

Permalink
v0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrilJl committed May 12, 2024
1 parent 151c764 commit 86ec257
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 8 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# VisualStudio
.vscode

# Dev
dev

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,33 @@ test_mean(data, n_batches)
>>> True
```

## Performance

In addition to result accuracy, much attention has been given to computation times and memory usage. Fun fact, calculating the variance using `batchstats` consumes little RAM while being faster than `numpy.var`:

```python
%load_ext memory_profiler
import numpy as np
from batchstats import BatchVar

data = np.random.randn(100_000, 1000)
print(data.nbytes/2**20)

%memit a = np.var(data, axis=0)
%memit b = BatchVar().update_batch(data)()
np.allclose(a, b)
>>> 762.939453125
>>> peak memory: 1604.63 MiB, increment: 763.35 MiB
>>> peak memory: 842.62 MiB, increment: 0.91 MiB
>>> True
%timeit a = np.var(data, axis=0)
%timeit b = BatchVar().update_batch(data)()
>>> 510 ms ± 111 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> 306 ms ± 5.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```



## Requesting Additional Statistics

If you require additional statistics that are not currently implemented in `batchstats`, feel free to open an issue on the GitHub repository or submit a pull request with your suggested feature. We welcome contributions and feedback from the community to improve `batchstats` and make it more versatile for various data analysis tasks.
2 changes: 2 additions & 0 deletions batchstats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .core import BatchCov, BatchMax, BatchMean, BatchMin, BatchStat, BatchSum, BatchVar

__all__ = ['BatchCov', 'BatchMax', 'BatchMean', 'BatchMin', 'BatchStat', 'BatchSum', 'BatchVar']

__version__ = '0.2'
20 changes: 15 additions & 5 deletions batchstats/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import string

import numpy as np


Expand Down Expand Up @@ -307,8 +309,13 @@ def compute_incremental_variance(v, p, u):
numpy.ndarray: Incremental variance.
"""
ret = np.einsum('ij,ij->j', v, v)
ret -= np.einsum('j,ij->j', p + u, v)
alphabet = string.ascii_lowercase
ndim = v.ndim
assert p.ndim == u.ndim == ndim-1
ij, j = alphabet[:ndim], alphabet[1:ndim]

ret = np.einsum(f'{ij},{ij}->{j}', v, v)
ret -= np.einsum(f'{j},{ij}->{j}', p + u, v)
ret += len(v)*p*u
return ret

Expand Down Expand Up @@ -367,7 +374,7 @@ def __init__(self, ddof=0):
self.cov = None
self.ddof = ddof

def _process_batch(self, batch1, batch2, assume_valid=False):
def _process_batch(self, batch1, batch2=None, assume_valid=False):
"""
Process the input batches, handling NaN values if necessary.
Expand All @@ -383,7 +390,10 @@ def _process_batch(self, batch1, batch2, assume_valid=False):
UnequalSamplesNumber: If the batches have unequal lengths.
"""
batch1, batch2 = np.atleast_2d(np.asarray(batch1)), np.atleast_2d(np.asarray(batch2))
if batch2 is None:
batch1 = batch2 = np.atleast_2d(np.asarray(batch1))
else:
batch1, batch2 = np.atleast_2d(np.asarray(batch1)), np.atleast_2d(np.asarray(batch2))
if assume_valid:
self.n_samples += len(batch1)
return batch1, batch2
Expand All @@ -398,7 +408,7 @@ def _process_batch(self, batch1, batch2, assume_valid=False):
else:
return batch1[mask], batch2[mask]

def update_batch(self, batch1, batch2, assume_valid=False):
def update_batch(self, batch1, batch2=None, assume_valid=False):
"""
Update the covariance with new batches of data.
Expand Down
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from setuptools import find_packages, setup

# Read version from the __init__.py file
with open("batchstats/__init__.py", "r") as f:
for line in f:
if line.startswith("__version__"):
version = line.strip().split()[-1][1:-1]
break

setup(
name='batchstats',
version='0.1',
version=version,
author='Cyril Joly',
description='Efficient batch statistics computation library for Python.',
long_description=open('README.md').read(),
Expand Down
43 changes: 41 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ def data():
return np.random.randn(m, n)


@pytest.fixture
def data_2d_features():
m, n, o = 100_000, 50, 60
return np.random.randn(m, n, o)


@pytest.fixture
def n_batches():
return 31
Expand Down Expand Up @@ -76,19 +82,52 @@ def test_var_ddof(data, n_batches):
assert np.allclose(true_stat, batch_stat)


def test_cov(data, n_batches):
def test_cov_1(data, n_batches):
true_cov = np.cov(data.T, ddof=0)
true_var = np.var(data, axis=0)

batchvar = BatchVar()
batchcov = BatchCov()
for batch_data in np.array_split(data, n_batches):
batchvar.update_batch(batch_data)
batchcov.update_batch(batch_data, batch_data)
batchcov.update_batch(batch_data)

cov = batchcov()
var = batchvar()

assert np.allclose(cov, true_cov)
assert np.allclose(var, true_var)
assert np.allclose(var, np.diag(cov))


def test_cov_2(data, n_batches):
true_cov = np.cov(data.T, ddof=0)
index = np.arange(25)

batchcov = BatchCov()
for batch_data in np.array_split(data, n_batches):
batchcov.update_batch(batch_data, batch_data[:, index])

cov = batchcov()

assert np.allclose(cov, true_cov[:, index])


def test_mean_2d_features(data_2d_features, n_batches):
true_stat = np.mean(data_2d_features, axis=0)

batchvar = BatchMean()
for batch_data in np.array_split(data_2d_features, n_batches):
batchvar.update_batch(batch=batch_data)
batch_stat = batchvar()
assert np.allclose(true_stat, batch_stat)


def test_var_2d_features(data_2d_features, n_batches):
true_stat = np.var(data_2d_features, axis=0)

batchvar = BatchVar()
for batch_data in np.array_split(data_2d_features, n_batches):
batchvar.update_batch(batch=batch_data)
batch_stat = batchvar()
assert np.allclose(true_stat, batch_stat)

0 comments on commit 86ec257

Please sign in to comment.