238 lines
7.5 KiB
Python
238 lines
7.5 KiB
Python
import pretrainedmodels
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision import models
|
|
from efficientnet_pytorch import EfficientNet
|
|
|
|
|
|
class PretrainedModel(nn.Module):
|
|
"""Pretrained model, either from Cadene or TorchVision."""
|
|
|
|
def __init__(self):
|
|
super(PretrainedModel, self).__init__()
|
|
|
|
def forward(self, x):
|
|
raise NotImplementedError('Subclass of PretrainedModel ' +
|
|
'must implement forward method.')
|
|
|
|
def fine_tuning_parameters(self, boundary_layers, lrs):
|
|
"""Get a list of parameter groups that can be passed to an optimizer.
|
|
|
|
Args:
|
|
boundary_layers: List of names for the boundary layers.
|
|
lrs: List of learning rates for each parameter group, from earlier
|
|
to later layers.
|
|
|
|
Returns:
|
|
param_groups: List of dictionaries, one per parameter group.
|
|
"""
|
|
|
|
def gen_params(start_layer, end_layer):
|
|
saw_start_layer = False
|
|
for name, param in self.named_parameters():
|
|
if end_layer is not None and name == end_layer:
|
|
# Saw the last layer -> done
|
|
return
|
|
if start_layer is None or name == start_layer:
|
|
# Saw the first layer -> Start returning layers
|
|
saw_start_layer = True
|
|
|
|
if saw_start_layer:
|
|
yield param
|
|
|
|
if len(lrs) != boundary_layers + 1:
|
|
raise ValueError(f'Got {boundary_layers + 1} param groups, ' +
|
|
f'but {lrs} learning rates')
|
|
|
|
# Fine-tune the network's layers from encoder.2 onwards
|
|
boundary_layers = [None] + boundary_layers + [None]
|
|
param_groups = []
|
|
for i in range(len(boundary_layers) - 1):
|
|
start, end = boundary_layers[i:i + 2]
|
|
param_groups.append({'params': gen_params(start, end),
|
|
'lr': lrs[i]})
|
|
return param_groups
|
|
|
|
|
|
class EfficientNetModel(PretrainedModel):
|
|
"""EfficientNet models:
|
|
https://github.com/lukemelas/EfficientNet-PyTorch
|
|
"""
|
|
def __init__(self, model_name, model_args=None):
|
|
super().__init__()
|
|
num_classes = model_args.get("num_classes", None)
|
|
pretrained = model_args.get("pretrained", False)
|
|
|
|
if pretrained:
|
|
self.model = EfficientNet.from_pretrained(
|
|
model_name, num_classes=num_classes)
|
|
else:
|
|
self.model = EfficientNet.from_name(
|
|
model_name, num_classes=num_classes)
|
|
|
|
def forward(self, x):
|
|
x = self.model(x)
|
|
return x
|
|
|
|
|
|
class CadeneModel(PretrainedModel):
|
|
"""Models from Cadene's GitHub page of pretrained networks:
|
|
https://github.com/Cadene/pretrained-models.pytorch
|
|
"""
|
|
|
|
def __init__(self, model_name, model_args=None):
|
|
super(CadeneModel, self).__init__()
|
|
|
|
model_class = pretrainedmodels.__dict__[model_name]
|
|
pretrained = "imagenet" if model_args['pretrained'] else None
|
|
self.model = model_class(num_classes=1000,
|
|
pretrained=pretrained)
|
|
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
num_ftrs = self.model.last_linear.in_features
|
|
self.fc = nn.Linear(num_ftrs, model_args['num_classes'])
|
|
|
|
def forward(self, x):
|
|
x = self.model.features(x)
|
|
x = F.relu(x, inplace=False)
|
|
x = self.pool(x).view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
|
|
class TorchVisionModel(PretrainedModel):
|
|
"""Models from TorchVision's GitHub page of pretrained neural networks:
|
|
https://github.com/pytorch/vision/tree/master/torchvision/models
|
|
"""
|
|
|
|
def __init__(self, model_fn, model_args):
|
|
super(TorchVisionModel, self).__init__()
|
|
|
|
self.model = model_fn(pretrained=model_args.pretrained)
|
|
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
num_outputs = model_args['num_classes']
|
|
|
|
if 'fc' in self.model.__dict__:
|
|
num_ftrs = self.model.classifier.in_features
|
|
self.model.fc = nn.Linear(num_ftrs, num_outputs)
|
|
elif 'classifier' in self.model.__dict__:
|
|
num_ftrs = self.model.classifier.in_features
|
|
self.model.classifier = nn.Linear(num_ftrs, num_outputs)
|
|
|
|
def forward(self, x):
|
|
x = self.model.features(x)
|
|
x = F.relu(x, inplace=False)
|
|
x = self.pool(x).view(x.size(0), -1)
|
|
x = self.model.classifier(x)
|
|
return x
|
|
|
|
|
|
class EfficientNetB0(EfficientNetModel):
|
|
def __init__(self, model_args=None):
|
|
super().__init__('efficientnet-b0', model_args)
|
|
|
|
|
|
class EfficientNetB1(EfficientNetModel):
|
|
def __init__(self, model_args=None):
|
|
super().__init__('efficientnet-b1', model_args)
|
|
|
|
|
|
class EfficientNetB2(EfficientNetModel):
|
|
def __init__(self, model_args=None):
|
|
super().__init__('efficientnet-b2', model_args)
|
|
|
|
|
|
class EfficientNetB3(EfficientNetModel):
|
|
def __init__(self, model_args=None):
|
|
super().__init__('efficientnet-b3', model_args)
|
|
|
|
|
|
class EfficientNetB4(EfficientNetModel):
|
|
def __init__(self, model_args=None):
|
|
super().__init__('efficientnet-b4', model_args)
|
|
|
|
|
|
class EfficientNetB5(EfficientNetModel):
|
|
def __init__(self, model_args=None):
|
|
super().__init__('efficientnet-b5', model_args)
|
|
|
|
|
|
class EfficientNetB6(EfficientNetModel):
|
|
def __init__(self, model_args=None):
|
|
super().__init__('efficientnet-b6', model_args)
|
|
|
|
|
|
class EfficientNetB7(EfficientNetModel):
|
|
def __init__(self, model_args=None):
|
|
super().__init__('efficientnet-b7', model_args)
|
|
|
|
|
|
class DenseNet121(TorchVisionModel):
|
|
def __init__(self, model_args=None):
|
|
super(DenseNet121, self).__init__(models.densenet121, model_args)
|
|
|
|
|
|
class DenseNet161(TorchVisionModel):
|
|
def __init__(self, model_args=None):
|
|
super(DenseNet161, self).__init__(models.densenet161, model_args)
|
|
|
|
|
|
class DenseNet201(TorchVisionModel):
|
|
def __init__(self, model_args=None):
|
|
super(DenseNet201, self).__init__(models.densenet201, model_args)
|
|
|
|
|
|
class ResNet101(TorchVisionModel):
|
|
def __init__(self, model_args=None):
|
|
super(ResNet101, self).__init__(models.resnet101, model_args)
|
|
|
|
|
|
class ResNet152(TorchVisionModel):
|
|
def __init__(self, model_args=None):
|
|
super(ResNet152, self).__init__(models.resnet152, model_args)
|
|
|
|
|
|
class Inceptionv3(TorchVisionModel):
|
|
def __init__(self, model_args=None):
|
|
super(Inceptionv3, self).__init__(models.inception_v3, model_args)
|
|
|
|
|
|
class Inceptionv4(CadeneModel):
|
|
def __init__(self, model_args=None):
|
|
super(Inceptionv4, self).__init__('inceptionv4', model_args)
|
|
|
|
|
|
class ResNet18(CadeneModel):
|
|
def __init__(self, model_args=None):
|
|
super(ResNet18, self).__init__('resnet18', model_args)
|
|
|
|
|
|
class ResNet34(CadeneModel):
|
|
def __init__(self, model_args=None):
|
|
super(ResNet34, self).__init__('resnet34', model_args)
|
|
|
|
|
|
class ResNeXt101(CadeneModel):
|
|
def __init__(self, model_args=None):
|
|
super(ResNeXt101, self).__init__('resnext101_64x4d', model_args)
|
|
|
|
|
|
class NASNetA(CadeneModel):
|
|
def __init__(self, model_args=None):
|
|
super(NASNetA, self).__init__('nasnetalarge', model_args)
|
|
|
|
|
|
class MNASNet(CadeneModel):
|
|
def __init__(self, model_args=None):
|
|
super(MNASNet, self).__init__('nasnetamobile', model_args)
|
|
|
|
|
|
class SENet154(CadeneModel):
|
|
def __init__(self, model_args=None):
|
|
super(SENet154, self).__init__('senet154', model_args)
|
|
|
|
|
|
class SEResNeXt101(CadeneModel):
|
|
def __init__(self, model_args=None):
|
|
super(SEResNeXt101, self).__init__('se_resnext101_32x4d', model_args)
|