457 words
2 minutes
Data Loading & Preprocessing

VIII. Data Loading & Preprocessing (数据加载与预处理)#

1. torch.utils.data.Dataset#

Abstract base class for custom datasets. Must implement __len__ and __getitem__.
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
Note: Put all preprocessing / augmentation inside __getitem__ for Lazy Loading (懒加载).

2. torch.utils.data.DataLoader#

Wraps a Dataset into an iterable batch loader with parallel reading (并行读取) and data shuffling (数据打乱).
from torch.utils.data import DataLoader
loader = DataLoader(
dataset=train_ds, batch_size=32,
shuffle=True, num_workers=4, pin_memory=True
)
for x, y in loader:
...
Note: On Windows, num_workers > 0 requires if __name__ == '__main__': guard.

3. torchvision.transforms#

Image preprocessing and data augmentation (数据增强) library. Chain multiple transforms with Compose.
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
Note: The Normalize parameters are ImageNet statistics. Keep them consistent when using Transfer Learning (迁移学习).

4. torchvision.datasets.ImageFolder#

Automatically builds an image classification dataset from directory structure — subdirectory names become class labels (类别标签).
from torchvision.datasets import ImageFolder
# data/train/cat/*.jpg, data/train/dog/*.jpg
ds = ImageFolder(root='data/train', transform=transform)
print(ds.classes) # ['cat', 'dog']
Note: Save the class_to_idx dictionary alongside the model checkpoint.

5. torch.utils.data.random_split()#

Randomly splits a dataset into train/validation subsets by specified lengths.
from torch.utils.data import random_split
n_val = int(len(dataset) * 0.2)
train_ds, val_ds = random_split(dataset, [len(dataset) - n_val, n_val])
Note: Pass generator=torch.Generator().manual_seed(42) for reproducible splits.

6. torchvision.models (pretrained)#

Provides many pre-trained models: ResNet, VGG, ViT, etc. Enables rapid Transfer Learning (迁移学习).
import torchvision.models as models
model = models.resnet50(weights='IMAGENET1K_V2')
model.fc = nn.Linear(2048, 10) # replace head for fine-tuning
Note: Freeze the backbone: for p in model.parameters(): p.requires_grad = False.
💡 One-line Takeaway
The data pipeline is: Dataset (what) → transforms (how to augment) → DataLoader (how to batch).

Data Loading & Preprocessing
https://lxy-alexander.github.io/blog/posts/pytorch/api/08data-loading--preprocessing/
Author
Alexander Lee
Published at
2026-03-12
License
CC BY-NC-SA 4.0