* bug in multitracker - img_size param passed to Darknet init but no img_size parameter expected.

* Various minor bug fixes
This commit is contained in:
Denver Dash 2020-03-20 05:45:22 -06:00 committed by GitHub
parent 7ec84619b1
commit 5618eb5390
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 4 deletions

View file

@ -2,6 +2,7 @@ from numba import jit
from collections import deque
import torch
from utils.kalman_filter import KalmanFilter
from utils.log import logger
from models import *
from tracker import matching
from .basetrack import BaseTrack, TrackState
@ -42,14 +43,15 @@ class STrack(BaseTrack):
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
@staticmethod
def multi_predict(stracks):
def multi_predict(stracks, kalman_filter):
if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
for i, st in enumerate(stracks):
if st.state != TrackState.Tracked:
multi_mean[i][7] = 0
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
# multi_mean, multi_covariance = STrack.kalman_filter.multi_predict(multi_mean, multi_covariance)
multi_mean, multi_covariance = kalman_filter.multi_predict(multi_mean, multi_covariance)
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
stracks[i].mean = mean
stracks[i].covariance = cov
@ -158,7 +160,7 @@ class STrack(BaseTrack):
class JDETracker(object):
def __init__(self, opt, frame_rate=30):
self.opt = opt
self.model = Darknet(opt.cfg, opt.img_size, nID=30)
self.model = Darknet(opt.cfg, nID=14455)
# load_darknet_weights(self.model, opt.weights)
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
self.model.cuda().eval()
@ -240,7 +242,7 @@ class JDETracker(object):
# Combining currently tracked_stracks and lost_stracks
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF
STrack.multi_predict(strack_pool)
STrack.multi_predict(strack_pool, self.kalman_filter)
dists = matching.embedding_distance(strack_pool, detections)

View file

@ -78,6 +78,9 @@ class LoadImages: # for inference
class LoadVideo: # for inference
def __init__(self, path, img_size=(1088, 608)):
if not os.path.isfile(path):
raise FileExistsError
self.cap = cv2.VideoCapture(path)
self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS)))
self.vw = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))