Documentation (#95)

* Added documentation

* Added docstrings and comments

* Removed unused imports

* Removed unused imports

* Added functionality of saving checkpoints during training process

* Update train.py

* Update multitracker.py
This commit is contained in:
Parthesh Soni 2020-03-14 07:54:27 +05:30 committed by GitHub
parent 0a0665e682
commit 24f351d1b5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 184 additions and 89 deletions

10
demo.py
View file

@ -24,23 +24,13 @@ Todo:
* More documentation * More documentation
""" """
import os
import os.path as osp
import cv2
import logging import logging
import argparse import argparse
import motmetrics as mm
from tracker.multitracker import JDETracker
from utils import visualization as vis
from utils.utils import * from utils.utils import *
from utils.io import read_results
from utils.log import logger from utils.log import logger
from utils.timer import Timer from utils.timer import Timer
from utils.evaluation import Evaluator
from utils.parse_config import parse_model_cfg from utils.parse_config import parse_model_cfg
import utils.datasets as datasets import utils.datasets as datasets
import torch
from track import eval_seq from track import eval_seq

View file

@ -39,6 +39,41 @@ def write_results(filename, results, data_type):
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30): def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
'''
Processes the video sequence given and provides the output of tracking result (write the results in video file)
It uses JDE model for getting information about the online targets present.
Parameters
----------
opt : Namespace
Contains information passed as commandline arguments.
dataloader : LoadVideo
Instance of LoadVideo class used for fetching the image sequence and associated data.
data_type : String
Type of dataset corresponding(similar) to the given video.
result_filename : String
The name(path) of the file for storing results.
save_dir : String
Path to the folder for storing the frames containing bounding box information (Result frames).
show_image : bool
Option for shhowing individial frames during run-time.
frame_rate : int
Frame-rate of the given video.
Returns
-------
(Returns are not significant here)
frame_id : int
Sequence number of the last sequence
'''
if save_dir: if save_dir:
mkdir_if_missing(save_dir) mkdir_if_missing(save_dir)
tracker = JDETracker(opt, frame_rate=frame_rate) tracker = JDETracker(opt, frame_rate=frame_rate)

View file

@ -1,6 +1,3 @@
import cv2
import torch
import torch.nn.functional as F
import numpy as np import numpy as np
import scipy import scipy
from scipy.spatial.distance import cdist from scipy.spatial.distance import cdist
@ -8,7 +5,6 @@ import lap
from cython_bbox import bbox_overlaps as bbox_ious from cython_bbox import bbox_overlaps as bbox_ious
from utils import kalman_filter from utils import kalman_filter
import time
def merge_matches(m1, m2, shape): def merge_matches(m1, m2, shape):
O,P,Q = shape O,P,Q = shape

View file

@ -1,15 +1,6 @@
import numpy as np
from numba import jit from numba import jit
from collections import deque from collections import deque
import itertools
import os
import os.path as osp
import time
import torch import torch
import torch.nn.functional as F
from utils.utils import *
from utils.log import logger
from utils.kalman_filter import KalmanFilter from utils.kalman_filter import KalmanFilter
from models import * from models import *
from tracker import matching from tracker import matching
@ -17,7 +8,6 @@ from .basetrack import BaseTrack, TrackState
class STrack(BaseTrack): class STrack(BaseTrack):
shared_kalman = KalmanFilter()
def __init__(self, tlwh, score, temp_feat, buffer_size=30): def __init__(self, tlwh, score, temp_feat, buffer_size=30):
@ -64,7 +54,6 @@ class STrack(BaseTrack):
stracks[i].mean = mean stracks[i].mean = mean
stracks[i].covariance = cov stracks[i].covariance = cov
def activate(self, kalman_filter, frame_id): def activate(self, kalman_filter, frame_id):
"""Start a new tracklet""" """Start a new tracklet"""
self.kalman_filter = kalman_filter self.kalman_filter = kalman_filter
@ -112,7 +101,7 @@ class STrack(BaseTrack):
self.update_features(new_track.curr_feat) self.update_features(new_track.curr_feat)
@property @property
#@jit(nopython=True) @jit
def tlwh(self): def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y, """Get current position in bounding box format `(top left x, top left y,
width, height)`. width, height)`.
@ -125,7 +114,7 @@ class STrack(BaseTrack):
return ret return ret
@property @property
#@jit(nopython=True) @jit
def tlbr(self): def tlbr(self):
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e., """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`. `(top left, bottom right)`.
@ -135,7 +124,7 @@ class STrack(BaseTrack):
return ret return ret
@staticmethod @staticmethod
#@jit(nopython=True) @jit
def tlwh_to_xyah(tlwh): def tlwh_to_xyah(tlwh):
"""Convert bounding box to format `(center x, center y, aspect ratio, """Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`. height)`, where the aspect ratio is `width / height`.
@ -149,14 +138,14 @@ class STrack(BaseTrack):
return self.tlwh_to_xyah(self.tlwh) return self.tlwh_to_xyah(self.tlwh)
@staticmethod @staticmethod
#@jit(nopython=True) @jit
def tlbr_to_tlwh(tlbr): def tlbr_to_tlwh(tlbr):
ret = np.asarray(tlbr).copy() ret = np.asarray(tlbr).copy()
ret[2:] -= ret[:2] ret[2:] -= ret[:2]
return ret return ret
@staticmethod @staticmethod
#@jit(nopython=True) @jit
def tlwh_to_tlbr(tlwh): def tlwh_to_tlbr(tlwh):
ret = np.asarray(tlwh).copy() ret = np.asarray(tlwh).copy()
ret[2:] += ret[:2] ret[2:] += ret[:2]
@ -166,11 +155,10 @@ class STrack(BaseTrack):
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame) return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
class JDETracker(object): class JDETracker(object):
def __init__(self, opt, frame_rate=30): def __init__(self, opt, frame_rate=30):
self.opt = opt self.opt = opt
self.model = Darknet(opt.cfg) self.model = Darknet(opt.cfg, opt.img_size, nID=30)
# load_darknet_weights(self.model, opt.weights) # load_darknet_weights(self.model, opt.weights)
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False) self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
self.model.cuda().eval() self.model.cuda().eval()
@ -187,61 +175,106 @@ class JDETracker(object):
self.kalman_filter = KalmanFilter() self.kalman_filter = KalmanFilter()
def update(self, im_blob, img0): def update(self, im_blob, img0):
"""
Processes the image frame and finds bounding box(detections).
Associates the detection with corresponding tracklets and also handles lost, removed, refound and active tracklets
Parameters
----------
im_blob : torch.float32
Tensor of shape depending upon the size of image. By default, shape of this tensor is [1, 3, 608, 1088]
img0 : ndarray
ndarray of shape depending on the input image sequence. By default, shape is [608, 1080, 3]
Returns
-------
output_stracks : list of Strack(instances)
The list contains information regarding the online_tracklets for the recieved image tensor.
"""
self.frame_id += 1 self.frame_id += 1
activated_starcks = [] activated_starcks = [] # for storing active tracks, for the current frame
refind_stracks = [] refind_stracks = [] # Lost Tracks whose detections are obtained in the current frame
lost_stracks = [] lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
removed_stracks = [] removed_stracks = []
t1 = time.time() t1 = time.time()
''' Step 1: Network forward, get detections & embeddings''' ''' Step 1: Network forward, get detections & embeddings'''
with torch.no_grad(): with torch.no_grad():
pred = self.model(im_blob) pred = self.model(im_blob)
# pred is tensor of all the proposals (default number of proposals: 54264). Proposals have information associated with the bounding box and embeddings
pred = pred[pred[:, :, 4] > self.opt.conf_thres] pred = pred[pred[:, :, 4] > self.opt.conf_thres]
# pred now has lesser number of proposals. Proposals rejected on basis of object confidence score
if len(pred) > 0: if len(pred) > 0:
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres, dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres, self.opt.nms_thres)[0].cpu()
self.opt.nms_thres)[0] # Final proposals are obtained in dets. Information of bounding box and embeddings also included
# Next step changes the detection scales
scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round() scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
dets, embs = dets[:, :5].cpu().numpy(), dets[:, 6:].cpu().numpy() '''Detections is list of (x1, y1, x2, y2, object_conf, class_score, class_pred)'''
'''Detections''' # class_pred is the embeddings.
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
(tlbrs, f) in zip(dets, embs)] detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
else: else:
detections = [] detections = []
t2 = time.time()
# print('Forward: {} s'.format(t2-t1))
''' Add newly detected tracklets to tracked_stracks''' ''' Add newly detected tracklets to tracked_stracks'''
unconfirmed = [] unconfirmed = []
tracked_stracks = [] # type: list[STrack] tracked_stracks = [] # type: list[STrack]
for track in self.tracked_stracks: for track in self.tracked_stracks:
if not track.is_activated: if not track.is_activated:
# previous tracks which are not active in the current frame are added in unconfirmed list
unconfirmed.append(track) unconfirmed.append(track)
# print("Should not be here, in unconfirmed")
else: else:
# Active tracks are added to the local list 'tracked_stracks'
tracked_stracks.append(track) tracked_stracks.append(track)
''' Step 2: First association, with embedding''' ''' Step 2: First association, with embedding'''
# Combining currently tracked_stracks and lost_stracks
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF # Predict the current location with KF
STrack.multi_predict(strack_pool) STrack.multi_predict(strack_pool)
dists = matching.embedding_distance(strack_pool, detections) dists = matching.embedding_distance(strack_pool, detections)
# dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)
dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections) dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
# The dists is the list of distances of the detection with the tracks in strack_pool
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7) matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
# The matches is the array for corresponding matches of the detection with the corresponding strack_pool
for itracked, idet in matches: for itracked, idet in matches:
# itracked is the id of the track and idet is the detection
track = strack_pool[itracked] track = strack_pool[itracked]
det = detections[idet] det = detections[idet]
if track.state == TrackState.Tracked: if track.state == TrackState.Tracked:
# If the track is active, add the detection to the track
track.update(detections[idet], self.frame_id) track.update(detections[idet], self.frame_id)
activated_starcks.append(track) activated_starcks.append(track)
else: else:
# We have obtained a detection from a track which is not active, hence put the track in refind_stracks list
track.re_activate(det, self.frame_id, new_id=False) track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track) refind_stracks.append(track)
# None of the steps below happen if there are no undetected tracks.
''' Step 3: Second association, with IOU''' ''' Step 3: Second association, with IOU'''
detections = [detections[i] for i in u_detection] detections = [detections[i] for i in u_detection]
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state==TrackState.Tracked ] # detections is now a list of the unmatched detections
r_tracked_stracks = [] # This is container for stracks which were tracked till the
# previous frame but no detection was found for it in the current frame
for i in u_track:
if strack_pool[i].state == TrackState.Tracked:
r_tracked_stracks.append(strack_pool[i])
dists = matching.iou_distance(r_tracked_stracks, detections) dists = matching.iou_distance(r_tracked_stracks, detections)
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5) matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
# matches is the list of detections which matched with corresponding tracks by IOU distance method
for itracked, idet in matches: for itracked, idet in matches:
track = r_tracked_stracks[itracked] track = r_tracked_stracks[itracked]
det = detections[idet] det = detections[idet]
@ -251,12 +284,14 @@ class JDETracker(object):
else: else:
track.re_activate(det, self.frame_id, new_id=False) track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track) refind_stracks.append(track)
# Same process done for some unmatched detections, but now considering IOU_distance as measure
for it in u_track: for it in u_track:
track = r_tracked_stracks[it] track = r_tracked_stracks[it]
if not track.state == TrackState.Lost: if not track.state == TrackState.Lost:
track.mark_lost() track.mark_lost()
lost_stracks.append(track) lost_stracks.append(track)
# If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame''' '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections = [detections[i] for i in u_detection] detections = [detections[i] for i in u_detection]
@ -265,11 +300,14 @@ class JDETracker(object):
for itracked, idet in matches: for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id) unconfirmed[itracked].update(detections[idet], self.frame_id)
activated_starcks.append(unconfirmed[itracked]) activated_starcks.append(unconfirmed[itracked])
# The tracks which are yet not matched
for it in u_unconfirmed: for it in u_unconfirmed:
track = unconfirmed[it] track = unconfirmed[it]
track.mark_removed() track.mark_removed()
removed_stracks.append(track) removed_stracks.append(track)
# after all these confirmation steps, if a new detection is found, it is initialized for a new track
""" Step 4: Init new stracks""" """ Step 4: Init new stracks"""
for inew in u_detection: for inew in u_detection:
track = detections[inew] track = detections[inew]
@ -279,14 +317,18 @@ class JDETracker(object):
activated_starcks.append(track) activated_starcks.append(track)
""" Step 5: Update state""" """ Step 5: Update state"""
# If the tracks are lost for more frames than the threshold number, the tracks are removed.
for track in self.lost_stracks: for track in self.lost_stracks:
if self.frame_id - track.end_frame > self.max_time_lost: if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed() track.mark_removed()
removed_stracks.append(track) removed_stracks.append(track)
# print('Remained match {} s'.format(t4-t3))
# Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks) self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks) self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
# self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks) self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks) self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
@ -301,6 +343,7 @@ class JDETracker(object):
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks])) logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks])) logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks])) logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
# print('Final {} s'.format(t5-t4))
return output_stracks return output_stracks
def joint_stracks(tlista, tlistb): def joint_stracks(tlista, tlistb):

View file

@ -1,9 +1,10 @@
import argparse import argparse
import json import json
import time import time
from time import gmtime, strftime
import test import test
from models import * from models import *
from shutil import copyfile
from utils.datasets import JointDataset, collate_fn from utils.datasets import JointDataset, collate_fn
from utils.utils import * from utils.utils import *
from utils.log import logger from utils.log import logger
@ -13,6 +14,10 @@ from torchvision.transforms import transforms as T
def train( def train(
cfg, cfg,
data_cfg, data_cfg,
weights_from="",
weights_to="",
save_every=10,
img_size=(1088, 608),
resume=False, resume=False,
epochs=100, epochs=100,
batch_size=16, batch_size=16,
@ -20,9 +25,16 @@ def train(
freeze_backbone=False, freeze_backbone=False,
opt=None, opt=None,
): ):
weights = 'weights' # The function starts
mkdir_if_missing(weights)
latest = osp.join(weights, 'latest.pt') timme = strftime("%Y-%d-%m %H:%M:%S", gmtime())
timme = timme[5:-3].replace('-', '_')
timme = timme.replace(' ', '_')
timme = timme.replace(':', '_')
weights_to = osp.join(weights_to, 'run' + timme)
mkdir_if_missing(weights_to)
if resume:
latest_resume = osp.join(weights_from, 'latest.pt')
torch.backends.cudnn.benchmark = True # unsuitable for multiscale torch.backends.cudnn.benchmark = True # unsuitable for multiscale
@ -32,24 +44,19 @@ def train(
trainset_paths = data_config['train'] trainset_paths = data_config['train']
dataset_root = data_config['root'] dataset_root = data_config['root']
f.close() f.close()
cfg_dict = parse_model_cfg(cfg)
img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
# Get dataloader
transforms = T.Compose([T.ToTensor()]) transforms = T.Compose([T.ToTensor()])
# Get dataloader
dataset = JointDataset(dataset_root, trainset_paths, img_size, augment=True, transforms=transforms) dataset = JointDataset(dataset_root, trainset_paths, img_size, augment=True, transforms=transforms)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn) num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
# Initialize model # Initialize model
model = Darknet(cfg_dict, dataset.nID) model = Darknet(cfg, dataset.nID)
cutoff = -1 # backbone reaches to cutoff layer cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0 start_epoch = 0
if resume: if resume:
checkpoint = torch.load(latest, map_location='cpu') checkpoint = torch.load(latest_resume, map_location='cpu')
# Load weights to resume from # Load weights to resume from
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
@ -67,33 +74,33 @@ def train(
else: else:
# Initialize model with backbone (optional) # Initialize model with backbone (optional)
if cfg.endswith('yolov3.cfg'): if cfg.endswith('yolov3.cfg'):
load_darknet_weights(model, osp.join(weights ,'darknet53.conv.74')) load_darknet_weights(model, osp.join(weights_from, 'darknet53.conv.74'))
cutoff = 75 cutoff = 75
elif cfg.endswith('yolov3-tiny.cfg'): elif cfg.endswith('yolov3-tiny.cfg'):
load_darknet_weights(model, osp.join(weights , 'yolov3-tiny.conv.15')) load_darknet_weights(model, osp.join(weights_from, 'yolov3-tiny.conv.15'))
cutoff = 15 cutoff = 15
model.cuda().train() model.cuda().train()
# Set optimizer # Set optimizer
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=opt.lr, momentum=.9, weight_decay=1e-4) optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=opt.lr, momentum=.9,
weight_decay=1e-4)
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Set scheduler # Set scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[int(0.5*opt.epochs), int(0.75*opt.epochs)], gamma=0.1) milestones=[int(0.5 * opt.epochs), int(0.75 * opt.epochs)],
gamma=0.1)
# An important trick for detection: freeze bn during fine-tuning # An important trick for detection: freeze bn during fine-tuning
if not opt.unfreeze_bn: if not opt.unfreeze_bn:
for i, (name, p) in enumerate(model.named_parameters()): for i, (name, p) in enumerate(model.named_parameters()):
p.requires_grad = False if 'batch_norm' in name else True p.requires_grad = False if 'batch_norm' in name else True
model_info(model) # model_info(model)
t0 = time.time() t0 = time.time()
for epoch in range(epochs): for epoch in range(epochs):
epoch += start_epoch epoch += start_epoch
logger.info(('%8s%12s' + '%10s' * 6) % ( logger.info(('%8s%12s' + '%10s' * 6) % (
'Epoch', 'Batch', 'box', 'conf', 'id', 'total', 'nTargets', 'time')) 'Epoch', 'Batch', 'box', 'conf', 'id', 'total', 'nTargets', 'time'))
@ -120,7 +127,6 @@ def train(
# Compute loss, compute gradient, update parameters # Compute loss, compute gradient, update parameters
loss, components = model(imgs.cuda(), targets.cuda(), targets_len.cuda()) loss, components = model(imgs.cuda(), targets.cuda(), targets_len.cuda())
components = torch.mean(components.view(-1, 5), dim=0) components = torch.mean(components.view(-1, 5), dim=0)
loss = torch.mean(loss) loss = torch.mean(loss)
loss.backward() loss.backward()
@ -135,6 +141,7 @@ def train(
for ii, key in enumerate(model.module.loss_names): for ii, key in enumerate(model.module.loss_names):
rloss[key] = (rloss[key] * ui + components[ii]) / (ui + 1) rloss[key] = (rloss[key] * ui + components[ii]) / (ui + 1)
# rloss indicates running loss values with mean updated at every epoch
s = ('%8s%12s' + '%10.3g' * 6) % ( s = ('%8s%12s' + '%10.3g' * 6) % (
'%g/%g' % (epoch, epochs - 1), '%g/%g' % (epoch, epochs - 1),
'%g/%g' % (i, len(dataloader) - 1), '%g/%g' % (i, len(dataloader) - 1),
@ -149,26 +156,45 @@ def train(
checkpoint = {'epoch': epoch, checkpoint = {'epoch': epoch,
'model': model.module.state_dict(), 'model': model.module.state_dict(),
'optimizer': optimizer.state_dict()} 'optimizer': optimizer.state_dict()}
torch.save(checkpoint, latest)
copyfile(cfg, weights_to + '/cfg/yolo3.cfg')
copyfile(data_cfg, weights_to + '/cfg/ccmcpe.json')
latest = osp.join(weights_to, 'latest.pt')
torch.save(checkpoint, latest)
if epoch % save_every == 0 and epoch != 0:
# making the checkpoint lite
checkpoint["optimizer"] = []
torch.save(checkpoint, osp.join(weights_to, "weights_epoch_" + str(epoch) + ".pt"))
# Calculate mAP # Calculate mAP
if epoch % opt.test_interval == 0: if epoch % opt.test_interval == 0:
with torch.no_grad(): with torch.no_grad():
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, print_interval=40) mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size,
test.test_emb(cfg, data_cfg, weights=latest, batch_size=batch_size, print_interval=40) print_interval=40, nID=dataset.nID)
test.test_emb(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size,
print_interval=40, nID=dataset.nID)
# Call scheduler.step() after opimizer.step() with pytorch > 1.1.0 # Call scheduler.step() after opimizer.step() with pytorch > 1.1.0
scheduler.step() scheduler.step()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=30, help='number of epochs') parser.add_argument('--epochs', type=int, default=30, help='number of epochs')
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch') parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step') parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step')
parser.add_argument('--cfg', type=str, default='cfg/yolov3_1088x608.cfg', help='cfg file path') parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
parser.add_argument('--weights-from', type=str, default='weights/',
help='Path for getting the trained model for resuming training (Should only be used with '
'--resume)')
parser.add_argument('--weights-to', type=str, default='weights/',
help='Store the trained weights after resuming training session. It will create a new folder '
'with timestamp in the given path')
parser.add_argument('--save-model-after', type=int, default=10,
help='Save a checkpoint of model at given interval of epochs')
parser.add_argument('--data-cfg', type=str, default='cfg/ccmcpe.json', help='coco.data file path') parser.add_argument('--data-cfg', type=str, default='cfg/ccmcpe.json', help='coco.data file path')
parser.add_argument('--img-size', type=int, default=[1088, 608], nargs='+', help='pixels')
parser.add_argument('--resume', action='store_true', help='resume training flag') parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--print-interval', type=int, default=40, help='print interval') parser.add_argument('--print-interval', type=int, default=40, help='print interval')
parser.add_argument('--test-interval', type=int, default=9, help='test interval') parser.add_argument('--test-interval', type=int, default=9, help='test interval')
@ -181,6 +207,10 @@ if __name__ == '__main__':
train( train(
opt.cfg, opt.cfg,
opt.data_cfg, opt.data_cfg,
weights_from=opt.weights_from,
weights_to=opt.weights_to,
save_every=opt.save_model_after,
img_size=opt.img_size,
resume=opt.resume, resume=opt.resume,
epochs=opt.epochs, epochs=opt.epochs,
batch_size=opt.batch_size, batch_size=opt.batch_size,

View file

@ -393,6 +393,9 @@ class JointDataset(LoadImagesAndLabels): # for training
def __getitem__(self, files_index): def __getitem__(self, files_index):
"""
Iterator function for train dataset
"""
for i, c in enumerate(self.cds): for i, c in enumerate(self.cds):
if files_index >= c: if files_index >= c:
ds = list(self.label_files.keys())[i] ds = list(self.label_files.keys())[i]

View file

@ -1,5 +1,4 @@
# vim: expandtab:ts=4:sw=4 # vim: expandtab:ts=4:sw=4
import numba
import numpy as np import numpy as np
import scipy.linalg import scipy.linalg

View file

@ -1,6 +1,5 @@
import glob import glob
import random import random
import time
import os import os
import os.path as osp import os.path as osp