Bugfix (#122)
* bug in multitracker - img_size param passed to Darknet init but no img_size parameter expected. * Various minor bug fixes
This commit is contained in:
parent
7ec84619b1
commit
5618eb5390
2 changed files with 9 additions and 4 deletions
|
@ -2,6 +2,7 @@ from numba import jit
|
|||
from collections import deque
|
||||
import torch
|
||||
from utils.kalman_filter import KalmanFilter
|
||||
from utils.log import logger
|
||||
from models import *
|
||||
from tracker import matching
|
||||
from .basetrack import BaseTrack, TrackState
|
||||
|
@ -42,14 +43,15 @@ class STrack(BaseTrack):
|
|||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
def multi_predict(stracks, kalman_filter):
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
for i, st in enumerate(stracks):
|
||||
if st.state != TrackState.Tracked:
|
||||
multi_mean[i][7] = 0
|
||||
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||
# multi_mean, multi_covariance = STrack.kalman_filter.multi_predict(multi_mean, multi_covariance)
|
||||
multi_mean, multi_covariance = kalman_filter.multi_predict(multi_mean, multi_covariance)
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
@ -158,7 +160,7 @@ class STrack(BaseTrack):
|
|||
class JDETracker(object):
|
||||
def __init__(self, opt, frame_rate=30):
|
||||
self.opt = opt
|
||||
self.model = Darknet(opt.cfg, opt.img_size, nID=30)
|
||||
self.model = Darknet(opt.cfg, nID=14455)
|
||||
# load_darknet_weights(self.model, opt.weights)
|
||||
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
|
||||
self.model.cuda().eval()
|
||||
|
@ -240,7 +242,7 @@ class JDETracker(object):
|
|||
# Combining currently tracked_stracks and lost_stracks
|
||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
||||
# Predict the current location with KF
|
||||
STrack.multi_predict(strack_pool)
|
||||
STrack.multi_predict(strack_pool, self.kalman_filter)
|
||||
|
||||
|
||||
dists = matching.embedding_distance(strack_pool, detections)
|
||||
|
|
|
@ -78,6 +78,9 @@ class LoadImages: # for inference
|
|||
|
||||
class LoadVideo: # for inference
|
||||
def __init__(self, path, img_size=(1088, 608)):
|
||||
if not os.path.isfile(path):
|
||||
raise FileExistsError
|
||||
|
||||
self.cap = cv2.VideoCapture(path)
|
||||
self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS)))
|
||||
self.vw = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
|
|
Loading…
Reference in a new issue