surveilling-surveillance/detection/models/detection/efficientdet/model.py
2021-05-20 13:22:04 -07:00

465 lines
18 KiB
Python

import torch.nn as nn
import torch
from torchvision.ops.boxes import nms as nms_torch
from .efficientnet import EfficientNet as EffNet
from .efficientnet.utils import MemoryEfficientSwish, Swish
from .efficientnet.utils_extra import Conv2dStaticSamePadding, MaxPool2dStaticSamePadding
def nms(dets, thresh):
return nms_torch(dets[:, :4], dets[:, 4], thresh)
class SeparableConvBlock(nn.Module):
"""
created by Zylo117
"""
def __init__(self, in_channels, out_channels=None, norm=True, activation=False, onnx_export=False):
super(SeparableConvBlock, self).__init__()
if out_channels is None:
out_channels = in_channels
# Q: whether separate conv
# share bias between depthwise_conv and pointwise_conv
# or just pointwise_conv apply bias.
# A: Confirmed, just pointwise_conv applies bias, depthwise_conv has no bias.
self.depthwise_conv = Conv2dStaticSamePadding(in_channels, in_channels,
kernel_size=3, stride=1, groups=in_channels, bias=False)
self.pointwise_conv = Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=1, stride=1)
self.norm = norm
if self.norm:
# Warning: pytorch momentum is different from tensorflow's, momentum_pytorch = 1 - momentum_tensorflow
self.bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.01, eps=1e-3)
self.activation = activation
if self.activation:
self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
def forward(self, x):
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
if self.norm:
x = self.bn(x)
if self.activation:
x = self.swish(x)
return x
class BiFPN(nn.Module):
"""
modified by Zylo117
"""
def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4, onnx_export=False, attention=True,
use_p8=False):
"""
Args:
num_channels:
conv_channels:
first_time: whether the input comes directly from the efficientnet,
if True, downchannel it first, and downsample P5 to generate P6 then P7
epsilon: epsilon of fast weighted attention sum of BiFPN, not the BN's epsilon
onnx_export: if True, use Swish instead of MemoryEfficientSwish
"""
super(BiFPN, self).__init__()
self.epsilon = epsilon
self.use_p8 = use_p8
# Conv layers
self.conv6_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
self.conv5_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
self.conv4_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
self.conv3_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
self.conv4_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
self.conv5_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
self.conv6_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
self.conv7_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
if use_p8:
self.conv7_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
self.conv8_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
# Feature scaling layers
self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p4_downsample = MaxPool2dStaticSamePadding(3, 2)
self.p5_downsample = MaxPool2dStaticSamePadding(3, 2)
self.p6_downsample = MaxPool2dStaticSamePadding(3, 2)
self.p7_downsample = MaxPool2dStaticSamePadding(3, 2)
if use_p8:
self.p7_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p8_downsample = MaxPool2dStaticSamePadding(3, 2)
self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
self.first_time = first_time
if self.first_time:
self.p5_down_channel = nn.Sequential(
Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
)
self.p4_down_channel = nn.Sequential(
Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
)
self.p3_down_channel = nn.Sequential(
Conv2dStaticSamePadding(conv_channels[0], num_channels, 1),
nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
)
self.p5_to_p6 = nn.Sequential(
Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
MaxPool2dStaticSamePadding(3, 2)
)
self.p6_to_p7 = nn.Sequential(
MaxPool2dStaticSamePadding(3, 2)
)
if use_p8:
self.p7_to_p8 = nn.Sequential(
MaxPool2dStaticSamePadding(3, 2)
)
self.p4_down_channel_2 = nn.Sequential(
Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
)
self.p5_down_channel_2 = nn.Sequential(
Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
)
# Weight
self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p6_w1_relu = nn.ReLU()
self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p5_w1_relu = nn.ReLU()
self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p4_w1_relu = nn.ReLU()
self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p3_w1_relu = nn.ReLU()
self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
self.p4_w2_relu = nn.ReLU()
self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
self.p5_w2_relu = nn.ReLU()
self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
self.p6_w2_relu = nn.ReLU()
self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p7_w2_relu = nn.ReLU()
self.attention = attention
def forward(self, inputs):
"""
illustration of a minimal bifpn unit
P7_0 -------------------------> P7_2 -------->
|-------------| ↑
↓ |
P6_0 ---------> P6_1 ---------> P6_2 -------->
|-------------|--------------↑ ↑
↓ |
P5_0 ---------> P5_1 ---------> P5_2 -------->
|-------------|--------------↑ ↑
↓ |
P4_0 ---------> P4_1 ---------> P4_2 -------->
|-------------|--------------↑ ↑
|--------------↓ |
P3_0 -------------------------> P3_2 -------->
"""
# downsample channels using same-padding conv2d to target phase's if not the same
# judge: same phase as target,
# if same, pass;
# elif earlier phase, downsample to target phase's by pooling
# elif later phase, upsample to target phase's by nearest interpolation
if self.attention:
outs = self._forward_fast_attention(inputs)
else:
outs = self._forward(inputs)
return outs
def _forward_fast_attention(self, inputs):
if self.first_time:
p3, p4, p5 = inputs
p6_in = self.p5_to_p6(p5)
p7_in = self.p6_to_p7(p6_in)
p3_in = self.p3_down_channel(p3)
p4_in = self.p4_down_channel(p4)
p5_in = self.p5_down_channel(p5)
else:
# P3_0, P4_0, P5_0, P6_0 and P7_0
p3_in, p4_in, p5_in, p6_in, p7_in = inputs
# P7_0 to P7_2
# Weights for P6_0 and P7_0 to P6_1
p6_w1 = self.p6_w1_relu(self.p6_w1)
weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
# Connections for P6_0 and P7_0 to P6_1 respectively
p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in)))
# Weights for P5_0 and P6_1 to P5_1
p5_w1 = self.p5_w1_relu(self.p5_w1)
weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
# Connections for P5_0 and P6_1 to P5_1 respectively
p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up)))
# Weights for P4_0 and P5_1 to P4_1
p4_w1 = self.p4_w1_relu(self.p4_w1)
weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
# Connections for P4_0 and P5_1 to P4_1 respectively
p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up)))
# Weights for P3_0 and P4_1 to P3_2
p3_w1 = self.p3_w1_relu(self.p3_w1)
weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
# Connections for P3_0 and P4_1 to P3_2 respectively
p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up)))
if self.first_time:
p4_in = self.p4_down_channel_2(p4)
p5_in = self.p5_down_channel_2(p5)
# Weights for P4_0, P4_1 and P3_2 to P4_2
p4_w2 = self.p4_w2_relu(self.p4_w2)
weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
# Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
p4_out = self.conv4_down(
self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out)))
# Weights for P5_0, P5_1 and P4_2 to P5_2
p5_w2 = self.p5_w2_relu(self.p5_w2)
weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
# Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
p5_out = self.conv5_down(
self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out)))
# Weights for P6_0, P6_1 and P5_2 to P6_2
p6_w2 = self.p6_w2_relu(self.p6_w2)
weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
# Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
p6_out = self.conv6_down(
self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out)))
# Weights for P7_0 and P6_2 to P7_2
p7_w2 = self.p7_w2_relu(self.p7_w2)
weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
# Connections for P7_0 and P6_2 to P7_2
p7_out = self.conv7_down(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out)))
return p3_out, p4_out, p5_out, p6_out, p7_out
def _forward(self, inputs):
if self.first_time:
p3, p4, p5 = inputs
p6_in = self.p5_to_p6(p5)
p7_in = self.p6_to_p7(p6_in)
if self.use_p8:
p8_in = self.p7_to_p8(p7_in)
p3_in = self.p3_down_channel(p3)
p4_in = self.p4_down_channel(p4)
p5_in = self.p5_down_channel(p5)
else:
if self.use_p8:
# P3_0, P4_0, P5_0, P6_0, P7_0 and P8_0
p3_in, p4_in, p5_in, p6_in, p7_in, p8_in = inputs
else:
# P3_0, P4_0, P5_0, P6_0 and P7_0
p3_in, p4_in, p5_in, p6_in, p7_in = inputs
if self.use_p8:
# P8_0 to P8_2
# Connections for P7_0 and P8_0 to P7_1 respectively
p7_up = self.conv7_up(self.swish(p7_in + self.p7_upsample(p8_in)))
# Connections for P6_0 and P7_0 to P6_1 respectively
p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_up)))
else:
# P7_0 to P7_2
# Connections for P6_0 and P7_0 to P6_1 respectively
p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in)))
# Connections for P5_0 and P6_1 to P5_1 respectively
p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up)))
# Connections for P4_0 and P5_1 to P4_1 respectively
p4_up = self.conv4_up(self.swish(p4_in + self.p4_upsample(p5_up)))
# Connections for P3_0 and P4_1 to P3_2 respectively
p3_out = self.conv3_up(self.swish(p3_in + self.p3_upsample(p4_up)))
if self.first_time:
p4_in = self.p4_down_channel_2(p4)
p5_in = self.p5_down_channel_2(p5)
# Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
p4_out = self.conv4_down(
self.swish(p4_in + p4_up + self.p4_downsample(p3_out)))
# Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
p5_out = self.conv5_down(
self.swish(p5_in + p5_up + self.p5_downsample(p4_out)))
# Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
p6_out = self.conv6_down(
self.swish(p6_in + p6_up + self.p6_downsample(p5_out)))
if self.use_p8:
# Connections for P7_0, P7_1 and P6_2 to P7_2 respectively
p7_out = self.conv7_down(
self.swish(p7_in + p7_up + self.p7_downsample(p6_out)))
# Connections for P8_0 and P7_2 to P8_2
p8_out = self.conv8_down(self.swish(p8_in + self.p8_downsample(p7_out)))
return p3_out, p4_out, p5_out, p6_out, p7_out, p8_out
else:
# Connections for P7_0 and P6_2 to P7_2
p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out)))
return p3_out, p4_out, p5_out, p6_out, p7_out
class Regressor(nn.Module):
"""
modified by Zylo117
"""
def __init__(self, in_channels, num_anchors, num_layers, pyramid_levels=5, onnx_export=False):
super(Regressor, self).__init__()
self.num_layers = num_layers
self.conv_list = nn.ModuleList(
[SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)])
self.bn_list = nn.ModuleList(
[nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in
range(pyramid_levels)])
self.header = SeparableConvBlock(in_channels, num_anchors * 4, norm=False, activation=False)
self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
def forward(self, inputs):
feats = []
for feat, bn_list in zip(inputs, self.bn_list):
for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list):
feat = conv(feat)
feat = bn(feat)
feat = self.swish(feat)
feat = self.header(feat)
feat = feat.permute(0, 2, 3, 1)
feat = feat.contiguous().view(feat.shape[0], -1, 4)
feats.append(feat)
feats = torch.cat(feats, dim=1)
return feats
class Classifier(nn.Module):
"""
modified by Zylo117
"""
def __init__(self, in_channels, num_anchors, num_classes, num_layers, pyramid_levels=5, onnx_export=False):
super(Classifier, self).__init__()
self.num_anchors = num_anchors
self.num_classes = num_classes
self.num_layers = num_layers
self.conv_list = nn.ModuleList(
[SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)])
self.bn_list = nn.ModuleList(
[nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in
range(pyramid_levels)])
self.header = SeparableConvBlock(in_channels, num_anchors * num_classes, norm=False, activation=False)
self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
def forward(self, inputs):
feats = []
for feat, bn_list in zip(inputs, self.bn_list):
for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list):
feat = conv(feat)
feat = bn(feat)
feat = self.swish(feat)
feat = self.header(feat)
feat = feat.permute(0, 2, 3, 1)
feat = feat.contiguous().view(feat.shape[0], feat.shape[1], feat.shape[2], self.num_anchors,
self.num_classes)
feat = feat.contiguous().view(feat.shape[0], -1, self.num_classes)
feats.append(feat)
feats = torch.cat(feats, dim=1)
feats = feats.sigmoid()
return feats
class EfficientNet(nn.Module):
"""
modified by Zylo117
"""
def __init__(self, compound_coef, load_weights=False):
super(EfficientNet, self).__init__()
model = EffNet.from_pretrained(f'efficientnet-b{compound_coef}', load_weights)
del model._conv_head
del model._bn1
del model._avg_pooling
del model._dropout
del model._fc
self.model = model
def forward(self, x):
x = self.model._conv_stem(x)
x = self.model._bn0(x)
x = self.model._swish(x)
feature_maps = []
# TODO: temporarily storing extra tensor last_x and del it later might not be a good idea,
# try recording stride changing when creating efficientnet,
# and then apply it here.
last_x = None
for idx, block in enumerate(self.model._blocks):
drop_connect_rate = self.model._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self.model._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
if block._depthwise_conv.stride == [2, 2]:
feature_maps.append(last_x)
elif idx == len(self.model._blocks) - 1:
feature_maps.append(x)
last_x = x
del last_x
return feature_maps[1:]
if __name__ == '__main__':
from tensorboardX import SummaryWriter
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)