PyTorch
Published:
This lesson covers PyTorch Tutorial, https://pytorch.org/tutorials/beginner/basics/intro.html
Save and Load Model
topic = "pytorch"
lesson = 8
from n import *
home, models_path = get_project_dir("FashionMNIST")
print_(home)
/home/naneja/datasets/n/FashionMNIST
import torch
import torchvision.models as models
models_path
'/home/naneja/datasets/n/FashionMNIST/models'
# pretrained model
model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
model_path = f"{models_path}/vgg16.pth"
torch.save(model.state_dict(), model_path)
print_("model saved: " + model_path)
model saved: /home/naneja/datasets/n/FashionMNIST/models/vgg16.pth
# No default weights loaded
model = models.vgg16()
# Weights loaded from saved weights
model.load_state_dict(torch.load(model_path))
model.eval();