Pytorchのモデル管理とパラメータ保存&ロード

ディープラーニング




Pytorchのモデル管理とパラメータ保存&ロード

ファイルは下のように管理することにします。

学習するときは、learning_unit.ipynbを使用
推論は、eval_unit.ipynbを使用することにします。

┝models
│   └model.py
┝param
|   └ model.pth
┝learning_unit.ipynb
┝eval_unit.ipynb

モデルの学習

PyTorchでは、モデルそのものを保存することはオススメされていません。
model.state_dict()を保存します。

#====== 保存 =======
torch.save(model.state_dict(), "./params/model.pth")

import torch
import torch.nn as nn
from torchvision import transforms

from models.model import Net

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#====== データセット =======

#====== モデル(ネットワーク) 設定======
model = Net().to(device)

#エポック数
num_epochs = 10
#損失関数
criterion = nn.CrossEntropyLoss()
#学習率
learning_rate = 0.05
#最適化
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#====== 学習 =======
def learning():

#====== 保存 =======
torch.save(model.state_dict(), "./params/model.pth")

学習済みモデルで推論

学習ユニットで、保存したmodel.state_dict()を読み込みます。

torch.load()でファイルを読み込みます。
model.load_state_dict()で、モデルに適用します。

この時、学習ユニットと推論ユニットで分けた場合でdeviceが依存してしまいます。
torch.load()の引数に”map_location=device”でdeviceを選択することで解決できます。

#====== ロード =======
model.load_state_dict(torch.load("model.pth", map_location=device))

import torch
import torch.nn as nn
from torchvision import transforms

from models.model import Net

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#====== データセット =======

#====== モデル(ネットワーク) 設定======
model = Net().to(device)

#====== ロード =======
model.load_state_dict(torch.load("model.pth", map_location=device))

#====== 推論 ======

参考

SAVING AND LOADING MODELS|PyTorch

タイトルとURLをコピーしました