Featured image of post 手撸一下Resnet18

手撸一下Resnet18

回忆一下准备涉猎nlp,pytorch要忘光光了

手撸一下Resnet18

模型结构参考
模型结构参考

模型细节可以参考CSDN上的一篇博客

import torch
import torch.nn as nn


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


class Block(nn.Module):
    def __init__(self, in_dim, dim, kernal_size, stride):
        super(Block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_dim,
                               out_channels=dim,
                               kernel_size=kernal_size,
                               stride=stride,
                               padding=1)
        self.bn1 = nn.BatchNorm2d(dim)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=dim,
                               out_channels=dim,
                               kernel_size=kernal_size,
                               stride=1,
                               padding=1)
        self.bn2 = nn.BatchNorm2d(dim)
        self.conv3 = nn.Conv2d(in_channels=dim,
                               out_channels=dim,
                               kernel_size=kernal_size,
                               stride=1,
                               padding=1)
        self.bn3 = nn.BatchNorm2d(dim)
        self.relu2 = nn.ReLU()
        self.conv4 = nn.Conv2d(in_channels=dim,
                               out_channels=dim,
                               kernel_size=kernal_size,
                               stride=1,
                               padding=1)
        self.bn4 = nn.BatchNorm2d(dim)

        # 维度不一致有两种情况,可能是stride改变了w*h,也可能是conv层改变了维度
        if stride == 2 or in_dim != dim:
            self.downsample = nn.Sequential(*[
                nn.Conv2d(in_dim, dim, kernel_size=1, stride=stride),
                nn.BatchNorm2d(dim)
            ])
        else:
            self.downsample = Identity()  # 写一个什么都不干的类,好处是在forward里面少了ifelse的判断

    def forward(self, x):
        h = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        # Identity(x)
        identity = self.downsample(h)  # h和x的维度可能不一致,因为in_dim不一定等于dim(out_dim)
        x = x + identity
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu2(x)
        x = self.conv4(x)
        x = self.bn4(x)

        return x


class ResNet18(nn.Module):
    def __init__(self, in_dim=64, num_class=10):  # 以十分类为例
        super(ResNet18, self).__init__()
        self.in_dim = in_dim
        # stem
        self.conv1 = nn.Conv2d(in_channels=3,
                               out_channels=in_dim,
                               kernel_size=7,  # 7有点大了,注意输入图片要大一些
                               stride=2,
                               padding=3)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3,
                                    stride=2,
                                    padding=1)
        # res block
        self.layer1 = self.generateLayer(dim=64,
                                         kernal_size=3,
                                         stride=2)
        self.layer2 = self.generateLayer(dim=128,
                                         kernal_size=3,
                                         stride=1)
        self.layer3 = self.generateLayer(dim=256,
                                         kernal_size=3,
                                         stride=1)
        self.layer4 = self.generateLayer(dim=512,
                                         kernal_size=3,
                                         stride=1)

        # head
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_features=512, out_features=10)

    def generateLayer(self, dim, kernal_size, stride):
        layer_list = []
        layer_list.append(Block(self.in_dim, dim, kernal_size, stride))
        self.in_dim = dim
        layer_list.append(Block(self.in_dim, dim, kernal_size, 1))

        return nn.Sequential(*layer_list)

    def forward(self, x):
        # stem
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # res block
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # head
        x = self.avgpool(x)  # Batch * 1 * 1 * C
        x = x.flatten(1)  # Batch * C
        x = self.fc(x)  # Batch * num_class

        return x


def main():
    t = torch.randn([4, 3, 224, 224])
    model = ResNet18()
    out = model(t)
    print(out.shape)
    return


if __name__ == "__main__":
    main()


值得注意的是使用downsample处理维度问题,利用identity类搞一个什么都不干的东西,可以避免在forward中进行if-else判断

之后写一写train的部分吧,然后上数据集测一下效果

Licensed under CC BY-NC-SA 4.0
comments powered by Disqus
Built with Hugo
Theme Stack designed by Jimmy