Pytorchのデータセットを見てみる
PyTorchのtorchvision.datasetsにはMNISTなどのデータセットが用意されています。
これらのデータセットは、datasetとして保持しているので、画像やラベルの確認も容易にできます。
https://betashort-lab.com/データサイエンス/pytorchのdatasetとdataloaderからbatchサイズ・target・etc-を知りたい/”
datasetなので、dataloaderに渡して使います。
MNIST
MNISTは、0~9までの手書き数字画像のデータセットです。
01 02 03 04 05 06 07 08 09 10 | from torchvision.datasets import MNIST from torchvision import transforms mnist_train = MNIST( "MNIST" , train = True , download = True , transform = transforms.ToTensor()) mnist_test = MNIST( "MNIST" , train = False , download = True , transform = transforms.ToTensor()) batch_size = 128 train_loader = DataLoader(mnist_train, batch_size = batch_size, shuffle = True ) test_loader = DataLoader(mnist_test, batch_size = batch_size, shuffle = True ) |
FASHION-MNIST
1 2 3 4 5 6 | from torchvision.datasets import FashionMNIST from torchvision import transforms fashion_mnist_train = FashionMNIST( "FashionMNIST" , train = True , download = True , transform = transforms.ToTensor()) fashion_mnist_test = FashionMNIST( "FashionMNIST" , train = False , download = True , transform = transforms.ToTensor()) |
EMNIST
1 2 3 4 5 6 | from torchvision.datasets import EMNIST from torchvision import transforms emnist_train = EMNIST( "EMNIST" , split = "byclass" , train = True , download = True , transform = transforms.ToTensor()) emnist_test = EMNIST( "EMNIST" , split = "byclass" , train = False , download = True , transform = transforms.ToTensor()) |
CIFAR10
1 2 3 4 5 6 | from torchvision.datasets import CIFAR10 from torchvision import transforms CIFAR10_train = CIFAR10( "CIFAR10" , train = True , download = True , transform = transforms.ToTensor()) CIFAR10_test = CIFAR10( "CIFAR10" , train = False , download = True , transform = transforms.ToTensor()) |
CIFAR100
1 2 3 4 5 6 | from torchvision.datasets import CIFAR100 from torchvision import transforms CIFAR100_train = CIFAR100( "CIFAR100" , train = True , download = True , transform = transforms.ToTensor()) CIFAR100_test = CIFAR100( "CIFAR100" , train = False , download = True , transform = transforms.ToTensor()) |
STL10
1 2 3 4 5 6 | from torchvision.datasets import STL10 from torchvision import transforms STL10_train = STL10( "STL10" , split = 'train' , download = True , transform = transforms.ToTensor()) STL10_test = STL10( "STL10" , split = 'test' , download = True , transform = transforms.ToTensor()) |
SVHN
1 2 3 4 5 6 | from torchvision.datasets import SVHN from torchvision import transforms SVHN_train = SVHN( "SVHN" , split = 'train' , download = True , transform = transforms.ToTensor()) SVHN_test = SVHN( "SVHN" , split = 'test' , download = True , transform = transforms.ToTensor()) |
参考
https://amzn.to/2VNZs3U