PyTorch

2 minute read

Published:

This lesson covers PyTorch Tutorial, https://pytorch.org/tutorials/beginner/basics/intro.html

Datasets and Dataloaders

  • allow you to use pre-loaded datasets as well as your own data

  • torch.utils.data.Dataset
    • stores the samples and their corresponding labels
  • torch.utils.data.DataLoader
    • wraps an iterable around the Dataset to enable easy access to the samples
topic = "pytorch"
lesson = 3

from n import *
home, models_path = get_project_dir("FashionMNIST")
print(home)
/home/naneja/datasets/n/FashionMNIST

Loading a Dataset

#!conda install pandas -y
import os
import random
seed = 0

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
    root=home,
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root=home,
    train=False,
    download=True,
    transform=ToTensor()
)

Iterating and Visualizing the Dataset

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
# Dataset can be indexed
print_(f"len(training_data)={len(training_data)}")

print(type(training_data))
print(type(training_data[0]))

img, label = training_data[0] # unpack tuple

print_(f"img.shape={img.shape}")
print_(f"label={label}")
print_(labels_map[label])

img = img.squeeze()
print_(f"img.shape={img.shape}")
plt.imshow(img, cmap="gray")
plt.axis("off")
plt.title(labels_map[label]);

img_name = get_img_name(lesson)
plt.savefig(img_name)
insert_image(img_name, topic)

len(training_data) = 60000

<class 'torchvision.datasets.mnist.FashionMNIST'>
<class 'tuple'>

img.shape = torch.Size([1, 28, 28])

label = 9

Ankle Boot

img.shape = torch.Size([28, 28])

png

random.seed(seed)
sample_idx = random.sample(range(len(training_data)), 9)
print_(sample_idx)

figure = plt.figure(figsize=(8, 8))
rows, cols = 3, 3

for i, idx in enumerate(sample_idx):
    
    img, label = training_data[idx]
    figure.add_subplot(rows, cols, i+1)
    
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
    
img_name = get_img_name(lesson)
plt.savefig(img_name)
insert_image(img_name, topic)

[55340, 25247, 49673, 58343, 27562, 2653, 16968, 33506, 31845]

png

Creating a Custom Dataset for your files

  • __init__
  • __len__
  • __getitem__
import os
import pandas as pd
from torchvision.io import read_image

class ImageDataset(Dataset):
    """\
    Custom Dataset uses
    Annotation csv file that has two columns
    First col image path and Second col target label.
    """
    def __init__(self, annotations_file, image_dir,
                transform=None, target_transform=None):
        
        self.annotations_file = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = label_transform
        
    def __len__(self):
        return len(self.img_labels)
    
    def __getitem__(self, idx):
        """return image with index idx"""
        img_path = self.annotations_file.iloc[idx, 0]
        img_path = os.path.join(self.img_dir, img_path)
        image = read_image(img_path)
        if self.transform:
            image = self.transform(image)
        
        label = self.annotations_file.iloc[idx, 1]
        if self.target_transform:
            label = self.target_transform(label)
        
        return image, label
        

Preparing your data for training with DataLoaders

  • samples in “minibatches”
  • reshuffle the data at every epoch to reduce model overfitting
  • use Python’s multiprocessing to speed up data retrieval
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, 
                              batch_size=64, 
                              shuffle=True)

test_dataloader = DataLoader(test_data, 
                             batch_size=64, 
                             shuffle=True)

Iterate through the DataLoader

  • Each iteration returns a batch of train_features and train_labels
  • shuffle=True
    • after we iterate over all batches the data is shuffled
# Display image and label

train_features, train_labels = next(iter(train_dataloader))

print_(f"Feature batch shape: {train_features.size()}")
print_(f"Labels batch shape: {train_labels.size()}")

img = train_features[0].squeeze()
label = train_labels[0]

plt.imshow(img, cmap="gray")
plt.title(labels_map[label.item()]);

img_name = get_img_name(lesson)
plt.savefig(img_name)
insert_image(img_name, topic)

Feature batch shape: torch.Size([64, 1, 28, 28])

Labels batch shape: torch.Size([64])

png