ResNetをPytorchで実装したいからメモする

画像処理とOpenCV




ResNetをPytorchで実装したいからメモする

今更ながら、ResNetを勉強します。

Pytorchで実装(写経+理解)します。

ResNet

resnet1

  • https://www.slideshare.net/KotaNagasato/resnet-82940994
  • 構造

    resnet2

  • https://arxiv.org/abs/1512.03385
  • ResNet by PyTorch

    BasicBlock

    class BasicBlock(nn.Module):
        
        def __init__(self, in_ch, out_ch, stride=1, downsample=None):
            super(ResidealBlock, sel).__init__()
            self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
            self.bn1 = nn.BatchNorm2d(out_ch)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
            self.bn2 = nn.BatchNorm2d(out_ch)
            self.downsample = downsample
            
        def forward(self, x):
            #スキップ
            residual = x
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
            out = self.conv2(out)
            out = self.bn2(out)
            
            #downsampleがNoneでなければ
            if self.downsample:
                residual = self.downsample(x)
                
            out += residual
            
            out = self.relu(out)
            
            return out
    

    BottleNeck

    class BottleNeckBlock(nn.Module):
        
        def __init__(self, in_ch, out_ch, stride=1, downsample=None):
            super(ResidealBlock, sel).__init__()
            self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
            self.bn1 = nn.BatchNorm2d(out_ch)
            
            self.relu = nn.ReLU(inplace=True)
            
            self.conv2 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
            self.bn2 = nn.BatchNorm2d(out_ch)
            
            self.conv3 = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
            self.bn3 = nn.BatchNorm2d(out_ch)
            
            self.downsample = downsample
            
        def forward(self, x):
            #スキップ
            residual = x
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)
            out = self.conv3(out)
            out = self.bn3(out)
            
            #downsampleがNoneでなければ
            if self.downsample:
                residual = self.downsample(x)
                
            out += residual
            
            out = self.relu(out)
            
            return out
    
    if (stride != 1) or (self.in_ch, != out_ch):
                downsampel = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, padding=0, bias=False),
                    nn.BatchNorm2d(out_ch),
                )
    

    参考

    1. https://www.kaggle.com/readilen/resnet-for-mnist-with-pytorch
    2. https://pytorch.org/docs/stable/_modules/torchvision/models/resnet.html
    3. https://qiita.com/supersaiakujin/items/935bbc9610d0f87607e8
    4. http://www.pabloruizruiz10.com/resources/CNNs/ResNet-PyTorch.html
    5. https://www.slideshare.net/KotaNagasato/resnet-82940994
    6. https://deepage.net/deep_learning/2016/11/30/resnet.html
    タイトルとURLをコピーしました