Pytorchの画像データセット-torchvision.datasets

ディープラーニング




Pytorchのデータセットを見てみる

PyTorchのtorchvision.datasetsにはMNISTなどのデータセットが用意されています。

これらのデータセットは、datasetとして保持しているので、画像やラベルの確認も容易にできます。

https://betashort-lab.com/データサイエンス/pytorchのdatasetとdataloaderからbatchサイズ・target・etc-を知りたい/”

datasetなので、dataloaderに渡して使います。

MNIST


MNISTは、0~9までの手書き数字画像のデータセットです。

01
02
03
04
05
06
07
08
09
10
from torchvision.datasets import MNIST
from torchvision import transforms
 
mnist_train = MNIST("MNIST", train=True, download=True, transform=transforms.ToTensor())
 
mnist_test = MNIST("MNIST", train=False, download=True, transform=transforms.ToTensor())
 
batch_size = 128
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

FASHION-MNIST

1
2
3
4
5
6
from torchvision.datasets import FashionMNIST
from torchvision import transforms
 
fashion_mnist_train = FashionMNIST("FashionMNIST", train=True, download=True, transform=transforms.ToTensor())
 
fashion_mnist_test = FashionMNIST("FashionMNIST", train=False, download=True, transform=transforms.ToTensor())

EMNIST

1
2
3
4
5
6
from torchvision.datasets import EMNIST
from torchvision import transforms
 
emnist_train = EMNIST("EMNIST", split="byclass", train=True, download=True, transform=transforms.ToTensor())
 
emnist_test = EMNIST("EMNIST", split="byclass", train=False, download=True, transform=transforms.ToTensor())

CIFAR10

1
2
3
4
5
6
from torchvision.datasets import CIFAR10
from torchvision import transforms
 
CIFAR10_train = CIFAR10("CIFAR10", train=True, download=True, transform=transforms.ToTensor())
 
CIFAR10_test = CIFAR10("CIFAR10", train=False, download=True, transform=transforms.ToTensor())

CIFAR100

1
2
3
4
5
6
from torchvision.datasets import CIFAR100
from torchvision import transforms
 
CIFAR100_train = CIFAR100("CIFAR100", train=True, download=True, transform=transforms.ToTensor())
 
CIFAR100_test = CIFAR100("CIFAR100", train=False, download=True, transform=transforms.ToTensor())

STL10

1
2
3
4
5
6
from torchvision.datasets import STL10
from torchvision import transforms
 
STL10_train = STL10("STL10", split='train', download=True, transform=transforms.ToTensor())
 
STL10_test = STL10("STL10", split='test', download=True, transform=transforms.ToTensor())

SVHN

1
2
3
4
5
6
from torchvision.datasets import SVHN
from torchvision import transforms
 
SVHN_train = SVHN("SVHN", split='train', download=True, transform=transforms.ToTensor())
 
SVHN_test = SVHN("SVHN", split='test', download=True, transform=transforms.ToTensor())

参考

https://amzn.to/2VNZs3U
タイトルとURLをコピーしました