Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add]: TensorRT Support #14

Merged
merged 4 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions SimpleHigherHRNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import numpy as np
import torch
from torchvision.transforms import transforms

import tensorrt as trt
from models.higherhrnet import HigherHRNet
from misc.HeatmapParser import HeatmapParser
from misc.utils import get_multi_scale_size, resize_align_multi_scale, get_multi_stage_outputs, aggregate_results, get_final_preds, bbox_iou
from misc.utils import get_multi_scale_size, resize_align_multi_scale, get_multi_stage_outputs, aggregate_results, get_final_preds, bbox_iou,TRTModule_hrnet
from collections import OrderedDict,namedtuple


class SimpleHigherHRNet:
Expand All @@ -30,7 +31,8 @@ def __init__(self,
filter_redundant_poses=True,
max_nof_people=30,
max_batch_size=32,
device=torch.device("cpu")):
device=torch.device("cpu"),
enable_tensorrt=False):
"""
Initializes a new SimpleHigherHRNet object.
HigherHRNet is initialized on the torch.device("device") and
Expand Down Expand Up @@ -74,6 +76,7 @@ def __init__(self,
self.max_nof_people = max_nof_people
self.max_batch_size = max_batch_size
self.device = device
self.enable_tensorrt=enable_tensorrt

# assert nof_joints in (14, 15, 17)
if self.nof_joints == 14:
Expand All @@ -90,33 +93,36 @@ def __init__(self,
else:
raise ValueError('Wrong model name.')

checkpoint = torch.load(checkpoint_path, map_location=self.device)
if 'model' in checkpoint:
checkpoint = checkpoint['model']
# fix issue with official high-resolution weights
checkpoint = OrderedDict([(k[2:] if k[:2] == '1.' else k, v) for k, v in checkpoint.items()])
self.model.load_state_dict(checkpoint)

if 'cuda' in str(self.device):
print("device: 'cuda' - ", end="")
if not self.enable_tensorrt:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
if 'model' in checkpoint:
checkpoint = checkpoint['model']
# fix issue with official high-resolution weights
checkpoint = OrderedDict([(k[2:] if k[:2] == '1.' else k, v) for k, v in checkpoint.items()])
self.model.load_state_dict(checkpoint)
if 'cuda' in str(self.device):
print("device: 'cuda' - ", end="")

if 'cuda' == str(self.device):
# if device is set to 'cuda', all available GPUs will be used
print("%d GPU(s) will be used" % torch.cuda.device_count())
device_ids = None
else:
# if device is set to 'cuda:IDS', only that/those device(s) will be used
print("GPU(s) '%s' will be used" % str(self.device))
device_ids = [int(x) for x in str(self.device)[5:].split(',')]
self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)

if 'cuda' == str(self.device):
# if device is set to 'cuda', all available GPUs will be used
print("%d GPU(s) will be used" % torch.cuda.device_count())
device_ids = None
elif 'cpu' == str(self.device):
print("device: 'cpu'")
else:
# if device is set to 'cuda:IDS', only that/those device(s) will be used
print("GPU(s) '%s' will be used" % str(self.device))
device_ids = [int(x) for x in str(self.device)[5:].split(',')]

self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
elif 'cpu' == str(self.device):
print("device: 'cpu'")
raise ValueError('Wrong device name.')
self.model = self.model.to(device)
self.model.eval()
else:
raise ValueError('Wrong device name.')

self.model = self.model.to(device)
self.model.eval()
if device.type == 'cpu':
raise ValueError('TensorRT does not support cpu device.')
self.model=TRTModule_hrnet(path=checkpoint_path,device=self.device)

self.output_parser = HeatmapParser(num_joints=self.nof_joints,
joint_set=self.joint_set,
Expand Down Expand Up @@ -201,6 +207,7 @@ def _predict_batch(self, image):
image = image.to(self.device)
images.append(image)
images = torch.cat(images)
# images=images

# inference
# output: list of HigherHRNet outputs (heatmaps)
Expand Down
Binary file added misc/__pycache__/HeatmapParser.cpython-38.pyc
Binary file not shown.
Binary file added misc/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added misc/__pycache__/utils.cpython-38.pyc
Binary file not shown.
Binary file added misc/__pycache__/visualization.cpython-38.pyc
Binary file not shown.
117 changes: 116 additions & 1 deletion misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import munkres
import numpy as np
import torch

from collections import OrderedDict,namedtuple
import tensorrt as trt
# starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)

# solution proposed in https://github.com/pytorch/pytorch/issues/229#issuecomment-299424875
def flip_tensor(tensor, dim=0):
Expand Down Expand Up @@ -370,7 +372,17 @@ def get_multi_stage_outputs(model, image,
# but it could also be (no checkpoints with this configuration)
# [(batch, nof_joints*2, height//4, width//4), (batch, nof_joints*2, height//2, width//2), (batch, nof_joints, height, width)]
if len(image) <= max_batch_size:
# print(image.size())
# starter.record()

outputs = model(image)

# ender.record()
# WAIT FOR GPU SYNC
# torch.cuda.synchronize()
# curr_time = starter.elapsed_time(ender)
# print(curr_time)

else:
outputs = [
torch.empty((image.shape[0], nof_joints * 2, image.shape[-2] // 4, image.shape[-1] // 4),
Expand Down Expand Up @@ -593,3 +605,106 @@ def get_final_preds(grouped_joints, center, scale, heatmap_size):
return final_results
#
#

def torch_device_from_trt(device):
if device == trt.TensorLocation.DEVICE:
return torch.device("cuda")
elif device == trt.TensorLocation.HOST:
return torch.device("cpu")
else:
return TypeError("%s is not supported by torch" % device)
def torch_dtype_from_trt(dtype):
if dtype == trt.int8:
return torch.int8
elif trt.__version__ >= '7.0' and dtype == trt.bool:
return torch.bool
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
else:
raise TypeError("%s is not supported by torch" % dtype)
class TRTModule_hrnet(torch.nn.Module):
"""
TensorRT wrapper for HigherHRNet.
Args:
path: Path to the .engine file for trt inference.
device: The cuda device to be used

"""
def __init__(self,path=None,device=None):
super(TRTModule_hrnet, self).__init__()
logger = trt.Logger(trt.Logger.INFO)
with open(path, 'rb') as f, trt.Runtime(logger) as runtime:
self.engine = runtime.deserialize_cuda_engine(f.read())
if self.engine is not None:
self.context = self.engine.create_execution_context()
self.input_names = ['images']
self.output_names = []
self.input_flattener = None
self.output_flattener = None
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))

self.bindings = OrderedDict()
fp16 = False # default updated below
dynamic = False
for i in range(self.engine.num_bindings):
name = self.engine.get_binding_name(i)
dtype = trt.nptype(self.engine.get_binding_dtype(i))
if self.engine.binding_is_input(i):
if -1 in tuple(self.engine.get_binding_shape(i)): # dynamic
dynamic = True
self.context.set_binding_shape(i, tuple(self.engine.get_profile_shape(0, i)[2]))
if dtype == np.float16:
fp16 = True
else:
self.output_names.append(name)
shape = tuple(self.context.get_binding_shape(i))
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
self.bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items())
self.batch_size = self.bindings['images'].shape[0]



def forward(self, *inputs):
bindings = [None] * (len(self.input_names) + len(self.output_names))

if self.input_flattener is not None:
inputs = self.input_flattener.flatten(inputs)

for i, input_name in enumerate(self.input_names):
idx = self.engine.get_binding_index(input_name)
shape = tuple(inputs[i].shape)
bindings[idx] = inputs[i].contiguous().data_ptr()
self.context.set_binding_shape(idx, shape)

# create output tensors
outputs = [None] * len(self.output_names)
for i, output_name in enumerate(self.output_names):
idx = self.engine.get_binding_index(output_name)
dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx))
shape = tuple(self.context.get_binding_shape(idx))
device = torch_device_from_trt(self.engine.get_location(idx))
output = torch.empty(size=shape, dtype=dtype, device=device)
outputs[i] = output
bindings[idx] = output.data_ptr()

self.context.execute_async_v2(
bindings, torch.cuda.current_stream().cuda_stream
)

if self.output_flattener is not None:
outputs = self.output_flattener.unflatten(outputs)
else:
outputs = tuple(outputs)
if len(outputs) == 1:
outputs = outputs[0]

return outputs

def enable_profiling(self):
if not self.context.profiler:
self.context.profiler = trt.Profiler()
Loading