This commit is contained in:
Zhongdao 2019-10-11 17:26:59 +08:00
parent 2042751580
commit 59419367fd
3 changed files with 17 additions and 50 deletions

View file

@ -2,10 +2,11 @@ import argparse
import json import json
import time import time
import test # Import test.py to get mAP after each epoch import test
from models import * from models import *
from utils.datasets import JointDataset, collate_fn from utils.datasets import JointDataset, collate_fn
from utils.utils import * from utils.utils import *
from utils.log import logger
from torchvision.transforms import transforms as T from torchvision.transforms import transforms as T
@ -18,13 +19,11 @@ def train(
batch_size=16, batch_size=16,
accumulated_batches=1, accumulated_batches=1,
freeze_backbone=False, freeze_backbone=False,
var=0,
opt=None, opt=None,
): ):
weights = 'weights' + os.sep weights = 'weights'
latest = weights + 'latest.pt' mkdir_if_missing(weights)
best = weights + 'best.pt' latest = osp.join(weights, 'latest.pt')
device = torch_utils.select_device()
torch.backends.cudnn.benchmark = True # unsuitable for multiscale torch.backends.cudnn.benchmark = True # unsuitable for multiscale
@ -45,40 +44,37 @@ def train(
# Initialize model # Initialize model
model = Darknet(cfg, img_size, dataset.nID) model = Darknet(cfg, img_size, dataset.nID)
lr0 = opt.lr
cutoff = -1 # backbone reaches to cutoff layer cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0 start_epoch = 0
best_loss = float('inf')
if resume: if resume:
checkpoint = torch.load(latest, map_location='cpu') checkpoint = torch.load(latest, map_location='cpu')
# Load weights to resume from # Load weights to resume from
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
model.to(device).train() model.cuda().train()
# Set optimizer # Set optimizer
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=lr0, momentum=.9) optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=opt.lr, momentum=.9)
start_epoch = checkpoint['epoch'] + 1 start_epoch = checkpoint['epoch'] + 1
if checkpoint['optimizer'] is not None: if checkpoint['optimizer'] is not None:
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
best_loss = checkpoint['best_loss']
del checkpoint # current, saved del checkpoint # current, saved
else: else:
# Initialize model with backbone (optional) # Initialize model with backbone (optional)
if cfg.endswith('yolov3.cfg'): if cfg.endswith('yolov3.cfg'):
load_darknet_weights(model, weights + 'darknet53.conv.74') load_darknet_weights(model, osp.join(weights ,'darknet53.conv.74'))
cutoff = 75 cutoff = 75
elif cfg.endswith('yolov3-tiny.cfg'): elif cfg.endswith('yolov3-tiny.cfg'):
load_darknet_weights(model, weights + 'yolov3-tiny.conv.15') load_darknet_weights(model, osp.join(weights , 'yolov3-tiny.conv.15'))
cutoff = 15 cutoff = 15
model.to(device).train() model.cuda().train()
# Set optimizer # Set optimizer
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=lr0, momentum=.9, weight_decay=1e-4) optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=opt.lr, momentum=.9, weight_decay=1e-4)
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Set scheduler # Set scheduler
@ -95,7 +91,7 @@ def train(
for epoch in range(epochs): for epoch in range(epochs):
epoch += start_epoch epoch += start_epoch
print(('%8s%12s' + '%10s' * 6) % ( logger.info(('%8s%12s' + '%10s' * 6) % (
'Epoch', 'Batch', 'box', 'conf', 'id', 'total', 'nTargets', 'time')) 'Epoch', 'Batch', 'box', 'conf', 'id', 'total', 'nTargets', 'time'))
# Update scheduler (automatic) # Update scheduler (automatic)
@ -118,7 +114,7 @@ def train(
# SGD burn-in # SGD burn-in
burnin = min(1000, len(dataloader)) burnin = min(1000, len(dataloader))
if (epoch == 0) & (i <= burnin): if (epoch == 0) & (i <= burnin):
lr = lr0 * (i / burnin) **4 lr = opt.lr * (i / burnin) **4
for g in optimizer.param_groups: for g in optimizer.param_groups:
g['lr'] = lr g['lr'] = lr
@ -148,12 +144,11 @@ def train(
rloss['nT'], time.time() - t0) rloss['nT'], time.time() - t0)
t0 = time.time() t0 = time.time()
if i % opt.print_interval == 0: if i % opt.print_interval == 0:
print(s) logger.info(s)
# Save latest checkpoint # Save latest checkpoint
checkpoint = {'epoch': epoch, checkpoint = {'epoch': epoch,
# 'best_loss': best_loss,
'model': model.module.state_dict(), 'model': model.module.state_dict(),
'optimizer': optimizer.state_dict()} 'optimizer': optimizer.state_dict()}
torch.save(checkpoint, latest) torch.save(checkpoint, latest)
@ -176,14 +171,11 @@ if __name__ == '__main__':
parser.add_argument('--data-cfg', type=str, default='cfg/ccmcpe.json', help='coco.data file path') parser.add_argument('--data-cfg', type=str, default='cfg/ccmcpe.json', help='coco.data file path')
parser.add_argument('--img-size', type=int, default=(1088, 608), help='pixels') parser.add_argument('--img-size', type=int, default=(1088, 608), help='pixels')
parser.add_argument('--resume', action='store_true', help='resume training flag') parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--var', type=float, default=0, help='test variable')
parser.add_argument('--print-interval', type=int, default=40, help='print interval') parser.add_argument('--print-interval', type=int, default=40, help='print interval')
parser.add_argument('--test-interval', type=int, default=9, help='test interval') parser.add_argument('--test-interval', type=int, default=9, help='test interval')
parser.add_argument('--lr', type=float, default=1e-2, help='init lr') parser.add_argument('--lr', type=float, default=1e-2, help='init lr')
parser.add_argument('--idw', type=float, default=0.1, help='loss id weight')
parser.add_argument('--unfreeze-bn', action='store_true', help='unfreeze bn') parser.add_argument('--unfreeze-bn', action='store_true', help='unfreeze bn')
opt = parser.parse_args() opt = parser.parse_args()
print(opt, end='\n\n')
init_seeds() init_seeds()
@ -195,6 +187,5 @@ if __name__ == '__main__':
epochs=opt.epochs, epochs=opt.epochs,
batch_size=opt.batch_size, batch_size=opt.batch_size,
accumulated_batches=opt.accumulated_batches, accumulated_batches=opt.accumulated_batches,
var=opt.var,
opt=opt, opt=opt,
) )

View file

@ -1,25 +0,0 @@
import torch
def init_seeds(seed=0):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def select_device(force_cpu=False):
if force_cpu:
cuda = False
device = torch.device('cpu')
else:
cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if cuda else 'cpu')
if torch.cuda.device_count() > 1:
print('WARNING Using GPU0 Only: https://github.com/ultralytics/yolov3/issues/21')
torch.cuda.set_device(0) # OPTIONAL: Set your GPU if multiple available
# print('Using ', torch.cuda.device_count(), ' GPUs')
print('Using %s %s\n' % (device.type, torch.cuda.get_device_properties(0) if cuda else ''))
print(device)
return device

View file

@ -10,7 +10,6 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from utils import torch_utils
import maskrcnn_benchmark.layers.nms as nms import maskrcnn_benchmark.layers.nms as nms
# Set printoptions # Set printoptions
torch.set_printoptions(linewidth=1320, precision=5, profile='long') torch.set_printoptions(linewidth=1320, precision=5, profile='long')
@ -28,7 +27,9 @@ def float3(x): # format floats to 3 decimals
def init_seeds(seed=0): def init_seeds(seed=0):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch_utils.init_seeds(seed=seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def load_classes(path): def load_classes(path):