Skip to content

Commit

Permalink
[Feature] Support iou-based tracking (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tau-J committed Feb 8, 2024
1 parent 224b7c7 commit f688151
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 12 deletions.
128 changes: 118 additions & 10 deletions rtmlib/tools/solution/pose_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,41 @@
cv2.imshow('img', img_show)
cv2.waitKey(10)
'''
import warnings

import numpy as np


def compute_iou(bboxA, bboxB):
"""Compute the Intersection over Union (IoU) between two boxes .
Args:
bboxA (list): The first bbox info (left, top, right, bottom, score).
bboxB (list): The second bbox info (left, top, right, bottom, score).
Returns:
float: The IoU value.
"""

x1 = max(bboxA[0], bboxB[0])
y1 = max(bboxA[1], bboxB[1])
x2 = min(bboxA[2], bboxB[2])
y2 = min(bboxA[3], bboxB[3])

inter_area = max(0, x2 - x1) * max(0, y2 - y1)

bboxA_area = (bboxA[2] - bboxA[0]) * (bboxA[3] - bboxA[1])
bboxB_area = (bboxB[2] - bboxB[0]) * (bboxB[3] - bboxB[1])
union_area = float(bboxA_area + bboxB_area - inter_area)
if union_area == 0:
union_area = 1e-5
warnings.warn('union_area=0 is unexpected')

iou = inter_area / union_area

return iou


def pose_to_bbox(keypoints: np.ndarray, expansion: float = 1.25) -> np.ndarray:
"""Get bounding box from keypoints.
Expand Down Expand Up @@ -74,15 +106,18 @@ class PoseTracker:
backend (str): Backend of pose estimation model.
device (str): Device of pose estimation model.
"""
MIN_AREA = 1000

def __init__(self,
solution: type,
det_frequency: int = 1,
tracking: bool = True,
tracking_thr: float = 0.3,
mode: str = 'balanced',
to_openpose: bool = False,
backend: str = 'onnxruntime',
device: str = 'cpu'):
print('pose tracker', backend, device)

model = solution(mode=mode,
to_openpose=to_openpose,
backend=backend,
Expand All @@ -92,27 +127,100 @@ def __init__(self,
self.pose_model = model.pose_model

self.det_frequency = det_frequency
self.tracking = tracking
self.tracking_thr = tracking_thr
self.reset()

if self.tracking:
print('Tracking is on, you can get higher FPS by turning it off:'
'`PoseTracker(tracking=False)`')

def reset(self):
"""Reset pose tracker."""
self.cnt = 0
self.instance_list = []
self.frame_cnt = 0
self.next_id = 0
self.bboxes_last_frame = []
self.track_ids_last_frame = []

def __call__(self, image: np.ndarray):

if self.cnt % self.det_frequency == 0:
if self.frame_cnt % self.det_frequency == 0:
bboxes = self.det_model(image)
else:
bboxes = self.instance_list
bboxes = self.bboxes_last_frame

keypoints, scores = self.pose_model(image, bboxes=bboxes)

instances = []
for kpts in keypoints:
instances.append(pose_to_bbox(kpts))
if not self.tracking:
# without tracking

bboxes_current_frame = []
for kpts in keypoints:
bbox = pose_to_bbox(kpts)
bboxes_current_frame.append(bbox)
else:
# with tracking

self.instance_list = instances
self.cnt += 1
if len(self.track_ids_last_frame) == 0:
self.next_id = len(self.bboxes_last_frame)
self.track_ids_last_frame = list(range(self.next_id))

bboxes_current_frame = []
track_ids_current_frame = []
for kpts in keypoints:
bbox = pose_to_bbox(kpts)

track_id, _ = self.track_by_iou(bbox)

if track_id > -1:
track_ids_current_frame.append(track_id)
bboxes_current_frame.append(bbox)

self.track_ids_last_frame = track_ids_current_frame

self.bboxes_last_frame = bboxes_current_frame
self.frame_cnt += 1

return keypoints, scores

def track_by_iou(self, bbox):
"""Get track id using IoU tracking greedily.
Args:
bbox (list): The bbox info (left, top, right, bottom, score).
next_id (int): The next track id.
Returns:
track_id (int): The track id.
match_result (list): The matched bbox.
next_id (int): The updated next track id.
"""

area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])

max_iou_score = -1
max_index = -1
match_result = None
for index, each_bbox in enumerate(self.bboxes_last_frame):

iou_score = compute_iou(bbox, each_bbox)
if iou_score > max_iou_score:
max_iou_score = iou_score
max_index = index

if max_iou_score > self.tracking_thr:
# if the bbox has a match and the IoU is larger than threshold
track_id = self.track_ids_last_frame.pop(max_index)
match_result = self.bboxes_last_frame.pop(max_index)

elif area >= self.MIN_AREA:
# no match, but the bbox is large enough,
# assign a new track id
track_id = self.next_id
self.next_id += 1

else:
# if the bbox is too small, ignore it
track_id = -1

return track_id, match_result
4 changes: 2 additions & 2 deletions wholebody_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

wholebody = PoseTracker(
Wholebody,
det_frequency=100,
det_frequency=7,
to_openpose=openpose_skeleton,
mode='performance', # balanced, performance, lightweight
backend=backend,
Expand All @@ -38,7 +38,7 @@

# if you want to use black background instead of original image,
# img_show = np.zeros(img_show.shape, dtype=np.uint8)
print(scores)
# print(scores)
img_show = draw_skeleton(img_show,
keypoints,
scores,
Expand Down

0 comments on commit f688151

Please sign in to comment.