FastestDet/module/custom_layers.py
2022-07-01 21:22:39 +08:00

98 lines
4.1 KiB
Python

import torch
import torch.nn as nn
class Conv1x1(nn.Module):
def __init__(self, input_channels, output_channels):
super(Conv1x1, self).__init__()
self.conv1x1 = nn.Sequential(nn.Conv2d(input_channels, output_channels, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv1x1(x)
class Head(nn.Module):
def __init__(self, input_channels, output_channels):
super(Head, self).__init__()
self.conv5x5 = nn.Sequential(nn.Conv2d(input_channels, input_channels, 5, 1, 2, groups = input_channels, bias = False),
nn.BatchNorm2d(input_channels),
nn.ReLU(inplace=True),
nn.Conv2d(input_channels, output_channels, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(output_channels)
)
def forward(self, x):
return self.conv5x5(x)
class SPP(nn.Module):
def __init__(self, input_channels, output_channels):
super(SPP, self).__init__()
self.Conv1x1 = Conv1x1(input_channels, output_channels)
self.S1 = nn.Sequential(nn.Conv2d(output_channels, output_channels, 5, 1, 2, groups = output_channels, bias = False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
self.S2 = nn.Sequential(nn.Conv2d(output_channels, output_channels, 5, 1, 2, groups = output_channels, bias = False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
nn.Conv2d(output_channels, output_channels, 5, 1, 2, groups = output_channels, bias = False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
self.S3 = nn.Sequential(nn.Conv2d(output_channels, output_channels, 5, 1, 2, groups = output_channels, bias = False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
nn.Conv2d(output_channels, output_channels, 5, 1, 2, groups = output_channels, bias = False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
nn.Conv2d(output_channels, output_channels, 5, 1, 2, groups = output_channels, bias = False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
self.output = nn.Sequential(nn.Conv2d(output_channels * 3, output_channels, 1, 1, 0, bias = False),
nn.BatchNorm2d(output_channels),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.Conv1x1(x)
y1 = self.S1(x)
y2 = self.S2(x)
y3 = self.S3(x)
y = torch.cat((y1, y2, y3), dim=1)
y = self.relu(x + self.output(y))
return y
class DetectHead(nn.Module):
def __init__(self, input_channels, category_num):
super(DetectHead, self).__init__()
self.conv1x1 = Conv1x1(input_channels, input_channels)
self.obj_layers = Head(input_channels, 1)
self.reg_layers = Head(input_channels, 4)
self.cls_layers = Head(input_channels, category_num)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = self.conv1x1(x)
obj = self.sigmoid(self.obj_layers(x))
reg = self.reg_layers(x)
cls = self.softmax(self.cls_layers(x))
return torch.cat((obj, reg, cls), dim =1)