Pytorchでオートエンコーダー

ディープラーニング




Pytorchでオートエンコーダー

Pytorchでオートエンコーダーを作ります。

多分オートエンコーダーだと思う。

オートエンコーダーについて調べ中。。。
下の図のようなはず。。。

結果

input

output

プログラム

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader,TensorDataset
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST

mnist_train = MNIST("MNIST", train=True, download=True, transform= transforms.ToTensor())
mnist_test = MNIST("MNIST", train=False, download=True, transform= transforms.ToTensor())
batch_size = 32
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

class autoencoder(nn.Module):
    
    def __init__(self):
        super(autoencoder, self).__init__()
        
        self.encode_net = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 64)
        )
        
        self.decode_net = nn.Sequential(
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 28*28)
        )
        
    def forward(self, x):
        out = self.encode_net(x)
        out = self.decode_net(out)
        
        return out
            
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = autoencoder().to(device)

criterion = nn.MSELoss()

optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

num_epochs = 30

train_loss_list = []

val_loss_list = []


for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    net.train()
    
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        optimizer.zero_grad()
        images = images.view(images.size(0), -1)
        outputs = net.forward(images)
        loss = criterion(outputs, images)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        
    avg_train_loss = train_loss / len(train_loader.dataset)
    net.eval()
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            images = images.view(images.size(0), -1)
            outputs = net.forward(images)
            loss = criterion(outputs, images)
            val_loss += loss.item()
            
            if epoch == 0 or (epoch+1) % 10 == 0:
                images = images.view(images.size(0), 1, 28, 28)
                save_image(images, 'p3inputs_{}_{}.png'.format(epoch+1,num_epochs))
                outputs = outputs.view(outputs.size(0), 1, 28, 28)
                save_image(outputs, 'p3autoencode_{}_{}.png'.format(epoch+1,num_epochs))
    avg_val_loss = val_loss / len(test_loader.dataset)
    
    train_loss_list.append(avg_train_loss)
    val_loss_list.append(avg_test_loss)
    print('Epoch[{}], train_loss:{:.4f}, val_loss:{:.4f}'
          .format(epoch+1, avg_train_loss, avg_val_loss))
    

参考

[amazonjs asin=”4798055476″ locale=”JP” title=”PyTorchニューラルネットワーク実装ハンドブック (Pythonライブラリ定番セレクション)”]

タイトルとURLをコピーしました