PytorchのDatasetsで画像データセットを作る

ディープラーニング





Pytorchでデータセットの作り方

Pytorchでの、データの読み込みとデータセットの作り方を説明します。

KaggleやSIGNATEでは、画像データとは別に、画像のIdと画像のLabelとcsvファイルが用意されていることが多いです。そのデータを利用して、データセットを作る方法もあります。

また、フォルダー分けされた画像データを自動的に読み込み、データセットを作る方法もあります。すごく簡単なので、非常に便利です。

データセットクラスを作る

Datasetを継承します。

  1. def __init__(self):
  2. 初期化

  3. def __len__(self):
  4. データの個数を返す
    __getitem__で、取得する画像の数を表す

  5. def __getitem__(self, i):
  6. def __len__(self):で返却された数値分のiが処理される
    for i in range()のような感じ

class Create_Datasets(Dataset):
    
    def __init__(self):

    def __len__(self):
        
        return len()
    
    def __getitem__(self, i):
        
        return image, label

csvファイルのデータを利用する

csvファイルには、画像ファイルの名前(Id)と画像のラベル(Label)のペアで保存されているとします。
下のような感じに保存されています。

Idは、画像ファイルの名前ですが、拡張子がありません。

次のCreate_Datasets内の読み込み部分で、拡張子を追加する必要があります。

import numpy as np
import pandas as pd
import os
from torch.utils.data import Dataset

class Create_Datasets(Dataset):
    
    def __init__(self, dir_name, csv_file, data_transform):
        self.dir_name = dir_name
        self.df = pd.read_csv(self.dir_name + csv_file)
        self.data_transform = data_transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, i):
        file = df['id'][i]
        label = np.array(df['label'][i])
        image = Image.open(self.dir_name + "train/" + file + ".tif")
        image = self.data_transform(image)

        return image, label

フォルダーを利用する

dir_nameは、画像のフォルダを指定します。
ただし、下で紹介するプログラムは、ラベルなしです。

import numpy as np
import pandas as pd
import os
from torch.utils.data import Dataset

class Creat_Datasets(Dataset):
    def __init__(self, dir_name, csv_file, data_transform):
        self.dir_name = dir_name
        self.file = os.listdir(dir_name)
        self.data_transform = data_transform
        
    def __len__(self):
        return len(self.file)
    
    def __getitem__(self, i):
        
        image = Image.open(self.dir_name + self.file[i])
        image = self.data_transform(image)
        return image

データの処理

data_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
タイトルとURLをコピーしました