Pytorchでオートエンコーダー


Pytorchでオートエンコーダー

Pytorchでオートエンコーダーを作ります。

多分オートエンコーダーだと思う。

オートエンコーダーについて調べ中。。。
下の図のようなはず。。。

結果

input

output

プログラム

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader,TensorDataset
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST

mnist_train = MNIST("MNIST", train=True, download=True, transform= transforms.ToTensor())
mnist_test = MNIST("MNIST", train=False, download=True, transform= transforms.ToTensor())
batch_size = 32
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

class autoencoder(nn.Module):
    
    def __init__(self):
        super(autoencoder, self).__init__()
        
        self.encode_net = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 64)
        )
        
        self.decode_net = nn.Sequential(
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 28*28)
        )
        
    def forward(self, x):
        out = self.encode_net(x)
        out = self.decode_net(out)
        
        return out
            
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = autoencoder().to(device)

criterion = nn.MSELoss()

optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

num_epochs = 30

train_loss_list = []

val_loss_list = []


for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    net.train()
    
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        optimizer.zero_grad()
        images = images.view(images.size(0), -1)
        outputs = net.forward(images)
        loss = criterion(outputs, images)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        
    avg_train_loss = train_loss / len(train_loader.dataset)
    net.eval()
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            images = images.view(images.size(0), -1)
            outputs = net.forward(images)
            loss = criterion(outputs, images)
            val_loss += loss.item()
            
            if epoch == 0 or (epoch+1) % 10 == 0:
                images = images.view(images.size(0), 1, 28, 28)
                save_image(images, 'p3inputs_{}_{}.png'.format(epoch+1,num_epochs))
                outputs = outputs.view(outputs.size(0), 1, 28, 28)
                save_image(outputs, 'p3autoencode_{}_{}.png'.format(epoch+1,num_epochs))
    avg_val_loss = val_loss / len(test_loader.dataset)
    
    train_loss_list.append(avg_train_loss)
    val_loss_list.append(avg_test_loss)
    print('Epoch[{}], train_loss:{:.4f}, val_loss:{:.4f}'
          .format(epoch+1, avg_train_loss, avg_val_loss))
    

参考