자율주행/Deep Q Network

ANN ,CNN -> ㅡMNIST 분석

Tony Lim 2021. 2. 22. 14:52
728x90

 

CNN 을 먼저 해주고 그것을 일렬로 벡터로 늘려서 ANN을 해주고 비교한다.

 

MNIST dataset을 간단한 cnn + ann 으로 훈련시키고 평가해보자

#!/usr/bin/env python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

batch_size = 128
num_epochs = 10

learning_rate = 0.00025


# for training
trn_dataset = datasets.MNIST('./mnist_data/',
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor()]))

#for validation
val_dataset = datasets.MNIST("./mnist_data/",
        download=False,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor()]))

#to shuffle and make it to mini batch
trn_loader = torch.utils.data.DataLoader(trn_dataset,
        batch_size=batch_size,
        shuffle=True)

val_loader = torch.utils.data.DataLoader(val_dataset,
        batch_size=batch_size,
        shuffle=True)

#CNN Network

class CNNClassifier(nn.Module):
    
     
    def __init__(self):
        super(CNNClassifier,self).__init__() 
        self.conv1 = nn.Conv2d(1,16,3,2) 
        self.conv2 = nn.Conv2d(16,32,3,2)
        self.conv3 = nn.Conv2d(32,64,3,1)

        self.fc1 = nn.Linear(64*4*4,256)
        self.fc2 = nn.Linear(256,64)
        self.fc3 = nn.Linear(64,10)

    # init
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.softmax(x,dim=1)

def get_accuracy(y,label):
    y_idx = torch.argmax(y,dim=1)
    result = y_idx - label

    num_correct = 0
    for i in range(len(result)):
        if result[i] == 0:
            num_correct += 1

    return num_correct/y.shape[0]

cnn = CNNClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)

num_batches = len(trn_loader)


for epoch in range(num_epochs):
    trn_loss_list = []
    trn_acc_list = []
    for i ,data in enumerate(trn_loader):
        cnn.train()

        x,label = data
        x = x.to(device)
        label = label.to(device)

        model_output = cnn(x)
        loss = criterion(model_output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        trn_acc = get_accuracy(model_output,label)

        trn_loss_list.append(loss.item())
        trn_acc_list.append(trn_acc)

        if (i+1) % 100 == 0:
            cnn.eval()
            with torch.no_grad():
                val_loss_list = []
                val_acc_list = []
                
                for j ,val in enumerate(val_loader):
                    val_x ,val_label = val

                    val_x = val_x.to(device)
                    val_label = val_label.to(device)

                    val_output = cnn(val_x)

                    val_loss = criterion(val_output,val_label)
                    val_acc = get_accuracy(val_output, val_label)

                    val_loss_list.append(val_loss.item())
                    val_acc_list.append(val_acc)

            print("epoch: {}/{} | step: {}/{} | trn loss: {:.4f} | val loss: {:.4f} | trn acc: {:.4f} | val acc: {:.4f}".format(epoch+1, num_epochs, i+1, num_batches, np.mean(trn_loss_list), np.mean(val_loss_list), np.mean(trn_acc_list), np.mean(val_acc_list)))

 

728x90

'자율주행 > Deep Q Network' 카테고리의 다른 글

DQN Upgrade  (0) 2021.02.25
DQN(Deep Q Network)  (0) 2021.02.24
Reinforcement Learning  (0) 2021.02.23