PyTorchのDataloader -samplerとclass_weightなども-

ディープラーニング




PyTorchのDataloader

torch.utils.data.DataLoader()

流れ

  • Datasetを用意する
  • DataLoaderにDatasetを渡す
  • DataLoaderからBatchごとのデータをもらって学習する

Datasetは下のようなもの

TensorDataset(X_train, y_train)

画像のDatasetの作り方は下のリンクから
PyTorchのdatasetsで画像データセットを作る

dataloader

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

指定するargumentsは、だいたい下の4つだと思います。
samplerはbatch内の重みを決めるためのものなので、使わない場合も多いです。

  1. dataset: Datasetをを指定
  2. batch_size: batchサイズを指定
  3. shuffle: シャッフルするかどうか
  4. sampler: batchの中身の配分

sampler

datasetsのBatchを決めるための機能
samplerを渡すときは、shuffleがFalseだとエラーが出ます。

PytorchのDataloaderとSamplerの使い方 - Qiita
Dataloaderとは datasetsからバッチごとに取り出すことを目的に使われます。 基本的にtorch.utils.data.DataLoaderを使います。 イメージとしてはdatasetsはデータすべてのリスト、Da...

classのweightで

batch_size = 20
class_sample_count = [10, 1, 20, 3, 4] # dataset has 10 class-1 samples, 1 class-2 samples, etc.
weights = 1 / (torch.Tensor(class_sample_count)*1e-5)
weights = weights.double()
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, batch_size)
trainloader = torch.utils.data.DataLoader(trainDataset, batch_size = batch_size, sampler = sampler)

参考

タイトルとURLをコピーしました