FastestDet/module/shufflenetv2.py
2022-07-01 21:22:39 +08:00

113 lines
4.0 KiB
Python

import torch
import torch.nn as nn
class ShuffleV2Block(nn.Module):
def __init__(self, inp, oup, mid_channels, *, ksize, stride):
super(ShuffleV2Block, self).__init__()
self.stride = stride
assert stride in [1, 2]
self.mid_channels = mid_channels
self.ksize = ksize
pad = ksize // 2
self.pad = pad
self.inp = inp
outputs = oup - inp
branch_main = [
# pw
nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
# dw
nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),
nn.BatchNorm2d(mid_channels),
# pw-linear
nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=False),
nn.BatchNorm2d(outputs),
nn.ReLU(inplace=True),
]
self.branch_main = nn.Sequential(*branch_main)
if stride == 2:
branch_proj = [
# dw
nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False),
nn.BatchNorm2d(inp),
# pw-linear
nn.Conv2d(inp, inp, 1, 1, 0, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True),
]
self.branch_proj = nn.Sequential(*branch_proj)
else:
self.branch_proj = None
def forward(self, old_x):
if self.stride==1:
x_proj, x = self.channel_shuffle(old_x)
return torch.cat((x_proj, self.branch_main(x)), 1)
elif self.stride==2:
x_proj = old_x
x = old_x
return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1)
def channel_shuffle(self, x):
batchsize, num_channels, height, width = x.data.size()
assert (num_channels % 4 == 0)
x = x.reshape(batchsize * num_channels // 2, 2, height * width)
x = x.permute(1, 0, 2)
x = x.reshape(2, -1, num_channels // 2, height, width)
return x[0], x[1]
class ShuffleNetV2(nn.Module):
def __init__(self, stage_repeats, stage_out_channels, load_param):
super(ShuffleNetV2, self).__init__()
self.stage_repeats = stage_repeats
self.stage_out_channels = stage_out_channels
# building first layer
input_channel = self.stage_out_channels[1]
self.first_conv = nn.Sequential(
nn.Conv2d(3, input_channel, 3, 2, 1, bias=False),
nn.BatchNorm2d(input_channel),
nn.ReLU(inplace=True),
)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
stage_names = ["stage2", "stage3", "stage4"]
for idxstage in range(len(self.stage_repeats)):
numrepeat = self.stage_repeats[idxstage]
output_channel = self.stage_out_channels[idxstage+2]
stageSeq = []
for i in range(numrepeat):
if i == 0:
stageSeq.append(ShuffleV2Block(input_channel, output_channel,
mid_channels=output_channel // 2, ksize=3, stride=2))
else:
stageSeq.append(ShuffleV2Block(input_channel // 2, output_channel,
mid_channels=output_channel // 2, ksize=3, stride=1))
input_channel = output_channel
setattr(self, stage_names[idxstage], nn.Sequential(*stageSeq))
if load_param == False:
self._initialize_weights()
else:
print("load param...")
def forward(self, x):
x = self.first_conv(x)
x = self.maxpool(x)
P1 = self.stage2(x)
P2 = self.stage3(P1)
P3 = self.stage4(P2)
return P1, P2, P3
def _initialize_weights(self):
print("Initialize params from:%s"%"./module/shufflenetv2.pth")
self.load_state_dict(torch.load("./module/shufflenetv2.pth"), strict = True)