1.Accelerate the association step.

2.Provide more trained models with different input resoulution.
This commit is contained in:
Zhongdao 2020-01-29 21:45:07 +08:00
parent 7216bcaadf
commit c40826179b
12 changed files with 994 additions and 166 deletions

28
cfg/yolov3.cfg → cfg/yolov3_1088x608.cfg Executable file → Normal file
View File

@ -1,26 +1,10 @@
[net] [net]
# Testing
#batch=1
#subdivisions=1
# Training
batch=16 batch=16
subdivisions=1 subdivisions=1
width=608 width=1088
height=1088 height=608
embedding_dim=512
channels=3 channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 500200
policy=steps
steps=400000,450000
scales=.1,.1
[convolutional] [convolutional]
batch_normalize=1 batch_normalize=1
@ -611,7 +595,7 @@ layers = -3
size=3 size=3
stride=1 stride=1
pad=1 pad=1
filters=512 filters=$embedding_dim
activation=linear activation=linear
[route] [route]
@ -712,7 +696,7 @@ layers = -3
size=3 size=3
stride=1 stride=1
pad=1 pad=1
filters=512 filters=$embedding_dim
activation=linear activation=linear
[route] [route]
@ -815,7 +799,7 @@ layers = -3
size=3 size=3
stride=1 stride=1
pad=1 pad=1
filters=512 filters=$embedding_dim
activation=linear activation=linear
[route] [route]

817
cfg/yolov3_576x320.cfg Normal file
View File

@ -0,0 +1,817 @@
[net]
batch=16
subdivisions=1
width= 576
height=320
embedding_dim=512
channels=3
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
# Downsample
[convolutional]
batch_normalize=1
filters=64
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=32
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=128
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=256
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=512
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
######################
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=24
activation=linear
######### embedding ###########
[route]
layers = -3
[convolutional]
size=3
stride=1
pad=1
filters=$embedding_dim
activation=linear
[route]
layers = -3, -1
###############################
[yolo]
mask = 8,9,10,11
anchors = 6,16, 8,23, 11,32, 16,45, 21,64, 30,90, 43,128, 60,180, 85,255, 120,360, 170,420, 340, 320
classes=1
num=12
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
[route]
layers = -7
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 61
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=24
activation=linear
######### embedding ###########
[route]
layers = -3
[convolutional]
size=3
stride=1
pad=1
filters=$embedding_dim
activation=linear
[route]
layers = -3, -1
###############################
[yolo]
mask = 4,5,6,7
anchors = 6,16, 8,23, 11,32, 16,45, 21,64, 30,90, 43,128, 60,180, 85,255, 120,320, 170,320, 340,320
classes=1
num=12
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
[route]
layers = -7
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 36
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=24
activation=linear
######### embedding ###########
[route]
layers = -3
[convolutional]
size=3
stride=1
pad=1
filters=$embedding_dim
activation=linear
[route]
layers = -3, -1
###############################
[yolo]
mask = 0,1,2,3
anchors = 6,16, 8,23, 11,32, 16,45, 21,64, 30,90, 43,128, 60,180, 85,255, 120,320, 170,320, 340,320
classes=1
num=12
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1

View File

@ -1,26 +1,10 @@
[net] [net]
# Testing
#batch=1
#subdivisions=1
# Training
batch=16 batch=16
subdivisions=1 subdivisions=1
width=480 width=864
height=864 height=480
embedding_dim=512
channels=3 channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 500200
policy=steps
steps=400000,450000
scales=.1,.1
[convolutional] [convolutional]
batch_normalize=1 batch_normalize=1
@ -611,7 +595,7 @@ layers = -3
size=3 size=3
stride=1 stride=1
pad=1 pad=1
filters=512 filters=$embedding_dim
activation=linear activation=linear
[route] [route]
@ -712,7 +696,7 @@ layers = -3
size=3 size=3
stride=1 stride=1
pad=1 pad=1
filters=512 filters=$embedding_dim
activation=linear activation=linear
[route] [route]
@ -815,7 +799,7 @@ layers = -3
size=3 size=3
stride=1 stride=1
pad=1 pad=1
filters=512 filters=$embedding_dim
activation=linear activation=linear
[route] [route]

View File

@ -74,7 +74,8 @@ def create_modules(module_defs):
nC = int(module_def['classes']) # number of classes nC = int(module_def['classes']) # number of classes
img_size = (int(hyperparams['width']),int(hyperparams['height'])) img_size = (int(hyperparams['width']),int(hyperparams['height']))
# Define detection layer # Define detection layer
yolo_layer = YOLOLayer(anchors, nC, hyperparams['nID'], img_size, yolo_layer_count, cfg=hyperparams['cfg']) yolo_layer = YOLOLayer(anchors, nC, int(hyperparams['nID']),
int(hyperparams['embedding_dim']), img_size, yolo_layer_count)
modules.add_module('yolo_%d' % i, yolo_layer) modules.add_module('yolo_%d' % i, yolo_layer)
yolo_layer_count += 1 yolo_layer_count += 1
@ -108,7 +109,7 @@ class Upsample(nn.Module):
class YOLOLayer(nn.Module): class YOLOLayer(nn.Module):
def __init__(self, anchors, nC, nID, img_size, yolo_layer, cfg): def __init__(self, anchors, nC, nID, nE, img_size, yolo_layer):
super(YOLOLayer, self).__init__() super(YOLOLayer, self).__init__()
self.layer = yolo_layer self.layer = yolo_layer
nA = len(anchors) nA = len(anchors)
@ -117,7 +118,7 @@ class YOLOLayer(nn.Module):
self.nC = nC # number of classes (80) self.nC = nC # number of classes (80)
self.nID = nID # number of identities self.nID = nID # number of identities
self.img_size = 0 self.img_size = 0
self.emb_dim = 512 self.emb_dim = nE
self.shift = [1, 3, 5] self.shift = [1, 3, 5]
self.SmoothL1Loss = nn.SmoothL1Loss() self.SmoothL1Loss = nn.SmoothL1Loss()
@ -127,7 +128,9 @@ class YOLOLayer(nn.Module):
self.s_c = nn.Parameter(-4.15*torch.ones(1)) # -4.15 self.s_c = nn.Parameter(-4.15*torch.ones(1)) # -4.15
self.s_r = nn.Parameter(-4.85*torch.ones(1)) # -4.85 self.s_r = nn.Parameter(-4.85*torch.ones(1)) # -4.85
self.s_id = nn.Parameter(-2.3*torch.ones(1)) # -2.3 self.s_id = nn.Parameter(-2.3*torch.ones(1)) # -2.3
self.emb_scale = math.sqrt(2) * math.log(self.nID-1)
self.emb_scale = math.sqrt(2) * math.log(self.nID-1) if self.nID>1 else 1
def forward(self, p_cat, img_size, targets=None, classifier=None, test_emb=False): def forward(self, p_cat, img_size, targets=None, classifier=None, test_emb=False):
@ -178,7 +181,7 @@ class YOLOLayer(nn.Module):
if test_emb: if test_emb:
if np.prod(embedding.shape)==0 or np.prod(tids.shape) == 0: if np.prod(embedding.shape)==0 or np.prod(tids.shape) == 0:
return torch.zeros(0, self. emb_dim+1).cuda() return torch.zeros(0, self.emb_dim+1).cuda()
emb_and_gt = torch.cat([embedding, tids.float()], dim=1) emb_and_gt = torch.cat([embedding, tids.float()], dim=1)
return emb_and_gt return emb_and_gt
@ -210,21 +213,23 @@ class YOLOLayer(nn.Module):
class Darknet(nn.Module): class Darknet(nn.Module):
"""YOLOv3 object detection model""" """YOLOv3 object detection model"""
def __init__(self, cfg_path, img_size=(1088, 608), nID=1591, test_emb=False): def __init__(self, cfg_dict, nID=0, test_emb=False):
super(Darknet, self).__init__() super(Darknet, self).__init__()
if isinstance(cfg_dict, str):
self.module_defs = parse_model_cfg(cfg_path) cfg_dict = parse_model_cfg(cfg_dict)
self.module_defs[0]['cfg'] = cfg_path self.module_defs = cfg_dict
self.module_defs[0]['nID'] = nID self.module_defs[0]['nID'] = nID
self.img_size = [int(self.module_defs[0]['width']), int(self.module_defs[0]['height'])]
self.emb_dim = int(self.module_defs[0]['embedding_dim'])
self.hyperparams, self.module_list = create_modules(self.module_defs) self.hyperparams, self.module_list = create_modules(self.module_defs)
self.img_size = img_size
self.loss_names = ['loss', 'box', 'conf', 'id', 'nT'] self.loss_names = ['loss', 'box', 'conf', 'id', 'nT']
self.losses = OrderedDict() self.losses = OrderedDict()
for ln in self.loss_names: for ln in self.loss_names:
self.losses[ln] = 0 self.losses[ln] = 0
self.emb_dim = 512 self.test_emb = test_emb
self.classifier = nn.Linear(self.emb_dim, nID)
self.test_emb=test_emb self.classifier = nn.Linear(self.emb_dim, nID) if nID>0 else None
def forward(self, x, targets=None, targets_len=None): def forward(self, x, targets=None, targets_len=None):
@ -256,7 +261,8 @@ class Darknet(nn.Module):
for name, loss in zip(self.loss_names, losses): for name, loss in zip(self.loss_names, losses):
self.losses[name] += loss self.losses[name] += loss
elif self.test_emb: elif self.test_emb:
targets = [targets[i][:int(l)] for i,l in enumerate(targets_len)] if targets is not None:
targets = [targets[i][:int(l)] for i,l in enumerate(targets_len)]
x = module[0](x, self.img_size, targets, self.classifier, self.test_emb) x = module[0](x, self.img_size, targets, self.classifier, self.test_emb)
else: # get detections else: # get detections
x = module[0](x, self.img_size) x = module[0](x, self.img_size)
@ -282,7 +288,8 @@ def shift_tensor_vertically(t, delta):
def create_grids(self, img_size, nGh, nGw): def create_grids(self, img_size, nGh, nGw):
self.stride = img_size[0]/nGw self.stride = img_size[0]/nGw
assert self.stride == img_size[1] / nGh assert self.stride == img_size[1] / nGh, \
"{} v.s. {}/{}".format(self.stride, img_size[1], nGh)
# build xy offsets # build xy offsets
grid_x = torch.arange(nGw).repeat((nGh, 1)).view((1, 1, nGh, nGw)).float() grid_x = torch.arange(nGw).repeat((nGh, 1)).view((1, 1, nGh, nGw)).float()

15
test.py
View File

@ -16,12 +16,10 @@ def test(
data_cfg, data_cfg,
weights, weights,
batch_size=16, batch_size=16,
img_size=416,
iou_thres=0.5, iou_thres=0.5,
conf_thres=0.3, conf_thres=0.3,
nms_thres=0.45, nms_thres=0.45,
print_interval=40, print_interval=40,
nID=14455,
): ):
# Configure run # Configure run
@ -32,9 +30,11 @@ def test(
nC = 1 nC = 1
test_path = data_cfg_dict['test'] test_path = data_cfg_dict['test']
dataset_root = data_cfg_dict['root'] dataset_root = data_cfg_dict['root']
cfg_dict = parse_model_cfg(cfg)
img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
# Initialize model # Initialize model
model = Darknet(cfg, img_size, nID) model = Darknet(cfg_dict, test_emb=False)
# Load weights # Load weights
if weights.endswith('.pt'): # pytorch format if weights.endswith('.pt'): # pytorch format
@ -149,12 +149,10 @@ def test_emb(
data_cfg, data_cfg,
weights, weights,
batch_size=16, batch_size=16,
img_size=416,
iou_thres=0.5, iou_thres=0.5,
conf_thres=0.3, conf_thres=0.3,
nms_thres=0.45, nms_thres=0.45,
print_interval=40, print_interval=40,
nID=14455,
): ):
# Configure run # Configure run
@ -163,9 +161,11 @@ def test_emb(
f.close() f.close()
test_paths = data_cfg_dict['test_emb'] test_paths = data_cfg_dict['test_emb']
dataset_root = data_cfg_dict['root'] dataset_root = data_cfg_dict['root']
cfg_dict = parse_model_cfg(cfg)
img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
# Initialize model # Initialize model
model = Darknet(cfg, img_size, nID, test_emb=True) model = Darknet(cfg_dict, test_emb=True)
# Load weights # Load weights
if weights.endswith('.pt'): # pytorch format if weights.endswith('.pt'): # pytorch format
@ -231,7 +231,6 @@ if __name__ == '__main__':
parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected') parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
parser.add_argument('--conf-thres', type=float, default=0.3, help='object confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.3, help='object confidence threshold')
parser.add_argument('--nms-thres', type=float, default=0.5, help='iou threshold for non-maximum suppression') parser.add_argument('--nms-thres', type=float, default=0.5, help='iou threshold for non-maximum suppression')
parser.add_argument('--img-size', type=int, default=(1088, 608), help='size of each image dimension')
parser.add_argument('--print-interval', type=int, default=10, help='size of each image dimension') parser.add_argument('--print-interval', type=int, default=10, help='size of each image dimension')
parser.add_argument('--test-emb', action='store_true', help='test embedding') parser.add_argument('--test-emb', action='store_true', help='test embedding')
opt = parser.parse_args() opt = parser.parse_args()
@ -244,7 +243,6 @@ if __name__ == '__main__':
opt.data_cfg, opt.data_cfg,
opt.weights, opt.weights,
opt.batch_size, opt.batch_size,
opt.img_size,
opt.iou_thres, opt.iou_thres,
opt.conf_thres, opt.conf_thres,
opt.nms_thres, opt.nms_thres,
@ -256,7 +254,6 @@ if __name__ == '__main__':
opt.data_cfg, opt.data_cfg,
opt.weights, opt.weights,
opt.batch_size, opt.batch_size,
opt.img_size,
opt.iou_thres, opt.iou_thres,
opt.conf_thres, opt.conf_thres,
opt.nms_thres, opt.nms_thres,

View File

@ -5,13 +5,14 @@ import logging
import argparse import argparse
import motmetrics as mm import motmetrics as mm
import torch
from tracker.multitracker import JDETracker from tracker.multitracker import JDETracker
from utils import visualization as vis from utils import visualization as vis
from utils.log import logger from utils.log import logger
from utils.timer import Timer from utils.timer import Timer
from utils.evaluation import Evaluator from utils.evaluation import Evaluator
from utils.parse_config import parse_model_cfg
import utils.datasets as datasets import utils.datasets as datasets
import torch
from utils.utils import * from utils.utils import *
@ -84,6 +85,10 @@ def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',),
mkdir_if_missing(result_root) mkdir_if_missing(result_root)
data_type = 'mot' data_type = 'mot'
# Read config
cfg_dict = parse_model_cfg(opt.cfg)
opt.img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
# run tracking # run tracking
accs = [] accs = []
n_frame = 0 n_frame = 0
@ -134,7 +139,6 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='track.py') parser = argparse.ArgumentParser(prog='track.py')
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path') parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
parser.add_argument('--weights', type=str, default='weights/latest.pt', help='path to weights file') parser.add_argument('--weights', type=str, default='weights/latest.pt', help='path to weights file')
parser.add_argument('--img-size', type=int, default=[1088, 608], nargs='+', help='pixels')
parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected') parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
parser.add_argument('--conf-thres', type=float, default=0.5, help='object confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.5, help='object confidence threshold')
parser.add_argument('--nms-thres', type=float, default=0.4, help='iou threshold for non-maximum suppression') parser.add_argument('--nms-thres', type=float, default=0.4, help='iou threshold for non-maximum suppression')
@ -162,6 +166,8 @@ if __name__ == '__main__':
MOT17-11-SDP MOT17-11-SDP
MOT17-13-SDP MOT17-13-SDP
''' '''
seqs_str = '''MOT17-02-SDP
'''
data_root = '/home/wangzd/datasets/MOT/MOT17/images/train' data_root = '/home/wangzd/datasets/MOT/MOT17/images/train'
else: else:
seqs_str = '''MOT16-01 seqs_str = '''MOT16-01

View File

@ -1,8 +1,10 @@
import cv2 import cv2
import torch
import torch.nn.functional as F
import numpy as np import numpy as np
import scipy import scipy
from scipy.spatial.distance import cdist from scipy.spatial.distance import cdist
from sklearn.utils import linear_assignment_ import lap
from cython_bbox import bbox_overlaps as bbox_ious from cython_bbox import bbox_overlaps as bbox_ious
from utils import kalman_filter from utils import kalman_filter
@ -25,32 +27,19 @@ def merge_matches(m1, m2, shape):
return match, unmatched_O, unmatched_Q return match, unmatched_O, unmatched_Q
def _indices_to_matches(cost_matrix, indices, thresh):
matched_cost = cost_matrix[tuple(zip(*indices))]
matched_mask = (matched_cost <= thresh)
matches = indices[matched_mask]
unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
return matches, unmatched_a, unmatched_b
def linear_assignment(cost_matrix, thresh): def linear_assignment(cost_matrix, thresh):
"""
Simple linear assignment
:type cost_matrix: np.ndarray
:type thresh: float
:return: matches, unmatched_a, unmatched_b
"""
if cost_matrix.size == 0: if cost_matrix.size == 0:
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
matches, unmatched_a, unmatched_b = [], [], []
cost_matrix[cost_matrix > thresh] = thresh + 1e-4 cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
indices = linear_assignment_.linear_assignment(cost_matrix) for ix, mx in enumerate(x):
if mx >= 0:
return _indices_to_matches(cost_matrix, indices, thresh) matches.append([ix, mx])
unmatched_a = np.where(x < 0)[0]
unmatched_b = np.where(y < 0)[0]
matches = np.asarray(matches)
return matches, unmatched_a, unmatched_b
def ious(atlbrs, btlbrs): def ious(atlbrs, btlbrs):
""" """
@ -104,21 +93,9 @@ def embedding_distance(tracks, detections, metric='cosine'):
if cost_matrix.size == 0: if cost_matrix.size == 0:
return cost_matrix return cost_matrix
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float) det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float)
for i, track in enumerate(tracks): track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float)
cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) cost_matrix = np.maximum(0.0, cdist(track_features, det_features)) # Nomalized features
return cost_matrix
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
if cost_matrix.size == 0:
return cost_matrix
gating_dim = 2 if only_position else 4
gating_threshold = kalman_filter.chi2inv95[gating_dim]
measurements = np.asarray([det.to_xyah() for det in detections])
for row, track in enumerate(tracks):
gating_distance = kf.gating_distance(
track.mean, track.covariance, measurements, only_position)
cost_matrix[row, gating_distance > gating_threshold] = np.inf
return cost_matrix return cost_matrix
@ -130,10 +107,7 @@ def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda
measurements = np.asarray([det.to_xyah() for det in detections]) measurements = np.asarray([det.to_xyah() for det in detections])
for row, track in enumerate(tracks): for row, track in enumerate(tracks):
gating_distance = kf.gating_distance( gating_distance = kf.gating_distance(
track.mean, track.covariance, measurements, only_position) track.mean, track.covariance, measurements, only_position, metric='maha')
cost_matrix[row, gating_distance > gating_threshold] = np.inf cost_matrix[row, gating_distance > gating_threshold] = np.inf
#print(cost_matrix[row])
#print(gating_distance)
#print('-'*90)
cost_matrix[row] = lambda_ * cost_matrix[row] + (1-lambda_)* gating_distance cost_matrix[row] = lambda_ * cost_matrix[row] + (1-lambda_)* gating_distance
return cost_matrix return cost_matrix

View File

@ -6,6 +6,7 @@ import os
import os.path as osp import os.path as osp
import time import time
import torch import torch
import torch.nn.functional as F
from utils.utils import * from utils.utils import *
from utils.log import logger from utils.log import logger
@ -16,6 +17,7 @@ from .basetrack import BaseTrack, TrackState
class STrack(BaseTrack): class STrack(BaseTrack):
shared_kalman = KalmanFilter()
def __init__(self, tlwh, score, temp_feat, buffer_size=30): def __init__(self, tlwh, score, temp_feat, buffer_size=30):
@ -41,7 +43,7 @@ class STrack(BaseTrack):
else: else:
self.smooth_feat = self.alpha *self.smooth_feat + (1-self.alpha) * feat self.smooth_feat = self.alpha *self.smooth_feat + (1-self.alpha) * feat
self.features.append(feat) self.features.append(feat)
self.smooth_feat /= np.linalg.norm(self.smooth_feat) self.smooth_feat /= np.linalg.norm(self.smooth_feat)
def predict(self): def predict(self):
mean_state = self.mean.copy() mean_state = self.mean.copy()
@ -49,6 +51,19 @@ class STrack(BaseTrack):
mean_state[7] = 0 mean_state[7] = 0
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
@staticmethod
def multi_predict(stracks):
if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
for i,st in enumerate(stracks):
if st.state != TrackState.Tracked:
multi_mean[i][7] = 0
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
stracks[i].mean = mean
stracks[i].covariance = cov
def activate(self, kalman_filter, frame_id): def activate(self, kalman_filter, frame_id):
"""Start a new tracklet""" """Start a new tracklet"""
@ -97,7 +112,7 @@ class STrack(BaseTrack):
self.update_features(new_track.curr_feat) self.update_features(new_track.curr_feat)
@property @property
@jit #@jit(nopython=True)
def tlwh(self): def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y, """Get current position in bounding box format `(top left x, top left y,
width, height)`. width, height)`.
@ -110,7 +125,7 @@ class STrack(BaseTrack):
return ret return ret
@property @property
@jit #@jit(nopython=True)
def tlbr(self): def tlbr(self):
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e., """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`. `(top left, bottom right)`.
@ -120,7 +135,7 @@ class STrack(BaseTrack):
return ret return ret
@staticmethod @staticmethod
@jit #@jit(nopython=True)
def tlwh_to_xyah(tlwh): def tlwh_to_xyah(tlwh):
"""Convert bounding box to format `(center x, center y, aspect ratio, """Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`. height)`, where the aspect ratio is `width / height`.
@ -134,14 +149,14 @@ class STrack(BaseTrack):
return self.tlwh_to_xyah(self.tlwh) return self.tlwh_to_xyah(self.tlwh)
@staticmethod @staticmethod
@jit #@jit(nopython=True)
def tlbr_to_tlwh(tlbr): def tlbr_to_tlwh(tlbr):
ret = np.asarray(tlbr).copy() ret = np.asarray(tlbr).copy()
ret[2:] -= ret[:2] ret[2:] -= ret[:2]
return ret return ret
@staticmethod @staticmethod
@jit #@jit(nopython=True)
def tlwh_to_tlbr(tlwh): def tlwh_to_tlbr(tlwh):
ret = np.asarray(tlwh).copy() ret = np.asarray(tlwh).copy()
ret[2:] += ret[:2] ret[2:] += ret[:2]
@ -151,10 +166,11 @@ class STrack(BaseTrack):
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame) return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
class JDETracker(object): class JDETracker(object):
def __init__(self, opt, frame_rate=30): def __init__(self, opt, frame_rate=30):
self.opt = opt self.opt = opt
self.model = Darknet(opt.cfg, opt.img_size, nID=14455) self.model = Darknet(opt.cfg)
# load_darknet_weights(self.model, opt.weights) # load_darknet_weights(self.model, opt.weights)
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False) self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
self.model.cuda().eval() self.model.cuda().eval()
@ -183,17 +199,16 @@ class JDETracker(object):
pred = self.model(im_blob) pred = self.model(im_blob)
pred = pred[pred[:, :, 4] > self.opt.conf_thres] pred = pred[pred[:, :, 4] > self.opt.conf_thres]
if len(pred) > 0: if len(pred) > 0:
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres, self.opt.nms_thres)[0].cpu() dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres,
self.opt.nms_thres)[0]
scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round() scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
dets, embs = dets[:, :5].cpu().numpy(), dets[:, 6:].cpu().numpy()
'''Detections''' '''Detections'''
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])] (tlbrs, f) in zip(dets, embs)]
else: else:
detections = [] detections = []
t2 = time.time()
# print('Forward: {} s'.format(t2-t1))
''' Add newly detected tracklets to tracked_stracks''' ''' Add newly detected tracklets to tracked_stracks'''
unconfirmed = [] unconfirmed = []
tracked_stracks = [] # type: list[STrack] tracked_stracks = [] # type: list[STrack]
@ -206,11 +221,8 @@ class JDETracker(object):
''' Step 2: First association, with embedding''' ''' Step 2: First association, with embedding'''
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF # Predict the current location with KF
for strack in strack_pool: STrack.multi_predict(strack_pool)
strack.predict()
dists = matching.embedding_distance(strack_pool, detections) dists = matching.embedding_distance(strack_pool, detections)
#dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)
dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections) dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7) matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
@ -271,13 +283,10 @@ class JDETracker(object):
if self.frame_id - track.end_frame > self.max_time_lost: if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed() track.mark_removed()
removed_stracks.append(track) removed_stracks.append(track)
t4 = time.time()
# print('Ramained match {} s'.format(t4-t3))
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks) self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks) self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
# self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks) self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks) self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
@ -292,8 +301,6 @@ class JDETracker(object):
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks])) logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks])) logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks])) logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
t5 = time.time()
# print('Final {} s'.format(t5-t4))
return output_stracks return output_stracks
def joint_stracks(tlista, tlistb): def joint_stracks(tlista, tlistb):

View File

@ -13,7 +13,6 @@ from torchvision.transforms import transforms as T
def train( def train(
cfg, cfg,
data_cfg, data_cfg,
img_size=(1088,608),
resume=False, resume=False,
epochs=100, epochs=100,
batch_size=16, batch_size=16,
@ -33,16 +32,19 @@ def train(
trainset_paths = data_config['train'] trainset_paths = data_config['train']
dataset_root = data_config['root'] dataset_root = data_config['root']
f.close() f.close()
cfg_dict = parse_model_cfg(cfg)
img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
transforms = T.Compose([T.ToTensor()])
# Get dataloader # Get dataloader
transforms = T.Compose([T.ToTensor()])
dataset = JointDataset(dataset_root, trainset_paths, img_size, augment=True, transforms=transforms) dataset = JointDataset(dataset_root, trainset_paths, img_size, augment=True, transforms=transforms)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn) num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
# Initialize model # Initialize model
model = Darknet(cfg, img_size, dataset.nID) model = Darknet(cfg_dict, dataset.nID)
cutoff = -1 # backbone reaches to cutoff layer cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0 start_epoch = 0
@ -87,14 +89,13 @@ def train(
p.requires_grad = False if 'batch_norm' in name else True p.requires_grad = False if 'batch_norm' in name else True
model_info(model) model_info(model)
t0 = time.time() t0 = time.time()
for epoch in range(epochs): for epoch in range(epochs):
epoch += start_epoch epoch += start_epoch
logger.info(('%8s%12s' + '%10s' * 6) % ( logger.info(('%8s%12s' + '%10s' * 6) % (
'Epoch', 'Batch', 'box', 'conf', 'id', 'total', 'nTargets', 'time')) 'Epoch', 'Batch', 'box', 'conf', 'id', 'total', 'nTargets', 'time'))
# Freeze darknet53.conv.74 for first epoch # Freeze darknet53.conv.74 for first epoch
if freeze_backbone and (epoch < 2): if freeze_backbone and (epoch < 2):
@ -108,7 +109,7 @@ def train(
for i, (imgs, targets, _, _, targets_len) in enumerate(dataloader): for i, (imgs, targets, _, _, targets_len) in enumerate(dataloader):
if sum([len(x) for x in targets]) < 1: # if no targets continue if sum([len(x) for x in targets]) < 1: # if no targets continue
continue continue
# SGD burn-in # SGD burn-in
burnin = min(1000, len(dataloader)) burnin = min(1000, len(dataloader))
if (epoch == 0) & (i <= burnin): if (epoch == 0) & (i <= burnin):
@ -154,8 +155,8 @@ def train(
# Calculate mAP # Calculate mAP
if epoch % opt.test_interval ==0: if epoch % opt.test_interval ==0:
with torch.no_grad(): with torch.no_grad():
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size, print_interval=40, nID=dataset.nID) mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, print_interval=40)
test.test_emb(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size, print_interval=40, nID=dataset.nID) test.test_emb(cfg, data_cfg, weights=latest, batch_size=batch_size, print_interval=40)
# Call scheduler.step() after opimizer.step() with pytorch > 1.1.0 # Call scheduler.step() after opimizer.step() with pytorch > 1.1.0
@ -166,9 +167,8 @@ if __name__ == '__main__':
parser.add_argument('--epochs', type=int, default=30, help='number of epochs') parser.add_argument('--epochs', type=int, default=30, help='number of epochs')
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch') parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step') parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step')
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path') parser.add_argument('--cfg', type=str, default='cfg/yolov3_1088x608.cfg', help='cfg file path')
parser.add_argument('--data-cfg', type=str, default='cfg/ccmcpe.json', help='coco.data file path') parser.add_argument('--data-cfg', type=str, default='cfg/ccmcpe.json', help='coco.data file path')
parser.add_argument('--img-size', type=int, default=[1088, 608], nargs='+', help='pixels')
parser.add_argument('--resume', action='store_true', help='resume training flag') parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--print-interval', type=int, default=40, help='print interval') parser.add_argument('--print-interval', type=int, default=40, help='print interval')
parser.add_argument('--test-interval', type=int, default=9, help='test interval') parser.add_argument('--test-interval', type=int, default=9, help='test interval')
@ -181,7 +181,6 @@ if __name__ == '__main__':
train( train(
opt.cfg, opt.cfg,
opt.data_cfg, opt.data_cfg,
img_size=opt.img_size,
resume=opt.resume, resume=opt.resume,
epochs=opt.epochs, epochs=opt.epochs,
batch_size=opt.batch_size, batch_size=opt.batch_size,

View File

@ -2,7 +2,7 @@ import os
import numpy as np import numpy as np
import copy import copy
import motmetrics as mm import motmetrics as mm
mm.lap.default_solver = 'lap'
from utils.io import read_results, unzip_objs from utils.io import read_results, unzip_objs
@ -39,18 +39,20 @@ class Evaluator(object):
ignore_objs = self.gt_ignore_frame_dict.get(frame_id, []) ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
ignore_tlwhs = unzip_objs(ignore_objs)[0] ignore_tlwhs = unzip_objs(ignore_objs)[0]
# remove ignored results # remove ignored results
keep = np.ones(len(trk_tlwhs), dtype=bool) keep = np.ones(len(trk_tlwhs), dtype=bool)
iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5) iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5)
match_is, match_js = mm.lap.linear_sum_assignment(iou_distance) if len(iou_distance) > 0:
match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js]) match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
match_ious = iou_distance[match_is, match_js] match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
match_ious = iou_distance[match_is, match_js]
match_js = np.asarray(match_js, dtype=int) match_js = np.asarray(match_js, dtype=int)
match_js = match_js[np.logical_not(np.isnan(match_ious))] match_js = match_js[np.logical_not(np.isnan(match_ious))]
keep[match_js] = False keep[match_js] = False
trk_tlwhs = trk_tlwhs[keep] trk_tlwhs = trk_tlwhs[keep]
trk_ids = trk_ids[keep] trk_ids = trk_ids[keep]
# get distance matrix # get distance matrix
iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5) iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5)

View File

@ -1,4 +1,5 @@
# vim: expandtab:ts=4:sw=4 # vim: expandtab:ts=4:sw=4
import numba
import numpy as np import numpy as np
import scipy.linalg import scipy.linalg
@ -116,7 +117,7 @@ class KalmanFilter(object):
self._std_weight_velocity * mean[3]] self._std_weight_velocity * mean[3]]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
mean = np.dot(self._motion_mat, mean) mean = np.dot(mean, self._motion_mat.T)
covariance = np.linalg.multi_dot(( covariance = np.linalg.multi_dot((
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
@ -150,6 +151,48 @@ class KalmanFilter(object):
covariance = np.linalg.multi_dot(( covariance = np.linalg.multi_dot((
self._update_mat, covariance, self._update_mat.T)) self._update_mat, covariance, self._update_mat.T))
return mean, covariance + innovation_cov return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
"""Run Kalman filter prediction step (Vectorized version).
Parameters
----------
mean : ndarray
The Nx8 dimensional mean matrix of the object states at the previous
time step.
covariance : ndarray
The Nx8x8 dimensional covariance matrics of the object states at the
previous time step.
Returns
-------
(ndarray, ndarray)
Returns the mean vector and covariance matrix of the predicted
state. Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[:, 3],
self._std_weight_position * mean[:, 3],
1e-2 * np.ones_like(mean[:, 3]),
self._std_weight_position * mean[:, 3]]
std_vel = [
self._std_weight_velocity * mean[:, 3],
self._std_weight_velocity * mean[:, 3],
1e-5 * np.ones_like(mean[:, 3]),
self._std_weight_velocity * mean[:, 3]]
sqr = np.square(np.r_[std_pos, std_vel]).T
motion_cov = []
for i in range(len(mean)):
motion_cov.append(np.diag(sqr[i]))
motion_cov = np.asarray(motion_cov)
mean = np.dot(mean, self._motion_mat.T)
left = np.dot(self._motion_mat, covariance).transpose((1,0,2))
covariance = np.dot(left, self._motion_mat.T) + motion_cov
return mean, covariance
def update(self, mean, covariance, measurement): def update(self, mean, covariance, measurement):
"""Run Kalman filter correction step. """Run Kalman filter correction step.
@ -186,7 +229,7 @@ class KalmanFilter(object):
return new_mean, new_covariance return new_mean, new_covariance
def gating_distance(self, mean, covariance, measurements, def gating_distance(self, mean, covariance, measurements,
only_position=False): only_position=False, metric='maha'):
"""Compute gating distance between state distribution and measurements. """Compute gating distance between state distribution and measurements.
A suitable distance threshold can be obtained from `chi2inv95`. If A suitable distance threshold can be obtained from `chi2inv95`. If
@ -219,11 +262,17 @@ class KalmanFilter(object):
if only_position: if only_position:
mean, covariance = mean[:2], covariance[:2, :2] mean, covariance = mean[:2], covariance[:2, :2]
measurements = measurements[:, :2] measurements = measurements[:, :2]
cholesky_factor = np.linalg.cholesky(covariance)
d = measurements - mean d = measurements - mean
z = scipy.linalg.solve_triangular( if metric == 'gaussian':
cholesky_factor, d.T, lower=True, check_finite=False, return np.sum(d * d, axis=1)
overwrite_b=True) elif metric == 'maha':
squared_maha = np.sum(z * z, axis=0) cholesky_factor = np.linalg.cholesky(covariance)
return squared_maha z = scipy.linalg.solve_triangular(
cholesky_factor, d.T, lower=True, check_finite=False,
overwrite_b=True)
squared_maha = np.sum(z * z, axis=0)
return squared_maha
else:
raise ValueError('invalid distance metric')

View File

@ -14,7 +14,9 @@ def parse_model_cfg(path):
else: else:
key, value = line.split("=") key, value = line.split("=")
value = value.strip() value = value.strip()
module_defs[-1][key.rstrip()] = value.strip() if value[0] == '$':
value = module_defs[0].get(value.strip('$'), None)
module_defs[-1][key.rstrip()] = value
return module_defs return module_defs