837 words
4 minutes
Math & Statistical Operations

III. Math & Statistical Operations (数学与统计运算)#

1. torch.matmul() / @#

General Matrix Multiplication (通用矩阵乘法). Supports 2D matrices, batched matrix multiplication (批量矩阵乘), and mixed broadcasting.
a = torch.rand(3, 4)
b = torch.rand(4, 5)
c = torch.matmul(a, b) # [3, 5]
d = a @ b # equivalent
# Batched matmul
x = torch.rand(8, 3, 4)
y = torch.rand(8, 4, 5)
z = x @ y # [8, 3, 5]
Note: The core of Transformer Attention Computation (注意力计算). mm is 2D-only; matmul is more general.

2. torch.sum() / torch.mean()#

Computes the sum or mean over all elements or a specified axis. keepdim=True preserves the reduced dimension.
x = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
print(x.sum()) # 21.0
print(x.sum(dim=0)) # [5, 7, 9]
print(x.mean(dim=1, keepdim=True)) # [[2.], [5.]]
Note: keepdim=True avoids Dimension Alignment Issues (维度对齐问题) during broadcasting.

3) torch.max() / torch.min()#

Returns the maximum/minimum value. When a dim is specified, returns both values and indices (argmax/argmin).
x = torch.tensor([3., 1., 4., 1., 5.])
print(x.max()) # tensor(5.)
vals, idx = x.max(dim=0) # vals=5.0, idx=4
idx2 = x.argmax() # tensor(4)
Note: Classification Prediction Labels (分类网络预测标签): preds = logits.argmax(dim=1).

4) torch.abs() / torch.sqrt()#

Element-wise absolute value or square root. Used in loss computation (损失计算) and feature normalization (特征归一化).
x = torch.tensor([-1., 4., -9.])
print(torch.abs(x)) # [1., 4., 9.]
y = torch.tensor([1., 4., 9.])
print(torch.sqrt(y)) # [1., 2., 3.]
Note: torch.sqrt returns NaN for negative inputs. Use torch.clamp(x, min=0) first.

5) torch.clamp()#

Clips values to the range [min, max]. Values outside the range are truncated to the boundary.
x = torch.tensor([-2., 0., 3., 8.])
y = torch.clamp(x, min=0., max=5.) # tensor([0., 0., 3., 5.])
z = torch.clamp(x, min=0.) # equivalent to ReLU
Note: The go-to tool for Gradient Clipping (梯度裁剪), normalization, and avoiding log(0).

6) torch.pow() / torch.exp() / torch.log()#

Element-wise power, natural exponent, and natural logarithm.
x = torch.tensor([1., 2., 3.])
print(torch.pow(x, 2)) # [1., 4., 9.]
print(torch.exp(x)) # [e^1, e^2, e^3]
print(torch.log(x)) # [0., 0.693, 1.099]
print(torch.log1p(x)) # Numerically stable log(1+x)
Note: Cross-entropy already uses log_softmax internally. When computing manually, use log_softmax for Numerical Stability (数值稳定性).

7) torch.dot() / torch.cross()#

dot: Inner product of 1D vectors. cross: Cross product (叉积) of 3D vectors (physics / 3D graphics).
a = torch.tensor([1., 2., 3.])
b = torch.tensor([4., 5., 6.])
print(torch.dot(a, b)) # 32.0
Note: For batched inner products, use (a * b).sum(-1) — more efficient than looping dot.

8) torch.norm() / torch.linalg.norm()#

Computes vector/matrix norms: L1, L2, Frobenius Norm (Frobenius范数), etc.
x = torch.tensor([3., 4.])
print(torch.linalg.norm(x)) # L2: 5.0
print(torch.linalg.norm(x, ord=1)) # L1: 7.0
Note: torch.norm is deprecated. New code should use torch.linalg.norm.

9) torch.topk()#

Returns the top-k largest (or smallest) values and their indices from a Tensor.
x = torch.tensor([3., 1., 4., 1., 5., 9.])
vals, idx = torch.topk(x, k=3)
# vals: tensor([9., 5., 4.])
# idx: tensor([5, 4, 2])
_, top5 = logits.topk(5, dim=1) # Top-5 accuracy evaluation
Note: Standard approach for Top-5 Accuracy (Top-5准确率) evaluation. Use largest=False to get the smallest k values.

10) torch.unique()#

Returns unique elements from a Tensor, with optional sorting, counting, and inverse mapping (逆映射).
x = torch.tensor([1, 2, 2, 3, 1, 4])
u, cnt = torch.unique(x, return_counts=True)
# u: tensor([1, 2, 3, 4])
# cnt: tensor([2, 2, 1, 1])
Note: Commonly used for processing Category Labels (类别标签) and deduplicating tokens.
💡 One-line Takeaway
matmul/@ powers Transformers, clamp guards numerical safety, and topk drives classification evaluation.

Math & Statistical Operations
https://lxy-alexander.github.io/blog/posts/pytorch/api/03math--statistical-operations/
Author
Alexander Lee
Published at
2026-03-12
License
CC BY-NC-SA 4.0