手撸一下Resnet18
模型细节可以参考CSDN上的一篇博客
import torch
import torch.nn as nn
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Block(nn.Module):
def __init__(self, in_dim, dim, kernal_size, stride):
super(Block, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_dim,
out_channels=dim,
kernel_size=kernal_size,
stride=stride,
padding=1)
self.bn1 = nn.BatchNorm2d(dim)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=dim,
out_channels=dim,
kernel_size=kernal_size,
stride=1,
padding=1)
self.bn2 = nn.BatchNorm2d(dim)
self.conv3 = nn.Conv2d(in_channels=dim,
out_channels=dim,
kernel_size=kernal_size,
stride=1,
padding=1)
self.bn3 = nn.BatchNorm2d(dim)
self.relu2 = nn.ReLU()
self.conv4 = nn.Conv2d(in_channels=dim,
out_channels=dim,
kernel_size=kernal_size,
stride=1,
padding=1)
self.bn4 = nn.BatchNorm2d(dim)
# 维度不一致有两种情况,可能是stride改变了w*h,也可能是conv层改变了维度
if stride == 2 or in_dim != dim:
self.downsample = nn.Sequential(*[
nn.Conv2d(in_dim, dim, kernel_size=1, stride=stride),
nn.BatchNorm2d(dim)
])
else:
self.downsample = Identity() # 写一个什么都不干的类,好处是在forward里面少了ifelse的判断
def forward(self, x):
h = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
# Identity(x)
identity = self.downsample(h) # h和x的维度可能不一致,因为in_dim不一定等于dim(out_dim)
x = x + identity
x = self.conv3(x)
x = self.bn3(x)
x = self.relu2(x)
x = self.conv4(x)
x = self.bn4(x)
return x
class ResNet18(nn.Module):
def __init__(self, in_dim=64, num_class=10): # 以十分类为例
super(ResNet18, self).__init__()
self.in_dim = in_dim
# stem
self.conv1 = nn.Conv2d(in_channels=3,
out_channels=in_dim,
kernel_size=7, # 7有点大了,注意输入图片要大一些
stride=2,
padding=3)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3,
stride=2,
padding=1)
# res block
self.layer1 = self.generateLayer(dim=64,
kernal_size=3,
stride=2)
self.layer2 = self.generateLayer(dim=128,
kernal_size=3,
stride=1)
self.layer3 = self.generateLayer(dim=256,
kernal_size=3,
stride=1)
self.layer4 = self.generateLayer(dim=512,
kernal_size=3,
stride=1)
# head
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(in_features=512, out_features=10)
def generateLayer(self, dim, kernal_size, stride):
layer_list = []
layer_list.append(Block(self.in_dim, dim, kernal_size, stride))
self.in_dim = dim
layer_list.append(Block(self.in_dim, dim, kernal_size, 1))
return nn.Sequential(*layer_list)
def forward(self, x):
# stem
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
# res block
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# head
x = self.avgpool(x) # Batch * 1 * 1 * C
x = x.flatten(1) # Batch * C
x = self.fc(x) # Batch * num_class
return x
def main():
t = torch.randn([4, 3, 224, 224])
model = ResNet18()
out = model(t)
print(out.shape)
return
if __name__ == "__main__":
main()
值得注意的是使用downsample
处理维度问题,利用identity
类搞一个什么都不干的东西,可以避免在forward中进行if-else
判断
之后写一写train的部分吧,然后上数据集测一下效果