Documentation (#95)
* Added documentation * Added docstrings and comments * Removed unused imports * Removed unused imports * Added functionality of saving checkpoints during training process * Update train.py * Update multitracker.py
This commit is contained in:
parent
0a0665e682
commit
24f351d1b5
8 changed files with 184 additions and 89 deletions
10
demo.py
10
demo.py
|
@ -24,23 +24,13 @@ Todo:
|
|||
* More documentation
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import cv2
|
||||
import logging
|
||||
import argparse
|
||||
import motmetrics as mm
|
||||
|
||||
from tracker.multitracker import JDETracker
|
||||
from utils import visualization as vis
|
||||
from utils.utils import *
|
||||
from utils.io import read_results
|
||||
from utils.log import logger
|
||||
from utils.timer import Timer
|
||||
from utils.evaluation import Evaluator
|
||||
from utils.parse_config import parse_model_cfg
|
||||
import utils.datasets as datasets
|
||||
import torch
|
||||
from track import eval_seq
|
||||
|
||||
|
||||
|
|
35
track.py
35
track.py
|
@ -39,6 +39,41 @@ def write_results(filename, results, data_type):
|
|||
|
||||
|
||||
def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
|
||||
'''
|
||||
Processes the video sequence given and provides the output of tracking result (write the results in video file)
|
||||
|
||||
It uses JDE model for getting information about the online targets present.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
opt : Namespace
|
||||
Contains information passed as commandline arguments.
|
||||
|
||||
dataloader : LoadVideo
|
||||
Instance of LoadVideo class used for fetching the image sequence and associated data.
|
||||
|
||||
data_type : String
|
||||
Type of dataset corresponding(similar) to the given video.
|
||||
|
||||
result_filename : String
|
||||
The name(path) of the file for storing results.
|
||||
|
||||
save_dir : String
|
||||
Path to the folder for storing the frames containing bounding box information (Result frames).
|
||||
|
||||
show_image : bool
|
||||
Option for shhowing individial frames during run-time.
|
||||
|
||||
frame_rate : int
|
||||
Frame-rate of the given video.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(Returns are not significant here)
|
||||
frame_id : int
|
||||
Sequence number of the last sequence
|
||||
'''
|
||||
|
||||
if save_dir:
|
||||
mkdir_if_missing(save_dir)
|
||||
tracker = JDETracker(opt, frame_rate=frame_rate)
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
import cv2
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy.spatial.distance import cdist
|
||||
|
@ -8,7 +5,6 @@ import lap
|
|||
|
||||
from cython_bbox import bbox_overlaps as bbox_ious
|
||||
from utils import kalman_filter
|
||||
import time
|
||||
|
||||
def merge_matches(m1, m2, shape):
|
||||
O,P,Q = shape
|
||||
|
|
|
@ -1,15 +1,6 @@
|
|||
import numpy as np
|
||||
from numba import jit
|
||||
from collections import deque
|
||||
import itertools
|
||||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.utils import *
|
||||
from utils.log import logger
|
||||
from utils.kalman_filter import KalmanFilter
|
||||
from models import *
|
||||
from tracker import matching
|
||||
|
@ -17,7 +8,6 @@ from .basetrack import BaseTrack, TrackState
|
|||
|
||||
|
||||
class STrack(BaseTrack):
|
||||
shared_kalman = KalmanFilter()
|
||||
|
||||
def __init__(self, tlwh, score, temp_feat, buffer_size=30):
|
||||
|
||||
|
@ -64,7 +54,6 @@ class STrack(BaseTrack):
|
|||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
|
||||
def activate(self, kalman_filter, frame_id):
|
||||
"""Start a new tracklet"""
|
||||
self.kalman_filter = kalman_filter
|
||||
|
@ -112,7 +101,7 @@ class STrack(BaseTrack):
|
|||
self.update_features(new_track.curr_feat)
|
||||
|
||||
@property
|
||||
#@jit(nopython=True)
|
||||
@jit
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
|
@ -125,7 +114,7 @@ class STrack(BaseTrack):
|
|||
return ret
|
||||
|
||||
@property
|
||||
#@jit(nopython=True)
|
||||
@jit
|
||||
def tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
|
@ -135,7 +124,7 @@ class STrack(BaseTrack):
|
|||
return ret
|
||||
|
||||
@staticmethod
|
||||
#@jit(nopython=True)
|
||||
@jit
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
|
@ -149,14 +138,14 @@ class STrack(BaseTrack):
|
|||
return self.tlwh_to_xyah(self.tlwh)
|
||||
|
||||
@staticmethod
|
||||
#@jit(nopython=True)
|
||||
@jit
|
||||
def tlbr_to_tlwh(tlbr):
|
||||
ret = np.asarray(tlbr).copy()
|
||||
ret[2:] -= ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
#@jit(nopython=True)
|
||||
@jit
|
||||
def tlwh_to_tlbr(tlwh):
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[2:] += ret[:2]
|
||||
|
@ -166,11 +155,10 @@ class STrack(BaseTrack):
|
|||
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
|
||||
|
||||
|
||||
|
||||
class JDETracker(object):
|
||||
def __init__(self, opt, frame_rate=30):
|
||||
self.opt = opt
|
||||
self.model = Darknet(opt.cfg)
|
||||
self.model = Darknet(opt.cfg, opt.img_size, nID=30)
|
||||
# load_darknet_weights(self.model, opt.weights)
|
||||
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
|
||||
self.model.cuda().eval()
|
||||
|
@ -187,61 +175,106 @@ class JDETracker(object):
|
|||
self.kalman_filter = KalmanFilter()
|
||||
|
||||
def update(self, im_blob, img0):
|
||||
"""
|
||||
Processes the image frame and finds bounding box(detections).
|
||||
|
||||
Associates the detection with corresponding tracklets and also handles lost, removed, refound and active tracklets
|
||||
|
||||
Parameters
|
||||
----------
|
||||
im_blob : torch.float32
|
||||
Tensor of shape depending upon the size of image. By default, shape of this tensor is [1, 3, 608, 1088]
|
||||
|
||||
img0 : ndarray
|
||||
ndarray of shape depending on the input image sequence. By default, shape is [608, 1080, 3]
|
||||
|
||||
Returns
|
||||
-------
|
||||
output_stracks : list of Strack(instances)
|
||||
The list contains information regarding the online_tracklets for the recieved image tensor.
|
||||
|
||||
"""
|
||||
|
||||
self.frame_id += 1
|
||||
activated_starcks = []
|
||||
refind_stracks = []
|
||||
lost_stracks = []
|
||||
activated_starcks = [] # for storing active tracks, for the current frame
|
||||
refind_stracks = [] # Lost Tracks whose detections are obtained in the current frame
|
||||
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
|
||||
removed_stracks = []
|
||||
|
||||
t1 = time.time()
|
||||
''' Step 1: Network forward, get detections & embeddings'''
|
||||
with torch.no_grad():
|
||||
pred = self.model(im_blob)
|
||||
# pred is tensor of all the proposals (default number of proposals: 54264). Proposals have information associated with the bounding box and embeddings
|
||||
pred = pred[pred[:, :, 4] > self.opt.conf_thres]
|
||||
# pred now has lesser number of proposals. Proposals rejected on basis of object confidence score
|
||||
if len(pred) > 0:
|
||||
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres,
|
||||
self.opt.nms_thres)[0]
|
||||
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres, self.opt.nms_thres)[0].cpu()
|
||||
# Final proposals are obtained in dets. Information of bounding box and embeddings also included
|
||||
# Next step changes the detection scales
|
||||
scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
|
||||
dets, embs = dets[:, :5].cpu().numpy(), dets[:, 6:].cpu().numpy()
|
||||
'''Detections'''
|
||||
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
|
||||
(tlbrs, f) in zip(dets, embs)]
|
||||
'''Detections is list of (x1, y1, x2, y2, object_conf, class_score, class_pred)'''
|
||||
# class_pred is the embeddings.
|
||||
|
||||
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
|
||||
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
|
||||
else:
|
||||
detections = []
|
||||
|
||||
t2 = time.time()
|
||||
# print('Forward: {} s'.format(t2-t1))
|
||||
|
||||
''' Add newly detected tracklets to tracked_stracks'''
|
||||
unconfirmed = []
|
||||
tracked_stracks = [] # type: list[STrack]
|
||||
for track in self.tracked_stracks:
|
||||
if not track.is_activated:
|
||||
# previous tracks which are not active in the current frame are added in unconfirmed list
|
||||
unconfirmed.append(track)
|
||||
# print("Should not be here, in unconfirmed")
|
||||
else:
|
||||
# Active tracks are added to the local list 'tracked_stracks'
|
||||
tracked_stracks.append(track)
|
||||
|
||||
''' Step 2: First association, with embedding'''
|
||||
# Combining currently tracked_stracks and lost_stracks
|
||||
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
|
||||
# Predict the current location with KF
|
||||
STrack.multi_predict(strack_pool)
|
||||
|
||||
|
||||
dists = matching.embedding_distance(strack_pool, detections)
|
||||
# dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)
|
||||
dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
|
||||
# The dists is the list of distances of the detection with the tracks in strack_pool
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
||||
# The matches is the array for corresponding matches of the detection with the corresponding strack_pool
|
||||
|
||||
for itracked, idet in matches:
|
||||
# itracked is the id of the track and idet is the detection
|
||||
track = strack_pool[itracked]
|
||||
det = detections[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
# If the track is active, add the detection to the track
|
||||
track.update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
else:
|
||||
# We have obtained a detection from a track which is not active, hence put the track in refind_stracks list
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
|
||||
# None of the steps below happen if there are no undetected tracks.
|
||||
''' Step 3: Second association, with IOU'''
|
||||
detections = [detections[i] for i in u_detection]
|
||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state==TrackState.Tracked ]
|
||||
# detections is now a list of the unmatched detections
|
||||
r_tracked_stracks = [] # This is container for stracks which were tracked till the
|
||||
# previous frame but no detection was found for it in the current frame
|
||||
for i in u_track:
|
||||
if strack_pool[i].state == TrackState.Tracked:
|
||||
r_tracked_stracks.append(strack_pool[i])
|
||||
dists = matching.iou_distance(r_tracked_stracks, detections)
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
|
||||
|
||||
# matches is the list of detections which matched with corresponding tracks by IOU distance method
|
||||
for itracked, idet in matches:
|
||||
track = r_tracked_stracks[itracked]
|
||||
det = detections[idet]
|
||||
|
@ -251,12 +284,14 @@ class JDETracker(object):
|
|||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
# Same process done for some unmatched detections, but now considering IOU_distance as measure
|
||||
|
||||
for it in u_track:
|
||||
track = r_tracked_stracks[it]
|
||||
if not track.state == TrackState.Lost:
|
||||
track.mark_lost()
|
||||
lost_stracks.append(track)
|
||||
# If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
|
||||
|
||||
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
|
||||
detections = [detections[i] for i in u_detection]
|
||||
|
@ -265,11 +300,14 @@ class JDETracker(object):
|
|||
for itracked, idet in matches:
|
||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(unconfirmed[itracked])
|
||||
|
||||
# The tracks which are yet not matched
|
||||
for it in u_unconfirmed:
|
||||
track = unconfirmed[it]
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
|
||||
# after all these confirmation steps, if a new detection is found, it is initialized for a new track
|
||||
""" Step 4: Init new stracks"""
|
||||
for inew in u_detection:
|
||||
track = detections[inew]
|
||||
|
@ -279,14 +317,18 @@ class JDETracker(object):
|
|||
activated_starcks.append(track)
|
||||
|
||||
""" Step 5: Update state"""
|
||||
# If the tracks are lost for more frames than the threshold number, the tracks are removed.
|
||||
for track in self.lost_stracks:
|
||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
# print('Remained match {} s'.format(t4-t3))
|
||||
|
||||
# Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
|
||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
|
||||
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
|
||||
# self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
|
||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||
self.lost_stracks.extend(lost_stracks)
|
||||
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
|
||||
|
@ -301,6 +343,7 @@ class JDETracker(object):
|
|||
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
|
||||
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
|
||||
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
|
||||
# print('Final {} s'.format(t5-t4))
|
||||
return output_stracks
|
||||
|
||||
def joint_stracks(tlista, tlistb):
|
||||
|
|
80
train.py
80
train.py
|
@ -1,9 +1,10 @@
|
|||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
from time import gmtime, strftime
|
||||
import test
|
||||
from models import *
|
||||
from shutil import copyfile
|
||||
from utils.datasets import JointDataset, collate_fn
|
||||
from utils.utils import *
|
||||
from utils.log import logger
|
||||
|
@ -13,6 +14,10 @@ from torchvision.transforms import transforms as T
|
|||
def train(
|
||||
cfg,
|
||||
data_cfg,
|
||||
weights_from="",
|
||||
weights_to="",
|
||||
save_every=10,
|
||||
img_size=(1088, 608),
|
||||
resume=False,
|
||||
epochs=100,
|
||||
batch_size=16,
|
||||
|
@ -20,9 +25,16 @@ def train(
|
|||
freeze_backbone=False,
|
||||
opt=None,
|
||||
):
|
||||
weights = 'weights'
|
||||
mkdir_if_missing(weights)
|
||||
latest = osp.join(weights, 'latest.pt')
|
||||
# The function starts
|
||||
|
||||
timme = strftime("%Y-%d-%m %H:%M:%S", gmtime())
|
||||
timme = timme[5:-3].replace('-', '_')
|
||||
timme = timme.replace(' ', '_')
|
||||
timme = timme.replace(':', '_')
|
||||
weights_to = osp.join(weights_to, 'run' + timme)
|
||||
mkdir_if_missing(weights_to)
|
||||
if resume:
|
||||
latest_resume = osp.join(weights_from, 'latest.pt')
|
||||
|
||||
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
|
||||
|
||||
|
@ -32,24 +44,19 @@ def train(
|
|||
trainset_paths = data_config['train']
|
||||
dataset_root = data_config['root']
|
||||
f.close()
|
||||
cfg_dict = parse_model_cfg(cfg)
|
||||
img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
|
||||
|
||||
# Get dataloader
|
||||
transforms = T.Compose([T.ToTensor()])
|
||||
# Get dataloader
|
||||
dataset = JointDataset(dataset_root, trainset_paths, img_size, augment=True, transforms=transforms)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
|
||||
num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
|
||||
|
||||
# Initialize model
|
||||
model = Darknet(cfg_dict, dataset.nID)
|
||||
|
||||
|
||||
model = Darknet(cfg, dataset.nID)
|
||||
|
||||
cutoff = -1 # backbone reaches to cutoff layer
|
||||
start_epoch = 0
|
||||
if resume:
|
||||
checkpoint = torch.load(latest, map_location='cpu')
|
||||
checkpoint = torch.load(latest_resume, map_location='cpu')
|
||||
|
||||
# Load weights to resume from
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
|
@ -67,33 +74,33 @@ def train(
|
|||
else:
|
||||
# Initialize model with backbone (optional)
|
||||
if cfg.endswith('yolov3.cfg'):
|
||||
load_darknet_weights(model, osp.join(weights ,'darknet53.conv.74'))
|
||||
load_darknet_weights(model, osp.join(weights_from, 'darknet53.conv.74'))
|
||||
cutoff = 75
|
||||
elif cfg.endswith('yolov3-tiny.cfg'):
|
||||
load_darknet_weights(model, osp.join(weights , 'yolov3-tiny.conv.15'))
|
||||
load_darknet_weights(model, osp.join(weights_from, 'yolov3-tiny.conv.15'))
|
||||
cutoff = 15
|
||||
|
||||
model.cuda().train()
|
||||
|
||||
# Set optimizer
|
||||
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=opt.lr, momentum=.9, weight_decay=1e-4)
|
||||
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=opt.lr, momentum=.9,
|
||||
weight_decay=1e-4)
|
||||
|
||||
model = torch.nn.DataParallel(model)
|
||||
# Set scheduler
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
|
||||
milestones=[int(0.5*opt.epochs), int(0.75*opt.epochs)], gamma=0.1)
|
||||
milestones=[int(0.5 * opt.epochs), int(0.75 * opt.epochs)],
|
||||
gamma=0.1)
|
||||
|
||||
# An important trick for detection: freeze bn during fine-tuning
|
||||
if not opt.unfreeze_bn:
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
p.requires_grad = False if 'batch_norm' in name else True
|
||||
|
||||
model_info(model)
|
||||
|
||||
# model_info(model)
|
||||
t0 = time.time()
|
||||
for epoch in range(epochs):
|
||||
epoch += start_epoch
|
||||
|
||||
logger.info(('%8s%12s' + '%10s' * 6) % (
|
||||
'Epoch', 'Batch', 'box', 'conf', 'id', 'total', 'nTargets', 'time'))
|
||||
|
||||
|
@ -120,7 +127,6 @@ def train(
|
|||
# Compute loss, compute gradient, update parameters
|
||||
loss, components = model(imgs.cuda(), targets.cuda(), targets_len.cuda())
|
||||
components = torch.mean(components.view(-1, 5), dim=0)
|
||||
|
||||
loss = torch.mean(loss)
|
||||
loss.backward()
|
||||
|
||||
|
@ -135,6 +141,7 @@ def train(
|
|||
for ii, key in enumerate(model.module.loss_names):
|
||||
rloss[key] = (rloss[key] * ui + components[ii]) / (ui + 1)
|
||||
|
||||
# rloss indicates running loss values with mean updated at every epoch
|
||||
s = ('%8s%12s' + '%10.3g' * 6) % (
|
||||
'%g/%g' % (epoch, epochs - 1),
|
||||
'%g/%g' % (i, len(dataloader) - 1),
|
||||
|
@ -149,26 +156,45 @@ def train(
|
|||
checkpoint = {'epoch': epoch,
|
||||
'model': model.module.state_dict(),
|
||||
'optimizer': optimizer.state_dict()}
|
||||
torch.save(checkpoint, latest)
|
||||
|
||||
copyfile(cfg, weights_to + '/cfg/yolo3.cfg')
|
||||
copyfile(data_cfg, weights_to + '/cfg/ccmcpe.json')
|
||||
|
||||
latest = osp.join(weights_to, 'latest.pt')
|
||||
torch.save(checkpoint, latest)
|
||||
if epoch % save_every == 0 and epoch != 0:
|
||||
# making the checkpoint lite
|
||||
checkpoint["optimizer"] = []
|
||||
torch.save(checkpoint, osp.join(weights_to, "weights_epoch_" + str(epoch) + ".pt"))
|
||||
|
||||
# Calculate mAP
|
||||
if epoch % opt.test_interval == 0:
|
||||
with torch.no_grad():
|
||||
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, print_interval=40)
|
||||
test.test_emb(cfg, data_cfg, weights=latest, batch_size=batch_size, print_interval=40)
|
||||
|
||||
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size,
|
||||
print_interval=40, nID=dataset.nID)
|
||||
test.test_emb(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size,
|
||||
print_interval=40, nID=dataset.nID)
|
||||
|
||||
# Call scheduler.step() after opimizer.step() with pytorch > 1.1.0
|
||||
scheduler.step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--epochs', type=int, default=30, help='number of epochs')
|
||||
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
|
||||
parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step')
|
||||
parser.add_argument('--cfg', type=str, default='cfg/yolov3_1088x608.cfg', help='cfg file path')
|
||||
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
||||
parser.add_argument('--weights-from', type=str, default='weights/',
|
||||
help='Path for getting the trained model for resuming training (Should only be used with '
|
||||
'--resume)')
|
||||
parser.add_argument('--weights-to', type=str, default='weights/',
|
||||
help='Store the trained weights after resuming training session. It will create a new folder '
|
||||
'with timestamp in the given path')
|
||||
parser.add_argument('--save-model-after', type=int, default=10,
|
||||
help='Save a checkpoint of model at given interval of epochs')
|
||||
parser.add_argument('--data-cfg', type=str, default='cfg/ccmcpe.json', help='coco.data file path')
|
||||
parser.add_argument('--img-size', type=int, default=[1088, 608], nargs='+', help='pixels')
|
||||
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
||||
parser.add_argument('--print-interval', type=int, default=40, help='print interval')
|
||||
parser.add_argument('--test-interval', type=int, default=9, help='test interval')
|
||||
|
@ -181,6 +207,10 @@ if __name__ == '__main__':
|
|||
train(
|
||||
opt.cfg,
|
||||
opt.data_cfg,
|
||||
weights_from=opt.weights_from,
|
||||
weights_to=opt.weights_to,
|
||||
save_every=opt.save_model_after,
|
||||
img_size=opt.img_size,
|
||||
resume=opt.resume,
|
||||
epochs=opt.epochs,
|
||||
batch_size=opt.batch_size,
|
||||
|
|
|
@ -393,6 +393,9 @@ class JointDataset(LoadImagesAndLabels): # for training
|
|||
|
||||
|
||||
def __getitem__(self, files_index):
|
||||
"""
|
||||
Iterator function for train dataset
|
||||
"""
|
||||
for i, c in enumerate(self.cds):
|
||||
if files_index >= c:
|
||||
ds = list(self.label_files.keys())[i]
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# vim: expandtab:ts=4:sw=4
|
||||
import numba
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import glob
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
|
|
Loading…
Reference in a new issue