728x90
CNN 을 먼저 해주고 그것을 일렬로 벡터로 늘려서 ANN을 해주고 비교한다.
MNIST dataset을 간단한 cnn + ann 으로 훈련시키고 평가해보자
#!/usr/bin/env python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 128
num_epochs = 10
learning_rate = 0.00025
# for training
trn_dataset = datasets.MNIST('./mnist_data/',
download=True,
train=True,
transform=transforms.Compose([
transforms.ToTensor()]))
#for validation
val_dataset = datasets.MNIST("./mnist_data/",
download=False,
train=False,
transform=transforms.Compose([
transforms.ToTensor()]))
#to shuffle and make it to mini batch
trn_loader = torch.utils.data.DataLoader(trn_dataset,
batch_size=batch_size,
shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=True)
#CNN Network
class CNNClassifier(nn.Module):
def __init__(self):
super(CNNClassifier,self).__init__()
self.conv1 = nn.Conv2d(1,16,3,2)
self.conv2 = nn.Conv2d(16,32,3,2)
self.conv3 = nn.Conv2d(32,64,3,1)
self.fc1 = nn.Linear(64*4*4,256)
self.fc2 = nn.Linear(256,64)
self.fc3 = nn.Linear(64,10)
# init
def forward(self,x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.softmax(x,dim=1)
def get_accuracy(y,label):
y_idx = torch.argmax(y,dim=1)
result = y_idx - label
num_correct = 0
for i in range(len(result)):
if result[i] == 0:
num_correct += 1
return num_correct/y.shape[0]
cnn = CNNClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)
num_batches = len(trn_loader)
for epoch in range(num_epochs):
trn_loss_list = []
trn_acc_list = []
for i ,data in enumerate(trn_loader):
cnn.train()
x,label = data
x = x.to(device)
label = label.to(device)
model_output = cnn(x)
loss = criterion(model_output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
trn_acc = get_accuracy(model_output,label)
trn_loss_list.append(loss.item())
trn_acc_list.append(trn_acc)
if (i+1) % 100 == 0:
cnn.eval()
with torch.no_grad():
val_loss_list = []
val_acc_list = []
for j ,val in enumerate(val_loader):
val_x ,val_label = val
val_x = val_x.to(device)
val_label = val_label.to(device)
val_output = cnn(val_x)
val_loss = criterion(val_output,val_label)
val_acc = get_accuracy(val_output, val_label)
val_loss_list.append(val_loss.item())
val_acc_list.append(val_acc)
print("epoch: {}/{} | step: {}/{} | trn loss: {:.4f} | val loss: {:.4f} | trn acc: {:.4f} | val acc: {:.4f}".format(epoch+1, num_epochs, i+1, num_batches, np.mean(trn_loss_list), np.mean(val_loss_list), np.mean(trn_acc_list), np.mean(val_acc_list)))
728x90
'자율주행 > Deep Q Network' 카테고리의 다른 글
DQN Upgrade (0) | 2021.02.25 |
---|---|
DQN(Deep Q Network) (0) | 2021.02.24 |
Reinforcement Learning (0) | 2021.02.23 |