PyTorch Basic

 

Basic knowledges about PyTorch.

Saving & loading models

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.optim as optim

# model
class Net(nn.Module):
def __init__(self):
super().__init__()
pass

def forward(self, x):
pass

model = Net()

# optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
torch

Method 1: Saving & loading model’s parameters

Use state_dict

1
2
3
4
5
6
7
8
9
# save
torch.save(model.state_dict(), PATH) # use .pt or .pth as file extension

# load
model = Net(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
# must call it to set dropout and batch normalization layers to evaluation mode
# before running inference
model.eval()

Method 2: Saving & loading entile model

1
2
3
4
5
6
7
8
# save
torch.save(model, PATH)

# load

# model must be defined somewhere
model = torch.load(PATH)
model.eval()

Saving & loading a general checkpoint for inference and/or resuming training

1
2
3
4
5
6
7
8
# save
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH) # use .tar as file extension
1
2
3
4
5
6
7
8
9
10
11
12
13
# load
model = Net(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

References

Saving and loading models