457 words
2 minutes
Data Loading & Preprocessing

VIII. Data Loading & Preprocessing (数据加载与预处理)
1. torch.utils.data.Dataset
Abstract base class for custom datasets. Must implement
__len__ and __getitem__. from torch.utils.data import Dataset
class MyDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels
def __len__(self): return len(self.data)
def __getitem__(self, idx): return self.data[idx], self.labels[idx]Note: Put all preprocessing / augmentation inside
__getitem__ for Lazy Loading (懒加载).2. torch.utils.data.DataLoader
Wraps a Dataset into an iterable batch loader with parallel reading (并行读取) and data shuffling (数据打乱).
from torch.utils.data import DataLoaderloader = DataLoader( dataset=train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)for x, y in loader: ...Note: On Windows,
num_workers > 0 requires if __name__ == '__main__': guard.3. torchvision.transforms
Image preprocessing and data augmentation (数据增强) library. Chain multiple transforms with
Compose. from torchvision import transformstransform = transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])Note: The Normalize parameters are ImageNet statistics. Keep them consistent when using Transfer Learning (迁移学习).
4. torchvision.datasets.ImageFolder
Automatically builds an image classification dataset from directory structure — subdirectory names become class labels (类别标签).
from torchvision.datasets import ImageFolder# data/train/cat/*.jpg, data/train/dog/*.jpgds = ImageFolder(root='data/train', transform=transform)print(ds.classes) # ['cat', 'dog']Note: Save the
class_to_idx dictionary alongside the model checkpoint.5. torch.utils.data.random_split()
Randomly splits a dataset into train/validation subsets by specified lengths.
from torch.utils.data import random_splitn_val = int(len(dataset) * 0.2)train_ds, val_ds = random_split(dataset, [len(dataset) - n_val, n_val])Note: Pass
generator=torch.Generator().manual_seed(42) for reproducible splits.6. torchvision.models (pretrained)
Provides many pre-trained models: ResNet, VGG, ViT, etc. Enables rapid Transfer Learning (迁移学习).
import torchvision.models as modelsmodel = models.resnet50(weights='IMAGENET1K_V2')model.fc = nn.Linear(2048, 10) # replace head for fine-tuningNote: Freeze the backbone:
for p in model.parameters(): p.requires_grad = False.💡 One-line Takeaway
The data pipeline is:
The data pipeline is:
Dataset (what) → transforms (how to augment) → DataLoader (how to batch). Data Loading & Preprocessing
https://lxy-alexander.github.io/blog/posts/pytorch/api/08data-loading--preprocessing/