853 words
4 minutes
Neural Network Modules

V. Neural Network Modules — nn.Module (神经网络模块)
1. nn.Module
The base class for all neural networks in PyTorch. Manages parameters (参数), sub-modules (子模块), and defines the forward pass (前向传播) logic.
import torch.nn as nn
class MLP(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(784, 256) self.fc2 = nn.Linear(256, 10)
def forward(self, x): x = torch.relu(self.fc1(x)) return self.fc2(x)Note: Only implement
__init__ and forward. The backward pass is handled automatically by autograd.2. nn.Linear()
Fully Connected Layer (全连接层) / Affine Transformation (仿射变换): y = xWT + b. The most fundamental learnable layer.
fc = nn.Linear(in_features=128, out_features=64, bias=True)x = torch.rand(32, 128)out = fc(x) # shape [32, 64]Note: Weight shape is
[out, in]. bias=False is commonly paired with BatchNorm.3. nn.Conv2d()
2D Convolutional Layer (二维卷积层). Extracts local spatial features; the core building block of CNNs (卷积神经网络).
conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)x = torch.rand(8, 3, 224, 224)out = conv(x) # [8, 64, 224, 224]Note:
padding=kernel_size//2 preserves feature map size (Same Padding, 等尺寸填充).4. nn.BatchNorm2d()
Normalizes each channel of a mini-batch (小批量归一化). Accelerates training and mitigates gradient vanishing (梯度消失).
bn = nn.BatchNorm2d(num_features=64)x = torch.rand(8, 64, 28, 28)out = bn(x)# Standard order: Conv → BN → ReLUNote: BN is unstable when batch_size=1. Switch to GroupNorm or LayerNorm in that case.
5. nn.Dropout()
During training, randomly zeros out a fraction of neurons — a Regularization (正则化) technique to prevent Overfitting (过拟合).
dropout = nn.Dropout(p=0.5)x = torch.rand(4, 128)out = dropout(x) # 50% elements zeroed during train mode
dropout.eval()out_eval = dropout(x) # identical to x in eval modeNote: Forgetting
model.eval() is the #1 most common bug causing non-deterministic inference results.6. nn.Sequential()
Chains a series of layers in order, executing each
forward call sequentially. Simplifies model definition. model = nn.Sequential( nn.Linear(784, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 10))out = model(x)Note: Use
OrderedDict to name layers: nn.Sequential(OrderedDict([('fc', nn.Linear(...))])).7. nn.ModuleList() / nn.ModuleDict()
Registers sub-modules as a list or dictionary so that their parameters are correctly tracked and saved.
layers = nn.ModuleList([nn.Linear(64, 64) for _ in range(6)])for layer in layers: x = torch.relu(layer(x))
heads = nn.ModuleDict({ 'cls': nn.Linear(64, 10), 'reg': nn.Linear(64, 1)})Note: Plain Python
list / dict are not registered — parameters() will miss them!8. nn.Embedding()
Maps integer indices to dense vectors (稠密向量). The standard Word Embedding Lookup Table (词向量查找表) in NLP.
vocab_size, embed_dim = 10000, 128emb = nn.Embedding(vocab_size, embed_dim)ids = torch.randint(0, vocab_size, (16, 50)) # [batch, seq_len]out = emb(ids) # [16, 50, 128]Note:
padding_idx specifies a padding token whose embedding is excluded from gradient updates.9. nn.LSTM() / nn.GRU()
Long Short-Term Memory (长短时记忆) and Gated Recurrent Unit (门控循环单元) — classic recurrent layers for sequence data (序列数据).
lstm = nn.LSTM(input_size=128, hidden_size=256, num_layers=2, batch_first=True, dropout=0.2)x = torch.rand(8, 50, 128) # [batch, seq, feat]out, (h, c) = lstm(x)Note:
batch_first=True sets the input format to [B, T, F], which is more intuitive. The default is [T, B, F].10. nn.MultiheadAttention()
Multi-head Self-Attention (多头自注意力机制) — the core component of the Transformer Architecture (Transformer架构).
attn = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)x = torch.rand(4, 100, 512)out, weights = attn(query=x, key=x, value=x)Note: Use
key_padding_mask to mask padding tokens; use attn_mask for Causal Masking (因果掩码) in decoders.💡 One-line Takeaway
Every custom network inherits from
Every custom network inherits from
nn.Module; use ModuleList/Dict (not plain lists) to ensure parameters are tracked. Neural Network Modules
https://lxy-alexander.github.io/blog/posts/pytorch/api/05neural-network-modules/