PyTorchで重みの確認と、畳み込み層のカーネルの可視化

ディープラーニング




PyTorchで重みの確認と、畳み込み層のカーネルの可視化

pytorchのモデル管理とパラメータ保存ロード

実行したjupyternotebookはGitHubで公開しています。
https://github.com/betashort/python/blob/master/NN/PyTorch_CNN_Visualize_ConvWeight.ipynb

重みの確認

model.state_dict()

ネットワークを次のように定義したとして、

# ネットワークの定義
class CNNNet(nn.Module):
    def __init__(self):
        super(CNNNet,self).__init__()
        #畳み込み層
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = 16, kernel_size = 3, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 3, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        #全結合層
        self.dence = nn.Sequential(
            nn.Linear(32 * 6 * 6, 128),
            nn.ReLU(inplace=True),
            #nn.Dropout(p=0.2),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            #nn.Dropout(p=0.2),
            nn.Linear(64, 10),
        )
         
    #順伝播
    def forward(self,x):
         
        out = self.conv_layers(x)
        #Flatten
        out = out.view(out.size(0), -1)
        #全結合層
        out = self.dence(out)
         
        return out

model = CNNNet()

このネットワークを学習した後、model.state_dict()を実行すると、層の実行順にOrderedDictで格納されたパラメータが取得できます。

参照
pyTorchのNetworkのパラメータの閲覧と書き換え - Qiita
2019/9/29 投稿 2019/11/8 やや見やすく編集(主観) 2020/2/1 SGDの解説Link追加 2020/4/22 パラメータを途中で書き換える方法を追加した 0. この記事の対象者 pythonを触ったこ...

畳み込み層のカーネルの可視化

上で、model.state_dict()で、層の重み含め、パラメータを取得できることが確認できました。

試しに、model.state_dict()を実行すると以下のような結果が出ます。

OrderedDict([('conv_layers.0.weight', tensor([[[[-0.4207, -0.4480, -0.6588],
                        [-0.6090, -0.5266, -0.3595],
                        [-0.2160, -0.0956, -0.4485]]],
              
              
                      [[[-0.6202, -0.5363, -0.3624],
                        [-0.6650, -0.8477, -0.5149],
                        [-0.8891, -0.6004, -0.5109]]],
・・・・省略・・・・

1層目の畳み込み層のカーネルを可視化する場合は、’conv_layers.0.weight’のvalueを取得して、可視化させます。

1層目の畳み込み層の1つ目のカーネルの場合は下のようになります。

conv1_1 = np.array(model.state_dict()['conv_layers.0.weight'])[0]

plt.imshow(conv1_1.reshape(3,3), cmap='gray')

1層目の畳み込み層のカーネル全てを可視化する場合は下のようになります。

plt.figure(figsize=(16,16))
for i in range(16):
    kernel_weight = np.array(model.state_dict()['conv_layers.0.weight'])[i].reshape(3,3)
    
    plt.subplot(4, 4, i+1)
    plt.title(f"kernel_weight : {i+1}")
    plt.imshow(kernel_weight, cmap="gray")

plt.savefig("kernel_visual.png")
plt.show()

参考

  1. https://github.com/betashort/python/blob/master/NN/PyTorch_CNN_Visualize_ConvWeight.ipynb
タイトルとURLをコピーしました