Pytorchの画像データセット-torchvision.datasets

ディープラーニング




Pytorchのデータセットを見てみる

PyTorchのtorchvision.datasetsにはMNISTなどのデータセットが用意されています。

これらのデータセットは、datasetとして保持しているので、画像やラベルの確認も容易にできます。

https://betashort-lab.com/データサイエンス/pytorchのdatasetとdataloaderからbatchサイズ・target・etc-を知りたい/”

datasetなので、dataloaderに渡して使います。

MNIST


MNISTは、0~9までの手書き数字画像のデータセットです。

from torchvision.datasets import MNIST
from torchvision import transforms

mnist_train = MNIST("MNIST", train=True, download=True, transform=transforms.ToTensor())

mnist_test = MNIST("MNIST", train=False, download=True, transform=transforms.ToTensor())

batch_size = 128
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

FASHION-MNIST

from torchvision.datasets import FashionMNIST
from torchvision import transforms

fashion_mnist_train = FashionMNIST("FashionMNIST", train=True, download=True, transform=transforms.ToTensor())

fashion_mnist_test = FashionMNIST("FashionMNIST", train=False, download=True, transform=transforms.ToTensor())

EMNIST

from torchvision.datasets import EMNIST
from torchvision import transforms

emnist_train = EMNIST("EMNIST", split="byclass", train=True, download=True, transform=transforms.ToTensor())

emnist_test = EMNIST("EMNIST", split="byclass", train=False, download=True, transform=transforms.ToTensor())

CIFAR10

from torchvision.datasets import CIFAR10
from torchvision import transforms

CIFAR10_train = CIFAR10("CIFAR10", train=True, download=True, transform=transforms.ToTensor())

CIFAR10_test = CIFAR10("CIFAR10", train=False, download=True, transform=transforms.ToTensor())

CIFAR100

from torchvision.datasets import CIFAR100
from torchvision import transforms

CIFAR100_train = CIFAR100("CIFAR100", train=True, download=True, transform=transforms.ToTensor())

CIFAR100_test = CIFAR100("CIFAR100", train=False, download=True, transform=transforms.ToTensor())

STL10

from torchvision.datasets import STL10
from torchvision import transforms

STL10_train = STL10("STL10", split='train', download=True, transform=transforms.ToTensor())

STL10_test = STL10("STL10", split='test', download=True, transform=transforms.ToTensor())

SVHN

from torchvision.datasets import SVHN
from torchvision import transforms

SVHN_train = SVHN("SVHN", split='train', download=True, transform=transforms.ToTensor())

SVHN_test = SVHN("SVHN", split='test', download=True, transform=transforms.ToTensor())

参考

Bitly
タイトルとURLをコピーしました