Pytorchのデータセットを見てみる
PyTorchのtorchvision.datasetsにはMNISTなどのデータセットが用意されています。
これらのデータセットは、datasetとして保持しているので、画像やラベルの確認も容易にできます。
https://betashort-lab.com/データサイエンス/pytorchのdatasetとdataloaderからbatchサイズ・target・etc-を知りたい/”
datasetなので、dataloaderに渡して使います。
MNIST

MNISTは、0~9までの手書き数字画像のデータセットです。
from torchvision.datasets import MNIST
from torchvision import transforms
mnist_train = MNIST("MNIST", train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST("MNIST", train=False, download=True, transform=transforms.ToTensor())
batch_size = 128
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)
FASHION-MNIST

from torchvision.datasets import FashionMNIST
from torchvision import transforms
fashion_mnist_train = FashionMNIST("FashionMNIST", train=True, download=True, transform=transforms.ToTensor())
fashion_mnist_test = FashionMNIST("FashionMNIST", train=False, download=True, transform=transforms.ToTensor())
EMNIST
from torchvision.datasets import EMNIST
from torchvision import transforms
emnist_train = EMNIST("EMNIST", split="byclass", train=True, download=True, transform=transforms.ToTensor())
emnist_test = EMNIST("EMNIST", split="byclass", train=False, download=True, transform=transforms.ToTensor())
CIFAR10

from torchvision.datasets import CIFAR10
from torchvision import transforms
CIFAR10_train = CIFAR10("CIFAR10", train=True, download=True, transform=transforms.ToTensor())
CIFAR10_test = CIFAR10("CIFAR10", train=False, download=True, transform=transforms.ToTensor())
CIFAR100

from torchvision.datasets import CIFAR100
from torchvision import transforms
CIFAR100_train = CIFAR100("CIFAR100", train=True, download=True, transform=transforms.ToTensor())
CIFAR100_test = CIFAR100("CIFAR100", train=False, download=True, transform=transforms.ToTensor())
STL10

from torchvision.datasets import STL10
from torchvision import transforms
STL10_train = STL10("STL10", split='train', download=True, transform=transforms.ToTensor())
STL10_test = STL10("STL10", split='test', download=True, transform=transforms.ToTensor())
SVHN

from torchvision.datasets import SVHN
from torchvision import transforms
SVHN_train = SVHN("SVHN", split='train', download=True, transform=transforms.ToTensor())
SVHN_test = SVHN("SVHN", split='test', download=True, transform=transforms.ToTensor())
参考
https://amzn.to/2VNZs3U
