# save torch.save(model.state_dict(), PATH) # use .pt or .pth as file extension
# load model = Net(*args, **kwargs) model.load_state_dict(torch.load(PATH)) # must call it to set dropout and batch normalization layers to evaluation mode # before running inference model.eval()
Method 2: Saving & loading entile model
1 2 3 4 5 6 7 8
# save torch.save(model, PATH)
# load
# model must be defined somewhere model = torch.load(PATH) model.eval()
Saving & loading a general checkpoint for inference and/or resuming training
1 2 3 4 5 6 7 8
# save torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, ... }, PATH) # use .tar as file extension
1 2 3 4 5 6 7 8 9 10 11 12 13
# load model = Net(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs)