Skip to content

Commit

Permalink
v0.1.33
Browse files Browse the repository at this point in the history
  • Loading branch information
buswinka committed Oct 15, 2021
1 parent bb112d4 commit 9175150
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#HCAT - Hair Cell Analysis Tools
# HCAT - Hair Cell Analysis Tools

Hcat is a suite of machine learning enabled algorithms for performing common image analyses in the hearing field.
At present, it performs two fully automated analyses: (1) Volumetric hair cell segmentation and (2) 2D hair cell detection.
Expand Down
Binary file modified hcat/backends/__pycache__/backend.cpython-38.pyc
Binary file not shown.
Binary file modified hcat/backends/__pycache__/spatial_embedding.cpython-38.pyc
Binary file not shown.
7 changes: 3 additions & 4 deletions hcat/backends/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,17 +319,16 @@ def FasterRCNN_from_url(url: str, device: str, model: FasterRCNN = HairCellFaste
""" loads model from url """
path = os.path.join(hcat.__path__[0], 'Max_project_detection_resnet.trch')

if not os.path.exists(path) and _is_url(url):
if not os.path.exists(path):
print('Downloading Model File: ')
wget.download(url=url, out=path)
print(' ')
else:
raise ValueError(f'Url is not valid: {url}')


model = model.requires_grad_(False)
checkpoint = torch.load(path, map_location=torch.device('cpu'))
model = model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.load_state_dict(checkpoint)

for m in model.modules():
if isinstance(m, nn.BatchNorm3d):
Expand Down
18 changes: 10 additions & 8 deletions hcat/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ def _detect(f: str, curve_path: str = None, cell_detection_threshold: float = 0.

with torch.no_grad():

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
print('\x1b[1;32;40mCUDA: GPU successfully initialized!\x1b[0m')
else:
print('\x1b[1;33;40m'
'WARNING: GPU not present or CUDA is not correctly intialized for GPU accelerated computation. '
'Analysis may be slow.'
'\x1b[0m')

# Load and preprocess Image
image_base = load(f, 'TileScan 1 Merged', verbose=True) # from hcat.lib.utils
image_base = image_base[[2, 3],...].max(-1) if image_base.ndim == 4 else image_base
Expand All @@ -50,7 +59,7 @@ def _detect(f: str, curve_path: str = None, cell_detection_threshold: float = 0.

dtype = image_base.dtype if dtype is None else dtype
scale: int = hcat.lib.utils.get_dtype_offset(dtype)
device = 'cuda' if torch.cuda.is_available() else 'cpu'


temp = np.zeros(shape)
temp = np.concatenate((temp, image_base)) / scale * 255
Expand Down Expand Up @@ -83,13 +92,6 @@ def _detect(f: str, curve_path: str = None, cell_detection_threshold: float = 0.
image_base.sub_(0.5).div_(0.5)


if device == 'cuda':
print('\x1b[1;32;40mCUDA: GPU successfully initialized!\x1b[0m')
else:
print('\x1b[1;33;40m'
'WARNING: GPU not present or CUDA is not correctly intialized for GPU accelerated computation. '
'Analysis may be slow.'
'\x1b[0m')

# Initalize the model...
model = FasterRCNN_from_url(url='https://github.com/buswinka/hcat/blob/master/modelfiles/detection.trch?raw=true', device=device)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = hcat
version = 0.1.32
version = 0.1.33
author = Chris Buswinka
author_email = [email protected]
classifiers = Programming Language :: Python :: 3
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
long_description = fh.read()

setuptools.setup(
description="Whole cochlea hair cell segmentation toolkit",
description="A Hair Cell Analysis Toolbox",
long_description=long_description,
long_description_content_type="text/markdown",
packages=setuptools.find_packages(),
Expand All @@ -13,7 +13,7 @@
install_requires = [
'kornia>=0.5.2',
'numpy>=1.10.0',
'torch>=1.7.0',
'torch>=1.9.0',
'matplotlib>=3.3.2',
'scipy>=1.5.4',
'scikit-image>=0.17.2',
Expand All @@ -22,7 +22,7 @@
'lz4>=3.1.3',
'scikit-learn>=0.24.2',
'GPy>=1.10.0',
'torchvision>=0.8.1',
'torchvision>=0.10.0',
'elasticdeform>=0.4.9',
'wget>=3.2']
)

0 comments on commit 9175150

Please sign in to comment.