Pytorchでデータセットの作り方
Pytorchでの、データの読み込みとデータセットの作り方を説明します。
KaggleやSIGNATEでは、画像データとは別に、画像のIdと画像のLabelとcsvファイルが用意されていることが多いです。そのデータを利用して、データセットを作る方法もあります。
また、フォルダー分けされた画像データを自動的に読み込み、データセットを作る方法もあります。すごく簡単なので、非常に便利です。
データセットクラスを作る
Datasetを継承します。
-
- def __init__(self):
初期化
-
- def __len__(self):
データの個数を返す
__getitem__で、取得する画像の数を表す
-
- def __getitem__(self, i):
def __len__(self):で返却された数値分のiが処理される
for i in range()のような感じ
class Create_Datasets(Dataset): def __init__(self): def __len__(self): return len() def __getitem__(self, i): return image, label
csvファイルのデータを利用する
csvファイルには、画像ファイルの名前(Id)と画像のラベル(Label)のペアで保存されているとします。
下のような感じに保存されています。
Idは、画像ファイルの名前ですが、拡張子がありません。
次のCreate_Datasets内の読み込み部分で、拡張子を追加する必要があります。
import numpy as np import pandas as pd import os from torch.utils.data import Dataset class Create_Datasets(Dataset): def __init__(self, dir_name, csv_file, data_transform): self.dir_name = dir_name self.df = pd.read_csv(self.dir_name + csv_file) self.data_transform = data_transform def __len__(self): return len(self.df) def __getitem__(self, i): file = df['id'][i] label = np.array(df['label'][i]) image = Image.open(self.dir_name + "train/" + file + ".tif") image = self.data_transform(image) return image, label
フォルダーを利用する
dir_nameは、画像のフォルダを指定します。
ただし、下で紹介するプログラムは、ラベルなしです。
import numpy as np import pandas as pd import os from torch.utils.data import Dataset class Creat_Datasets(Dataset): def __init__(self, dir_name, csv_file, data_transform): self.dir_name = dir_name self.file = os.listdir(dir_name) self.data_transform = data_transform def __len__(self): return len(self.file) def __getitem__(self, i): image = Image.open(self.dir_name + self.file[i]) image = self.data_transform(image) return image
データの処理
data_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])