Skip to content

Commit

Permalink
refactor target assigner
Browse files Browse the repository at this point in the history
  • Loading branch information
jhultman committed Mar 4, 2020
1 parent ca40082 commit 575c555
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 24 deletions.
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
Pointnet2.PyTorch/*
spconv/*
torchsearchsorted/*
thirdparty/*
data/*

pythonpath.sh
**/__pycache__/*
*.bin
*.pth
*.so
Expand Down
43 changes: 26 additions & 17 deletions pvrcnn/core/proposal_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch import nn

from pvrcnn.ops import box_iou_rotated, Matcher
from .anchor_generator import AnchorGenerator


class ProposalTargetAssigner(nn.Module):
Expand All @@ -11,10 +12,10 @@ class ProposalTargetAssigner(nn.Module):
TODO: Make this run faster if possible.
"""

def __init__(self, cfg, anchors):
def __init__(self, cfg):
super(ProposalTargetAssigner, self).__init__()
self.cfg = cfg
self.anchors = anchors.cuda()
self.anchors = AnchorGenerator(cfg).anchors.cuda()
self.matchers = self.build_matchers(cfg)

def build_matchers(self, cfg):
Expand Down Expand Up @@ -65,25 +66,33 @@ def get_reg_targets(self, boxes, box_idx, G_cls):
G_reg = torch.zeros_like(self.anchors).masked_scatter_(M_reg, G_reg)
return G_reg, M_reg

def get_matches(self, boxes, class_idx):
def match_class_i(self, boxes, class_idx, full_idx, i):
class_mask = class_idx == i
anchors = self.anchors[i].view(-1, self.cfg.BOX_DOF)
iou = self.compute_iou(boxes[class_mask], anchors)
matches, labels = self.matchers[i](iou)
if (class_mask).any():
matches = full_idx[class_mask][matches]
return matches, labels

def apply_ignore_mask(self, matches, labels, box_ignore):
"""Ignore anchors matched to boxes[i] if box_ignore[i].
E.g., boxes containing too few lidar points."""
labels[box_ignore[matches] & (labels != -1)] = -1

def match_all_classes(self, boxes, class_idx, box_ignore):
"""Match boxes to anchors based on IOU."""
full_idx = torch.arange(boxes.shape[0])
matches, match_labels = [], []
for i in range(self.cfg.NUM_CLASSES):
if not (class_idx == i).any():
continue
anchors_i = self.anchors[i].view(-1, self.cfg.BOX_DOF)
iou = self.compute_iou(boxes[class_idx == i].cuda(), anchors_i)
_matches, _match_labels = self.matchers[i](iou)
matches += [full_idx[class_idx == i][_matches]]
match_labels += [_match_labels]
classes = range(self.cfg.NUM_CLASSES)
matches, labels = zip(*[self.match_class_i(
boxes, class_idx, full_idx, i) for i in classes])
matches = torch.stack(matches).view(self.anchors.shape[:-1])
match_labels = torch.stack(match_labels).view(self.anchors.shape[:-1])
return matches, match_labels
labels = torch.stack(labels).view(self.anchors.shape[:-1])
return matches, labels

def forward(self, item):
boxes, class_idx = item['boxes'], item['class_idx']
box_idx, G_cls = self.get_matches(boxes, class_idx)
box_idx, G_cls = self.match_all_classes(
item['boxes'].cuda(), item['class_idx'], item['box_ignore'])
G_cls, M_cls = self.get_cls_targets(G_cls)
G_reg, M_reg = self.get_reg_targets(boxes, box_idx, G_cls)
G_reg, M_reg = self.get_reg_targets(item['boxes'], box_idx, G_cls)
item.update(dict(G_cls=G_cls, G_reg=G_reg, M_cls=M_cls, M_reg=M_reg))
7 changes: 4 additions & 3 deletions pvrcnn/dataset/kitti_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os.path as osp
from torch.utils.data import Dataset

from pvrcnn.core import ProposalTargetAssigner, AnchorGenerator
from pvrcnn.core import ProposalTargetAssigner
from .kitti_utils import read_calib, read_label, read_velo
from .augmentation import ChainedAugmentation
from .database_sampler import DatabaseBuilder
Expand Down Expand Up @@ -88,11 +88,13 @@ def filter_out_of_bounds(self, item):
keep = ((xyz >= lower) & (xyz <= upper)).all(1)
item['boxes'] = item['boxes'][keep]
item['class_idx'] = item['class_idx'][keep]
item['box_ignore'] = np.full(keep.sum(), False)

def to_torch(self, item):
item['points'] = np.float32(item['points'])
item['boxes'] = torch.FloatTensor(item['boxes'])
item['class_idx'] = torch.LongTensor(item['class_idx'])
item['box_ignore'] = torch.BoolTensor(item['box_ignore'])

def drop_keys(self, item):
for key in ['velo_path', 'objects', 'calib']:
Expand All @@ -115,10 +117,9 @@ class KittiDatasetTrain(KittiDataset):

def __init__(self, cfg):
super(KittiDatasetTrain, self).__init__(cfg, split='train')
anchors = AnchorGenerator(cfg).anchors
DatabaseBuilder(cfg, self.annotations)
self.augmentation = ChainedAugmentation(cfg)
self.target_assigner = ProposalTargetAssigner(cfg, anchors)
self.target_assigner = ProposalTargetAssigner(cfg)

def preprocessing(self, item):
"""Applies augmentation and assigns targets."""
Expand Down
2 changes: 1 addition & 1 deletion pvrcnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def build_train_dataloader(cfg, preprocessor):
KittiDatasetTrain(cfg),
collate_fn=preprocessor.collate,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=2,
num_workers=1,
)
return dataloader

Expand Down

0 comments on commit 575c555

Please sign in to comment.