Documentation (#95)
* Added documentation * Added docstrings and comments * Removed unused imports * Removed unused imports * Added functionality of saving checkpoints during training process * Update train.py * Update multitracker.py
This commit is contained in:
parent
0a0665e682
commit
24f351d1b5
8 changed files with 184 additions and 89 deletions
14
demo.py
14
demo.py
|
@ -1,4 +1,4 @@
|
||||||
"""Demo file for running the JDE tracker on custom video sequences for pedestrian tracking.
|
"""Demo file for running the JDE tracker on custom video sequences for pedestrian tracking.
|
||||||
|
|
||||||
This file is the entry point to running the tracker on custom video sequences. It loads images from the provided video sequence, uses the JDE tracker for inference and outputs the video with bounding boxes indicating pedestrians. The bounding boxes also have associated ids (shown in different colours) to keep track of the movement of each individual.
|
This file is the entry point to running the tracker on custom video sequences. It loads images from the provided video sequence, uses the JDE tracker for inference and outputs the video with bounding boxes indicating pedestrians. The bounding boxes also have associated ids (shown in different colours) to keep track of the movement of each individual.
|
||||||
|
|
||||||
|
@ -24,29 +24,19 @@ Todo:
|
||||||
* More documentation
|
* More documentation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import cv2
|
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import motmetrics as mm
|
|
||||||
|
|
||||||
from tracker.multitracker import JDETracker
|
|
||||||
from utils import visualization as vis
|
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
from utils.io import read_results
|
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
from utils.timer import Timer
|
from utils.timer import Timer
|
||||||
from utils.evaluation import Evaluator
|
|
||||||
from utils.parse_config import parse_model_cfg
|
from utils.parse_config import parse_model_cfg
|
||||||
import utils.datasets as datasets
|
import utils.datasets as datasets
|
||||||
import torch
|
|
||||||
from track import eval_seq
|
from track import eval_seq
|
||||||
|
|
||||||
|
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
def track(opt):
|
def track(opt):
|
||||||
result_root = opt.output_root if opt.output_root!='' else '.'
|
result_root = opt.output_root if opt.output_root!='' else '.'
|
||||||
mkdir_if_missing(result_root)
|
mkdir_if_missing(result_root)
|
||||||
|
|
||||||
|
|
35
track.py
35
track.py
|
@ -39,6 +39,41 @@ def write_results(filename, results, data_type):
|
||||||
|
|
||||||
|
|
||||||
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
|
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
|
||||||
|
'''
|
||||||
|
Processes the video sequence given and provides the output of tracking result (write the results in video file)
|
||||||
|
|
||||||
|
It uses JDE model for getting information about the online targets present.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
opt : Namespace
|
||||||
|
Contains information passed as commandline arguments.
|
||||||
|
|
||||||
|
dataloader : LoadVideo
|
||||||
|
Instance of LoadVideo class used for fetching the image sequence and associated data.
|
||||||
|
|
||||||
|
data_type : String
|
||||||
|
Type of dataset corresponding(similar) to the given video.
|
||||||
|
|
||||||
|
result_filename : String
|
||||||
|
The name(path) of the file for storing results.
|
||||||
|
|
||||||
|
save_dir : String
|
||||||
|
Path to the folder for storing the frames containing bounding box information (Result frames).
|
||||||
|
|
||||||
|
show_image : bool
|
||||||
|
Option for shhowing individial frames during run-time.
|
||||||
|
|
||||||
|
frame_rate : int
|
||||||
|
Frame-rate of the given video.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(Returns are not significant here)
|
||||||
|
frame_id : int
|
||||||
|
Sequence number of the last sequence
|
||||||
|
'''
|
||||||
|
|
||||||
if save_dir:
|
if save_dir:
|
||||||
mkdir_if_missing(save_dir)
|
mkdir_if_missing(save_dir)
|
||||||
tracker = JDETracker(opt, frame_rate=frame_rate)
|
tracker = JDETracker(opt, frame_rate=frame_rate)
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
from scipy.spatial.distance import cdist
|
from scipy.spatial.distance import cdist
|
||||||
|
@ -8,7 +5,6 @@ import lap
|
||||||
|
|
||||||
from cython_bbox import bbox_overlaps as bbox_ious
|
from cython_bbox import bbox_overlaps as bbox_ious
|
||||||
from utils import kalman_filter
|
from utils import kalman_filter
|
||||||
import time
|
|
||||||
|
|
||||||
def merge_matches(m1, m2, shape):
|
def merge_matches(m1, m2, shape):
|
||||||
O,P,Q = shape
|
O,P,Q = shape
|
||||||
|
|
|
@ -1,15 +1,6 @@
|
||||||
import numpy as np
|
|
||||||
from numba import jit
|
from numba import jit
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import itertools
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import time
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from utils.utils import *
|
|
||||||
from utils.log import logger
|
|
||||||
from utils.kalman_filter import KalmanFilter
|
from utils.kalman_filter import KalmanFilter
|
||||||
from models import *
|
from models import *
|
||||||
from tracker import matching
|
from tracker import matching
|
||||||
|
@ -17,7 +8,6 @@ from .basetrack import BaseTrack, TrackState
|
||||||
|
|
||||||
|
|
||||||
class STrack(BaseTrack):
|
class STrack(BaseTrack):
|
||||||
shared_kalman = KalmanFilter()
|
|
||||||
|
|
||||||
def __init__(self, tlwh, score, temp_feat, buffer_size=30):
|
def __init__(self, tlwh, score, temp_feat, buffer_size=30):
|
||||||
|
|
||||||
|
@ -50,13 +40,13 @@ class STrack(BaseTrack):
|
||||||
if self.state != TrackState.Tracked:
|
if self.state != TrackState.Tracked:
|
||||||
mean_state[7] = 0
|
mean_state[7] = 0
|
||||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def multi_predict(stracks):
|
def multi_predict(stracks):
|
||||||
if len(stracks) > 0:
|
if len(stracks) > 0:
|
||||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||||
for i,st in enumerate(stracks):
|
for i, st in enumerate(stracks):
|
||||||
if st.state != TrackState.Tracked:
|
if st.state != TrackState.Tracked:
|
||||||
multi_mean[i][7] = 0
|
multi_mean[i][7] = 0
|
||||||
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||||
|
@ -64,7 +54,6 @@ class STrack(BaseTrack):
|
||||||
stracks[i].mean = mean
|
stracks[i].mean = mean
|
||||||
stracks[i].covariance = cov
|
stracks[i].covariance = cov
|
||||||
|
|
||||||
|
|
||||||
def activate(self, kalman_filter, frame_id):
|
def activate(self, kalman_filter, frame_id):
|
||||||
"""Start a new tracklet"""
|
"""Start a new tracklet"""
|
||||||
self.kalman_filter = kalman_filter
|
self.kalman_filter = kalman_filter
|
||||||
|
@ -112,7 +101,7 @@ class STrack(BaseTrack):
|
||||||
self.update_features(new_track.curr_feat)
|
self.update_features(new_track.curr_feat)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
#@jit(nopython=True)
|
@jit
|
||||||
def tlwh(self):
|
def tlwh(self):
|
||||||
"""Get current position in bounding box format `(top left x, top left y,
|
"""Get current position in bounding box format `(top left x, top left y,
|
||||||
width, height)`.
|
width, height)`.
|
||||||
|
@ -125,7 +114,7 @@ class STrack(BaseTrack):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@property
|
@property
|
||||||
#@jit(nopython=True)
|
@jit
|
||||||
def tlbr(self):
|
def tlbr(self):
|
||||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||||
`(top left, bottom right)`.
|
`(top left, bottom right)`.
|
||||||
|
@ -135,7 +124,7 @@ class STrack(BaseTrack):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
#@jit(nopython=True)
|
@jit
|
||||||
def tlwh_to_xyah(tlwh):
|
def tlwh_to_xyah(tlwh):
|
||||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||||
height)`, where the aspect ratio is `width / height`.
|
height)`, where the aspect ratio is `width / height`.
|
||||||
|
@ -149,14 +138,14 @@ class STrack(BaseTrack):
|
||||||
return self.tlwh_to_xyah(self.tlwh)
|
return self.tlwh_to_xyah(self.tlwh)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
#@jit(nopython=True)
|
@jit
|
||||||
def tlbr_to_tlwh(tlbr):
|
def tlbr_to_tlwh(tlbr):
|
||||||
ret = np.asarray(tlbr).copy()
|
ret = np.asarray(tlbr).copy()
|
||||||
ret[2:] -= ret[:2]
|
ret[2:] -= ret[:2]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
#@jit(nopython=True)
|
@jit
|
||||||
def tlwh_to_tlbr(tlwh):
|
def tlwh_to_tlbr(tlwh):
|
||||||
ret = np.asarray(tlwh).copy()
|
ret = np.asarray(tlwh).copy()
|
||||||
ret[2:] += ret[:2]
|
ret[2:] += ret[:2]
|
||||||
|
@ -166,11 +155,10 @@ class STrack(BaseTrack):
|
||||||
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JDETracker(object):
|
class JDETracker(object):
|
||||||
def __init__(self, opt, frame_rate=30):
|
def __init__(self, opt, frame_rate=30):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.model = Darknet(opt.cfg)
|
self.model = Darknet(opt.cfg, opt.img_size, nID=30)
|
||||||
# load_darknet_weights(self.model, opt.weights)
|
# load_darknet_weights(self.model, opt.weights)
|
||||||
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
|
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
|
||||||
self.model.cuda().eval()
|
self.model.cuda().eval()
|
||||||
|
@ -187,61 +175,106 @@ class JDETracker(object):
|
||||||
self.kalman_filter = KalmanFilter()
|
self.kalman_filter = KalmanFilter()
|
||||||
|
|
||||||
def update(self, im_blob, img0):
|
def update(self, im_blob, img0):
|
||||||
|
"""
|
||||||
|
Processes the image frame and finds bounding box(detections).
|
||||||
|
|
||||||
|
Associates the detection with corresponding tracklets and also handles lost, removed, refound and active tracklets
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
im_blob : torch.float32
|
||||||
|
Tensor of shape depending upon the size of image. By default, shape of this tensor is [1, 3, 608, 1088]
|
||||||
|
|
||||||
|
img0 : ndarray
|
||||||
|
ndarray of shape depending on the input image sequence. By default, shape is [608, 1080, 3]
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
output_stracks : list of Strack(instances)
|
||||||
|
The list contains information regarding the online_tracklets for the recieved image tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
self.frame_id += 1
|
self.frame_id += 1
|
||||||
activated_starcks = []
|
activated_starcks = [] # for storing active tracks, for the current frame
|
||||||
refind_stracks = []
|
refind_stracks = [] # Lost Tracks whose detections are obtained in the current frame
|
||||||
lost_stracks = []
|
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
|
||||||
removed_stracks = []
|
removed_stracks = []
|
||||||
|
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
''' Step 1: Network forward, get detections & embeddings'''
|
''' Step 1: Network forward, get detections & embeddings'''
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred = self.model(im_blob)
|
pred = self.model(im_blob)
|
||||||
|
# pred is tensor of all the proposals (default number of proposals: 54264). Proposals have information associated with the bounding box and embeddings
|
||||||
pred = pred[pred[:, :, 4] > self.opt.conf_thres]
|
pred = pred[pred[:, :, 4] > self.opt.conf_thres]
|
||||||
|
# pred now has lesser number of proposals. Proposals rejected on basis of object confidence score
|
||||||
if len(pred) > 0:
|
if len(pred) > 0:
|
||||||
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres,
|
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres, self.opt.nms_thres)[0].cpu()
|
||||||
self.opt.nms_thres)[0]
|
# Final proposals are obtained in dets. Information of bounding box and embeddings also included
|
||||||
|
# Next step changes the detection scales
|
||||||
scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
|
scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
|
||||||
dets, embs = dets[:, :5].cpu().numpy(), dets[:, 6:].cpu().numpy()
|
'''Detections is list of (x1, y1, x2, y2, object_conf, class_score, class_pred)'''
|
||||||
'''Detections'''
|
# class_pred is the embeddings.
|
||||||
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
|
|
||||||
(tlbrs, f) in zip(dets, embs)]
|
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
|
||||||
|
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
|
||||||
else:
|
else:
|
||||||
detections = []
|
detections = []
|
||||||
|
|
||||||
|
t2 = time.time()
|
||||||
|
# print('Forward: {} s'.format(t2-t1))
|
||||||
|
|
||||||
''' Add newly detected tracklets to tracked_stracks'''
|
''' Add newly detected tracklets to tracked_stracks'''
|
||||||
unconfirmed = []
|
unconfirmed = []
|
||||||
tracked_stracks = [] # type: list[STrack]
|
tracked_stracks = [] # type: list[STrack]
|
||||||
for track in self.tracked_stracks:
|
for track in self.tracked_stracks:
|
||||||
if not track.is_activated:
|
if not track.is_activated:
|
||||||
|
# previous tracks which are not active in the current frame are added in unconfirmed list
|
||||||
unconfirmed.append(track)
|
unconfirmed.append(track)
|
||||||
|
# print("Should not be here, in unconfirmed")
|
||||||
else:
|
else:
|
||||||
|
# Active tracks are added to the local list 'tracked_stracks'
|
||||||
tracked_stracks.append(track)
|
tracked_stracks.append(track)
|
||||||
|
|
||||||
''' Step 2: First association, with embedding'''
|
''' Step 2: First association, with embedding'''
|
||||||
|
# Combining currently tracked_stracks and lost_stracks
|
||||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
||||||
# Predict the current location with KF
|
# Predict the current location with KF
|
||||||
STrack.multi_predict(strack_pool)
|
STrack.multi_predict(strack_pool)
|
||||||
|
|
||||||
|
|
||||||
dists = matching.embedding_distance(strack_pool, detections)
|
dists = matching.embedding_distance(strack_pool, detections)
|
||||||
|
# dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)
|
||||||
dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
|
dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
|
||||||
|
# The dists is the list of distances of the detection with the tracks in strack_pool
|
||||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
||||||
|
# The matches is the array for corresponding matches of the detection with the corresponding strack_pool
|
||||||
|
|
||||||
for itracked, idet in matches:
|
for itracked, idet in matches:
|
||||||
|
# itracked is the id of the track and idet is the detection
|
||||||
track = strack_pool[itracked]
|
track = strack_pool[itracked]
|
||||||
det = detections[idet]
|
det = detections[idet]
|
||||||
if track.state == TrackState.Tracked:
|
if track.state == TrackState.Tracked:
|
||||||
|
# If the track is active, add the detection to the track
|
||||||
track.update(detections[idet], self.frame_id)
|
track.update(detections[idet], self.frame_id)
|
||||||
activated_starcks.append(track)
|
activated_starcks.append(track)
|
||||||
else:
|
else:
|
||||||
|
# We have obtained a detection from a track which is not active, hence put the track in refind_stracks list
|
||||||
track.re_activate(det, self.frame_id, new_id=False)
|
track.re_activate(det, self.frame_id, new_id=False)
|
||||||
refind_stracks.append(track)
|
refind_stracks.append(track)
|
||||||
|
|
||||||
|
# None of the steps below happen if there are no undetected tracks.
|
||||||
''' Step 3: Second association, with IOU'''
|
''' Step 3: Second association, with IOU'''
|
||||||
detections = [detections[i] for i in u_detection]
|
detections = [detections[i] for i in u_detection]
|
||||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state==TrackState.Tracked ]
|
# detections is now a list of the unmatched detections
|
||||||
|
r_tracked_stracks = [] # This is container for stracks which were tracked till the
|
||||||
|
# previous frame but no detection was found for it in the current frame
|
||||||
|
for i in u_track:
|
||||||
|
if strack_pool[i].state == TrackState.Tracked:
|
||||||
|
r_tracked_stracks.append(strack_pool[i])
|
||||||
dists = matching.iou_distance(r_tracked_stracks, detections)
|
dists = matching.iou_distance(r_tracked_stracks, detections)
|
||||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
|
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
|
||||||
|
# matches is the list of detections which matched with corresponding tracks by IOU distance method
|
||||||
for itracked, idet in matches:
|
for itracked, idet in matches:
|
||||||
track = r_tracked_stracks[itracked]
|
track = r_tracked_stracks[itracked]
|
||||||
det = detections[idet]
|
det = detections[idet]
|
||||||
|
@ -251,12 +284,14 @@ class JDETracker(object):
|
||||||
else:
|
else:
|
||||||
track.re_activate(det, self.frame_id, new_id=False)
|
track.re_activate(det, self.frame_id, new_id=False)
|
||||||
refind_stracks.append(track)
|
refind_stracks.append(track)
|
||||||
|
# Same process done for some unmatched detections, but now considering IOU_distance as measure
|
||||||
|
|
||||||
for it in u_track:
|
for it in u_track:
|
||||||
track = r_tracked_stracks[it]
|
track = r_tracked_stracks[it]
|
||||||
if not track.state == TrackState.Lost:
|
if not track.state == TrackState.Lost:
|
||||||
track.mark_lost()
|
track.mark_lost()
|
||||||
lost_stracks.append(track)
|
lost_stracks.append(track)
|
||||||
|
# If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
|
||||||
|
|
||||||
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
||||||
detections = [detections[i] for i in u_detection]
|
detections = [detections[i] for i in u_detection]
|
||||||
|
@ -265,11 +300,14 @@ class JDETracker(object):
|
||||||
for itracked, idet in matches:
|
for itracked, idet in matches:
|
||||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||||
activated_starcks.append(unconfirmed[itracked])
|
activated_starcks.append(unconfirmed[itracked])
|
||||||
|
|
||||||
|
# The tracks which are yet not matched
|
||||||
for it in u_unconfirmed:
|
for it in u_unconfirmed:
|
||||||
track = unconfirmed[it]
|
track = unconfirmed[it]
|
||||||
track.mark_removed()
|
track.mark_removed()
|
||||||
removed_stracks.append(track)
|
removed_stracks.append(track)
|
||||||
|
|
||||||
|
# after all these confirmation steps, if a new detection is found, it is initialized for a new track
|
||||||
""" Step 4: Init new stracks"""
|
""" Step 4: Init new stracks"""
|
||||||
for inew in u_detection:
|
for inew in u_detection:
|
||||||
track = detections[inew]
|
track = detections[inew]
|
||||||
|
@ -279,14 +317,18 @@ class JDETracker(object):
|
||||||
activated_starcks.append(track)
|
activated_starcks.append(track)
|
||||||
|
|
||||||
""" Step 5: Update state"""
|
""" Step 5: Update state"""
|
||||||
|
# If the tracks are lost for more frames than the threshold number, the tracks are removed.
|
||||||
for track in self.lost_stracks:
|
for track in self.lost_stracks:
|
||||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
if self.frame_id - track.end_frame > self.max_time_lost:
|
||||||
track.mark_removed()
|
track.mark_removed()
|
||||||
removed_stracks.append(track)
|
removed_stracks.append(track)
|
||||||
|
# print('Remained match {} s'.format(t4-t3))
|
||||||
|
|
||||||
|
# Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
|
||||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
||||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
||||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
||||||
|
# self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
|
||||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||||
self.lost_stracks.extend(lost_stracks)
|
self.lost_stracks.extend(lost_stracks)
|
||||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
||||||
|
@ -301,6 +343,7 @@ class JDETracker(object):
|
||||||
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
|
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
|
||||||
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
||||||
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
||||||
|
# print('Final {} s'.format(t5-t4))
|
||||||
return output_stracks
|
return output_stracks
|
||||||
|
|
||||||
def joint_stracks(tlista, tlistb):
|
def joint_stracks(tlista, tlistb):
|
||||||
|
|
110
train.py
110
train.py
|
@ -1,9 +1,10 @@
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from time import gmtime, strftime
|
||||||
import test
|
import test
|
||||||
from models import *
|
from models import *
|
||||||
|
from shutil import copyfile
|
||||||
from utils.datasets import JointDataset, collate_fn
|
from utils.datasets import JointDataset, collate_fn
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
|
@ -13,6 +14,10 @@ from torchvision.transforms import transforms as T
|
||||||
def train(
|
def train(
|
||||||
cfg,
|
cfg,
|
||||||
data_cfg,
|
data_cfg,
|
||||||
|
weights_from="",
|
||||||
|
weights_to="",
|
||||||
|
save_every=10,
|
||||||
|
img_size=(1088, 608),
|
||||||
resume=False,
|
resume=False,
|
||||||
epochs=100,
|
epochs=100,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
|
@ -20,9 +25,16 @@ def train(
|
||||||
freeze_backbone=False,
|
freeze_backbone=False,
|
||||||
opt=None,
|
opt=None,
|
||||||
):
|
):
|
||||||
weights = 'weights'
|
# The function starts
|
||||||
mkdir_if_missing(weights)
|
|
||||||
latest = osp.join(weights, 'latest.pt')
|
timme = strftime("%Y-%d-%m %H:%M:%S", gmtime())
|
||||||
|
timme = timme[5:-3].replace('-', '_')
|
||||||
|
timme = timme.replace(' ', '_')
|
||||||
|
timme = timme.replace(':', '_')
|
||||||
|
weights_to = osp.join(weights_to, 'run' + timme)
|
||||||
|
mkdir_if_missing(weights_to)
|
||||||
|
if resume:
|
||||||
|
latest_resume = osp.join(weights_from, 'latest.pt')
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
|
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
|
||||||
|
|
||||||
|
@ -32,24 +44,19 @@ def train(
|
||||||
trainset_paths = data_config['train']
|
trainset_paths = data_config['train']
|
||||||
dataset_root = data_config['root']
|
dataset_root = data_config['root']
|
||||||
f.close()
|
f.close()
|
||||||
cfg_dict = parse_model_cfg(cfg)
|
|
||||||
img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
|
|
||||||
|
|
||||||
# Get dataloader
|
|
||||||
transforms = T.Compose([T.ToTensor()])
|
transforms = T.Compose([T.ToTensor()])
|
||||||
|
# Get dataloader
|
||||||
dataset = JointDataset(dataset_root, trainset_paths, img_size, augment=True, transforms=transforms)
|
dataset = JointDataset(dataset_root, trainset_paths, img_size, augment=True, transforms=transforms)
|
||||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
|
||||||
num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
|
num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = Darknet(cfg_dict, dataset.nID)
|
model = Darknet(cfg, dataset.nID)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cutoff = -1 # backbone reaches to cutoff layer
|
cutoff = -1 # backbone reaches to cutoff layer
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
if resume:
|
if resume:
|
||||||
checkpoint = torch.load(latest, map_location='cpu')
|
checkpoint = torch.load(latest_resume, map_location='cpu')
|
||||||
|
|
||||||
# Load weights to resume from
|
# Load weights to resume from
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
|
@ -67,36 +74,36 @@ def train(
|
||||||
else:
|
else:
|
||||||
# Initialize model with backbone (optional)
|
# Initialize model with backbone (optional)
|
||||||
if cfg.endswith('yolov3.cfg'):
|
if cfg.endswith('yolov3.cfg'):
|
||||||
load_darknet_weights(model, osp.join(weights ,'darknet53.conv.74'))
|
load_darknet_weights(model, osp.join(weights_from, 'darknet53.conv.74'))
|
||||||
cutoff = 75
|
cutoff = 75
|
||||||
elif cfg.endswith('yolov3-tiny.cfg'):
|
elif cfg.endswith('yolov3-tiny.cfg'):
|
||||||
load_darknet_weights(model, osp.join(weights , 'yolov3-tiny.conv.15'))
|
load_darknet_weights(model, osp.join(weights_from, 'yolov3-tiny.conv.15'))
|
||||||
cutoff = 15
|
cutoff = 15
|
||||||
|
|
||||||
model.cuda().train()
|
model.cuda().train()
|
||||||
|
|
||||||
# Set optimizer
|
# Set optimizer
|
||||||
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=opt.lr, momentum=.9, weight_decay=1e-4)
|
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=opt.lr, momentum=.9,
|
||||||
|
weight_decay=1e-4)
|
||||||
|
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
# Set scheduler
|
# Set scheduler
|
||||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
|
||||||
milestones=[int(0.5*opt.epochs), int(0.75*opt.epochs)], gamma=0.1)
|
milestones=[int(0.5 * opt.epochs), int(0.75 * opt.epochs)],
|
||||||
|
gamma=0.1)
|
||||||
# An important trick for detection: freeze bn during fine-tuning
|
|
||||||
|
# An important trick for detection: freeze bn during fine-tuning
|
||||||
if not opt.unfreeze_bn:
|
if not opt.unfreeze_bn:
|
||||||
for i, (name, p) in enumerate(model.named_parameters()):
|
for i, (name, p) in enumerate(model.named_parameters()):
|
||||||
p.requires_grad = False if 'batch_norm' in name else True
|
p.requires_grad = False if 'batch_norm' in name else True
|
||||||
|
|
||||||
model_info(model)
|
# model_info(model)
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
epoch += start_epoch
|
epoch += start_epoch
|
||||||
|
|
||||||
logger.info(('%8s%12s' + '%10s' * 6) % (
|
logger.info(('%8s%12s' + '%10s' * 6) % (
|
||||||
'Epoch', 'Batch', 'box', 'conf', 'id', 'total', 'nTargets', 'time'))
|
'Epoch', 'Batch', 'box', 'conf', 'id', 'total', 'nTargets', 'time'))
|
||||||
|
|
||||||
# Freeze darknet53.conv.74 for first epoch
|
# Freeze darknet53.conv.74 for first epoch
|
||||||
if freeze_backbone and (epoch < 2):
|
if freeze_backbone and (epoch < 2):
|
||||||
for i, (name, p) in enumerate(model.named_parameters()):
|
for i, (name, p) in enumerate(model.named_parameters()):
|
||||||
|
@ -109,18 +116,17 @@ def train(
|
||||||
for i, (imgs, targets, _, _, targets_len) in enumerate(dataloader):
|
for i, (imgs, targets, _, _, targets_len) in enumerate(dataloader):
|
||||||
if sum([len(x) for x in targets]) < 1: # if no targets continue
|
if sum([len(x) for x in targets]) < 1: # if no targets continue
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# SGD burn-in
|
# SGD burn-in
|
||||||
burnin = min(1000, len(dataloader))
|
burnin = min(1000, len(dataloader))
|
||||||
if (epoch == 0) & (i <= burnin):
|
if (epoch == 0) & (i <= burnin):
|
||||||
lr = opt.lr * (i / burnin) **4
|
lr = opt.lr * (i / burnin) ** 4
|
||||||
for g in optimizer.param_groups:
|
for g in optimizer.param_groups:
|
||||||
g['lr'] = lr
|
g['lr'] = lr
|
||||||
|
|
||||||
# Compute loss, compute gradient, update parameters
|
# Compute loss, compute gradient, update parameters
|
||||||
loss, components = model(imgs.cuda(), targets.cuda(), targets_len.cuda())
|
loss, components = model(imgs.cuda(), targets.cuda(), targets_len.cuda())
|
||||||
components = torch.mean(components.view(-1, 5),dim=0)
|
components = torch.mean(components.view(-1, 5), dim=0)
|
||||||
|
|
||||||
loss = torch.mean(loss)
|
loss = torch.mean(loss)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
@ -131,44 +137,64 @@ def train(
|
||||||
|
|
||||||
# Running epoch-means of tracked metrics
|
# Running epoch-means of tracked metrics
|
||||||
ui += 1
|
ui += 1
|
||||||
|
|
||||||
for ii, key in enumerate(model.module.loss_names):
|
for ii, key in enumerate(model.module.loss_names):
|
||||||
rloss[key] = (rloss[key] * ui + components[ii]) / (ui + 1)
|
rloss[key] = (rloss[key] * ui + components[ii]) / (ui + 1)
|
||||||
|
|
||||||
|
# rloss indicates running loss values with mean updated at every epoch
|
||||||
s = ('%8s%12s' + '%10.3g' * 6) % (
|
s = ('%8s%12s' + '%10.3g' * 6) % (
|
||||||
'%g/%g' % (epoch, epochs - 1),
|
'%g/%g' % (epoch, epochs - 1),
|
||||||
'%g/%g' % (i, len(dataloader) - 1),
|
'%g/%g' % (i, len(dataloader) - 1),
|
||||||
rloss['box'], rloss['conf'],
|
rloss['box'], rloss['conf'],
|
||||||
rloss['id'],rloss['loss'],
|
rloss['id'], rloss['loss'],
|
||||||
rloss['nT'], time.time() - t0)
|
rloss['nT'], time.time() - t0)
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
if i % opt.print_interval == 0:
|
if i % opt.print_interval == 0:
|
||||||
logger.info(s)
|
logger.info(s)
|
||||||
|
|
||||||
# Save latest checkpoint
|
# Save latest checkpoint
|
||||||
checkpoint = {'epoch': epoch,
|
checkpoint = {'epoch': epoch,
|
||||||
'model': model.module.state_dict(),
|
'model': model.module.state_dict(),
|
||||||
'optimizer': optimizer.state_dict()}
|
'optimizer': optimizer.state_dict()}
|
||||||
torch.save(checkpoint, latest)
|
|
||||||
|
|
||||||
|
copyfile(cfg, weights_to + '/cfg/yolo3.cfg')
|
||||||
|
copyfile(data_cfg, weights_to + '/cfg/ccmcpe.json')
|
||||||
|
|
||||||
|
latest = osp.join(weights_to, 'latest.pt')
|
||||||
|
torch.save(checkpoint, latest)
|
||||||
|
if epoch % save_every == 0 and epoch != 0:
|
||||||
|
# making the checkpoint lite
|
||||||
|
checkpoint["optimizer"] = []
|
||||||
|
torch.save(checkpoint, osp.join(weights_to, "weights_epoch_" + str(epoch) + ".pt"))
|
||||||
|
|
||||||
# Calculate mAP
|
# Calculate mAP
|
||||||
if epoch % opt.test_interval ==0:
|
if epoch % opt.test_interval == 0:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, print_interval=40)
|
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size,
|
||||||
test.test_emb(cfg, data_cfg, weights=latest, batch_size=batch_size, print_interval=40)
|
print_interval=40, nID=dataset.nID)
|
||||||
|
test.test_emb(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size,
|
||||||
|
print_interval=40, nID=dataset.nID)
|
||||||
|
|
||||||
|
# Call scheduler.step() after opimizer.step() with pytorch > 1.1.0
|
||||||
# Call scheduler.step() after opimizer.step() with pytorch > 1.1.0
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--epochs', type=int, default=30, help='number of epochs')
|
parser.add_argument('--epochs', type=int, default=30, help='number of epochs')
|
||||||
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
|
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
|
||||||
parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step')
|
parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step')
|
||||||
parser.add_argument('--cfg', type=str, default='cfg/yolov3_1088x608.cfg', help='cfg file path')
|
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
||||||
|
parser.add_argument('--weights-from', type=str, default='weights/',
|
||||||
|
help='Path for getting the trained model for resuming training (Should only be used with '
|
||||||
|
'--resume)')
|
||||||
|
parser.add_argument('--weights-to', type=str, default='weights/',
|
||||||
|
help='Store the trained weights after resuming training session. It will create a new folder '
|
||||||
|
'with timestamp in the given path')
|
||||||
|
parser.add_argument('--save-model-after', type=int, default=10,
|
||||||
|
help='Save a checkpoint of model at given interval of epochs')
|
||||||
parser.add_argument('--data-cfg', type=str, default='cfg/ccmcpe.json', help='coco.data file path')
|
parser.add_argument('--data-cfg', type=str, default='cfg/ccmcpe.json', help='coco.data file path')
|
||||||
|
parser.add_argument('--img-size', type=int, default=[1088, 608], nargs='+', help='pixels')
|
||||||
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
||||||
parser.add_argument('--print-interval', type=int, default=40, help='print interval')
|
parser.add_argument('--print-interval', type=int, default=40, help='print interval')
|
||||||
parser.add_argument('--test-interval', type=int, default=9, help='test interval')
|
parser.add_argument('--test-interval', type=int, default=9, help='test interval')
|
||||||
|
@ -181,6 +207,10 @@ if __name__ == '__main__':
|
||||||
train(
|
train(
|
||||||
opt.cfg,
|
opt.cfg,
|
||||||
opt.data_cfg,
|
opt.data_cfg,
|
||||||
|
weights_from=opt.weights_from,
|
||||||
|
weights_to=opt.weights_to,
|
||||||
|
save_every=opt.save_model_after,
|
||||||
|
img_size=opt.img_size,
|
||||||
resume=opt.resume,
|
resume=opt.resume,
|
||||||
epochs=opt.epochs,
|
epochs=opt.epochs,
|
||||||
batch_size=opt.batch_size,
|
batch_size=opt.batch_size,
|
||||||
|
|
|
@ -393,6 +393,9 @@ class JointDataset(LoadImagesAndLabels): # for training
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, files_index):
|
def __getitem__(self, files_index):
|
||||||
|
"""
|
||||||
|
Iterator function for train dataset
|
||||||
|
"""
|
||||||
for i, c in enumerate(self.cds):
|
for i, c in enumerate(self.cds):
|
||||||
if files_index >= c:
|
if files_index >= c:
|
||||||
ds = list(self.label_files.keys())[i]
|
ds = list(self.label_files.keys())[i]
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
# vim: expandtab:ts=4:sw=4
|
# vim: expandtab:ts=4:sw=4
|
||||||
import numba
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.linalg
|
import scipy.linalg
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import glob
|
import glob
|
||||||
import random
|
import random
|
||||||
import time
|
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue