905 words
5 minutes
Tensor Shape & Dimension Transforms

II. Tensor Shape & Dimension Transforms (张量形状与维度变换)
1. Tensor.view() / Tensor.reshape()
Reshapes a Tensor without changing its data.
view requires Contiguous Memory (连续内存); reshape handles non-contiguous cases automatically. x = torch.arange(12) # shape [12]y = x.view(3, 4) # shape [3, 4]z = x.reshape(2, 6) # shape [2, 6]w = x.reshape(-1, 3) # -1 auto-infers → shape [4, 3]Note: Prefer
reshape by default; switch to view only when you need to guarantee zero-copy memory sharing.2. torch.squeeze() / torch.unsqueeze()
squeeze: Removes dimensions of size 1. unsqueeze: Inserts a size-1 dimension at a specified position. x = torch.zeros(1, 3, 1, 5)y = x.squeeze() # [3, 5]
z = torch.zeros(3, 5)w = z.unsqueeze(0) # [1, 3, 5]v = z.unsqueeze(-1) # [3, 5, 1]Note:
unsqueeze is frequently used for Broadcasting (广播): add a batch dimension (批次维度) or channel dimension (通道维度) to a vector.3. torch.cat()
Concatenates multiple Tensors along an existing dimension (已有维度). Does not create a new axis.
a = torch.zeros(2, 3)b = torch.ones(4, 3)c = torch.cat([a, b], dim=0) # shape [6, 3]
d = torch.cat([torch.zeros(2, 2), torch.ones(2, 4)], dim=1) # shape [2, 6]Note: All Tensors must have the same shape on every dimension except the concatenation axis (拼接轴).
4. torch.stack()
Stacks Tensors along a new dimension (新维度). All input Tensors must be exactly the same shape.
a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])c = torch.stack([a, b]) # [2, 3] — new dim=0d = torch.stack([a, b], dim=1) # [3, 2] — new dim=1Note: Key difference from
cat: stack requires identical shapes and always creates a new axis.5. Tensor.permute()
Reorders dimensions according to a specified axis order. Equivalent to NumPy's
transpose(axes). # Convert NCHW → NHWCx = torch.zeros(8, 3, 224, 224)y = x.permute(0, 2, 3, 1) # shape [8, 224, 224, 3]Note: After
permute, the Tensor becomes Non-contiguous (非连续). Call .contiguous() if a subsequent op requires contiguous memory.6. Tensor.transpose()
Swaps exactly two specified dimensions. A simplified version of
permute for axis swapping. x = torch.zeros(4, 5, 6)y = x.transpose(1, 2) # shape [4, 6, 5]
m = torch.rand(3, 4)mt = m.t() # 2D matrix transpose → shape [4, 3]Note: Use
.transpose(-1, -2) for Batched Matrix Transpose (批量矩阵转置), valid over any batch dimension.7. torch.split() / torch.chunk()
split: Splits by specified sizes. chunk: Splits into equal pieces; the last chunk may be smaller. x = torch.arange(10)parts = torch.split(x, 3) # (tensor([0,1,2]), tensor([3,4,5]), tensor([6,7,8]), tensor([9]))chunks = torch.chunk(x, 3) # 3 chunks: [0–3], [4–6], [7–9]Note: Multi-GPU Sharding (多GPU分片) and DataLoader batch splitting internally rely on
chunk / split logic.8. Tensor.flatten()
Flattens a Tensor to 1D, or flattens a specific range of dimensions.
x = torch.zeros(2, 3, 4)y = x.flatten() # shape [24]z = x.flatten(1, 2) # shape [2, 12] — only flatten dims 1–2Note: Most commonly used at the CNN → Fully Connected (全连接) transition. Equivalent to
x.view(x.size(0), -1).9. torch.broadcast_to()
Broadcasts (广播) a Tensor to a target shape as a read-only view. No data is copied.
x = torch.tensor([1, 2, 3]) # shape [3]y = torch.broadcast_to(x, (4, 3)) # shape [4, 3]Note: Broadcasting is the underlying mechanism of most PyTorch arithmetic operations. Understanding it helps avoid Shape Errors (形状错误).
10. Tensor.expand()
Expands size-1 dimensions to a specified size. Shares storage (no memory copy), unlike
repeat. x = torch.zeros(3, 1)y = x.expand(3, 4) # shape [3, 4] — zero memory copyz = x.repeat(1, 4) # shape [3, 4] — actual data copyNote:
expand results in a Non-contiguous Tensor; call .contiguous() or .clone() before writing.💡 One-line Takeaway
Master
Master
reshape, squeeze/unsqueeze, cat/stack, and permute — together they cover 90% of all shape manipulation needs. Tensor Shape & Dimension Transforms
https://lxy-alexander.github.io/blog/posts/pytorch/api/02tensor-shape--dimension-transforms/