Documentation (#95)
* Added documentation * Added docstrings and comments * Removed unused imports * Removed unused imports * Added functionality of saving checkpoints during training process * Update train.py * Update multitracker.py
This commit is contained in:
parent
0a0665e682
commit
24f351d1b5
8 changed files with 184 additions and 89 deletions
10
demo.py
10
demo.py
|
@ -24,23 +24,13 @@ Todo:
|
||||||
* More documentation
|
* More documentation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import cv2
|
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import motmetrics as mm
|
|
||||||
|
|
||||||
from tracker.multitracker import JDETracker
|
|
||||||
from utils import visualization as vis
|
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
from utils.io import read_results
|
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
from utils.timer import Timer
|
from utils.timer import Timer
|
||||||
from utils.evaluation import Evaluator
|
|
||||||
from utils.parse_config import parse_model_cfg
|
from utils.parse_config import parse_model_cfg
|
||||||
import utils.datasets as datasets
|
import utils.datasets as datasets
|
||||||
import torch
|
|
||||||
from track import eval_seq
|
from track import eval_seq
|
||||||
|
|
||||||
|
|
||||||
|
|
35
track.py
35
track.py
|
@ -39,6 +39,41 @@ def write_results(filename, results, data_type):
|
||||||
|
|
||||||
|
|
||||||
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
|
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
|
||||||
|
'''
|
||||||
|
Processes the video sequence given and provides the output of tracking result (write the results in video file)
|
||||||
|
|
||||||
|
It uses JDE model for getting information about the online targets present.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
opt : Namespace
|
||||||
|
Contains information passed as commandline arguments.
|
||||||
|
|
||||||
|
dataloader : LoadVideo
|
||||||
|
Instance of LoadVideo class used for fetching the image sequence and associated data.
|
||||||
|
|
||||||
|
data_type : String
|
||||||
|
Type of dataset corresponding(similar) to the given video.
|
||||||
|
|
||||||
|
result_filename : String
|
||||||
|
The name(path) of the file for storing results.
|
||||||
|
|
||||||
|
save_dir : String
|
||||||
|
Path to the folder for storing the frames containing bounding box information (Result frames).
|
||||||
|
|
||||||
|
show_image : bool
|
||||||
|
Option for shhowing individial frames during run-time.
|
||||||
|
|
||||||
|
frame_rate : int
|
||||||
|
Frame-rate of the given video.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(Returns are not significant here)
|
||||||
|
frame_id : int
|
||||||
|
Sequence number of the last sequence
|
||||||
|
'''
|
||||||
|
|
||||||
if save_dir:
|
if save_dir:
|
||||||
mkdir_if_missing(save_dir)
|
mkdir_if_missing(save_dir)
|
||||||
tracker = JDETracker(opt, frame_rate=frame_rate)
|
tracker = JDETracker(opt, frame_rate=frame_rate)
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
from scipy.spatial.distance import cdist
|
from scipy.spatial.distance import cdist
|
||||||
|
@ -8,7 +5,6 @@ import lap
|
||||||
|
|
||||||
from cython_bbox import bbox_overlaps as bbox_ious
|
from cython_bbox import bbox_overlaps as bbox_ious
|
||||||
from utils import kalman_filter
|
from utils import kalman_filter
|
||||||
import time
|
|
||||||
|
|
||||||
def merge_matches(m1, m2, shape):
|
def merge_matches(m1, m2, shape):
|
||||||
O,P,Q = shape
|
O,P,Q = shape
|
||||||
|
|
|
@ -1,15 +1,6 @@
|
||||||
import numpy as np
|
|
||||||
from numba import jit
|
from numba import jit
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import itertools
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import time
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from utils.utils import *
|
|
||||||
from utils.log import logger
|
|
||||||
from utils.kalman_filter import KalmanFilter
|
from utils.kalman_filter import KalmanFilter
|
||||||
from models import *
|
from models import *
|
||||||
from tracker import matching
|
from tracker import matching
|
||||||
|
@ -17,7 +8,6 @@ from .basetrack import BaseTrack, TrackState
|
||||||
|
|
||||||
|
|
||||||
class STrack(BaseTrack):
|
class STrack(BaseTrack):
|
||||||
shared_kalman = KalmanFilter()
|
|
||||||
|
|
||||||
def __init__(self, tlwh, score, temp_feat, buffer_size=30):
|
def __init__(self, tlwh, score, temp_feat, buffer_size=30):
|
||||||
|
|
||||||
|
@ -64,7 +54,6 @@ class STrack(BaseTrack):
|
||||||
stracks[i].mean = mean
|
stracks[i].mean = mean
|
||||||
stracks[i].covariance = cov
|
stracks[i].covariance = cov
|
||||||
|
|
||||||
|
|
||||||
def activate(self, kalman_filter, frame_id):
|
def activate(self, kalman_filter, frame_id):
|
||||||
"""Start a new tracklet"""
|
"""Start a new tracklet"""
|
||||||
self.kalman_filter = kalman_filter
|
self.kalman_filter = kalman_filter
|
||||||
|
@ -112,7 +101,7 @@ class STrack(BaseTrack):
|
||||||
self.update_features(new_track.curr_feat)
|
self.update_features(new_track.curr_feat)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
#@jit(nopython=True)
|
@jit
|
||||||
def tlwh(self):
|
def tlwh(self):
|
||||||
"""Get current position in bounding box format `(top left x, top left y,
|
"""Get current position in bounding box format `(top left x, top left y,
|
||||||
width, height)`.
|
width, height)`.
|
||||||
|
@ -125,7 +114,7 @@ class STrack(BaseTrack):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@property
|
@property
|
||||||
#@jit(nopython=True)
|
@jit
|
||||||
def tlbr(self):
|
def tlbr(self):
|
||||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||||
`(top left, bottom right)`.
|
`(top left, bottom right)`.
|
||||||
|
@ -135,7 +124,7 @@ class STrack(BaseTrack):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
#@jit(nopython=True)
|
@jit
|
||||||
def tlwh_to_xyah(tlwh):
|
def tlwh_to_xyah(tlwh):
|
||||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||||
height)`, where the aspect ratio is `width / height`.
|
height)`, where the aspect ratio is `width / height`.
|
||||||
|
@ -149,14 +138,14 @@ class STrack(BaseTrack):
|
||||||
return self.tlwh_to_xyah(self.tlwh)
|
return self.tlwh_to_xyah(self.tlwh)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
#@jit(nopython=True)
|
@jit
|
||||||
def tlbr_to_tlwh(tlbr):
|
def tlbr_to_tlwh(tlbr):
|
||||||
ret = np.asarray(tlbr).copy()
|
ret = np.asarray(tlbr).copy()
|
||||||
ret[2:] -= ret[:2]
|
ret[2:] -= ret[:2]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
#@jit(nopython=True)
|
@jit
|
||||||
def tlwh_to_tlbr(tlwh):
|
def tlwh_to_tlbr(tlwh):
|
||||||
ret = np.asarray(tlwh).copy()
|
ret = np.asarray(tlwh).copy()
|
||||||
ret[2:] += ret[:2]
|
ret[2:] += ret[:2]
|
||||||
|
@ -166,11 +155,10 @@ class STrack(BaseTrack):
|
||||||
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class JDETracker(object):
|
class JDETracker(object):
|
||||||
def __init__(self, opt, frame_rate=30):
|
def __init__(self, opt, frame_rate=30):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.model = Darknet(opt.cfg)
|
self.model = Darknet(opt.cfg, opt.img_size, nID=30)
|
||||||
# load_darknet_weights(self.model, opt.weights)
|
# load_darknet_weights(self.model, opt.weights)
|
||||||
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
|
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
|
||||||
self.model.cuda().eval()
|
self.model.cuda().eval()
|
||||||
|
@ -187,61 +175,106 @@ class JDETracker(object):
|
||||||
self.kalman_filter = KalmanFilter()
|
self.kalman_filter = KalmanFilter()
|
||||||
|
|
||||||
def update(self, im_blob, img0):
|
def update(self, im_blob, img0):
|
||||||
|
"""
|
||||||
|
Processes the image frame and finds bounding box(detections).
|
||||||
|
|
||||||
|
Associates the detection with corresponding tracklets and also handles lost, removed, refound and active tracklets
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
im_blob : torch.float32
|
||||||
|
Tensor of shape depending upon the size of image. By default, shape of this tensor is [1, 3, 608, 1088]
|
||||||
|
|
||||||
|
img0 : ndarray
|
||||||
|
ndarray of shape depending on the input image sequence. By default, shape is [608, 1080, 3]
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
output_stracks : list of Strack(instances)
|
||||||
|
The list contains information regarding the online_tracklets for the recieved image tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
self.frame_id += 1
|
self.frame_id += 1
|
||||||
activated_starcks = []
|
activated_starcks = [] # for storing active tracks, for the current frame
|
||||||
refind_stracks = []
|
refind_stracks = [] # Lost Tracks whose detections are obtained in the current frame
|
||||||
lost_stracks = []
|
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
|
||||||
removed_stracks = []
|
removed_stracks = []
|
||||||
|
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
''' Step 1: Network forward, get detections & embeddings'''
|
''' Step 1: Network forward, get detections & embeddings'''
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred = self.model(im_blob)
|
pred = self.model(im_blob)
|
||||||
|
# pred is tensor of all the proposals (default number of proposals: 54264). Proposals have information associated with the bounding box and embeddings
|
||||||
pred = pred[pred[:, :, 4] > self.opt.conf_thres]
|
pred = pred[pred[:, :, 4] > self.opt.conf_thres]
|
||||||
|
# pred now has lesser number of proposals. Proposals rejected on basis of object confidence score
|
||||||
if len(pred) > 0:
|
if len(pred) > 0:
|
||||||
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres,
|
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres, self.opt.nms_thres)[0].cpu()
|
||||||
self.opt.nms_thres)[0]
|
# Final proposals are obtained in dets. Information of bounding box and embeddings also included
|
||||||
|
# Next step changes the detection scales
|
||||||
scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
|
scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
|
||||||
dets, embs = dets[:, :5].cpu().numpy(), dets[:, 6:].cpu().numpy()
|
'''Detections is list of (x1, y1, x2, y2, object_conf, class_score, class_pred)'''
|
||||||
'''Detections'''
|
# class_pred is the embeddings.
|
||||||
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
|
|
||||||
(tlbrs, f) in zip(dets, embs)]
|
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
|
||||||
|
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
|
||||||
else:
|
else:
|
||||||
detections = []
|
detections = []
|
||||||
|
|
||||||
|
t2 = time.time()
|
||||||
|
# print('Forward: {} s'.format(t2-t1))
|
||||||
|
|
||||||
''' Add newly detected tracklets to tracked_stracks'''
|
''' Add newly detected tracklets to tracked_stracks'''
|
||||||
unconfirmed = []
|
unconfirmed = []
|
||||||
tracked_stracks = [] # type: list[STrack]
|
tracked_stracks = [] # type: list[STrack]
|
||||||
for track in self.tracked_stracks:
|
for track in self.tracked_stracks:
|
||||||
if not track.is_activated:
|
if not track.is_activated:
|
||||||
|
# previous tracks which are not active in the current frame are added in unconfirmed list
|
||||||
unconfirmed.append(track)
|
unconfirmed.append(track)
|
||||||
|
# print("Should not be here, in unconfirmed")
|
||||||
else:
|
else:
|
||||||
|
# Active tracks are added to the local list 'tracked_stracks'
|
||||||
tracked_stracks.append(track)
|
tracked_stracks.append(track)
|
||||||
|
|
||||||
''' Step 2: First association, with embedding'''
|
''' Step 2: First association, with embedding'''
|
||||||
|
# Combining currently tracked_stracks and lost_stracks
|
||||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
||||||
# Predict the current location with KF
|
# Predict the current location with KF
|
||||||
STrack.multi_predict(strack_pool)
|
STrack.multi_predict(strack_pool)
|
||||||
|
|
||||||
|
|
||||||
dists = matching.embedding_distance(strack_pool, detections)
|
dists = matching.embedding_distance(strack_pool, detections)
|
||||||
|
# dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)
|
||||||
dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
|
dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
|
||||||
|
# The dists is the list of distances of the detection with the tracks in strack_pool
|
||||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
||||||
|
# The matches is the array for corresponding matches of the detection with the corresponding strack_pool
|
||||||
|
|
||||||
for itracked, idet in matches:
|
for itracked, idet in matches:
|
||||||
|
# itracked is the id of the track and idet is the detection
|
||||||
track = strack_pool[itracked]
|
track = strack_pool[itracked]
|
||||||
det = detections[idet]
|
det = detections[idet]
|
||||||
if track.state == TrackState.Tracked:
|
if track.state == TrackState.Tracked:
|
||||||
|
# If the track is active, add the detection to the track
|
||||||
track.update(detections[idet], self.frame_id)
|
track.update(detections[idet], self.frame_id)
|
||||||
activated_starcks.append(track)
|
activated_starcks.append(track)
|
||||||
else:
|
else:
|
||||||
|
# We have obtained a detection from a track which is not active, hence put the track in refind_stracks list
|
||||||
track.re_activate(det, self.frame_id, new_id=False)
|
track.re_activate(det, self.frame_id, new_id=False)
|
||||||
refind_stracks.append(track)
|
refind_stracks.append(track)
|
||||||
|
|
||||||
|
# None of the steps below happen if there are no undetected tracks.
|
||||||
''' Step 3: Second association, with IOU'''
|
''' Step 3: Second association, with IOU'''
|
||||||
detections = [detections[i] for i in u_detection]
|
detections = [detections[i] for i in u_detection]
|
||||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state==TrackState.Tracked ]
|
# detections is now a list of the unmatched detections
|
||||||
|
r_tracked_stracks = [] # This is container for stracks which were tracked till the
|
||||||
|
# previous frame but no detection was found for it in the current frame
|
||||||
|
for i in u_track:
|
||||||
|
if strack_pool[i].state == TrackState.Tracked:
|
||||||
|
r_tracked_stracks.append(strack_pool[i])
|
||||||
dists = matching.iou_distance(r_tracked_stracks, detections)
|
dists = matching.iou_distance(r_tracked_stracks, detections)
|
||||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
|
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
|
||||||
|
# matches is the list of detections which matched with corresponding tracks by IOU distance method
|
||||||
for itracked, idet in matches:
|
for itracked, idet in matches:
|
||||||
track = r_tracked_stracks[itracked]
|
track = r_tracked_stracks[itracked]
|
||||||
det = detections[idet]
|
det = detections[idet]
|
||||||
|
@ -251,12 +284,14 @@ class JDETracker(object):
|
||||||
else:
|
else:
|
||||||
track.re_activate(det, self.frame_id, new_id=False)
|
track.re_activate(det, self.frame_id, new_id=False)
|
||||||
refind_stracks.append(track)
|
refind_stracks.append(track)
|
||||||
|
# Same process done for some unmatched detections, but now considering IOU_distance as measure
|
||||||
|
|
||||||
for it in u_track:
|
for it in u_track:
|
||||||
track = r_tracked_stracks[it]
|
track = r_tracked_stracks[it]
|
||||||
if not track.state == TrackState.Lost:
|
if not track.state == TrackState.Lost:
|
||||||
track.mark_lost()
|
track.mark_lost()
|
||||||
lost_stracks.append(track)
|
lost_stracks.append(track)
|
||||||
|
# If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
|
||||||
|
|
||||||
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
||||||
detections = [detections[i] for i in u_detection]
|
detections = [detections[i] for i in u_detection]
|
||||||
|
@ -265,11 +300,14 @@ class JDETracker(object):
|
||||||
for itracked, idet in matches:
|
for itracked, idet in matches:
|
||||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||||
activated_starcks.append(unconfirmed[itracked])
|
activated_starcks.append(unconfirmed[itracked])
|
||||||
|
|
||||||
|
# The tracks which are yet not matched
|
||||||
for it in u_unconfirmed:
|
for it in u_unconfirmed:
|
||||||
track = unconfirmed[it]
|
track = unconfirmed[it]
|
||||||
track.mark_removed()
|
track.mark_removed()
|
||||||
removed_stracks.append(track)
|
removed_stracks.append(track)
|
||||||
|
|
||||||
|
# after all these confirmation steps, if a new detection is found, it is initialized for a new track
|
||||||
""" Step 4: Init new stracks"""
|
""" Step 4: Init new stracks"""
|
||||||
for inew in u_detection:
|
for inew in u_detection:
|
||||||
track = detections[inew]
|
track = detections[inew]
|
||||||
|
@ -279,14 +317,18 @@ class JDETracker(object):
|
||||||
activated_starcks.append(track)
|
activated_starcks.append(track)
|
||||||
|
|
||||||
""" Step 5: Update state"""
|
""" Step 5: Update state"""
|
||||||
|
# If the tracks are lost for more frames than the threshold number, the tracks are removed.
|
||||||
for track in self.lost_stracks:
|
for track in self.lost_stracks:
|
||||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
if self.frame_id - track.end_frame > self.max_time_lost:
|
||||||
track.mark_removed()
|
track.mark_removed()
|
||||||
removed_stracks.append(track)
|
removed_stracks.append(track)
|
||||||
|
# print('Remained match {} s'.format(t4-t3))
|
||||||
|
|
||||||
|
# Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
|
||||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
||||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
||||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
||||||
|
# self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
|
||||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||||
self.lost_stracks.extend(lost_stracks)
|
self.lost_stracks.extend(lost_stracks)
|
||||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
||||||
|
@ -301,6 +343,7 @@ class JDETracker(object):
|
||||||
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
|
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
|
||||||
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
||||||
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
||||||
|
# print('Final {} s'.format(t5-t4))
|
||||||
return output_stracks
|
return output_stracks
|
||||||
|
|
||||||
def joint_stracks(tlista, tlistb):
|
def joint_stracks(tlista, tlistb):
|
||||||
|
|
80
train.py
80
train.py
|
@ -1,9 +1,10 @@
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from time import gmtime, strftime
|
||||||
import test
|
import test
|
||||||
from models import *
|
from models import *
|
||||||
|
from shutil import copyfile
|
||||||
from utils.datasets import JointDataset, collate_fn
|
from utils.datasets import JointDataset, collate_fn
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
from utils.log import logger
|
from utils.log import logger
|
||||||
|
@ -13,6 +14,10 @@ from torchvision.transforms import transforms as T
|
||||||
def train(
|
def train(
|
||||||
cfg,
|
cfg,
|
||||||
data_cfg,
|
data_cfg,
|
||||||
|
weights_from="",
|
||||||
|
weights_to="",
|
||||||
|
save_every=10,
|
||||||
|
img_size=(1088, 608),
|
||||||
resume=False,
|
resume=False,
|
||||||
epochs=100,
|
epochs=100,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
|
@ -20,9 +25,16 @@ def train(
|
||||||
freeze_backbone=False,
|
freeze_backbone=False,
|
||||||
opt=None,
|
opt=None,
|
||||||
):
|
):
|
||||||
weights = 'weights'
|
# The function starts
|
||||||
mkdir_if_missing(weights)
|
|
||||||
latest = osp.join(weights, 'latest.pt')
|
timme = strftime("%Y-%d-%m %H:%M:%S", gmtime())
|
||||||
|
timme = timme[5:-3].replace('-', '_')
|
||||||
|
timme = timme.replace(' ', '_')
|
||||||
|
timme = timme.replace(':', '_')
|
||||||
|
weights_to = osp.join(weights_to, 'run' + timme)
|
||||||
|
mkdir_if_missing(weights_to)
|
||||||
|
if resume:
|
||||||
|
latest_resume = osp.join(weights_from, 'latest.pt')
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
|
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
|
||||||
|
|
||||||
|
@ -32,24 +44,19 @@ def train(
|
||||||
trainset_paths = data_config['train']
|
trainset_paths = data_config['train']
|
||||||
dataset_root = data_config['root']
|
dataset_root = data_config['root']
|
||||||
f.close()
|
f.close()
|
||||||
cfg_dict = parse_model_cfg(cfg)
|
|
||||||
img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
|
|
||||||
|
|
||||||
# Get dataloader
|
|
||||||
transforms = T.Compose([T.ToTensor()])
|
transforms = T.Compose([T.ToTensor()])
|
||||||
|
# Get dataloader
|
||||||
dataset = JointDataset(dataset_root, trainset_paths, img_size, augment=True, transforms=transforms)
|
dataset = JointDataset(dataset_root, trainset_paths, img_size, augment=True, transforms=transforms)
|
||||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
|
||||||
num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
|
num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = Darknet(cfg_dict, dataset.nID)
|
model = Darknet(cfg, dataset.nID)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cutoff = -1 # backbone reaches to cutoff layer
|
cutoff = -1 # backbone reaches to cutoff layer
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
if resume:
|
if resume:
|
||||||
checkpoint = torch.load(latest, map_location='cpu')
|
checkpoint = torch.load(latest_resume, map_location='cpu')
|
||||||
|
|
||||||
# Load weights to resume from
|
# Load weights to resume from
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
|
@ -67,33 +74,33 @@ def train(
|
||||||
else:
|
else:
|
||||||
# Initialize model with backbone (optional)
|
# Initialize model with backbone (optional)
|
||||||
if cfg.endswith('yolov3.cfg'):
|
if cfg.endswith('yolov3.cfg'):
|
||||||
load_darknet_weights(model, osp.join(weights ,'darknet53.conv.74'))
|
load_darknet_weights(model, osp.join(weights_from, 'darknet53.conv.74'))
|
||||||
cutoff = 75
|
cutoff = 75
|
||||||
elif cfg.endswith('yolov3-tiny.cfg'):
|
elif cfg.endswith('yolov3-tiny.cfg'):
|
||||||
load_darknet_weights(model, osp.join(weights , 'yolov3-tiny.conv.15'))
|
load_darknet_weights(model, osp.join(weights_from, 'yolov3-tiny.conv.15'))
|
||||||
cutoff = 15
|
cutoff = 15
|
||||||
|
|
||||||
model.cuda().train()
|
model.cuda().train()
|
||||||
|
|
||||||
# Set optimizer
|
# Set optimizer
|
||||||
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=opt.lr, momentum=.9, weight_decay=1e-4)
|
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=opt.lr, momentum=.9,
|
||||||
|
weight_decay=1e-4)
|
||||||
|
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
# Set scheduler
|
# Set scheduler
|
||||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
|
||||||
milestones=[int(0.5*opt.epochs), int(0.75*opt.epochs)], gamma=0.1)
|
milestones=[int(0.5 * opt.epochs), int(0.75 * opt.epochs)],
|
||||||
|
gamma=0.1)
|
||||||
|
|
||||||
# An important trick for detection: freeze bn during fine-tuning
|
# An important trick for detection: freeze bn during fine-tuning
|
||||||
if not opt.unfreeze_bn:
|
if not opt.unfreeze_bn:
|
||||||
for i, (name, p) in enumerate(model.named_parameters()):
|
for i, (name, p) in enumerate(model.named_parameters()):
|
||||||
p.requires_grad = False if 'batch_norm' in name else True
|
p.requires_grad = False if 'batch_norm' in name else True
|
||||||
|
|
||||||
model_info(model)
|
# model_info(model)
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
epoch += start_epoch
|
epoch += start_epoch
|
||||||
|
|
||||||
logger.info(('%8s%12s' + '%10s' * 6) % (
|
logger.info(('%8s%12s' + '%10s' * 6) % (
|
||||||
'Epoch', 'Batch', 'box', 'conf', 'id', 'total', 'nTargets', 'time'))
|
'Epoch', 'Batch', 'box', 'conf', 'id', 'total', 'nTargets', 'time'))
|
||||||
|
|
||||||
|
@ -120,7 +127,6 @@ def train(
|
||||||
# Compute loss, compute gradient, update parameters
|
# Compute loss, compute gradient, update parameters
|
||||||
loss, components = model(imgs.cuda(), targets.cuda(), targets_len.cuda())
|
loss, components = model(imgs.cuda(), targets.cuda(), targets_len.cuda())
|
||||||
components = torch.mean(components.view(-1, 5), dim=0)
|
components = torch.mean(components.view(-1, 5), dim=0)
|
||||||
|
|
||||||
loss = torch.mean(loss)
|
loss = torch.mean(loss)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
@ -135,6 +141,7 @@ def train(
|
||||||
for ii, key in enumerate(model.module.loss_names):
|
for ii, key in enumerate(model.module.loss_names):
|
||||||
rloss[key] = (rloss[key] * ui + components[ii]) / (ui + 1)
|
rloss[key] = (rloss[key] * ui + components[ii]) / (ui + 1)
|
||||||
|
|
||||||
|
# rloss indicates running loss values with mean updated at every epoch
|
||||||
s = ('%8s%12s' + '%10.3g' * 6) % (
|
s = ('%8s%12s' + '%10.3g' * 6) % (
|
||||||
'%g/%g' % (epoch, epochs - 1),
|
'%g/%g' % (epoch, epochs - 1),
|
||||||
'%g/%g' % (i, len(dataloader) - 1),
|
'%g/%g' % (i, len(dataloader) - 1),
|
||||||
|
@ -149,26 +156,45 @@ def train(
|
||||||
checkpoint = {'epoch': epoch,
|
checkpoint = {'epoch': epoch,
|
||||||
'model': model.module.state_dict(),
|
'model': model.module.state_dict(),
|
||||||
'optimizer': optimizer.state_dict()}
|
'optimizer': optimizer.state_dict()}
|
||||||
torch.save(checkpoint, latest)
|
|
||||||
|
|
||||||
|
copyfile(cfg, weights_to + '/cfg/yolo3.cfg')
|
||||||
|
copyfile(data_cfg, weights_to + '/cfg/ccmcpe.json')
|
||||||
|
|
||||||
|
latest = osp.join(weights_to, 'latest.pt')
|
||||||
|
torch.save(checkpoint, latest)
|
||||||
|
if epoch % save_every == 0 and epoch != 0:
|
||||||
|
# making the checkpoint lite
|
||||||
|
checkpoint["optimizer"] = []
|
||||||
|
torch.save(checkpoint, osp.join(weights_to, "weights_epoch_" + str(epoch) + ".pt"))
|
||||||
|
|
||||||
# Calculate mAP
|
# Calculate mAP
|
||||||
if epoch % opt.test_interval == 0:
|
if epoch % opt.test_interval == 0:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, print_interval=40)
|
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size,
|
||||||
test.test_emb(cfg, data_cfg, weights=latest, batch_size=batch_size, print_interval=40)
|
print_interval=40, nID=dataset.nID)
|
||||||
|
test.test_emb(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size,
|
||||||
|
print_interval=40, nID=dataset.nID)
|
||||||
|
|
||||||
# Call scheduler.step() after opimizer.step() with pytorch > 1.1.0
|
# Call scheduler.step() after opimizer.step() with pytorch > 1.1.0
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--epochs', type=int, default=30, help='number of epochs')
|
parser.add_argument('--epochs', type=int, default=30, help='number of epochs')
|
||||||
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
|
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
|
||||||
parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step')
|
parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step')
|
||||||
parser.add_argument('--cfg', type=str, default='cfg/yolov3_1088x608.cfg', help='cfg file path')
|
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
||||||
|
parser.add_argument('--weights-from', type=str, default='weights/',
|
||||||
|
help='Path for getting the trained model for resuming training (Should only be used with '
|
||||||
|
'--resume)')
|
||||||
|
parser.add_argument('--weights-to', type=str, default='weights/',
|
||||||
|
help='Store the trained weights after resuming training session. It will create a new folder '
|
||||||
|
'with timestamp in the given path')
|
||||||
|
parser.add_argument('--save-model-after', type=int, default=10,
|
||||||
|
help='Save a checkpoint of model at given interval of epochs')
|
||||||
parser.add_argument('--data-cfg', type=str, default='cfg/ccmcpe.json', help='coco.data file path')
|
parser.add_argument('--data-cfg', type=str, default='cfg/ccmcpe.json', help='coco.data file path')
|
||||||
|
parser.add_argument('--img-size', type=int, default=[1088, 608], nargs='+', help='pixels')
|
||||||
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
||||||
parser.add_argument('--print-interval', type=int, default=40, help='print interval')
|
parser.add_argument('--print-interval', type=int, default=40, help='print interval')
|
||||||
parser.add_argument('--test-interval', type=int, default=9, help='test interval')
|
parser.add_argument('--test-interval', type=int, default=9, help='test interval')
|
||||||
|
@ -181,6 +207,10 @@ if __name__ == '__main__':
|
||||||
train(
|
train(
|
||||||
opt.cfg,
|
opt.cfg,
|
||||||
opt.data_cfg,
|
opt.data_cfg,
|
||||||
|
weights_from=opt.weights_from,
|
||||||
|
weights_to=opt.weights_to,
|
||||||
|
save_every=opt.save_model_after,
|
||||||
|
img_size=opt.img_size,
|
||||||
resume=opt.resume,
|
resume=opt.resume,
|
||||||
epochs=opt.epochs,
|
epochs=opt.epochs,
|
||||||
batch_size=opt.batch_size,
|
batch_size=opt.batch_size,
|
||||||
|
|
|
@ -393,6 +393,9 @@ class JointDataset(LoadImagesAndLabels): # for training
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, files_index):
|
def __getitem__(self, files_index):
|
||||||
|
"""
|
||||||
|
Iterator function for train dataset
|
||||||
|
"""
|
||||||
for i, c in enumerate(self.cds):
|
for i, c in enumerate(self.cds):
|
||||||
if files_index >= c:
|
if files_index >= c:
|
||||||
ds = list(self.label_files.keys())[i]
|
ds = list(self.label_files.keys())[i]
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
# vim: expandtab:ts=4:sw=4
|
# vim: expandtab:ts=4:sw=4
|
||||||
import numba
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.linalg
|
import scipy.linalg
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import glob
|
import glob
|
||||||
import random
|
import random
|
||||||
import time
|
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue