Documentation (#95)
* Added documentation * Added docstrings and comments * Removed unused imports * Removed unused imports * Added functionality of saving checkpoints during training process * Update train.py * Update multitracker.py
This commit is contained in:
parent
0a0665e682
commit
24f351d1b5
8 changed files with 184 additions and 89 deletions
14
demo.py
14
demo.py
|
@ -1,4 +1,4 @@
|
|||
"""Demo file for running the JDE tracker on custom video sequences for pedestrian tracking.
|
||||
"""Demo file for running the JDE tracker on custom video sequences for pedestrian tracking.
|
||||
|
||||
This file is the entry point to running the tracker on custom video sequences. It loads images from the provided video sequence, uses the JDE tracker for inference and outputs the video with bounding boxes indicating pedestrians. The bounding boxes also have associated ids (shown in different colours) to keep track of the movement of each individual.
|
||||
|
||||
|
@ -24,29 +24,19 @@ Todo:
|
|||
* More documentation
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import cv2
|
||||
import logging
|
||||
import argparse
|
||||
import motmetrics as mm
|
||||
|
||||
from tracker.multitracker import JDETracker
|
||||
from utils import visualization as vis
|
||||
from utils.utils import *
|
||||
from utils.io import read_results
|
||||
from utils.log import logger
|
||||
from utils.timer import Timer
|
||||
from utils.evaluation import Evaluator
|
||||
from utils.parse_config import parse_model_cfg
|
||||
import utils.datasets as datasets
|
||||
import torch
|
||||
from track import eval_seq
|
||||
|
||||
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
def track(opt):
|
||||
def track(opt):
|
||||
result_root = opt.output_root if opt.output_root!='' else '.'
|
||||
mkdir_if_missing(result_root)
|
||||
|
||||
|
|
35
track.py
35
track.py
|
@ -39,6 +39,41 @@ def write_results(filename, results, data_type):
|
|||
|
||||
|
||||
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
|
||||
'''
|
||||
Processes the video sequence given and provides the output of tracking result (write the results in video file)
|
||||
|
||||
It uses JDE model for getting information about the online targets present.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
opt : Namespace
|
||||
Contains information passed as commandline arguments.
|
||||
|
||||
dataloader : LoadVideo
|
||||
Instance of LoadVideo class used for fetching the image sequence and associated data.
|
||||
|
||||
data_type : String
|
||||
Type of dataset corresponding(similar) to the given video.
|
||||
|
||||
result_filename : String
|
||||
The name(path) of the file for storing results.
|
||||
|
||||
save_dir : String
|
||||
Path to the folder for storing the frames containing bounding box information (Result frames).
|
||||
|
||||
show_image : bool
|
||||
Option for shhowing individial frames during run-time.
|
||||
|
||||
frame_rate : int
|
||||
Frame-rate of the given video.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(Returns are not significant here)
|
||||
frame_id : int
|
||||
Sequence number of the last sequence
|
||||
'''
|
||||
|
||||
if save_dir:
|
||||
mkdir_if_missing(save_dir)
|
||||
tracker = JDETracker(opt, frame_rate=frame_rate)
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
import cv2
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy.spatial.distance import cdist
|
||||
|
@ -8,7 +5,6 @@ import lap
|
|||
|
||||
from cython_bbox import bbox_overlaps as bbox_ious
|
||||
from utils import kalman_filter
|
||||
import time
|
||||
|
||||
def merge_matches(m1, m2, shape):
|
||||
O,P,Q = shape
|
||||
|
|
|
@ -1,15 +1,6 @@
|
|||
import numpy as np
|
||||
from numba import jit
|
||||
from collections import deque
|
||||
import itertools
|
||||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.utils import *
|
||||
from utils.log import logger
|
||||
from utils.kalman_filter import KalmanFilter
|
||||
from models import *
|
||||
from tracker import matching
|
||||
|
@ -17,7 +8,6 @@ from .basetrack import BaseTrack, TrackState
|
|||
|
||||
|
||||
class STrack(BaseTrack):
|
||||
shared_kalman = KalmanFilter()
|
||||
|
||||
def __init__(self, tlwh, score, temp_feat, buffer_size=30):
|
||||
|
||||
|
@ -50,13 +40,13 @@ class STrack(BaseTrack):
|
|||
if self.state != TrackState.Tracked:
|
||||
mean_state[7] = 0
|
||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
for i,st in enumerate(stracks):
|
||||
for i, st in enumerate(stracks):
|
||||
if st.state != TrackState.Tracked:
|
||||
multi_mean[i][7] = 0
|
||||
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||
|
@ -64,7 +54,6 @@ class STrack(BaseTrack):
|
|||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
|
||||
def activate(self, kalman_filter, frame_id):
|
||||
"""Start a new tracklet"""
|
||||
self.kalman_filter = kalman_filter
|
||||
|
@ -112,7 +101,7 @@ class STrack(BaseTrack):
|
|||
self.update_features(new_track.curr_feat)
|
||||
|
||||
@property
|
||||
#@jit(nopython=True)
|
||||
@jit
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
|
@ -125,7 +114,7 @@ class STrack(BaseTrack):
|
|||
return ret
|
||||
|
||||
@property
|
||||
#@jit(nopython=True)
|
||||
@jit
|
||||
def tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
|
@ -135,7 +124,7 @@ class STrack(BaseTrack):
|
|||
return ret
|
||||
|
||||
@staticmethod
|
||||
#@jit(nopython=True)
|
||||
@jit
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
|
@ -149,14 +138,14 @@ class STrack(BaseTrack):
|
|||
return self.tlwh_to_xyah(self.tlwh)
|
||||
|
||||
@staticmethod
|
||||
#@jit(nopython=True)
|
||||
@jit
|
||||
def tlbr_to_tlwh(tlbr):
|
||||
ret = np.asarray(tlbr).copy()
|
||||
ret[2:] -= ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
#@jit(nopython=True)
|
||||
@jit
|
||||
def tlwh_to_tlbr(tlwh):
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[2:] += ret[:2]
|
||||
|
@ -166,11 +155,10 @@ class STrack(BaseTrack):
|
|||
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
||||
|
||||
|
||||
|
||||
class JDETracker(object):
|
||||
def __init__(self, opt, frame_rate=30):
|
||||
self.opt = opt
|
||||
self.model = Darknet(opt.cfg)
|
||||
self.model = Darknet(opt.cfg, opt.img_size, nID=30)
|
||||
# load_darknet_weights(self.model, opt.weights)
|
||||
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
|
||||
self.model.cuda().eval()
|
||||
|
@ -187,61 +175,106 @@ class JDETracker(object):
|
|||
self.kalman_filter = KalmanFilter()
|
||||
|
||||
def update(self, im_blob, img0):
|
||||
"""
|
||||
Processes the image frame and finds bounding box(detections).
|
||||
|
||||
Associates the detection with corresponding tracklets and also handles lost, removed, refound and active tracklets
|
||||
|
||||
Parameters
|
||||
----------
|
||||
im_blob : torch.float32
|
||||
Tensor of shape depending upon the size of image. By default, shape of this tensor is [1, 3, 608, 1088]
|
||||
|
||||
img0 : ndarray
|
||||
ndarray of shape depending on the input image sequence. By default, shape is [608, 1080, 3]
|
||||
|
||||
Returns
|
||||
-------
|
||||
output_stracks : list of Strack(instances)
|
||||
The list contains information regarding the online_tracklets for the recieved image tensor.
|
||||
|
||||
"""
|
||||
|
||||
self.frame_id += 1
|
||||
activated_starcks = []
|
||||
refind_stracks = []
|
||||
lost_stracks = []
|
||||
activated_starcks = [] # for storing active tracks, for the current frame
|
||||
refind_stracks = [] # Lost Tracks whose detections are obtained in the current frame
|
||||
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
|
||||
removed_stracks = []
|
||||
|
||||
t1 = time.time()
|
||||
''' Step 1: Network forward, get detections & embeddings'''
|
||||
with torch.no_grad():
|
||||
pred = self.model(im_blob)
|
||||
# pred is tensor of all the proposals (default number of proposals: 54264). Proposals have information associated with the bounding box and embeddings
|
||||
pred = pred[pred[:, :, 4] > self.opt.conf_thres]
|
||||
# pred now has lesser number of proposals. Proposals rejected on basis of object confidence score
|
||||
if len(pred) > 0:
|
||||
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres,
|
||||
self.opt.nms_thres)[0]
|
||||
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres, self.opt.nms_thres)[0].cpu()
|
||||
# Final proposals are obtained in dets. Information of bounding box and embeddings also included
|
||||
# Next step changes the detection scales
|
||||
scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
|
||||
dets, embs = dets[:, :5].cpu().numpy(), dets[:, 6:].cpu().numpy()
|
||||
'''Detections'''
|
||||
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
|
||||
(tlbrs, f) in zip(dets, embs)]
|
||||
'''Detections is list of (x1, y1, x2, y2, object_conf, class_score, class_pred)'''
|
||||
# class_pred is the embeddings.
|
||||
|
||||
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
|
||||
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
|
||||
else:
|
||||
detections = []
|
||||
|
||||
t2 = time.time()
|
||||
# print('Forward: {} s'.format(t2-t1))
|
||||
|
||||
''' Add newly detected tracklets to tracked_stracks'''
|
||||
unconfirmed = []
|
||||
tracked_stracks = [] # type: list[STrack]
|
||||
for track in self.tracked_stracks:
|
||||
if not track.is_activated:
|
||||
# previous tracks which are not active in the current frame are added in unconfirmed list
|
||||
unconfirmed.append(track)
|
||||
# print("Should not be here, in unconfirmed")
|
||||
else:
|
||||
# Active tracks are added to the local list 'tracked_stracks'
|
||||
tracked_stracks.append(track)
|
||||
|
||||
''' Step 2: First association, with embedding'''
|
||||
# Combining currently tracked_stracks and lost_stracks
|
||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
||||
# Predict the current location with KF
|
||||
STrack.multi_predict(strack_pool)
|
||||
|
||||
|
||||
dists = matching.embedding_distance(strack_pool, detections)
|
||||
# dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)
|
||||
dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
|
||||
# The dists is the list of distances of the detection with the tracks in strack_pool
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
||||
# The matches is the array for corresponding matches of the detection with the corresponding strack_pool
|
||||
|
||||
for itracked, idet in matches:
|
||||
# itracked is the id of the track and idet is the detection
|
||||
track = strack_pool[itracked]
|
||||
det = detections[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
# If the track is active, add the detection to the track
|
||||
track.update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
else:
|
||||
# We have obtained a detection from a track which is not active, hence put the track in refind_stracks list
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
|
||||
# None of the steps below happen if there are no undetected tracks.
|
||||
''' Step 3: Second association, with IOU'''
|
||||
detections = [detections[i] for i in u_detection]
|
||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state==TrackState.Tracked ]
|
||||
# detections is now a list of the unmatched detections
|
||||
r_tracked_stracks = [] # This is container for stracks which were tracked till the
|
||||
# previous frame but no detection was found for it in the current frame
|
||||
for i in u_track:
|
||||
if strack_pool[i].state == TrackState.Tracked:
|
||||
r_tracked_stracks.append(strack_pool[i])
|
||||
dists = matching.iou_distance(r_tracked_stracks, detections)
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
|
||||
|
||||
# matches is the list of detections which matched with corresponding tracks by IOU distance method
|
||||
for itracked, idet in matches:
|
||||
track = r_tracked_stracks[itracked]
|
||||
det = detections[idet]
|
||||
|
@ -251,12 +284,14 @@ class JDETracker(object):
|
|||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
# Same process done for some unmatched detections, but now considering IOU_distance as measure
|
||||
|
||||
for it in u_track:
|
||||
track = r_tracked_stracks[it]
|
||||
if not track.state == TrackState.Lost:
|
||||
track.mark_lost()
|
||||
lost_stracks.append(track)
|
||||
# If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
|
||||
|
||||
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
||||
detections = [detections[i] for i in u_detection]
|
||||
|
@ -265,11 +300,14 @@ class JDETracker(object):
|
|||
for itracked, idet in matches:
|
||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(unconfirmed[itracked])
|
||||
|
||||
# The tracks which are yet not matched
|
||||
for it in u_unconfirmed:
|
||||
track = unconfirmed[it]
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
|
||||
# after all these confirmation steps, if a new detection is found, it is initialized for a new track
|
||||
""" Step 4: Init new stracks"""
|
||||
for inew in u_detection:
|
||||
track = detections[inew]
|
||||
|
@ -279,14 +317,18 @@ class JDETracker(object):
|
|||
activated_starcks.append(track)
|
||||
|
||||
""" Step 5: Update state"""
|
||||
# If the tracks are lost for more frames than the threshold number, the tracks are removed.
|
||||
for track in self.lost_stracks:
|
||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
# print('Remained match {} s'.format(t4-t3))
|
||||
|
||||
# Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
|
||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
||||
# self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
|
||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||
self.lost_stracks.extend(lost_stracks)
|
||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
||||
|
@ -301,6 +343,7 @@ class JDETracker(object):
|
|||
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
|
||||
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
||||
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
||||
# print('Final {} s'.format(t5-t4))
|
||||
return output_stracks
|
||||
|
||||
def joint_stracks(tlista, tlistb):
|
||||
|
|
110
train.py
110
train.py
|
@ -1,9 +1,10 @@
|
|||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import test
|
||||
from time import gmtime, strftime
|
||||
import test
|
||||
from models import *
|
||||
from shutil import copyfile
|
||||
from utils.datasets import JointDataset, collate_fn
|
||||
from utils.utils import *
|
||||
from utils.log import logger
|
||||
|
@ -13,6 +14,10 @@ from torchvision.transforms import transforms as T
|
|||
def train(
|
||||
cfg,
|
||||
data_cfg,
|
||||
weights_from="",
|
||||
weights_to="",
|
||||
save_every=10,
|
||||
img_size=(1088, 608),
|
||||
resume=False,
|
||||
epochs=100,
|
||||
batch_size=16,
|
||||
|
@ -20,9 +25,16 @@ def train(
|
|||
freeze_backbone=False,
|
||||
opt=None,
|
||||
):
|
||||
weights = 'weights'
|
||||
mkdir_if_missing(weights)
|
||||
latest = osp.join(weights, 'latest.pt')
|
||||
# The function starts
|
||||
|
||||
timme = strftime("%Y-%d-%m %H:%M:%S", gmtime())
|
||||
timme = timme[5:-3].replace('-', '_')
|
||||
timme = timme.replace(' ', '_')
|
||||
timme = timme.replace(':', '_')
|
||||
weights_to = osp.join(weights_to, 'run' + timme)
|
||||
mkdir_if_missing(weights_to)
|
||||
if resume:
|
||||
latest_resume = osp.join(weights_from, 'latest.pt')
|
||||
|
||||
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
|
||||
|
||||
|
@ -32,24 +44,19 @@ def train(
|
|||
trainset_paths = data_config['train']
|
||||
dataset_root = data_config['root']
|
||||
f.close()
|
||||
cfg_dict = parse_model_cfg(cfg)
|
||||
img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
|
||||
|
||||
# Get dataloader
|
||||
transforms = T.Compose([T.ToTensor()])
|
||||
# Get dataloader
|
||||
dataset = JointDataset(dataset_root, trainset_paths, img_size, augment=True, transforms=transforms)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
|
||||
|
||||
num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
|
||||
# Initialize model
|
||||
model = Darknet(cfg_dict, dataset.nID)
|
||||
|
||||
|
||||
model = Darknet(cfg, dataset.nID)
|
||||
|
||||
cutoff = -1 # backbone reaches to cutoff layer
|
||||
start_epoch = 0
|
||||
if resume:
|
||||
checkpoint = torch.load(latest, map_location='cpu')
|
||||
checkpoint = torch.load(latest_resume, map_location='cpu')
|
||||
|
||||
# Load weights to resume from
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
|
@ -67,36 +74,36 @@ def train(
|
|||
else:
|
||||
# Initialize model with backbone (optional)
|
||||
if cfg.endswith('yolov3.cfg'):
|
||||
load_darknet_weights(model, osp.join(weights ,'darknet53.conv.74'))
|
||||
load_darknet_weights(model, osp.join(weights_from, 'darknet53.conv.74'))
|
||||
cutoff = 75
|
||||
elif cfg.endswith('yolov3-tiny.cfg'):
|
||||
load_darknet_weights(model, osp.join(weights , 'yolov3-tiny.conv.15'))
|
||||
load_darknet_weights(model, osp.join(weights_from, 'yolov3-tiny.conv.15'))
|
||||
cutoff = 15
|
||||
|
||||
model.cuda().train()
|
||||
|
||||
# Set optimizer
|
||||
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=opt.lr, momentum=.9, weight_decay=1e-4)
|
||||
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=opt.lr, momentum=.9,
|
||||
weight_decay=1e-4)
|
||||
|
||||
model = torch.nn.DataParallel(model)
|
||||
# Set scheduler
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
|
||||
milestones=[int(0.5*opt.epochs), int(0.75*opt.epochs)], gamma=0.1)
|
||||
|
||||
# An important trick for detection: freeze bn during fine-tuning
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
|
||||
milestones=[int(0.5 * opt.epochs), int(0.75 * opt.epochs)],
|
||||
gamma=0.1)
|
||||
|
||||
# An important trick for detection: freeze bn during fine-tuning
|
||||
if not opt.unfreeze_bn:
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
p.requires_grad = False if 'batch_norm' in name else True
|
||||
|
||||
model_info(model)
|
||||
|
||||
# model_info(model)
|
||||
t0 = time.time()
|
||||
for epoch in range(epochs):
|
||||
epoch += start_epoch
|
||||
|
||||
logger.info(('%8s%12s' + '%10s' * 6) % (
|
||||
'Epoch', 'Batch', 'box', 'conf', 'id', 'total', 'nTargets', 'time'))
|
||||
|
||||
|
||||
# Freeze darknet53.conv.74 for first epoch
|
||||
if freeze_backbone and (epoch < 2):
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
|
@ -109,18 +116,17 @@ def train(
|
|||
for i, (imgs, targets, _, _, targets_len) in enumerate(dataloader):
|
||||
if sum([len(x) for x in targets]) < 1: # if no targets continue
|
||||
continue
|
||||
|
||||
|
||||
# SGD burn-in
|
||||
burnin = min(1000, len(dataloader))
|
||||
if (epoch == 0) & (i <= burnin):
|
||||
lr = opt.lr * (i / burnin) **4
|
||||
lr = opt.lr * (i / burnin) ** 4
|
||||
for g in optimizer.param_groups:
|
||||
g['lr'] = lr
|
||||
|
||||
|
||||
# Compute loss, compute gradient, update parameters
|
||||
loss, components = model(imgs.cuda(), targets.cuda(), targets_len.cuda())
|
||||
components = torch.mean(components.view(-1, 5),dim=0)
|
||||
|
||||
components = torch.mean(components.view(-1, 5), dim=0)
|
||||
loss = torch.mean(loss)
|
||||
loss.backward()
|
||||
|
||||
|
@ -131,44 +137,64 @@ def train(
|
|||
|
||||
# Running epoch-means of tracked metrics
|
||||
ui += 1
|
||||
|
||||
|
||||
for ii, key in enumerate(model.module.loss_names):
|
||||
rloss[key] = (rloss[key] * ui + components[ii]) / (ui + 1)
|
||||
|
||||
# rloss indicates running loss values with mean updated at every epoch
|
||||
s = ('%8s%12s' + '%10.3g' * 6) % (
|
||||
'%g/%g' % (epoch, epochs - 1),
|
||||
'%g/%g' % (i, len(dataloader) - 1),
|
||||
rloss['box'], rloss['conf'],
|
||||
rloss['id'],rloss['loss'],
|
||||
rloss['id'], rloss['loss'],
|
||||
rloss['nT'], time.time() - t0)
|
||||
t0 = time.time()
|
||||
if i % opt.print_interval == 0:
|
||||
logger.info(s)
|
||||
|
||||
|
||||
# Save latest checkpoint
|
||||
checkpoint = {'epoch': epoch,
|
||||
'model': model.module.state_dict(),
|
||||
'optimizer': optimizer.state_dict()}
|
||||
torch.save(checkpoint, latest)
|
||||
|
||||
copyfile(cfg, weights_to + '/cfg/yolo3.cfg')
|
||||
copyfile(data_cfg, weights_to + '/cfg/ccmcpe.json')
|
||||
|
||||
latest = osp.join(weights_to, 'latest.pt')
|
||||
torch.save(checkpoint, latest)
|
||||
if epoch % save_every == 0 and epoch != 0:
|
||||
# making the checkpoint lite
|
||||
checkpoint["optimizer"] = []
|
||||
torch.save(checkpoint, osp.join(weights_to, "weights_epoch_" + str(epoch) + ".pt"))
|
||||
|
||||
# Calculate mAP
|
||||
if epoch % opt.test_interval ==0:
|
||||
if epoch % opt.test_interval == 0:
|
||||
with torch.no_grad():
|
||||
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, print_interval=40)
|
||||
test.test_emb(cfg, data_cfg, weights=latest, batch_size=batch_size, print_interval=40)
|
||||
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size,
|
||||
print_interval=40, nID=dataset.nID)
|
||||
test.test_emb(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size,
|
||||
print_interval=40, nID=dataset.nID)
|
||||
|
||||
|
||||
# Call scheduler.step() after opimizer.step() with pytorch > 1.1.0
|
||||
# Call scheduler.step() after opimizer.step() with pytorch > 1.1.0
|
||||
scheduler.step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--epochs', type=int, default=30, help='number of epochs')
|
||||
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
|
||||
parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step')
|
||||
parser.add_argument('--cfg', type=str, default='cfg/yolov3_1088x608.cfg', help='cfg file path')
|
||||
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
||||
parser.add_argument('--weights-from', type=str, default='weights/',
|
||||
help='Path for getting the trained model for resuming training (Should only be used with '
|
||||
'--resume)')
|
||||
parser.add_argument('--weights-to', type=str, default='weights/',
|
||||
help='Store the trained weights after resuming training session. It will create a new folder '
|
||||
'with timestamp in the given path')
|
||||
parser.add_argument('--save-model-after', type=int, default=10,
|
||||
help='Save a checkpoint of model at given interval of epochs')
|
||||
parser.add_argument('--data-cfg', type=str, default='cfg/ccmcpe.json', help='coco.data file path')
|
||||
parser.add_argument('--img-size', type=int, default=[1088, 608], nargs='+', help='pixels')
|
||||
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
||||
parser.add_argument('--print-interval', type=int, default=40, help='print interval')
|
||||
parser.add_argument('--test-interval', type=int, default=9, help='test interval')
|
||||
|
@ -181,6 +207,10 @@ if __name__ == '__main__':
|
|||
train(
|
||||
opt.cfg,
|
||||
opt.data_cfg,
|
||||
weights_from=opt.weights_from,
|
||||
weights_to=opt.weights_to,
|
||||
save_every=opt.save_model_after,
|
||||
img_size=opt.img_size,
|
||||
resume=opt.resume,
|
||||
epochs=opt.epochs,
|
||||
batch_size=opt.batch_size,
|
||||
|
|
|
@ -393,6 +393,9 @@ class JointDataset(LoadImagesAndLabels): # for training
|
|||
|
||||
|
||||
def __getitem__(self, files_index):
|
||||
"""
|
||||
Iterator function for train dataset
|
||||
"""
|
||||
for i, c in enumerate(self.cds):
|
||||
if files_index >= c:
|
||||
ds = list(self.label_files.keys())[i]
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# vim: expandtab:ts=4:sw=4
|
||||
import numba
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import glob
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
|
|
Loading…
Reference in a new issue