1252 words
6 minutes
Advanced Features & Utilities

XI. Advanced Features & Utilities (高级特性与实用工具)
1. torch.nn.utils.clip_grad_norm_()
Clips the global L2 norm of all parameter gradients to prevent Gradient Explosion (梯度爆炸). Essential for RNN/Transformer training.
optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()Note:
max_norm=1.0 is standard for Transformers. Must call after backward(), before step().2. torch.nn.utils.weight_norm()
Decomposes parameters into direction and magnitude, accelerating convergence. Used in WaveNet and generative models.
from torch.nn.utils import weight_norm, remove_weight_normwn_conv = weight_norm(nn.Conv1d(64, 64, 3, padding=1))remove_weight_norm(wn_conv) # merge before deploymentNote: Always call
remove_weight_norm before deployment to merge decomposed parameters.3. torch.nn.functional.interpolate()
Upsamples or downsamples feature maps with bilinear, nearest, bicubic, etc. interpolation modes.
x = torch.rand(1, 64, 28, 28)up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)print(up.shape) # [1, 64, 56, 56]Note:
align_corners=False matches TensorFlow's default behavior — important when migrating models.4. torch.nn.functional.grid_sample()
Samples from a feature map at normalized grid coordinates. The core of Spatial Transformer Networks (空间变换网络, STN).
theta = torch.eye(2, 3, dtype=torch.float).unsqueeze(0)grid = F.affine_grid(theta, x.size())out = F.grid_sample(x, grid, mode='bilinear')Note: Coordinate range is [-1, 1].
padding_mode='reflection' is more natural for image borders.5. torch.einsum()
Einstein Summation (爱因斯坦求和): expresses complex tensor operations as a concise string equation.
c = torch.einsum('ij,jk->ik', a, b) # matrix multiply
# Attention scores: Q:[B,H,L,D], K:[B,H,L,D]scores = torch.einsum('bhld,bhmd->bhlm', Q, K)Note: Mastering
einsum dramatically simplifies Transformer and graph neural network code.6. torch.profiler.profile()
Performance profiler that records per-operator CPU/GPU time and memory usage to locate bottlenecks.
from torch.profiler import profile, ProfilerActivitywith profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: model(x)print(prof.key_averages().table(sort_by='cuda_time_total'))Note: Export a Chrome trace with
prof.export_chrome_trace() for visual bottleneck analysis in a browser.7. torch.nn.init.*
Provides Xavier, Kaiming, Orthogonal and other Parameter Initialization (参数初始化) strategies. Directly impacts training stability.
def init_weights(m): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.zeros_(m.bias)
model.apply(init_weights)Note: Sigmoid/Tanh → Xavier; ReLU family → Kaiming; Transformer → Orthogonal initialization.
8. torch.Tensor.item()
Converts a single-element Tensor to a Python scalar. Commonly used to log loss values.
loss = criterion(output, label)loss_val = loss.item() # detaches from graphprint(f'Loss: {loss_val:.4f}')
# WRONG: total_loss += loss ← graph grows → OOM# CORRECT:total_loss += loss.item()Note: Accumulating
loss tensors directly causes the computation graph to grow unboundedly → OOM. Always use .item().9. torch.Tensor.clone()
Creates a deep copy (深拷贝) of a Tensor with fully independent data, while preserving gradient propagation.
x = torch.rand(3, requires_grad=True)y = x.clone() # gradient can still propagatez = x.detach().clone() # gradient detached
buf = torch.empty(3)buf.copy_(x) # in-place copyNote: Use
clone().detach() for Target Network (目标网络) parameter copying in RL / momentum update.10. torch.where()
Conditional selection: returns values from
x where condition is True, else from y. Vectorized if-else. x = torch.tensor([-1., 2., -3., 4.])y = torch.zeros_like(x)out = torch.where(x > 0, x, y) # tensor([0., 2., 0., 4.]) — manual ReLUNote: Both branches participate in gradient computation; the gradient of the unselected branch is zero.
11. torch.gather()
Gathers values from a source Tensor by an index Tensor — enables irregular indexing (不规则索引) like NMS.
logits = torch.rand(4, 10)targets = torch.tensor([3, 7, 1, 5]).unsqueeze(1) # [4, 1]scores = logits.gather(dim=1, index=targets) # [4, 1]Note: Core tool for sequence decoding, top-k sampling, and Q-value selection in Reinforcement Learning (强化学习).
12. torch.scatter_()
Scatters values from
src into self at positions specified by index (in-place). y = torch.zeros(4, 5)labels = torch.tensor([[2], [0], [4], [1]])y.scatter_(dim=1, index=labels, value=1.0) # one-hot encodingNote:
scatter_add_ implements segment sum — the fundamental primitive for Graph Neural Network (图神经网络) message aggregation.13. torch.masked_fill()
Fills positions where the mask is True with a specified value. Essential for Attention Masking (注意力掩码).
L = 5mask = torch.triu(torch.ones(L, L), diagonal=1).bool()scores = torch.rand(L, L)scores = scores.masked_fill(mask, float('-inf')) # causal maskNote: Transformer decoder's Causal Self-Attention (因果自注意力) must use this to block future information.
14. torch.nn.utils.rnn.pad_sequence()
Pads a list of variable-length sequences to a uniform-length tensor for NLP batch processing.
from torch.nn.utils.rnn import pad_sequenceseqs = [torch.tensor([1, 2, 3]), torch.tensor([4, 5]), torch.tensor([6])]padded = pad_sequence(seqs, batch_first=True) # [3, 3]Note: Combine with
pack_padded_sequence to skip LSTM computation on padding positions.15. nn.TransformerEncoder()
Multi-layer Transformer Encoder with built-in Multi-head Attention + FFN + Residual Normalization.
enc_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=2048, batch_first=True)encoder = nn.TransformerEncoder(enc_layer, num_layers=6)out = encoder(src, src_key_padding_mask=m)Note:
norm_first=True (Pre-LN) trains more stably — recommended for large models.16. torch.Tensor.contiguous()
Returns a Contiguous Memory (连续内存) copy of the Tensor. Returns itself (zero overhead) if already contiguous.
x = torch.rand(4, 5, 6)y = x.permute(2, 0, 1) # non-contiguousprint(y.is_contiguous()) # Falsez = y.contiguous()w = z.view(6, -1) # safe to view nowNote: In most cases,
reshape handles this automatically. Call manually only when contiguous memory is truly required.17. torch.Tensor.type() / .to(dtype)
Converts the Tensor's Data Type (数据类型): float32 ↔ float16 ↔ int64, etc.
x = torch.tensor([1, 2, 3])f = x.float() # int → float32h = x.half() # float32 → float16l = x.long() # → int64y = x.to(dtype=torch.float32) # recommendedNote: Cross-entropy labels need
long(); normalized image pixels need float(); inference acceleration uses half().💡 One-line Takeaway
The advanced toolkit:
The advanced toolkit:
clip_grad_norm_ (stability), einsum (clarity), gather/scatter (indexing), masked_fill (attention), item() (memory safety). Advanced Features & Utilities
https://lxy-alexander.github.io/blog/posts/pytorch/api/11advanced-features--utilities/