Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scale function(_get_mean_var) updated for dense array, speedup upto ~4.65x #3099

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

ashish615
Copy link
Contributor

@ashish615 ashish615 commented Jun 5, 2024

Hi,
We are submitting PR for speed up of the _get_mean_var function.

Time(sec)
Original 18.49
Updated 3.97
Speedup 4.65743073

experiment setup : AWS r7i.24xlarge

import time
import numpy as np

import pandas as pd

import scanpy as sc
from sklearn.cluster import KMeans

import os
import wget

import warnings



warnings.filterwarnings('ignore', 'Expected ')
warnings.simplefilter('ignore')
input_file = "./1M_brain_cells_10X.sparse.h5ad"

if not os.path.exists(input_file):
    print('Downloading import file...')
    wget.download('https://rapids-single-cell-examples.s3.us-east-2.amazonaws.com/1M_brain_cells_10X.sparse.h5ad',input_file)


# marker genes
MITO_GENE_PREFIX = "mt-" # Prefix for mitochondrial genes to regress out
markers = ["Stmn2", "Hes1", "Olig1"] # Marker genes for visualization

# filtering cells
min_genes_per_cell = 200 # Filter out cells with fewer genes than this expressed
max_genes_per_cell = 6000 # Filter out cells with more genes than this expressed

# filtering genes
min_cells_per_gene = 1 # Filter out genes expressed in fewer cells than this
n_top_genes = 4000 # Number of highly variable genes to retain

# PCA
n_components = 50 # Number of principal components to compute

# t-SNE
tsne_n_pcs = 20 # Number of principal components to use for t-SNE

# k-means
k = 35 # Number of clusters for k-means

# Gene ranking

ranking_n_top_genes = 50 # Number of differential genes to compute for each cluster

# Number of parallel jobs
sc._settings.ScanpyConfig.n_jobs = os.cpu_count()

start=time.time()
tr=time.time()
adata = sc.read(input_file)
adata.var_names_make_unique()
adata.shape
print("Total read time : %s" % (time.time()-tr))



tr=time.time()
# To reduce the number of cells:
USE_FIRST_N_CELLS = 1300000
adata = adata[0:USE_FIRST_N_CELLS]
adata.shape

sc.pp.filter_cells(adata, min_genes=min_genes_per_cell)
sc.pp.filter_cells(adata, max_genes=max_genes_per_cell)
sc.pp.filter_genes(adata, min_cells=min_cells_per_gene)
sc.pp.normalize_total(adata, target_sum=1e4)
print("Total filter and normalize time : %s" % (time.time()-tr))


tr=time.time()
sc.pp.log1p(adata)
print("Total log time : %s" % (time.time()-tr))


# Select highly variable genes
sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor = "cell_ranger")

# Retain marker gene expression
for marker in markers:
        adata.obs[marker + "_raw"] = adata.X[:, adata.var.index == marker].toarray().ravel()

# Filter matrix to only variable genes
adata = adata[:, adata.var.highly_variable]

ts=time.time()
#Regress out confounding factors (number of counts, mitochondrial gene expression)
mito_genes = adata.var_names.str.startswith(MITO_GENE_PREFIX)
n_counts = np.array(adata.X.sum(axis=1))
adata.obs['percent_mito'] = np.array(np.sum(adata[:, mito_genes].X, axis=1)) / n_counts
adata.obs['n_counts'] = n_counts


sc.pp.regress_out(adata, ['n_counts', 'percent_mito'])
print("Total regress out time : %s" % (time.time()-ts))

#scale

ts=time.time()
sc.pp.scale(adata)
print("Total scale time : %s" % (time.time()-ts))

add timer around _get_mean_var call

mean, var = _get_mean_var(X)

we can also create _get_mean_var_std function that return std as well so we don't require to compute it in scale function(L168-L169).

@ashish615 ashish615 changed the title scale function updated for dense array, speedup upto ~4.65x scale function(_get_mean_var) updated for dense array, speedup upto ~4.65x Jun 5, 2024
Copy link

codecov bot commented Jun 5, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 76.31%. Comparing base (896e249) to head (7a1a62e).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3099      +/-   ##
==========================================
- Coverage   76.31%   76.31%   -0.01%     
==========================================
  Files         109      109              
  Lines       12513    12516       +3     
==========================================
+ Hits         9549     9551       +2     
- Misses       2964     2965       +1     
Files Coverage Δ
src/scanpy/preprocessing/_utils.py 95.12% <100.00%> (-2.25%) ⬇️

@Zethson Zethson requested a review from Intron7 June 13, 2024 13:16
@Zethson Zethson added this to the 1.10.2 milestone Jun 16, 2024
Copy link

scverse-benchmark bot commented Jun 17, 2024

Benchmark changes

Change Before [ad657ed] After [e7a4662] Ratio Benchmark (Parameter)
+ 259M 310M 1.2 preprocessing_log.FastSuite.peakmem_mean_var('pbmc68k_reduced')
+ 1.16±0.04ms 1.97±0.5ms 1.69 preprocessing_log.FastSuite.time_mean_var('pbmc68k_reduced')
+ 255M 315M 1.23 preprocessing_log.peakmem_highly_variable_genes('pbmc68k_reduced')
- 373M 322M 0.86 preprocessing_log.peakmem_pca('pbmc68k_reduced')
- 1.03G 779M 0.76 preprocessing_log.peakmem_scale('pbmc3k')
- 729±5ms 517±5ms 0.71 preprocessing_log.time_scale('pbmc3k')

Comparison: https://github.com/scverse/scanpy/compare/ad657edfb52e9957b9a93b3a16fc8a87852f3f09..e7a466265b08f6973a5cf3fecfc27879104c02f4
Last changed: Tue, 18 Jun 2024 18:39:49 +0000

More details: https://github.com/scverse/scanpy/pull/3099/checks?check_run_id=26384736173

@Intron7
Copy link
Member

Intron7 commented Jun 20, 2024

I have some small improvements that I would like to add next week for more precision for larger matrices

@ilan-gold ilan-gold modified the milestones: 1.10.2, 1.10.3 Jun 25, 2024
@Intron7
Copy link
Member

Intron7 commented Jun 26, 2024

@ashish615 after doing some benchmarking myself I found out that your solution for axis=1 is under performing compared to axis=0 for larger arrays. I think that is because of the memory access pattern you choose. I rewrote the function with that in mind. I'll again make a PR to you, because for some reason you disallow us from making changes to your PR.

@Intron7 Intron7 self-requested a review June 26, 2024 11:09
Copy link
Member

@Intron7 Intron7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please merge IntelLabs#2

docs/release-notes/1.10.2.md Outdated Show resolved Hide resolved
@Intron7 Intron7 self-requested a review June 26, 2024 14:02
Copy link
Member

@Intron7 Intron7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes the issues

src/scanpy/preprocessing/_utils.py Outdated Show resolved Hide resolved
src/scanpy/preprocessing/_utils.py Outdated Show resolved Hide resolved
remove casting to match previous behavior

Co-authored-by: Severin Dicks <[email protected]>
Copy link
Member

@Intron7 Intron7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

Copy link
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t see the claimed speedup in the benchmarks, what’s missing?

Also is numba.get_num_threads() safe? E.g. I think _get_mean_var is also called in each dask chunks. Will numba.get_num_threads() return a reasonable number in that case?

Otherwise nice! I’m not a huge fan of how unpythonic numba code looks, but I don’t think anything can be done about that.

Comment on lines +46 to +47
# enforce R convention (unbiased estimator) for variance
var *= X.shape[axis] / (X.shape[axis] - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before your change, this line ran unconditionally, now it only runs for the not isinstance(X, np.ndarray) case. Is that intentional? Then you should mention that in _compute_mean_var’s docstring.



@numba.njit(cache=True, parallel=True)
def _compute_mean_var(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have _get_mean_var. Maybe rename this to _get_mean_var_ndarray or _get_mean_var_dense?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can rename the kernel


@numba.njit(cache=True, parallel=True)
def _compute_mean_var(
X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1
X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads: int = 1

Comment on lines +55 to +83
if axis == 0:
axis_i = 1
sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64)
sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64)
mean = np.zeros(X.shape[axis_i], dtype=np.float64)
var = np.zeros(X.shape[axis_i], dtype=np.float64)
n = X.shape[axis]
for i in numba.prange(n_threads):
for r in range(i, n, n_threads):
for c in range(X.shape[axis_i]):
value = X[r, c]
sums[i, c] += value
sums_squared[i, c] += value * value
for c in numba.prange(X.shape[axis_i]):
sum_ = sums[:, c].sum()
mean[c] = sum_ / n
var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1)
else:
axis_i = 0
mean = np.zeros(X.shape[axis_i], dtype=np.float64)
var = np.zeros(X.shape[axis_i], dtype=np.float64)
for r in numba.prange(X.shape[0]):
for c in range(X.shape[1]):
value = X[r, c]
mean[r] += value
var[r] += value * value
for c in numba.prange(X.shape[0]):
mean[c] = mean[c] / X.shape[1]
var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don’t duplicate identical lines.

Suggested change
if axis == 0:
axis_i = 1
sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64)
sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64)
mean = np.zeros(X.shape[axis_i], dtype=np.float64)
var = np.zeros(X.shape[axis_i], dtype=np.float64)
n = X.shape[axis]
for i in numba.prange(n_threads):
for r in range(i, n, n_threads):
for c in range(X.shape[axis_i]):
value = X[r, c]
sums[i, c] += value
sums_squared[i, c] += value * value
for c in numba.prange(X.shape[axis_i]):
sum_ = sums[:, c].sum()
mean[c] = sum_ / n
var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1)
else:
axis_i = 0
mean = np.zeros(X.shape[axis_i], dtype=np.float64)
var = np.zeros(X.shape[axis_i], dtype=np.float64)
for r in numba.prange(X.shape[0]):
for c in range(X.shape[1]):
value = X[r, c]
mean[r] += value
var[r] += value * value
for c in numba.prange(X.shape[0]):
mean[c] = mean[c] / X.shape[1]
var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1)
axis_i = 1 - axis
mean = np.zeros(X.shape[axis_i], dtype=np.float64)
var = np.zeros(X.shape[axis_i], dtype=np.float64)
if axis == 0:
sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64)
sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64)
n = X.shape[axis]
for i in numba.prange(n_threads):
for r in range(i, n, n_threads):
for c in range(X.shape[axis_i]):
value = X[r, c]
sums[i, c] += value
sums_squared[i, c] += value * value
for c in numba.prange(X.shape[axis_i]):
sum_ = sums[:, c].sum()
mean[c] = sum_ / n
var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1)
else:
for r in numba.prange(X.shape[0]):
for c in range(X.shape[1]):
value = X[r, c]
mean[r] += value
var[r] += value * value
for c in numba.prange(X.shape[0]):
mean[c] = mean[c] / X.shape[1]
var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can slim this down a bit. The two different loops need to be separate though

@Intron7
Copy link
Member

Intron7 commented Jun 27, 2024

The function should also work for 1 thread. numba.get_num_threads() is fine it works well with the sparse arrays. But I have no experience with it inside of dask.


@numba.njit(cache=True, parallel=True)
def _compute_mean_var(
X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think _SupportedArray is the right type annotation here. This doesn't run directly on dask.Array, unless I am misunderstanding something.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants