Pytorchのデータセットを見てみる
PyTorchのtorchvision.datasetsにはMNISTなどのデータセットが用意されています。
これらのデータセットは、datasetとして保持しているので、画像やラベルの確認も容易にできます。
https://betashort-lab.com/データサイエンス/pytorchのdatasetとdataloaderからbatchサイズ・target・etc-を知りたい/”
datasetなので、dataloaderに渡して使います。
MNIST
MNISTは、0~9までの手書き数字画像のデータセットです。
from torchvision.datasets import MNIST from torchvision import transforms mnist_train = MNIST("MNIST", train=True, download=True, transform=transforms.ToTensor()) mnist_test = MNIST("MNIST", train=False, download=True, transform=transforms.ToTensor()) batch_size = 128 train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True) test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)
FASHION-MNIST
from torchvision.datasets import FashionMNIST from torchvision import transforms fashion_mnist_train = FashionMNIST("FashionMNIST", train=True, download=True, transform=transforms.ToTensor()) fashion_mnist_test = FashionMNIST("FashionMNIST", train=False, download=True, transform=transforms.ToTensor())
EMNIST
from torchvision.datasets import EMNIST from torchvision import transforms emnist_train = EMNIST("EMNIST", split="byclass", train=True, download=True, transform=transforms.ToTensor()) emnist_test = EMNIST("EMNIST", split="byclass", train=False, download=True, transform=transforms.ToTensor())
CIFAR10
from torchvision.datasets import CIFAR10 from torchvision import transforms CIFAR10_train = CIFAR10("CIFAR10", train=True, download=True, transform=transforms.ToTensor()) CIFAR10_test = CIFAR10("CIFAR10", train=False, download=True, transform=transforms.ToTensor())
CIFAR100
from torchvision.datasets import CIFAR100 from torchvision import transforms CIFAR100_train = CIFAR100("CIFAR100", train=True, download=True, transform=transforms.ToTensor()) CIFAR100_test = CIFAR100("CIFAR100", train=False, download=True, transform=transforms.ToTensor())
STL10
from torchvision.datasets import STL10 from torchvision import transforms STL10_train = STL10("STL10", split='train', download=True, transform=transforms.ToTensor()) STL10_test = STL10("STL10", split='test', download=True, transform=transforms.ToTensor())
SVHN
from torchvision.datasets import SVHN from torchvision import transforms SVHN_train = SVHN("SVHN", split='train', download=True, transform=transforms.ToTensor()) SVHN_test = SVHN("SVHN", split='test', download=True, transform=transforms.ToTensor())
参考
https://amzn.to/2VNZs3U