PyTorchのDataloader
torch.utils.data.DataLoader()
流れ
- Datasetを用意する
- DataLoaderにDatasetを渡す
- DataLoaderからBatchごとのデータをもらって学習する
Datasetは下のようなもの
TensorDataset(X_train, y_train)
画像のDatasetの作り方は下のリンクから
PyTorchのdatasetsで画像データセットを作る
dataloader
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
指定するargumentsは、だいたい下の4つだと思います。
samplerはbatch内の重みを決めるためのものなので、使わない場合も多いです。
- dataset: Datasetをを指定
- batch_size: batchサイズを指定
- shuffle: シャッフルするかどうか
- sampler: batchの中身の配分
sampler
datasetsのBatchを決めるための機能
samplerを渡すときは、shuffleがFalseだとエラーが出ます。
![](https://qiita-user-contents.imgix.net/https%3A%2F%2Fcdn.qiita.com%2Fassets%2Fpublic%2Farticle-ogp-background-412672c5f0600ab9a64263b751f1bc81.png?ixlib=rb-4.0.0&w=1200&mark64=aHR0cHM6Ly9xaWl0YS11c2VyLWNvbnRlbnRzLmltZ2l4Lm5ldC9-dGV4dD9peGxpYj1yYi00LjAuMCZ3PTk3MiZoPTM3OCZ0eHQ9UHl0b3JjaCVFMyU4MSVBRURhdGFsb2FkZXIlRTMlODElQThTYW1wbGVyJUUzJTgxJUFFJUU0JUJEJUJGJUUzJTgxJTg0JUU2JTk2JUI5JnR4dC1hbGlnbj1sZWZ0JTJDdG9wJnR4dC1jb2xvcj0lMjMyMTIxMjEmdHh0LWZvbnQ9SGlyYWdpbm8lMjBTYW5zJTIwVzYmdHh0LXNpemU9NTYmcz03ZWZiZDMxYmQyOWU2YzNkOTRjOGVmNDA2OWMyMTg5Nw&mark-x=142&mark-y=57&blend64=aHR0cHM6Ly9xaWl0YS11c2VyLWNvbnRlbnRzLmltZ2l4Lm5ldC9-dGV4dD9peGxpYj1yYi00LjAuMCZoPTc2Jnc9NzcwJnR4dD0lNDB0b21wJnR4dC1jb2xvcj0lMjMyMTIxMjEmdHh0LWZvbnQ9SGlyYWdpbm8lMjBTYW5zJTIwVzYmdHh0LXNpemU9MzYmdHh0LWFsaWduPWxlZnQlMkN0b3Amcz1jNjNmY2U0YjRmZTkwNTQyNjQ2N2I0OGM1ZWQxYjUyOA&blend-x=142&blend-y=486&blend-mode=normal&s=934a33b976f1ba12cfce62aa69d86efb)
PytorchのDataloaderとSamplerの使い方 - Qiita
...
classのweightで
batch_size = 20 class_sample_count = [10, 1, 20, 3, 4] # dataset has 10 class-1 samples, 1 class-2 samples, etc. weights = 1 / (torch.Tensor(class_sample_count)*1e-5) weights = weights.double() sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, batch_size) trainloader = torch.utils.data.DataLoader(trainDataset, batch_size = batch_size, sampler = sampler)