113 lines
4.0 KiB
Python
113 lines
4.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class ShuffleV2Block(nn.Module):
|
|
def __init__(self, inp, oup, mid_channels, *, ksize, stride):
|
|
super(ShuffleV2Block, self).__init__()
|
|
self.stride = stride
|
|
assert stride in [1, 2]
|
|
|
|
self.mid_channels = mid_channels
|
|
self.ksize = ksize
|
|
pad = ksize // 2
|
|
self.pad = pad
|
|
self.inp = inp
|
|
|
|
outputs = oup - inp
|
|
|
|
branch_main = [
|
|
# pw
|
|
nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=False),
|
|
nn.BatchNorm2d(mid_channels),
|
|
nn.ReLU(inplace=True),
|
|
# dw
|
|
nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),
|
|
nn.BatchNorm2d(mid_channels),
|
|
# pw-linear
|
|
nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=False),
|
|
nn.BatchNorm2d(outputs),
|
|
nn.ReLU(inplace=True),
|
|
]
|
|
self.branch_main = nn.Sequential(*branch_main)
|
|
|
|
if stride == 2:
|
|
branch_proj = [
|
|
# dw
|
|
nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False),
|
|
nn.BatchNorm2d(inp),
|
|
# pw-linear
|
|
nn.Conv2d(inp, inp, 1, 1, 0, bias=False),
|
|
nn.BatchNorm2d(inp),
|
|
nn.ReLU(inplace=True),
|
|
]
|
|
self.branch_proj = nn.Sequential(*branch_proj)
|
|
else:
|
|
self.branch_proj = None
|
|
|
|
def forward(self, old_x):
|
|
if self.stride==1:
|
|
x_proj, x = self.channel_shuffle(old_x)
|
|
return torch.cat((x_proj, self.branch_main(x)), 1)
|
|
elif self.stride==2:
|
|
x_proj = old_x
|
|
x = old_x
|
|
return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1)
|
|
|
|
def channel_shuffle(self, x):
|
|
batchsize, num_channels, height, width = x.data.size()
|
|
assert (num_channels % 4 == 0)
|
|
x = x.reshape(batchsize * num_channels // 2, 2, height * width)
|
|
x = x.permute(1, 0, 2)
|
|
x = x.reshape(2, -1, num_channels // 2, height, width)
|
|
return x[0], x[1]
|
|
|
|
class ShuffleNetV2(nn.Module):
|
|
def __init__(self, stage_repeats, stage_out_channels, load_param):
|
|
super(ShuffleNetV2, self).__init__()
|
|
|
|
self.stage_repeats = stage_repeats
|
|
self.stage_out_channels = stage_out_channels
|
|
|
|
# building first layer
|
|
input_channel = self.stage_out_channels[1]
|
|
self.first_conv = nn.Sequential(
|
|
nn.Conv2d(3, input_channel, 3, 2, 1, bias=False),
|
|
nn.BatchNorm2d(input_channel),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
|
stage_names = ["stage2", "stage3", "stage4"]
|
|
for idxstage in range(len(self.stage_repeats)):
|
|
numrepeat = self.stage_repeats[idxstage]
|
|
output_channel = self.stage_out_channels[idxstage+2]
|
|
stageSeq = []
|
|
for i in range(numrepeat):
|
|
if i == 0:
|
|
stageSeq.append(ShuffleV2Block(input_channel, output_channel,
|
|
mid_channels=output_channel // 2, ksize=3, stride=2))
|
|
else:
|
|
stageSeq.append(ShuffleV2Block(input_channel // 2, output_channel,
|
|
mid_channels=output_channel // 2, ksize=3, stride=1))
|
|
input_channel = output_channel
|
|
setattr(self, stage_names[idxstage], nn.Sequential(*stageSeq))
|
|
|
|
if load_param == False:
|
|
self._initialize_weights()
|
|
else:
|
|
print("load param...")
|
|
|
|
def forward(self, x):
|
|
x = self.first_conv(x)
|
|
x = self.maxpool(x)
|
|
P1 = self.stage2(x)
|
|
P2 = self.stage3(P1)
|
|
P3 = self.stage4(P2)
|
|
|
|
return P1, P2, P3
|
|
|
|
def _initialize_weights(self):
|
|
print("Initialize params from:%s"%"./module/shufflenetv2.pth")
|
|
self.load_state_dict(torch.load("./module/shufflenetv2.pth"), strict = True)
|