PyTorch

less than 1 minute read

Published:

This lesson covers PyTorch Tutorial, https://pytorch.org/tutorials/beginner/basics/intro.html

Save and Load Model

topic = "pytorch"
lesson = 8

from n import *
home, models_path = get_project_dir("FashionMNIST")
print_(home)

/home/naneja/datasets/n/FashionMNIST

import torch
import torchvision.models as models
models_path
'/home/naneja/datasets/n/FashionMNIST/models'
# pretrained model
model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

model_path = f"{models_path}/vgg16.pth"

torch.save(model.state_dict(), model_path)
print_("model saved: " + model_path)

model saved: /home/naneja/datasets/n/FashionMNIST/models/vgg16.pth

# No default weights loaded
model = models.vgg16()

# Weights loaded from saved weights
model.load_state_dict(torch.load(model_path))

model.eval();