PyTorchのDataloader
torch.utils.data.DataLoader()
流れ
- Datasetを用意する
- DataLoaderにDatasetを渡す
- DataLoaderからBatchごとのデータをもらって学習する
Datasetは下のようなもの
TensorDataset(X_train, y_train)
画像のDatasetの作り方は下のリンクから
PyTorchのdatasetsで画像データセットを作る
dataloader
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
指定するargumentsは、だいたい下の4つだと思います。
samplerはbatch内の重みを決めるためのものなので、使わない場合も多いです。
- dataset: Datasetをを指定
- batch_size: batchサイズを指定
- shuffle: シャッフルするかどうか
- sampler: batchの中身の配分
sampler
datasetsのBatchを決めるための機能
samplerを渡すときは、shuffleがFalseだとエラーが出ます。
PytorchのDataloaderとSamplerの使い方 - Qiita
...
classのweightで
batch_size = 20 class_sample_count = [10, 1, 20, 3, 4] # dataset has 10 class-1 samples, 1 class-2 samples, etc. weights = 1 / (torch.Tensor(class_sample_count)*1e-5) weights = weights.double() sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, batch_size) trainloader = torch.utils.data.DataLoader(trainDataset, batch_size = batch_size, sampler = sampler)