From 0789508438f15fe0537c06713d4afc67b9abe4a3 Mon Sep 17 00:00:00 2001 From: jhultman <27909223+jhultman@users.noreply.github.com> Date: Fri, 6 Mar 2020 04:20:45 +0000 Subject: [PATCH] add multiclass-batch inference and refactor dataset --- .gitignore | 5 +- configs/second/car.yaml | 3 +- pvrcnn/core/bev_drawer.py | 84 +++++++++++++++++++++ pvrcnn/core/box_encode.py | 36 +++++++++ pvrcnn/core/config.py | 3 + pvrcnn/core/geometry.py | 65 ++++++++++++++++ pvrcnn/core/proposal_targets.py | 38 ++++------ pvrcnn/dataset/augmentation.py | 55 +++++++++++++- pvrcnn/dataset/database_sampler.py | 116 ----------------------------- pvrcnn/dataset/kitti_dataset.py | 106 ++++++++++++++++---------- pvrcnn/dataset/kitti_utils.py | 51 +++++++++---- pvrcnn/detector/proposal.py | 66 ++++++++++++++-- pvrcnn/detector/second.py | 11 ++- pvrcnn/inference.py | 46 ++++++++---- pvrcnn/ops/iou_nms.py | 1 + pvrcnn/train.py | 21 +++--- 16 files changed, 473 insertions(+), 234 deletions(-) create mode 100644 pvrcnn/core/bev_drawer.py create mode 100644 pvrcnn/core/box_encode.py create mode 100644 pvrcnn/core/geometry.py delete mode 100644 pvrcnn/dataset/database_sampler.py diff --git a/.gitignore b/.gitignore index ee04276..ceffc74 100644 --- a/.gitignore +++ b/.gitignore @@ -12,9 +12,10 @@ build/* dist/* pvrcnn.egg-info/* -notebooks/* +.ipynb_checkpoints/* pvrcnn/eval/* -pvrcnn/core/viz_detections.py +pvrcnn/make_dataset.py images/proposals_* +notebooks/* diff --git a/configs/second/car.yaml b/configs/second/car.yaml index 2be75a0..8e7cb65 100644 --- a/configs/second/car.yaml +++ b/configs/second/car.yaml @@ -6,11 +6,12 @@ ANCHORS: [{ wlh: [1.6, 3.9, 1.56], yaw: [0, 1.501], iou_thresh: [0.45, 0.60], + score_thresh: 0.3, center_z: -1.0, }] NUM_CLASSES: 1 TRAIN: - BATCH_SIZE: 5 + BATCH_SIZE: 4 LAMBDA: 1.0 EPOCHS: 60 AUG: diff --git a/pvrcnn/core/bev_drawer.py b/pvrcnn/core/bev_drawer.py new file mode 100644 index 0000000..6531d7d --- /dev/null +++ b/pvrcnn/core/bev_drawer.py @@ -0,0 +1,84 @@ +import cv2 +import numpy as np + +from .geometry import box3d_to_bev_corners + + +def clipped_percentile(x, p=1): + """Transform to unit interval robustly.""" + p0, p1 = np.percentile(x, [p, 100 - p]) + x = (np.clip(x, p0, p1) - p0) / (p1 - p0 + 1e-1) + return x + + +def make_bev_map(points, pixel_size, bounds): + """Scatter points to create sparse occupancy image.""" + mask = ((points > bounds[:2]) & (points < bounds[2:])).all(1) + shape = np.int32(np.ceil((bounds[2:] - bounds[:2]) / pixel_size))[::-1] + pixels = np.int32(np.floor((points[mask] - bounds[:2]) / pixel_size)) + pixels, counts = np.unique(pixels, return_counts=True, axis=0) + bev_map = np.zeros(shape, dtype=np.float32) + bev_map[tuple(pixels[:, ::-1].T)] = counts + bev_map = clipped_percentile(bev_map) + return bev_map + + +class Drawer: + """Draw BEV occupancy map with boxes. Store in image attribute.""" + + def __init__(self, + points, + boxes=[], + labels=[], + pixel_size=np.r_[0.1, 0.1], + bounds=np.r_[0, -30, 60, 30]): + self.pixel_size = pixel_size + self.bounds = bounds + self.line_kw = dict(thickness=2) + self.text_kw = dict( + fontScale=0.6, + thickness=2, + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + ) + self.image = self.build_bev(points) + self.draw(boxes, labels) + + def build_bev(self, points): + image = make_bev_map( + points[:, :2], self.pixel_size, self.bounds) + image = (image * 255).astype(np.uint8) + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + return image + + def get_text_color(self, n): + color = np.tile(np.r_[255, 0, 0][None], (n, 1)) + color = [list(map(int, c)) for c in color] + return color + + def get_line_color(self, n): + color = np.tile(np.r_[0, 255, 0][None], (n, 1)) + color = [list(int(c) for c in ci) for ci in color] + return color + + def draw_text(self, locs, labels): + colors = self.get_text_color(locs.shape[0]) + locs = map(tuple, locs.astype(np.int32).tolist()) + for label, loc, color in zip(labels, locs, colors): + cv2.putText(self.image, label, loc, color=color, **self.text_kw) + + def draw_lines(self, lines): + colors = self.get_line_color(lines.shape[0]) + lines = map(tuple, lines.astype(np.int32).tolist()) + for line, color in zip(lines, colors): + cv2.line(self.image, line[0:2], line[2:4], color, **self.line_kw) + + def draw(self, boxes_, labels): + """No support yet for labels.""" + for boxes in boxes_: + extent = self.bounds[2:] - self.bounds[:2] + factor = np.r_[self.image.shape[:2]][::-1] / extent + corners = (box3d_to_bev_corners(boxes) - self.bounds[:2]) * factor + line_loc = corners[:, [1, 2, 2, 3, 3, 0]].reshape(-1, 4) + text_loc = corners[:, 0] + 0.3 + self.draw_lines(line_loc) + self.draw_text(text_loc, labels=[]) diff --git a/pvrcnn/core/box_encode.py b/pvrcnn/core/box_encode.py new file mode 100644 index 0000000..44ba1b3 --- /dev/null +++ b/pvrcnn/core/box_encode.py @@ -0,0 +1,36 @@ +import torch +import math + + +def _anchor_diagonal(A_wlh): + """Reference: VoxelNet.""" + A_wl, A_h = A_wlh.split([2, 1], -1) + A_norm = A_wl.norm(dim=-1, keepdim=True).expand_as(A_wl) + A_norm = torch.cat((A_norm, A_h), dim=-1) + return A_norm + + +def decode(deltas, anchors): + """Both inputs of shape (*, 7).""" + P_xyz, P_wlh, P_yaw = deltas.split([3, 3, 1], -1) + A_xyz, A_wlh, A_yaw = anchors.split([3, 3, 1], -1) + A_norm = _anchor_diagonal(A_wlh) + boxes = torch.cat(( + (P_xyz * A_norm + A_xyz), + (P_wlh.exp() * A_wlh), + (P_yaw + A_yaw)), dim=-1 + ) + return boxes + + +def encode(boxes, anchors): + """Both inputs of shape (*, 7).""" + G_xyz, G_wlh, G_yaw = boxes.split([3, 3, 1], -1) + A_xyz, A_wlh, A_yaw = anchors.split([3, 3, 1], -1) + A_norm = _anchor_diagonal(A_wlh) + deltas = torch.cat(( + (G_xyz - A_xyz) / A_norm, + (G_wlh / A_wlh).log(), + (G_yaw - A_yaw) % math.pi), dim=-1 + ) + return deltas diff --git a/pvrcnn/core/config.py b/pvrcnn/core/config.py index b9bfb08..c865dcd 100644 --- a/pvrcnn/core/config.py +++ b/pvrcnn/core/config.py @@ -25,6 +25,7 @@ 'wlh': [1.6, 3.9, 1.56], 'yaw': [0, np.pi / 2], 'iou_thresh': [0.45, 0.60], + 'score_thresh': 0.3, 'center_z': -1.0, }, { @@ -32,6 +33,7 @@ 'wlh': [0.6, 0.8, 1.73], 'yaw': [0, np.pi / 2], 'iou_thresh': [0.20, 0.35], + 'score_thresh': 0.3, 'center_z': -0.6, }, { @@ -39,6 +41,7 @@ 'wlh': [0.6, 1.76, 1.73], 'yaw': [0, np.pi / 2], 'iou_thresh': [0.20, 0.35], + 'score_thresh': 0.3, 'center_z': -0.6, }, ] diff --git a/pvrcnn/core/geometry.py b/pvrcnn/core/geometry.py new file mode 100644 index 0000000..5aa1028 --- /dev/null +++ b/pvrcnn/core/geometry.py @@ -0,0 +1,65 @@ +import numpy as np + + +def points_in_convex_polygon(points, polygon, ccw=True): + """points (N, 2) | polygon (M, V, 2) | mask (N, M)""" + polygon_roll = np.roll(polygon, shift=1, axis=1) + polygon_side = (-1) ** ccw * (polygon - polygon_roll)[None] + vertex_to_point = polygon[None] - points[:, None, None] + mask = (np.cross(polygon_side, vertex_to_point) > 0).all(2) + return mask + + +def box3d_to_bev_corners(boxes): + """ + :boxes np.ndarray shape (N, 7) + :corners np.ndarray shape (N, 4, 2) (ccw) + """ + xy, _, wl, _, yaw = np.split(boxes, [2, 3, 5, 6], 1) + c, s = np.cos(yaw), np.sin(yaw) + R = np.stack([c, -s, s, c], -1).reshape(-1, 2, 2) + corners = 0.5 * np.r_[-1, -1, +1, -1, +1, +1, -1, +1] + corners = (wl[:, None] * corners.reshape(4, 2)) + corners = np.einsum('ijk,imk->imj', R, corners) + xy[:, None] + return corners + + +class PointsInCuboids: + """Takes ~10ms for each scene.""" + + def __init__(self, points): + self.points = points + + def _height_threshold(self, boxes): + """Filter to z slice.""" + z1 = self.points[:, None, 2] + z2, h = boxes[:, [2, 5]].T + mask = (z1 > z2 - h / 2) & (z1 < z2 + h / 2) + return mask + + def _get_mask(self, boxes): + polygons = box3d_to_bev_corners(boxes) + mask = self._height_threshold(boxes) + mask &= points_in_convex_polygon( + self.points[:, :2], polygons) + return mask + + def __call__(self, boxes): + """Return list of points in each box.""" + mask = self._get_mask(boxes).T + points = list(map(self.points.__getitem__, mask)) + return points + + +class PointsNotInRectangles(PointsInCuboids): + + def _get_mask(self, boxes): + polygons = box3d_to_bev_corners(boxes) + mask = points_in_convex_polygon( + self.points[:, :2], polygons) + return mask + + def __call__(self, boxes): + """Return array of points not in any box.""" + mask = ~self._get_mask(boxes).any(1) + return self.points[mask] diff --git a/pvrcnn/core/proposal_targets.py b/pvrcnn/core/proposal_targets.py index b63cad3..cc328e4 100644 --- a/pvrcnn/core/proposal_targets.py +++ b/pvrcnn/core/proposal_targets.py @@ -4,6 +4,7 @@ from pvrcnn.ops import box_iou_rotated, Matcher from .anchor_generator import AnchorGenerator +from .box_encode import encode class ProposalTargetAssigner(nn.Module): @@ -34,34 +35,17 @@ def compute_iou(self, boxes, anchors): def get_cls_targets(self, G_cls): """ - 1. Disable ambiguous (matched to multiple classes). - 2. Replace ignore marker (-1) with binary mask. + Clamps ignore to 0 and represents with binary mask. + Note: allows anchor to be matched to multiple classes. """ - ambiguous = G_cls.eq(1).int().sum(0) > 1 - G_cls[:, ambiguous] = -1 M_cls = G_cls.ne(-1) G_cls = G_cls.clamp_(min=0) return G_cls, M_cls - def _encode_diagonal(self, A_wlh): - A_wl, A_h = A_wlh.split([2, 1], -1) - A_norm = A_wl.norm(dim=-1, keepdim=True).expand(-1, 2) - A_norm = torch.cat((A_norm, A_h), -1) - return A_norm - def get_reg_targets(self, boxes, box_idx, G_cls): """Standard VoxelNet-style box encoding.""" M_reg = G_cls == 1 - A = self.anchors[M_reg] - G = boxes[box_idx[M_reg]].cuda() - G_xyz, G_wlh, G_yaw = G.split([3, 3, 1], -1) - A_xyz, A_wlh, A_yaw = A.split([3, 3, 1], -1) - A_norm = self._encode_diagonal(A_wlh) - G_reg = torch.cat(( - (G_xyz - A_xyz) / A_norm, - (G_wlh / A_wlh).log(), - (G_yaw - A_yaw) % math.pi), dim=-1 - ) + G_reg = encode(boxes[box_idx[M_reg]], self.anchors[M_reg]) M_reg = M_reg.unsqueeze(-1) G_reg = torch.zeros_like(self.anchors).masked_scatter_(M_reg, G_reg) return G_reg, M_reg @@ -82,7 +66,7 @@ def apply_ignore_mask(self, matches, labels, box_ignore): def match_all_classes(self, boxes, class_idx, box_ignore): """Match boxes to anchors based on IOU.""" - full_idx = torch.arange(boxes.shape[0]) + full_idx = torch.arange(boxes.shape[0], device=boxes.device) classes = range(self.cfg.NUM_CLASSES) matches, labels = zip(*[self.match_class_i( boxes, class_idx, full_idx, i) for i in classes]) @@ -90,9 +74,15 @@ def match_all_classes(self, boxes, class_idx, box_ignore): labels = torch.stack(labels).view(self.anchors.shape[:-1]) return matches, labels + def to_device(self, item): + """Move items to anchors.device for fast rotated IOU.""" + keys = ['boxes', 'class_idx', 'box_ignore'] + items = [item[k].to(self.anchors.device) for k in keys] + return items + def forward(self, item): - box_idx, G_cls = self.match_all_classes( - item['boxes'].cuda(), item['class_idx'], item['box_ignore']) + boxes, class_idx, box_ignore = self.to_device(item) + box_idx, G_cls = self.match_all_classes(boxes, class_idx, box_ignore) G_cls, M_cls = self.get_cls_targets(G_cls) - G_reg, M_reg = self.get_reg_targets(item['boxes'], box_idx, G_cls) + G_reg, M_reg = self.get_reg_targets(boxes, box_idx, G_cls) item.update(dict(G_cls=G_cls, G_reg=G_reg, M_cls=M_cls, M_reg=M_reg)) diff --git a/pvrcnn/dataset/augmentation.py b/pvrcnn/dataset/augmentation.py index cca0e14..de99c26 100644 --- a/pvrcnn/dataset/augmentation.py +++ b/pvrcnn/dataset/augmentation.py @@ -2,10 +2,17 @@ import os.path as osp import numpy as np import torch +from tqdm import tqdm +from collections import defaultdict from itertools import compress -from .database_sampler import PointsNotInRectangles +from .kitti_utils import read_velo from pvrcnn.ops import box_iou_rotated +from pvrcnn.core.geometry import ( + points_in_convex_polygon, + PointsNotInRectangles, + PointsInCuboids, +) class Augmentation: @@ -107,6 +114,7 @@ def __call__(self, points, boxes): class SampleAugmentation(Augmentation): + """Pastes samples from database into scene.""" def __init__(self, cfg): super(SampleAugmentation, self).__init__(cfg) @@ -188,3 +196,48 @@ def __call__(self, points, boxes, class_idx): points, boxes, class_idx = self.cat_samples( samples, points, boxes, class_idx) return points, boxes, class_idx + + +class DatabaseBuilder: + """Builds cached database for SampleAugmentation.""" + + def __init__(self, cfg, annotations): + self.cfg = cfg + self.fpath = osp.join(cfg.DATA.CACHEDIR, 'database.pkl') + if osp.isfile(self.fpath): + print(f'Found cached database: {self.fpath}') + return + self._build(annotations) + + def _build(self, annotations): + database = defaultdict(list) + for item in tqdm(annotations.values(), desc='Building database'): + for key, val in zip(*self._process_item(item)): + database[key] += [val] + self._save_database(dict(database)) + + def _demean(self, points, boxes): + """Subtract box center (birds eye view).""" + _points, _boxes = [], [] + for points_i, box_i in zip(points, boxes): + center, zwlhr = np.split(box_i, [2]) + xy, zi = np.split(points_i, [2], 1) + _points += [np.concatenate((xy - center, zi), 1)] + _boxes += [np.concatenate((0 * center, zwlhr))] + return _points, _boxes + + def _process_item(self, item): + """Retrieve points in each box in scene.""" + points = read_velo(item['velo_path']) + class_idx, boxes = item['class_idx'], item['boxes'] + points = PointsInCuboids(points)(boxes) + keep = [len(p) > self.cfg.AUG.MIN_NUM_SAMPLE_PTS for p in points] + class_idx, points, boxes = [ + compress(t, keep) for t in (class_idx, points, boxes)] + points, boxes = self._demean(points, boxes) + samples = [dict(points=p, box=b) for (p, b) in zip(points, boxes)] + return class_idx, samples + + def _save_database(self, database): + with open(self.fpath, 'wb') as f: + pickle.dump(database, f) diff --git a/pvrcnn/dataset/database_sampler.py b/pvrcnn/dataset/database_sampler.py deleted file mode 100644 index d5e4625..0000000 --- a/pvrcnn/dataset/database_sampler.py +++ /dev/null @@ -1,116 +0,0 @@ -import pickle -import itertools -import numpy as np -import os.path as osp -from tqdm import tqdm -from collections import defaultdict - -from .kitti_utils import read_velo - - -def points_in_convex_polygon(points, polygon, ccw=True): - """points (N, 2) | polygon (M, V, 2) | mask (N, M)""" - polygon_roll = np.roll(polygon, shift=1, axis=1) - polygon_side = (-1) ** ccw * (polygon - polygon_roll)[None] - vertex_to_point = polygon[None] - points[:, None, None] - mask = (np.cross(polygon_side, vertex_to_point) > 0).all(2) - return mask - - -def center_to_corner_box2d(boxes): - """ - :boxes np.ndarray shape (N, 7) - :corners np.ndarray shape (N, 4, 2) (counter-clockwise) - """ - xy, _, wl, _, yaw = np.split(boxes, [2, 3, 5, 6], 1) - c, s = np.cos(yaw), np.sin(yaw) - R = np.stack([c, -s, s, c], -1).reshape(-1, 2, 2) - corners = 0.5 * np.r_[-1, -1, +1, -1, +1, +1, -1, +1] - corners = (wl[:, None] * corners.reshape(4, 2)) - corners = np.einsum('ijk,imk->imj', R, corners) + xy[:, None] - return corners - - -class PointsInCuboids: - """Takes ~10ms for each scene.""" - - def __init__(self, points): - self.points = points - - def _height_threshold(self, boxes): - """Filter to z slice.""" - z1 = self.points[:, None, 2] - z2, h = boxes[:, [2, 5]].T - mask = (z1 > z2 - h / 2) & (z1 < z2 + h / 2) - return mask - - def _get_mask(self, boxes): - polygons = center_to_corner_box2d(boxes) - mask = self._height_threshold(boxes) - mask &= points_in_convex_polygon( - self.points[:, :2], polygons) - return mask - - def __call__(self, boxes): - """Return list of points in each box.""" - mask = self._get_mask(boxes).T - points = list(map(self.points.__getitem__, mask)) - return points - - -class PointsNotInRectangles(PointsInCuboids): - - def _get_mask(self, boxes): - polygons = center_to_corner_box2d(boxes) - mask = points_in_convex_polygon( - self.points[:, :2], polygons) - return mask - - def __call__(self, boxes): - """Return array of points not in any box.""" - mask = ~self._get_mask(boxes).any(1) - return self.points[mask] - - -class DatabaseBuilder: - - def __init__(self, cfg, annotations): - self.cfg = cfg - self.fpath = osp.join(cfg.DATA.CACHEDIR, 'database.pkl') - if osp.isfile(self.fpath): - print(f'Found cached database: {self.fpath}') - return - self._build(annotations) - - def _build(self, annotations): - database = defaultdict(list) - for item in tqdm(annotations.values(), desc='Building database'): - for key, val in zip(*self._process_item(item)): - database[key] += [val] - self._save_database(dict(database)) - - def _demean(self, points, boxes): - """Subtract box center (birds eye view).""" - _points, _boxes = [], [] - for points_i, box_i in zip(points, boxes): - center, zwlhr = np.split(box_i, [2]) - xy, zi = np.split(points_i, [2], 1) - _points += [np.concatenate((xy - center, zi), 1)] - _boxes += [np.concatenate((0 * center, zwlhr))] - return _points, _boxes - - def _process_item(self, item): - """Retrieve points in each box in scene.""" - points = read_velo(item['velo_path']) - class_idx, boxes = item['class_idx'], item['boxes'] - points = PointsInCuboids(points)(boxes) - keep = [len(p) > self.cfg.AUG.MIN_NUM_SAMPLE_PTS for p in points] - class_idx, points, boxes = [ - itertools.compress(t, keep) for t in (class_idx, points, boxes)] - points, boxes = self._demean(points, boxes) - samples = [dict(points=p, box=b) for (p, b) in zip(points, boxes)] - return class_idx, samples - - def _save_database(self, database): - with open(self.fpath, 'wb') as f: - pickle.dump(database, f) diff --git a/pvrcnn/dataset/kitti_dataset.py b/pvrcnn/dataset/kitti_dataset.py index f0048e2..d9ffeb9 100644 --- a/pvrcnn/dataset/kitti_dataset.py +++ b/pvrcnn/dataset/kitti_dataset.py @@ -1,67 +1,78 @@ -from tqdm import tqdm import pickle import numpy as np import torch import os +from tqdm import tqdm from copy import deepcopy import os.path as osp from torch.utils.data import Dataset from pvrcnn.core import ProposalTargetAssigner -from .kitti_utils import read_calib, read_label, read_velo -from .augmentation import ChainedAugmentation -from .database_sampler import DatabaseBuilder +from .kitti_utils import read_calib, read_label, read_velo, filter_camera_fov +from .augmentation import ChainedAugmentation, DatabaseBuilder -class KittiDataset(Dataset): +class AnnotationLoader: + """Load annotations if exist, else create.""" - def __init__(self, cfg, split='val'): - super(KittiDataset, self).__init__() - self.cfg = cfg - self.split = split - self.load_annotations(cfg) + def __init__(self, cfg, inds, split='val'): + super(AnnotationLoader, self).__init__() + self.CACHEDIR = cfg.DATA.CACHEDIR + self.ROOTDIR = cfg.DATA.ROOTDIR + self.inds, self.split = inds, split + self.load_annotations() + if split == 'train': + DatabaseBuilder(cfg, self.annotations) - def __len__(self): - return len(self.inds) - - def read_splitfile(self, cfg): - fpath = osp.join(cfg.DATA.SPLITDIR, f'{self.split}.txt') - self.inds = np.loadtxt(fpath, dtype=np.int32).tolist() - - def read_cached_annotations(self, cfg): - fpath = osp.join(cfg.DATA.CACHEDIR, f'{self.split}.pkl') + def read_cached_annotations(self): + fpath = osp.join(self.CACHEDIR, f'{self.split}.pkl') + print(f'Loading cached annotations: {fpath}') with open(fpath, 'rb') as f: self.annotations = pickle.load(f) - print(f'Found cached annotations: {fpath}') - def cache_annotations(self, cfg): - fpath = osp.join(cfg.DATA.CACHEDIR, f'{self.split}.pkl') + def cache_annotations(self): + fpath = osp.join(self.CACHEDIR, f'{self.split}.pkl') + print(f'Caching annotations: {fpath}') with open(fpath, 'wb') as f: pickle.dump(self.annotations, f) - def load_annotations(self, cfg): - self.read_splitfile(cfg) + def load_annotations(self): try: - self.read_cached_annotations(cfg) + self.read_cached_annotations() except FileNotFoundError: - os.makedirs(cfg.DATA.CACHEDIR, exist_ok=True) + os.makedirs(self.CACHEDIR, exist_ok=True) self.create_annotations() - self.cache_annotations(cfg) - - def _path_helper(self, folder, idx, suffix): - return osp.join(self.cfg.DATA.ROOTDIR, folder, f'{idx:06d}.{suffix}') + self.crop_points() + self.cache_annotations() + + def crop_points(self): + """Limit points to camera FOV (KITTI-specific).""" + dir_new = osp.join(self.ROOTDIR, 'velodyne_reduced') + if osp.isdir(dir_new): + return print(f'Found existing reduced points: {dir_new}') + os.makedirs(dir_new, exist_ok=False) + for anno in tqdm(self.annotations.values(), desc='Filtering points'): + basename = osp.basename(anno['velo_path']) + path_old = osp.join(self.ROOTDIR, 'velodyne', basename) + points = filter_camera_fov(anno['calib'], read_velo(path_old)) + points.astype(np.float32).tofile(osp.join(dir_new, basename)) + + def _path_helper(self, fdir, idx, end): + fpath = osp.join(self.ROOTDIR, fdir, f'{idx:06d}.{end}') + return fpath def create_annotations(self): self.annotations = dict() - for idx in tqdm(self.inds, desc='Generating annotations'): + for idx in tqdm(self.inds, desc='Creating annotations'): item = dict( velo_path=self._path_helper('velodyne_reduced', idx, 'bin'), objects=read_label(self._path_helper('label_2', idx, 'txt')), calib=read_calib(self._path_helper('calib', idx, 'txt')), idx=idx, ) - self.annotations[idx] = self.numpify_objects(item) + self.numpify_objects(item) + self.annotations[idx] = item - def numpify_object(self, obj, calib): + def _numpify_object(self, obj, calib): """Converts from camera to velodyne frame.""" xyz = calib.C2V @ np.r_[calib.R0 @ obj.t, 1] box = np.r_[xyz, obj.w, obj.l, obj.h, -obj.ry] @@ -69,11 +80,29 @@ def numpify_object(self, obj, calib): return obj def numpify_objects(self, item): - objects = [self.numpify_object( + objects = [self._numpify_object( obj, item['calib']) for obj in item['objects']] item['boxes'] = np.stack([obj['box'] for obj in objects]) item['class_idx'] = np.r_[[obj['class_idx'] for obj in objects]] - return item + item.pop('objects') + + +class KittiDataset(Dataset): + + def __init__(self, cfg, split='val'): + super(KittiDataset, self).__init__() + self.cfg = cfg + self.split = split + self.load_annotations(cfg) + + def __len__(self): + return len(self.inds) + + def load_annotations(self, cfg): + fpath = osp.join(cfg.DATA.SPLITDIR, f'{self.split}.txt') + self.inds = np.loadtxt(fpath, dtype=np.int32).tolist() + loader = AnnotationLoader(cfg, self.inds, self.split) + self.annotations = loader.annotations def filter_bad_objects(self, item): class_idx = item['class_idx'][:, None] @@ -88,16 +117,16 @@ def filter_out_of_bounds(self, item): keep = ((xyz >= lower) & (xyz <= upper)).all(1) item['boxes'] = item['boxes'][keep] item['class_idx'] = item['class_idx'][keep] - item['box_ignore'] = np.full(keep.sum(), False) def to_torch(self, item): item['points'] = np.float32(item['points']) item['boxes'] = torch.FloatTensor(item['boxes']) item['class_idx'] = torch.LongTensor(item['class_idx']) - item['box_ignore'] = torch.BoolTensor(item['box_ignore']) + item['box_ignore'] = torch.full_like( + item['class_idx'], False, dtype=torch.bool) def drop_keys(self, item): - for key in ['velo_path', 'objects', 'calib']: + for key in ['velo_path', 'calib']: item.pop(key) def preprocessing(self, item): @@ -117,7 +146,6 @@ class KittiDatasetTrain(KittiDataset): def __init__(self, cfg): super(KittiDatasetTrain, self).__init__(cfg, split='train') - DatabaseBuilder(cfg, self.annotations) self.augmentation = ChainedAugmentation(cfg) self.target_assigner = ProposalTargetAssigner(cfg) diff --git a/pvrcnn/dataset/kitti_utils.py b/pvrcnn/dataset/kitti_utils.py index 8371fc1..a7f3c03 100644 --- a/pvrcnn/dataset/kitti_utils.py +++ b/pvrcnn/dataset/kitti_utils.py @@ -25,6 +25,7 @@ """ import numpy as np +from collections import namedtuple def read_label(label_filename): @@ -40,7 +41,21 @@ def read_velo(velo_filename): def read_calib(calib_filename): - return Calibration(calib_filename) + """Return calib as named tuple.""" + calib = CalibObject(calib_filename).astuple + return calib + + +def filter_camera_fov(calib, points_): + """Takes ~3.5 ms in KITTI.""" + keep = points_[:, 0] > 0 + p = points_[keep, :3] + ones = np.ones_like(p[:, 0:1]) + p = (calib.R0 @ calib.V2C) @ np.c_[p, ones].T + p = calib.P2 @ np.r_[p, ones.T] + p = (p / p[2:3])[:2].T + keep[keep] &= ((p >= 0) & (p <= calib.WH)).all(1) + return points_[keep] class Object3d: @@ -100,25 +115,31 @@ def get_obj_level(self): return 4 -class Calibration(object): +keys = ['V2C', 'C2V', 'R0', 'P2', 'WH'] +Calib = namedtuple('Calib', keys) + + +class CalibObject: """ 3d XYZ in