surveilling-surveillance/detection/models/classification.py
2021-05-20 13:22:04 -07:00

238 lines
7.5 KiB
Python

import pretrainedmodels
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from efficientnet_pytorch import EfficientNet
class PretrainedModel(nn.Module):
"""Pretrained model, either from Cadene or TorchVision."""
def __init__(self):
super(PretrainedModel, self).__init__()
def forward(self, x):
raise NotImplementedError('Subclass of PretrainedModel ' +
'must implement forward method.')
def fine_tuning_parameters(self, boundary_layers, lrs):
"""Get a list of parameter groups that can be passed to an optimizer.
Args:
boundary_layers: List of names for the boundary layers.
lrs: List of learning rates for each parameter group, from earlier
to later layers.
Returns:
param_groups: List of dictionaries, one per parameter group.
"""
def gen_params(start_layer, end_layer):
saw_start_layer = False
for name, param in self.named_parameters():
if end_layer is not None and name == end_layer:
# Saw the last layer -> done
return
if start_layer is None or name == start_layer:
# Saw the first layer -> Start returning layers
saw_start_layer = True
if saw_start_layer:
yield param
if len(lrs) != boundary_layers + 1:
raise ValueError(f'Got {boundary_layers + 1} param groups, ' +
f'but {lrs} learning rates')
# Fine-tune the network's layers from encoder.2 onwards
boundary_layers = [None] + boundary_layers + [None]
param_groups = []
for i in range(len(boundary_layers) - 1):
start, end = boundary_layers[i:i + 2]
param_groups.append({'params': gen_params(start, end),
'lr': lrs[i]})
return param_groups
class EfficientNetModel(PretrainedModel):
"""EfficientNet models:
https://github.com/lukemelas/EfficientNet-PyTorch
"""
def __init__(self, model_name, model_args=None):
super().__init__()
num_classes = model_args.get("num_classes", None)
pretrained = model_args.get("pretrained", False)
if pretrained:
self.model = EfficientNet.from_pretrained(
model_name, num_classes=num_classes)
else:
self.model = EfficientNet.from_name(
model_name, num_classes=num_classes)
def forward(self, x):
x = self.model(x)
return x
class CadeneModel(PretrainedModel):
"""Models from Cadene's GitHub page of pretrained networks:
https://github.com/Cadene/pretrained-models.pytorch
"""
def __init__(self, model_name, model_args=None):
super(CadeneModel, self).__init__()
model_class = pretrainedmodels.__dict__[model_name]
pretrained = "imagenet" if model_args['pretrained'] else None
self.model = model_class(num_classes=1000,
pretrained=pretrained)
self.pool = nn.AdaptiveAvgPool2d(1)
num_ftrs = self.model.last_linear.in_features
self.fc = nn.Linear(num_ftrs, model_args['num_classes'])
def forward(self, x):
x = self.model.features(x)
x = F.relu(x, inplace=False)
x = self.pool(x).view(x.size(0), -1)
x = self.fc(x)
return x
class TorchVisionModel(PretrainedModel):
"""Models from TorchVision's GitHub page of pretrained neural networks:
https://github.com/pytorch/vision/tree/master/torchvision/models
"""
def __init__(self, model_fn, model_args):
super(TorchVisionModel, self).__init__()
self.model = model_fn(pretrained=model_args.pretrained)
self.pool = nn.AdaptiveAvgPool2d(1)
num_outputs = model_args['num_classes']
if 'fc' in self.model.__dict__:
num_ftrs = self.model.classifier.in_features
self.model.fc = nn.Linear(num_ftrs, num_outputs)
elif 'classifier' in self.model.__dict__:
num_ftrs = self.model.classifier.in_features
self.model.classifier = nn.Linear(num_ftrs, num_outputs)
def forward(self, x):
x = self.model.features(x)
x = F.relu(x, inplace=False)
x = self.pool(x).view(x.size(0), -1)
x = self.model.classifier(x)
return x
class EfficientNetB0(EfficientNetModel):
def __init__(self, model_args=None):
super().__init__('efficientnet-b0', model_args)
class EfficientNetB1(EfficientNetModel):
def __init__(self, model_args=None):
super().__init__('efficientnet-b1', model_args)
class EfficientNetB2(EfficientNetModel):
def __init__(self, model_args=None):
super().__init__('efficientnet-b2', model_args)
class EfficientNetB3(EfficientNetModel):
def __init__(self, model_args=None):
super().__init__('efficientnet-b3', model_args)
class EfficientNetB4(EfficientNetModel):
def __init__(self, model_args=None):
super().__init__('efficientnet-b4', model_args)
class EfficientNetB5(EfficientNetModel):
def __init__(self, model_args=None):
super().__init__('efficientnet-b5', model_args)
class EfficientNetB6(EfficientNetModel):
def __init__(self, model_args=None):
super().__init__('efficientnet-b6', model_args)
class EfficientNetB7(EfficientNetModel):
def __init__(self, model_args=None):
super().__init__('efficientnet-b7', model_args)
class DenseNet121(TorchVisionModel):
def __init__(self, model_args=None):
super(DenseNet121, self).__init__(models.densenet121, model_args)
class DenseNet161(TorchVisionModel):
def __init__(self, model_args=None):
super(DenseNet161, self).__init__(models.densenet161, model_args)
class DenseNet201(TorchVisionModel):
def __init__(self, model_args=None):
super(DenseNet201, self).__init__(models.densenet201, model_args)
class ResNet101(TorchVisionModel):
def __init__(self, model_args=None):
super(ResNet101, self).__init__(models.resnet101, model_args)
class ResNet152(TorchVisionModel):
def __init__(self, model_args=None):
super(ResNet152, self).__init__(models.resnet152, model_args)
class Inceptionv3(TorchVisionModel):
def __init__(self, model_args=None):
super(Inceptionv3, self).__init__(models.inception_v3, model_args)
class Inceptionv4(CadeneModel):
def __init__(self, model_args=None):
super(Inceptionv4, self).__init__('inceptionv4', model_args)
class ResNet18(CadeneModel):
def __init__(self, model_args=None):
super(ResNet18, self).__init__('resnet18', model_args)
class ResNet34(CadeneModel):
def __init__(self, model_args=None):
super(ResNet34, self).__init__('resnet34', model_args)
class ResNeXt101(CadeneModel):
def __init__(self, model_args=None):
super(ResNeXt101, self).__init__('resnext101_64x4d', model_args)
class NASNetA(CadeneModel):
def __init__(self, model_args=None):
super(NASNetA, self).__init__('nasnetalarge', model_args)
class MNASNet(CadeneModel):
def __init__(self, model_args=None):
super(MNASNet, self).__init__('nasnetamobile', model_args)
class SENet154(CadeneModel):
def __init__(self, model_args=None):
super(SENet154, self).__init__('senet154', model_args)
class SEResNeXt101(CadeneModel):
def __init__(self, model_args=None):
super(SEResNeXt101, self).__init__('se_resnext101_32x4d', model_args)