Pytorchのモデル管理とパラメータ保存&ロード
ファイルは下のように管理することにします。
学習するときは、learning_unit.ipynbを使用
推論は、eval_unit.ipynbを使用することにします。
┝models │ └model.py ┝param | └ model.pth ┝learning_unit.ipynb ┝eval_unit.ipynb
モデルの学習
PyTorchでは、モデルそのものを保存することはオススメされていません。
model.state_dict()を保存します。
#====== 保存 ======= torch.save(model.state_dict(), "./params/model.pth")
例
import torch
import torch.nn as nn
from torchvision import transforms
from models.model import Net
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#====== データセット =======
#====== モデル(ネットワーク) 設定======
model = Net().to(device)
#エポック数
num_epochs = 10
#損失関数
criterion = nn.CrossEntropyLoss()
#学習率
learning_rate = 0.05
#最適化
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#====== 学習 =======
def learning():
#====== 保存 =======
torch.save(model.state_dict(), "./params/model.pth")
学習済みモデルで推論
学習ユニットで、保存したmodel.state_dict()を読み込みます。
torch.load()でファイルを読み込みます。
model.load_state_dict()で、モデルに適用します。
この時、学習ユニットと推論ユニットで分けた場合でdeviceが依存してしまいます。
torch.load()の引数に”map_location=device”でdeviceを選択することで解決できます。
#====== ロード =======
model.load_state_dict(torch.load("model.pth", map_location=device))
例
import torch
import torch.nn as nn
from torchvision import transforms
from models.model import Net
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#====== データセット =======
#====== モデル(ネットワーク) 設定======
model = Net().to(device)
#====== ロード =======
model.load_state_dict(torch.load("model.pth", map_location=device))
#====== 推論 ======
