PyTorch
Published:
This lesson covers PyTorch Tutorial, https://pytorch.org/tutorials/beginner/basics/intro.html
Datasets and Dataloaders
allow you to use pre-loaded datasets as well as your own data
torch.utils.data.Dataset
- stores the samples and their corresponding labels
torch.utils.data.DataLoader
- wraps an iterable around the Dataset to enable easy access to the samples
topic = "pytorch"
lesson = 3
from n import *
home, models_path = get_project_dir("FashionMNIST")
print(home)
/home/naneja/datasets/n/FashionMNIST
Loading a Dataset
#!conda install pandas -y
import os
import random
seed = 0
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root=home,
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root=home,
train=False,
download=True,
transform=ToTensor()
)
Iterating and Visualizing the Dataset
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
# Dataset can be indexed
print_(f"len(training_data)={len(training_data)}")
print(type(training_data))
print(type(training_data[0]))
img, label = training_data[0] # unpack tuple
print_(f"img.shape={img.shape}")
print_(f"label={label}")
print_(labels_map[label])
img = img.squeeze()
print_(f"img.shape={img.shape}")
plt.imshow(img, cmap="gray")
plt.axis("off")
plt.title(labels_map[label]);
img_name = get_img_name(lesson)
plt.savefig(img_name)
insert_image(img_name, topic)
len(training_data) = 60000
<class 'torchvision.datasets.mnist.FashionMNIST'>
<class 'tuple'>
img.shape = torch.Size([1, 28, 28])
label = 9
Ankle Boot
img.shape = torch.Size([28, 28])
random.seed(seed)
sample_idx = random.sample(range(len(training_data)), 9)
print_(sample_idx)
figure = plt.figure(figsize=(8, 8))
rows, cols = 3, 3
for i, idx in enumerate(sample_idx):
img, label = training_data[idx]
figure.add_subplot(rows, cols, i+1)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
img_name = get_img_name(lesson)
plt.savefig(img_name)
insert_image(img_name, topic)
[55340, 25247, 49673, 58343, 27562, 2653, 16968, 33506, 31845]
Creating a Custom Dataset for your files
- __init__
- __len__
- __getitem__
import os
import pandas as pd
from torchvision.io import read_image
class ImageDataset(Dataset):
"""\
Custom Dataset uses
Annotation csv file that has two columns
First col image path and Second col target label.
"""
def __init__(self, annotations_file, image_dir,
transform=None, target_transform=None):
self.annotations_file = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = label_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
"""return image with index idx"""
img_path = self.annotations_file.iloc[idx, 0]
img_path = os.path.join(self.img_dir, img_path)
image = read_image(img_path)
if self.transform:
image = self.transform(image)
label = self.annotations_file.iloc[idx, 1]
if self.target_transform:
label = self.target_transform(label)
return image, label
Preparing your data for training with DataLoaders
- samples in “minibatches”
- reshuffle the data at every epoch to reduce model overfitting
- use Python’s multiprocessing to speed up data retrieval
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data,
batch_size=64,
shuffle=True)
test_dataloader = DataLoader(test_data,
batch_size=64,
shuffle=True)
Iterate through the DataLoader
- Each iteration returns a batch of train_features and train_labels
- shuffle=True
- after we iterate over all batches the data is shuffled
# Display image and label
train_features, train_labels = next(iter(train_dataloader))
print_(f"Feature batch shape: {train_features.size()}")
print_(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.title(labels_map[label.item()]);
img_name = get_img_name(lesson)
plt.savefig(img_name)
insert_image(img_name, topic)
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])