MNIST-NN

2 minute read

Published:

This post covers introduction to Neural Network using MNIST dataset.

%matplotlib inline

import os

import torch
import torch.nn as nn

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

import sklearn
from sklearn import metrics
ckpt = './data/model'
os.makedirs(ckpt, exist_ok=True)

ckpt = os.path.join(ckpt, 'mnist_nn.pth')
print(ckpt)
epochs = 10
# https://en.wikipedia.org/wiki/Standard_score
mean = (0.5,)
std = (0.5,)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

ds_train = torchvision.datasets.MNIST(root='data', train=True, transform=transform, download=True)
ds_test = torchvision.datasets.MNIST(root='data', train=False, transform=transform, download=True)
print(len(ds_train))
print(len(ds_test))
loader_train = torch.utils.data.DataLoader(ds_train, batch_size=32, shuffle=True)
loader_test = torch.utils.data.DataLoader(ds_test, batch_size=32, shuffle=True)
print(len(loader_train), len(loader_test))
images, labels = next(iter(loader_train))
print(images.shape)
print(labels.shape)
img = images[0]
label = labels[0]
print(img.shape)
print(label)
img = img.squeeze()
print(img.shape)
plt.imshow(img, cmap='gray_r');
fig = plt.figure()

for idx in range(25):
    img = images[idx].squeeze()
    label = labels[idx].numpy()
    
    plt.subplot(5, 5, idx+1)
    plt.imshow(img, cmap='gray_r')
    plt.title(label)
    plt.axis('off')
fig.tight_layout()
model = nn.Sequential(nn.Linear(28*28, 128),
                      nn.ReLU(),
                      nn.Linear(128, 64),
                      nn.ReLU(),
                      nn.Linear(64, 10),
                     )
print(model)
for param in model.parameters():
    print(param.shape)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Train One Epoch
def train(loader):
    running_loss = 0.0
    running_acc = 0.0
    
    for data in loader:
        inputs, labels = data
        
        inputs = inputs.view(inputs.shape[0], -1) # Flatten the input
        
        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            outputs = model(inputs)
        
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() # return python scalar from loss which is tensor with gradient enabled
        
        y_true = labels.numpy()
        ps = torch.nn.functional.softmax(outputs, dim=1)
        pred = ps.max(1, keepdim=True) # tuple of prob and index of max
        y_pred = pred[1] # index
        y_pred = y_pred.numpy()
        
        acc = metrics.accuracy_score(y_true, y_pred)
        
        running_acc += acc
        
    running_loss /= len(loader)
    running_acc /= len(loader)
    return running_loss, running_acc

#train(loader_train)
# Test One Epoch
def test(loader):
    running_loss = 0.0
    running_acc = 0.0
    
    for data in loader:
        inputs, labels = data
        
        inputs = inputs.view(inputs.shape[0], -1) # Flatten the input
        
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
        
        loss = nn.CrossEntropyLoss()(outputs, labels)
        
        running_loss += loss.item() # return python scalar from loss which is tensor with gradient enabled
        
        y_true = labels.numpy()
        ps = torch.nn.functional.softmax(outputs, dim=1)
        pred = ps.max(1, keepdim=True) # tuple of prob and index of max
        y_pred = pred[1] # index
        y_pred = y_pred.numpy()
        
        acc = metrics.accuracy_score(y_true, y_pred)
        
        running_acc += acc
        
    running_loss /= len(loader)
    running_acc /= len(loader)
    
    return running_loss, running_acc

#train(loader_test)
losses = {'train': [], 'test': []}
acces = {'train': [], 'test': []}

for epoch in range(epochs):
    loss, acc = train(loader_train)
    losses['train'].append(loss)
    acces['train'].append(acc)
    
    print(f'Train Epoch: {epoch+1:2d}   Loss: {loss:.3f}   Acc: {acc:.3f}')
    
    loss, acc = test(loader_test)
    losses['test'].append(loss)
    acces['test'].append(acc)
    
    print(f'Test  Epoch: {epoch+1:2d}   Loss: {loss:.3f}   Acc: {acc:.3f}')
    
    print()
    #break
if epochs > 0:
    ckpt_dict = {}
    ckpt_dict['state_dict'] = model.state_dict()
    ckpt_dict['losses'] = losses
    ckpt_dict['acces'] = acces
    torch.save(ckpt_dict, ckpt)
ckpt_dict = torch.load(ckpt)

model = ckpt_dict['state_dict']
losses = ckpt_dict['losses']
acces = ckpt_dict['acces']
print(losses)
plt.plot(losses['train'], label='train')
plt.plot(losses['test'], label='test')
plt.legend();
plt.plot(acces['train']);
plt.plot(acces['test']);