Pytorchのモデル管理とパラメータ保存&ロード
ファイルは下のように管理することにします。
学習するときは、learning_unit.ipynbを使用
推論は、eval_unit.ipynbを使用することにします。
┝models │ └model.py ┝param | └ model.pth ┝learning_unit.ipynb ┝eval_unit.ipynb
モデルの学習
PyTorchでは、モデルそのものを保存することはオススメされていません。
model.state_dict()を保存します。
#====== 保存 ======= torch.save(model.state_dict(), "./params/model.pth")
例
import torch import torch.nn as nn from torchvision import transforms from models.model import Net device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #====== データセット ======= #====== モデル(ネットワーク) 設定====== model = Net().to(device) #エポック数 num_epochs = 10 #損失関数 criterion = nn.CrossEntropyLoss() #学習率 learning_rate = 0.05 #最適化 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) #====== 学習 ======= def learning(): #====== 保存 ======= torch.save(model.state_dict(), "./params/model.pth")
学習済みモデルで推論
学習ユニットで、保存したmodel.state_dict()を読み込みます。
torch.load()でファイルを読み込みます。
model.load_state_dict()で、モデルに適用します。
この時、学習ユニットと推論ユニットで分けた場合でdeviceが依存してしまいます。
torch.load()の引数に”map_location=device”でdeviceを選択することで解決できます。
#====== ロード ======= model.load_state_dict(torch.load("model.pth", map_location=device))
例
import torch import torch.nn as nn from torchvision import transforms from models.model import Net device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #====== データセット ======= #====== モデル(ネットワーク) 設定====== model = Net().to(device) #====== ロード ======= model.load_state_dict(torch.load("model.pth", map_location=device)) #====== 推論 ======