474 words
2 minutes
Model Saving, Loading & Deployment

IX. Model Saving, Loading & Deployment (模型保存、加载与部署)#

1. torch.save() / torch.load()#

Serializes / deserializes any Python object (model, tensor, dict) to/from a file.
# Recommended: save only the weights dict
torch.save(model.state_dict(), 'model_weights.pth')
# Load
state = torch.load('model_weights.pth', map_location='cpu')
model.load_state_dict(state)
Note: Saving the entire model object couples code paths. Always save only state_dict.

2. model.state_dict() / load_state_dict()#

Gets / loads an ordered dictionary of model parameters. The core interface for Transfer Learning (迁移学习) and checkpoint resuming (断点续训).
checkpoint = {
'epoch': epoch,
'model': model.state_dict(),
'optim': optimizer.state_dict(),
'loss': best_loss
}
torch.save(checkpoint, 'ckpt.pth')
Note: strict=False allows partial loading (skips missing keys) — common in transfer learning.

3. torch.jit.script() / torch.jit.trace()#

Compiles a model to TorchScript for deployment in Python-free environments (C++, mobile).
# trace: follow execution path (no control flow)
traced = torch.jit.trace(model, torch.rand(1, 3, 224, 224))
traced.save('traced.pt')
# script: supports dynamic control flow
scripted = torch.jit.script(model)
Note: Models with if/for branches → use script. Pure forward-pass models → use trace (faster).

4. torch.onnx.export()#

Exports a PyTorch model to ONNX format for cross-framework deployment (TensorRT, OpenVINO).
torch.onnx.export(
model, torch.rand(1, 3, 224, 224), 'model.onnx',
opset_version=17,
input_names=['input'], output_names=['output'],
dynamic_axes={'input': {0: 'batch'}}
)
Note: dynamic_axes enables Dynamic Batch Size (动态批次大小) — essential for production deployment. Validate with onnxruntime.

5. model.parameters() / named_parameters()#

Iterates over all learnable parameters. The named_ version also returns parameter names.
total = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Params: {total/1e6:.1f}M')
for name, p in model.named_parameters():
print(name, p.shape)
Note: Set layer-specific learning rates by passing [{'params': p, 'lr': lr}] list to the optimizer.

6. model.train() / model.eval()#

Switches between training and evaluation modes — affects Dropout and BatchNorm behavior.
model.train()
for x, y in train_loader:
loss = criterion(model(x), y)
loss.backward(); optimizer.step()
model.eval()
with torch.no_grad():
for x, y in val_loader:
pred = model(x)
Note: Forgetting model.eval() is the most common reason for unstable inference results.
💡 One-line Takeaway
Always save state_dict (not the model object), and remember the deploy path: PyTorch → TorchScript / ONNX → Runtime.

Model Saving, Loading & Deployment
https://lxy-alexander.github.io/blog/posts/pytorch/api/09model-saving-loading--deployment/
Author
Alexander Lee
Published at
2026-03-12
License
CC BY-NC-SA 4.0