Bugfix (#122)
* bug in multitracker - img_size param passed to Darknet init but no img_size parameter expected. * Various minor bug fixes
This commit is contained in:
parent
7ec84619b1
commit
5618eb5390
2 changed files with 9 additions and 4 deletions
|
@ -2,6 +2,7 @@ from numba import jit
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import torch
|
import torch
|
||||||
from utils.kalman_filter import KalmanFilter
|
from utils.kalman_filter import KalmanFilter
|
||||||
|
from utils.log import logger
|
||||||
from models import *
|
from models import *
|
||||||
from tracker import matching
|
from tracker import matching
|
||||||
from .basetrack import BaseTrack, TrackState
|
from .basetrack import BaseTrack, TrackState
|
||||||
|
@ -42,14 +43,15 @@ class STrack(BaseTrack):
|
||||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def multi_predict(stracks):
|
def multi_predict(stracks, kalman_filter):
|
||||||
if len(stracks) > 0:
|
if len(stracks) > 0:
|
||||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||||
for i, st in enumerate(stracks):
|
for i, st in enumerate(stracks):
|
||||||
if st.state != TrackState.Tracked:
|
if st.state != TrackState.Tracked:
|
||||||
multi_mean[i][7] = 0
|
multi_mean[i][7] = 0
|
||||||
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
# multi_mean, multi_covariance = STrack.kalman_filter.multi_predict(multi_mean, multi_covariance)
|
||||||
|
multi_mean, multi_covariance = kalman_filter.multi_predict(multi_mean, multi_covariance)
|
||||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||||
stracks[i].mean = mean
|
stracks[i].mean = mean
|
||||||
stracks[i].covariance = cov
|
stracks[i].covariance = cov
|
||||||
|
@ -158,7 +160,7 @@ class STrack(BaseTrack):
|
||||||
class JDETracker(object):
|
class JDETracker(object):
|
||||||
def __init__(self, opt, frame_rate=30):
|
def __init__(self, opt, frame_rate=30):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.model = Darknet(opt.cfg, opt.img_size, nID=30)
|
self.model = Darknet(opt.cfg, nID=14455)
|
||||||
# load_darknet_weights(self.model, opt.weights)
|
# load_darknet_weights(self.model, opt.weights)
|
||||||
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
|
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
|
||||||
self.model.cuda().eval()
|
self.model.cuda().eval()
|
||||||
|
@ -240,7 +242,7 @@ class JDETracker(object):
|
||||||
# Combining currently tracked_stracks and lost_stracks
|
# Combining currently tracked_stracks and lost_stracks
|
||||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
||||||
# Predict the current location with KF
|
# Predict the current location with KF
|
||||||
STrack.multi_predict(strack_pool)
|
STrack.multi_predict(strack_pool, self.kalman_filter)
|
||||||
|
|
||||||
|
|
||||||
dists = matching.embedding_distance(strack_pool, detections)
|
dists = matching.embedding_distance(strack_pool, detections)
|
||||||
|
|
|
@ -78,6 +78,9 @@ class LoadImages: # for inference
|
||||||
|
|
||||||
class LoadVideo: # for inference
|
class LoadVideo: # for inference
|
||||||
def __init__(self, path, img_size=(1088, 608)):
|
def __init__(self, path, img_size=(1088, 608)):
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
raise FileExistsError
|
||||||
|
|
||||||
self.cap = cv2.VideoCapture(path)
|
self.cap = cv2.VideoCapture(path)
|
||||||
self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS)))
|
self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS)))
|
||||||
self.vw = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
self.vw = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
|
Loading…
Reference in a new issue