diff --git a/tracker/multitracker.py b/tracker/multitracker.py index 4b6168d..eae4683 100644 --- a/tracker/multitracker.py +++ b/tracker/multitracker.py @@ -2,6 +2,7 @@ from numba import jit from collections import deque import torch from utils.kalman_filter import KalmanFilter +from utils.log import logger from models import * from tracker import matching from .basetrack import BaseTrack, TrackState @@ -42,14 +43,15 @@ class STrack(BaseTrack): self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) @staticmethod - def multi_predict(stracks): + def multi_predict(stracks, kalman_filter): if len(stracks) > 0: multi_mean = np.asarray([st.mean.copy() for st in stracks]) multi_covariance = np.asarray([st.covariance for st in stracks]) for i, st in enumerate(stracks): if st.state != TrackState.Tracked: multi_mean[i][7] = 0 - multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance) +# multi_mean, multi_covariance = STrack.kalman_filter.multi_predict(multi_mean, multi_covariance) + multi_mean, multi_covariance = kalman_filter.multi_predict(multi_mean, multi_covariance) for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): stracks[i].mean = mean stracks[i].covariance = cov @@ -158,7 +160,7 @@ class STrack(BaseTrack): class JDETracker(object): def __init__(self, opt, frame_rate=30): self.opt = opt - self.model = Darknet(opt.cfg, opt.img_size, nID=30) + self.model = Darknet(opt.cfg, nID=14455) # load_darknet_weights(self.model, opt.weights) self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False) self.model.cuda().eval() @@ -240,7 +242,7 @@ class JDETracker(object): # Combining currently tracked_stracks and lost_stracks strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) # Predict the current location with KF - STrack.multi_predict(strack_pool) + STrack.multi_predict(strack_pool, self.kalman_filter) dists = matching.embedding_distance(strack_pool, detections) diff --git a/utils/datasets.py b/utils/datasets.py index 736b8db..88546a3 100644 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -78,6 +78,9 @@ class LoadImages: # for inference class LoadVideo: # for inference def __init__(self, path, img_size=(1088, 608)): + if not os.path.isfile(path): + raise FileExistsError + self.cap = cv2.VideoCapture(path) self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS))) self.vw = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))