837 words
4 minutes
Math & Statistical Operations

III. Math & Statistical Operations (数学与统计运算)
1. torch.matmul() / @
General Matrix Multiplication (通用矩阵乘法). Supports 2D matrices, batched matrix multiplication (批量矩阵乘), and mixed broadcasting.
a = torch.rand(3, 4)b = torch.rand(4, 5)c = torch.matmul(a, b) # [3, 5]d = a @ b # equivalent
# Batched matmulx = torch.rand(8, 3, 4)y = torch.rand(8, 4, 5)z = x @ y # [8, 3, 5]Note: The core of Transformer Attention Computation (注意力计算).
mm is 2D-only; matmul is more general.2. torch.sum() / torch.mean()
Computes the sum or mean over all elements or a specified axis.
keepdim=True preserves the reduced dimension. x = torch.tensor([[1., 2., 3.], [4., 5., 6.]])print(x.sum()) # 21.0print(x.sum(dim=0)) # [5, 7, 9]print(x.mean(dim=1, keepdim=True)) # [[2.], [5.]]Note:
keepdim=True avoids Dimension Alignment Issues (维度对齐问题) during broadcasting.3) torch.max() / torch.min()
Returns the maximum/minimum value. When a
dim is specified, returns both values and indices (argmax/argmin). x = torch.tensor([3., 1., 4., 1., 5.])print(x.max()) # tensor(5.)vals, idx = x.max(dim=0) # vals=5.0, idx=4idx2 = x.argmax() # tensor(4)Note: Classification Prediction Labels (分类网络预测标签):
preds = logits.argmax(dim=1).4) torch.abs() / torch.sqrt()
Element-wise absolute value or square root. Used in loss computation (损失计算) and feature normalization (特征归一化).
x = torch.tensor([-1., 4., -9.])print(torch.abs(x)) # [1., 4., 9.]y = torch.tensor([1., 4., 9.])print(torch.sqrt(y)) # [1., 2., 3.]Note:
torch.sqrt returns NaN for negative inputs. Use torch.clamp(x, min=0) first.5) torch.clamp()
Clips values to the range [min, max]. Values outside the range are truncated to the boundary.
x = torch.tensor([-2., 0., 3., 8.])y = torch.clamp(x, min=0., max=5.) # tensor([0., 0., 3., 5.])z = torch.clamp(x, min=0.) # equivalent to ReLUNote: The go-to tool for Gradient Clipping (梯度裁剪), normalization, and avoiding
log(0).6) torch.pow() / torch.exp() / torch.log()
Element-wise power, natural exponent, and natural logarithm.
x = torch.tensor([1., 2., 3.])print(torch.pow(x, 2)) # [1., 4., 9.]print(torch.exp(x)) # [e^1, e^2, e^3]print(torch.log(x)) # [0., 0.693, 1.099]print(torch.log1p(x)) # Numerically stable log(1+x)Note: Cross-entropy already uses
log_softmax internally. When computing manually, use log_softmax for Numerical Stability (数值稳定性).7) torch.dot() / torch.cross()
dot: Inner product of 1D vectors. cross: Cross product (叉积) of 3D vectors (physics / 3D graphics). a = torch.tensor([1., 2., 3.])b = torch.tensor([4., 5., 6.])print(torch.dot(a, b)) # 32.0Note: For batched inner products, use
(a * b).sum(-1) — more efficient than looping dot.8) torch.norm() / torch.linalg.norm()
Computes vector/matrix norms: L1, L2, Frobenius Norm (Frobenius范数), etc.
x = torch.tensor([3., 4.])print(torch.linalg.norm(x)) # L2: 5.0print(torch.linalg.norm(x, ord=1)) # L1: 7.0Note:
torch.norm is deprecated. New code should use torch.linalg.norm.9) torch.topk()
Returns the top-k largest (or smallest) values and their indices from a Tensor.
x = torch.tensor([3., 1., 4., 1., 5., 9.])vals, idx = torch.topk(x, k=3)# vals: tensor([9., 5., 4.])# idx: tensor([5, 4, 2])
_, top5 = logits.topk(5, dim=1) # Top-5 accuracy evaluationNote: Standard approach for Top-5 Accuracy (Top-5准确率) evaluation. Use
largest=False to get the smallest k values.10) torch.unique()
Returns unique elements from a Tensor, with optional sorting, counting, and inverse mapping (逆映射).
x = torch.tensor([1, 2, 2, 3, 1, 4])u, cnt = torch.unique(x, return_counts=True)# u: tensor([1, 2, 3, 4])# cnt: tensor([2, 2, 1, 1])Note: Commonly used for processing Category Labels (类别标签) and deduplicating tokens.
💡 One-line Takeaway
matmul/@ powers Transformers, clamp guards numerical safety, and topk drives classification evaluation. Math & Statistical Operations
https://lxy-alexander.github.io/blog/posts/pytorch/api/03math--statistical-operations/