Skip to content

Commit

Permalink
feat(ml): tensorrt support for regression models
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz authored and sileht committed Oct 21, 2020
1 parent a8b81f2 commit 77a016b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,7 @@ measure | array | yes | empty | Output measures requested,
template | string | yes | empty | Output template in Mustache format
confidence_threshold | double | yes | 0.0 | only returns classifications or detections with probability strictly above threshold
bbox | bool | yes | false | returns bounding boxes around object when using an object detection model
regression | bool | yes | false | whether the output of a model is a regression target (i.e. vector of one or more floats)
The variables that are usable in the output template format are those from the standard JSON output. See the [output template](#output-templates) dedicated section for more details and examples.
Expand Down
23 changes: 17 additions & 6 deletions src/backends/tensorrt/tensorrtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ namespace dd
_readEngine = tl._readEngine;
_writeEngine = tl._writeEngine;
_TRTContextReady = tl._TRTContextReady;
_timeserie = tl._timeserie;
_buffers = tl._buffers;
_bbox = tl._bbox;
_ctc = tl._ctc;
_timeserie = tl._timeserie;
_regression = tl._regression;
_inputIndex = tl._inputIndex;
_outputIndex0 = tl._outputIndex0;
_outputIndex1 = tl._outputIndex1;
Expand Down Expand Up @@ -383,6 +384,9 @@ namespace dd
{
if (ad_output.has("bbox"))
_bbox = ad_output.get("bbox").get<bool>();
if (ad_output.has("regression"))
_regression = ad_output.get("regression").get<bool>();

// Ctc model
if (ad_output.has("ctc"))
{
Expand Down Expand Up @@ -410,6 +414,8 @@ namespace dd
throw MLLibBadParamException(
"timeseries not yet implemented over tensorRT backend");
}
else if (_regression)
out_blob = "pred";

if (_nclasses == 0)
{
Expand Down Expand Up @@ -529,9 +535,12 @@ namespace dd
throw MLLibBadParamException(
"timeseries not yet implemented over tensorRT backend");
}
else // classification
else // classification / regression
{
_buffers.resize(2);
if (_regression)
_buffers.resize(1);
else
_buffers.resize(2);
_floatOut.resize(_max_batch_size * this->_nclasses);
if (inputc._bw)
cudaMalloc(&_buffers.data()[_inputIndex],
Expand Down Expand Up @@ -612,7 +621,7 @@ namespace dd
throw MLLibBadParamException(
"timeseries not yet implemented over tensorRT backend");
}
else // classification
else // classification / regression
{
if (inputc._bw)
cudaMemcpyAsync(_buffers.data()[_inputIndex], inputc.data(),
Expand Down Expand Up @@ -745,7 +754,7 @@ namespace dd
throw MLLibBadParamException(
"timeseries not yet implemented over tensorRT backend");
}
else // classification
else // classification / regression
{
for (int j = 0; j < num_processed; j++)
{
Expand All @@ -761,7 +770,7 @@ namespace dd
for (int i = 0; i < _nclasses; i++)
{
double prob = _floatOut.at(j * _nclasses + i);
if (prob < confidence_threshold)
if (prob < confidence_threshold && !_regression)
continue;
probs.push_back(prob);
cats.push_back(this->_mlmodel.get_hcorresp(i));
Expand All @@ -782,6 +791,8 @@ namespace dd
out.add("nclasses", this->_nclasses);
if (_bbox)
out.add("bbox", true);
if (_regression)
out.add("regression", true);
out.add("roi", false);
out.add("multibox_rois", false);
tout.finalize(ad.getobj("parameters").getobj("output"), out,
Expand Down
3 changes: 2 additions & 1 deletion src/backends/tensorrt/tensorrtlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,12 @@ namespace dd

bool _bbox = false;
bool _ctc = false;
bool _regression = false;
bool _timeserie = false;

std::vector<void *> _buffers;

bool _TRTContextReady = false;
bool _timeserie = false;

int _inputIndex;
int _outputIndex0;
Expand Down

0 comments on commit 77a016b

Please sign in to comment.