Skip to content

Commit

Permalink
Updated GUI to allow for selecting custom models. Minor Bug Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris committed Feb 22, 2023
1 parent 0d258d6 commit d5a7c25
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 27 deletions.
24 changes: 12 additions & 12 deletions hcat/backends/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,31 @@ def init_model():
return HairCellConvNext


def FasterRCNN_from_url(url: str, device: str):
def FasterRCNN_from_url(url: str, device: str, path: str = None):
"""
Loads a FasterRCNN model from a url OR from a local source if available.
:param url: URL of pretrained model path. Will save the model to the source directory of HCAT.
:param device: Device to load the model to.
:return:
"""
path = os.path.join(hcat.__path__[0], 'detection_trained_model.trch')
print(path)

# Research Purposes...
# convnext = '/home/chris/Dropbox (Partners HealthCare)/HairCellInstance/Max_project_detection_resnet_NeXT101_04March_2022_MinValidation.trch'
# convnext = '/home/chris/Dropbox (Partners HealthCare)/trainHairCellDetection/models/Apr05_09-52-47_CHRISUBUNTU.trch'
# convnext = '/home/chris/Dropbox (Partners HealthCare)/trainHairCellDetection/models/Apr25_10-08-35_CHRISUBUNTU.trch'
# convnext = '/home/chris/Dropbox (Partners HealthCare)/trainHairCellDetection/models/Jul15_17-25-42_CHRISUBUNTU.trch'
# convnext_nocommunitydata = '/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/models/Max_project_detection_resnet_NeXT101_28January_2022_eFinal_just_arturlabtrainingdata.trch'
if path is None: # we dont have a user defined file...
path = os.path.join(hcat.__path__[0], 'detection_trained_model.trch')
if not os.path.exists(path):
wget.download(url=url, out=path) # this will download the file...

# Check if the path exists from the user...
if not os.path.exists(path):
wget.download(url=url, out=path)
raise RuntimeError(f'Could not locate the file at path: {path}')

model = init_model()

print(path)
checkpoint = torch.load(path, map_location=torch.device('cpu'))
try:
checkpoint = torch.load(path, map_location=torch.device('cpu'))
except:
raise RuntimeError(f'Could not load torch model from file: {path}')

if 'model_state_dict' in checkpoint:
checkpoint = checkpoint['model_state_dict']
model = model.to(device)
Expand Down
2 changes: 1 addition & 1 deletion hcat/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _detect(f: Optional[str] = None,
'limited. Consider Manual Calculation.', color='yellow')

# curvature estimation really only works if there is a lot of tissue...
if distance is not None and distance.max() > 4000:
if distance is not None and distance.max() > 100:
for c in cells: c.calculate_frequency(curvature[[0, 1], :], distance) # calculate cell's best frequency
cells = [c for c in cells if not c._distance_is_far_away] # remove a cell if its far away from curve

Expand Down
41 changes: 35 additions & 6 deletions hcat/detect_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ class gui:
def __init__(self):
sg.theme('DarkGrey5')
plt.ioff()
device_str = 'CUDA' if torch.cuda.is_available() else 'CPU'
# sg.set_options(font='Any')

button_column = [
[sg.FileBrowse(size=(16, 1), enable_events=True), ],
[sg.Button('Load', size=(16, 1)), ],
[sg.Button('Save', size=(16, 1))],
[sg.Button("⇦",k="previous_image", size=(4, 1)), sg.Push(), sg.Button("⇨", k="next_image", size=(4,1))],
[sg.Button("⇦", k="previous_image", size=(4, 1)), sg.Push(), sg.Button("⇨", k="next_image", size=(4,1))],
[sg.HorizontalSeparator(p=(0, 20))],
[sg.Text('Cell Diameter\n(In Pixels)')],
[sg.Input(size=(10, 1), enable_events=True, default_text=30, key='Diameter'),
Expand All @@ -49,6 +50,12 @@ def __init__(self):
[sg.Text('Overlap Threshold')],
[sg.Slider(range=(0, 100), orientation='h', enable_events=True, default_value=30, key='NMS', expand_x=True)],
[sg.HorizontalSeparator(p=(0, 20))],
[sg.Text('Model Selection', size=(15, 1), pad=(0, 10))],
[sg.FileBrowse(button_text='Select', k='select_model',
size=(15, 1), enable_events=True, file_types=('Pytorch Model', '*.trch'))],
[sg.Button('Reset', k='reset_model', size=(15,1))],
[sg.Text('Default', k='model_selection_text', size=(15, 1))],
[sg.HorizontalSeparator(p=(0, 20))],
[sg.Button('Run Analysis', size=(15, 1))],
[sg.Check(text=' Live Update', key='live_update', enable_events=True)],
[sg.Check(k='savexml', text=' Save XML', default=False, enable_events=True)]
Expand Down Expand Up @@ -96,6 +103,7 @@ def __init__(self):
[sg.HorizontalSeparator(pad=(0, 30))],
[sg.Text('OHC: None', key='OHC_count')],
[sg.Text('IHC: None', key='IHC_count')],
[sg.Text(f'Hardware Accelerator: {device_str}', key='device')]
]

layout = [[sg.Column(button_column, vertical_alignment='Top'),
Expand Down Expand Up @@ -141,7 +149,10 @@ def __init__(self):
self.rgb = None
self.contrast = None
self.brightness = None

self.model = None
self.model_path = None

self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
self.cochlea = None

Expand All @@ -153,7 +164,6 @@ def __init__(self):
self.labels = None
self.scores = None

self.model = None


def main_loop(self):
Expand All @@ -171,6 +181,22 @@ def main_loop(self):
if event == 'pan':
print('CLICKED: ', values['pan'])

if event == 'select_model':
self.model_path = values['select_model']
self.window['model_selection_text'].update(f'Custom')
try:
self.model = FasterRCNN_from_url(url=None, device=self.device, path=self.model_path)
except:
sg.popup_ok('Could not load the model from file.')
self.model_path = None
self.window['model_selection_text'].update(f'Default')

if event == 'reset_model':
self.model_path = None
self.model = None
self.window['model_selection_text'].update(f'Default')


if event == 'Exit' or event == sg.WIN_CLOSED:
return

Expand Down Expand Up @@ -202,7 +228,6 @@ def main_loop(self):
self.current_image_index = i
break


if self.valid_image_files[self.current_image_index] != values['Browse']:
print(f'{self.current_image_index}, {self.valid_image_files}')

Expand Down Expand Up @@ -283,8 +308,8 @@ def main_loop(self):

if event == 'Run Analysis' and self.__LOADED__:
# self.run_detection_model()
self.fast_model()
self.threshold_and_nms(values['Threshold'], values['NMS'])
self.fast_model() # runs model on image...
self.threshold_and_nms(values['Threshold'], values['NMS']) # removing junk data
self.draw_image()

if event in ['Threshold', 'NMS']:
Expand Down Expand Up @@ -566,7 +591,11 @@ def fast_model(self):
if self.model is None:
__model_url__ = 'https://www.dropbox.com/s/opf43jwcbgz02vm/detection_trained_model.trch?dl=1'
sg.popup_quick_message('Loading model from file. May take a while.')
self.model = FasterRCNN_from_url(url=__model_url__, device=self.device)
try:
self.model = FasterRCNN_from_url(url=__model_url__, device=self.device)
except:
sg.popup_ok('Could not load the model from file.')
return None

_image = self.scaled_image.clone()[0:3,...]

Expand Down
8 changes: 4 additions & 4 deletions hcat/lib/cochlea.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def write_csv(self, filename: Optional[Union[bool, str]] = None) -> None:
:return: None
"""
if self.analysis_type == 'segment':
label = 'cellID,frequency,percent_loc,x_loc,y_loc,z_loc,volume,summed,'
label = 'cellID,distance,frequency,percent_loc,x_loc,y_loc,z_loc,volume,summed,'
for c in ['myo', 'dapi', 'actin', 'gfp']:
label += f'{c}_mean,{c}_median,{c}_std,{c}_var,{c}_min,{c}_max,{c}_%zero,{c}_%saturated,'

Expand All @@ -266,7 +266,7 @@ def write_csv(self, filename: Optional[Union[bool, str]] = None) -> None:
f.write(label[:-1:] + '\n') # index to remove final comma

for cell in self.cells:
f.write(f'{cell.id},{cell.frequency},{cell.percent_loc},')
f.write(f'{cell.id},{cell.distance},{cell.frequency},{cell.percent_loc},')
f.write(f'{cell.loc[1]},{cell.loc[2]},{cell.loc[3]},{cell.volume},{cell.summed},')

for id in cell.channel_names:
Expand All @@ -277,7 +277,7 @@ def write_csv(self, filename: Optional[Union[bool, str]] = None) -> None:
f.write('\n')
f.close()
elif self.analysis_type == 'detect':
label = 'cellID,type,score,frequency,percent_loc,x_loc,y_loc'
label = 'cellID,type,score,distance,frequency,percent_loc,x_loc,y_loc'

if filename is None and self.path is not None:
filename = os.path.splitext(self.path)[0] + '.csv' # Remove .lif and add .csv
Expand All @@ -292,7 +292,7 @@ def write_csv(self, filename: Optional[Union[bool, str]] = None) -> None:
f.write(label[:-1:] + '\n') # index to remove final comma

for cell in self.cells:
f.write(f'{cell.id},{cell.type},{cell.scores},{cell.frequency},{cell.percent_loc},')
f.write(f'{cell.id},{cell.type},{cell.scores},{cell.distance},{cell.frequency},{cell.percent_loc},')
f.write(f'{cell.loc[1]},{cell.loc[2]}')
f.write('\n')
f.close()
Expand Down
7 changes: 6 additions & 1 deletion hcat/lib/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,12 @@ def __call__(self,
3) The estimated apex of the cochlea
"""
curvature, distance, apex = None, None, None
if self.method is None:

if csv is not None:
print('RUNNING EPL METHOD')
curvature, distance, apex = self.fitEPL(csv)

elif self.method is None:
if csv is not None:
curvature, distance, apex = self.fitEPL(csv)

Expand Down
4 changes: 3 additions & 1 deletion hcat/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def load(file: str, header_name: Optional[str] = 'TileScan 1 Merged',
pass
elif image_base.ndim == 3 and image_base.shape[0] <= 4: # Suppose you load a 2D image! with multiple channels
pass
elif image_base.ndim == 2:
image_base = image_base[np.newaxis, ...]
else:
print(
f'\x1b[1;31;40m' + f'Cannot load: \'{file}\'. Unsupported number of dimmensions: {image_base.ndim}' + '\x1b[0m')
Expand Down Expand Up @@ -478,7 +480,7 @@ def cochlea_to_xml(cochlea, filename: str) -> None:
tree.write(filename + '.xml')


def normalize_image(image: Tensor, verbose: Optional[bool] = False) -> Tensor:
def normalize_image(image: Tensor, *args, verbose: Optional[bool] = False) -> Tensor:
"""
Normalizes each channel in an image such that a majority of the image lies between 0 and 1.
Calculates the maximum value of the image following a gaussian blur with a 7x7 kernel size,
Expand Down
9 changes: 8 additions & 1 deletion hcat/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from hcat.detect import _detect
from hcat.detect_gui import gui
import os.path

import glob

Expand Down Expand Up @@ -32,7 +33,7 @@ def cli(ctx):
@click.option('--normalize', is_flag=True, help='Threshold (between 0 and 1) of cell detection.')
@click.option('--pixel_size', default=None, help='Pixel size in nm')
@click.option('--cell_diameter', default=None, help='Cell diameter in pixels')
@click.option('--predict_curvature', default=None, help='Cell diameter in pixels')
@click.option('--predict_curvature', is_flag=True, help='Cell diameter in pixels')
@click.option('--silent', default=False, help="Suppresses most of HCAT's logging ")
def detect(f: str, curve_path, cell_detection_threshold, nms_threshold, save_xml, save_png,
save_fig, normalize, pixel_size, dtype, cell_diameter, predict_curvature, silent):
Expand All @@ -44,6 +45,12 @@ def detect(f: str, curve_path, cell_detection_threshold, nms_threshold, save_xml

files = glob.glob(f)
for filename in files:
curve_path = filename[:-4:] + '_path.csv'
print('CURVE PATH: ', curve_path)

if not os.path.exists(filename[:-4:] + '_path.csv'):
raise ValueError(curve_path)

try:
_detect(f=filename,
curve_path=curve_path,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = hcat
version = 0.2.24
version = 0.2.26
author = Chris Buswinka
author_email = [email protected]
classifiers = Programming Language :: Python :: 3
Expand Down

0 comments on commit d5a7c25

Please sign in to comment.