Skip to content

Commit

Permalink
updating for ARO 2022 updated model loader
Browse files Browse the repository at this point in the history
  • Loading branch information
buswinka committed Feb 8, 2022
1 parent c4bf0e4 commit 6dbc861
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 39 deletions.
63 changes: 29 additions & 34 deletions hcat/backends/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,58 @@
import re
import torch
import wget
import hcat
from hcat.backends.convNeXt import ConvNeXt

from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.anchor_utils import AnchorGenerator

# load a pre-trained model for classification and return
backbone = ConvNeXt(in_channels=3, dims=[128, 256, 512, 1024], depths=[3, 3, 27, 3], out_channels=256)
path = '/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/models/convnext_base_22k_1k_384.pth'
def _init_model():
backbone = ConvNeXt(in_channels=3, dims=[128, 256, 512, 1024], depths=[3, 3, 27, 3], out_channels=256)
path = '/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/models/convnext_base_22k_1k_384.pth'

if os.path.exists(path):
state_dict = torch.load(path)
backbone.load_state_dict(state_dict['model'], strict=False)
if os.path.exists(path):
state_dict = torch.load(path)
backbone.load_state_dict(state_dict['model'], strict=False)

backbone.out_channels = 256
backbone.out_channels = 256

backbone = torch.jit.script(backbone)
backbone = torch.jit.script(backbone)

anchor_sizes = ((16,), (32,), (64,), (128,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
anchor_generator = AnchorGenerator(sizes=anchor_sizes,
aspect_ratios=aspect_ratios)
anchor_sizes = ((16,), (32,), (64,), (128,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
anchor_generator = AnchorGenerator(sizes=anchor_sizes,
aspect_ratios=aspect_ratios)

roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=7, sampling_ratio=2)
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=7, sampling_ratio=2)

HairCellConvNext = FasterRCNN(backbone,
num_classes=3,
rpn_anchor_generator=anchor_generator,
# box_roi_pool=roi_pooler,
min_size=256, max_size=600, )
HairCellConvNext = FasterRCNN(backbone,
num_classes=3,
rpn_anchor_generator=anchor_generator,
# box_roi_pool=roi_pooler,
min_size=256, max_size=600, )
return HairCellConvNext


HairCellConvNext = _init_model()


def FasterRCNN_from_url(url: str, device: str, model: FasterRCNN = HairCellConvNext):
""" loads model from url """
# path = os.path.join(hcat.__path__[0], 'Max_project_detection_resnet.trch')
path = '/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/models/Max_project_detection_resnet_NeXT101_24January_2022_MinValidation.trch' # os.path.join(hcat.__path__[0], 'Max_project_detection_resnet_NeXT101_Jan7_2022_eFinal.trch')

convnext = '/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/models/Max_project_detection_resnet_convNeXt_28January_2022_MinValidation.trch'
# convnext_nocommunitydata = '/media/DataStorage/Dropbox (Partners HealthCare)/HairCellInstance/models/Max_project_detection_resnet_NeXT101_28January_2022_eFinal_just_arturlabtrainingdata.trch'
path = os.path.join(hcat.__path__[0], 'detection.trch')

if not os.path.exists(path):
print('Downloading Model File: ')
wget.download(url=url, out=path)
print(' ')

model = model.requires_grad_(False)
try:
checkpoint = torch.load(path, map_location=torch.device('cpu'))
if 'model_state_dict' in checkpoint:
checkpoint = checkpoint['model_state_dict']
model = model.to(device)
model.load_state_dict(checkpoint)
except:
checkpoint = torch.load(convnext, map_location=torch.device('cpu'))
if 'model_state_dict' in checkpoint:
checkpoint = checkpoint['model_state_dict']
model = model.to(device)
model.load_state_dict(checkpoint)
checkpoint = torch.load(path, map_location=torch.device('cpu'))
if 'model_state_dict' in checkpoint:
checkpoint = checkpoint['model_state_dict']
model = model.to(device)
model.load_state_dict(checkpoint)

return model.eval().to(memory_format=torch.channels_last)

Expand Down
4 changes: 2 additions & 2 deletions modelfiles/detection.trch
Git LFS file not shown
3 changes: 0 additions & 3 deletions modelfiles/fasterrcnn.trch

This file was deleted.

0 comments on commit 6dbc861

Please sign in to comment.