Pytorchでデータセットの作り方
Pytorchでの、データの読み込みとデータセットの作り方を説明します。
KaggleやSIGNATEでは、画像データとは別に、画像のIdと画像のLabelとcsvファイルが用意されていることが多いです。そのデータを利用して、データセットを作る方法もあります。
また、フォルダー分けされた画像データを自動的に読み込み、データセットを作る方法もあります。すごく簡単なので、非常に便利です。
データセットクラスを作る
Datasetを継承します。
-
- def __init__(self):
初期化
-
- def __len__(self):
データの個数を返す
__getitem__で、取得する画像の数を表す
-
- def __getitem__(self, i):
def __len__(self):で返却された数値分のiが処理される
for i in range()のような感じ
class Create_Datasets(Dataset):
def __init__(self):
def __len__(self):
return len()
def __getitem__(self, i):
return image, label
csvファイルのデータを利用する
csvファイルには、画像ファイルの名前(Id)と画像のラベル(Label)のペアで保存されているとします。
下のような感じに保存されています。

Idは、画像ファイルの名前ですが、拡張子がありません。
次のCreate_Datasets内の読み込み部分で、拡張子を追加する必要があります。
import numpy as np
import pandas as pd
import os
from torch.utils.data import Dataset
class Create_Datasets(Dataset):
def __init__(self, dir_name, csv_file, data_transform):
self.dir_name = dir_name
self.df = pd.read_csv(self.dir_name + csv_file)
self.data_transform = data_transform
def __len__(self):
return len(self.df)
def __getitem__(self, i):
file = df['id'][i]
label = np.array(df['label'][i])
image = Image.open(self.dir_name + "train/" + file + ".tif")
image = self.data_transform(image)
return image, label
フォルダーを利用する
dir_nameは、画像のフォルダを指定します。
ただし、下で紹介するプログラムは、ラベルなしです。
import numpy as np
import pandas as pd
import os
from torch.utils.data import Dataset
class Creat_Datasets(Dataset):
def __init__(self, dir_name, csv_file, data_transform):
self.dir_name = dir_name
self.file = os.listdir(dir_name)
self.data_transform = data_transform
def __len__(self):
return len(self.file)
def __getitem__(self, i):
image = Image.open(self.dir_name + self.file[i])
image = self.data_transform(image)
return image
データの処理
data_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

