torchvision 是 PyTorch 的核心视觉库,提供了图像和视频处理所需的工具,包括数据集、模型、预处理和底层操作。以下是对其核心模块的详细解读:

1. 核心模块概览

模块 主要功能
datasets 预实现的视觉数据集(如 MNIST、CIFAR、ImageNet)
models 预训练的深度学习模型(如 ResNet、VGG、Faster R-CNN)
transforms 图像预处理和增强(如裁剪、旋转、归一化)
utils 辅助工具(如网格生成、边界框绘制)
ops 计算机视觉底层操作(如 NMS、ROI 对齐)
io 图像和视频的输入输出(支持高效解码和 GPU 加速)
datapoints 类型化数据结构(如 Image、BoundingBoxes)
extensions 可选扩展功能(如 CUDA 优化)

2. 各模块深度解析

2.1 torchvision.datasets

功能:提供预实现的数据集类,简化数据加载。
常用数据集

  • 分类数据集:MNIST、CIFAR10/100、ImageNet、FashionMNIST
  • 检测 / 分割数据集:COCO、VOC、Cityscapes
  • 其他:Caltech101/256、Stanford Cars、SVHN

示例

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义预处理操作
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize(   # 归一化
        mean=[0.1307],      # MNIST数据集的均值
        std=[0.3081]        # MNIST数据集的标准差
    )
])

# 加载MNIST训练集,传入定义好的transform
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform  # 注意这里用的是上面定义的transform变量
)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

 

2.2 torchvision.models

功能:提供预训练的深度学习模型,支持微调。
常用模型家族

  • 分类模型:ResNet、VGG、MobileNet、EfficientNet
  • 检测模型:Faster R-CNN、Mask R-CNN、SSD
  • 分割模型:DeepLabV3、FCN、U-Net
  • 生成模型:DCGAN、StyleGAN

示例

from torchvision import models

# 加载预训练的 ResNet18,34,50,101,152
resnet = models.resnet18(pretrained=True)

# 修改最后一层以适应新任务
num_classes = 10
resnet.fc = torch.nn.Linear(resnet.fc.in_features, num_classes)
2.3 torchvision.transforms

功能:图像预处理和增强,支持链式操作。
常用变换

  • 尺寸调整ResizeCenterCropRandomCrop
  • 数据增强RandomHorizontalFlipColorJitterRandomRotation
  • 格式转换ToTensorNormalizePILToTensor

示例

from torchvision import transforms

# 定义预处理管道
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
2.4 torchvision.utils

功能:辅助工具,简化可视化和调试。
常用函数

  • 图像网格make_grid(将多张图像拼接为网格)
  • 边界框绘制draw_bounding_boxes(可视化检测结果)
  • 分割掩码绘制draw_segmentation_masks(可视化分割结果)

示例

from torchvision.utils import make_grid, save_image
from PIL import Image
import matplotlib.pyplot as plt
# 创建图像网格'
images = torch.rand(8, 3, 224, 224)  # 8张随机RGB图像
grid = make_grid(images, nrow=8)  # images: [B, C, H, W]
save_image(grid, 'grid.png')
# img = Image.open('grid.png')
# img.show()

 

2.5 torchvision.ops

功能:底层计算机视觉操作,优化性能。
常用操作

  • 非极大值抑制(NMS)nms(去除重叠边界框)
  • ROI 对齐 / 池化roi_alignroi_pool(目标检测特征提取)
  • 边界框操作box_iou(计算 IoU)、box_convert(格式转换)

示例

from torchvision.ops import nms, roi_align

# NMS
keep_indices = nms(boxes, scores, iou_threshold=0.5)

# ROI 对齐
roi_features = roi_align(
    features,           # 特征图 [B, C, H, W]
    boxes,              # 边界框 [N, 5](batch_idx, x1, y1, x2, y2)
    output_size=(7, 7), # 输出尺寸
    spatial_scale=0.25  # 缩放比例
)
2.6 torchvision.io

功能:高效读写图像和视频,支持 GPU 加速。
核心函数

  • 图像操作read_imagewrite_image
  • 视频操作read_videowrite_videoVideoReader

示例

from torchvision.io import read_image, write_video

# 读取图像为张量
image = read_image("image.jpg")  # [C, H, W],dtype=torch.uint8

# 读取视频
frames, audio, info = read_video("video.mp4")
# frames: [T, H, W, C](帧数, 高, 宽, 通道)
2.7 torchvision.datapoints

功能:类型化数据结构,增强代码安全性。
核心类

  • Image:表示图像张量(形状为 [C, H, W]
  • BoundingBoxes:表示边界框(支持多种格式,如 XYXY、CXCYWH)
  • Mask:表示分割掩码

示例

from torchvision import datapoints

# 创建图像数据点
image = datapoints.Image(torch.rand(3, 224, 224))

# 创建边界框
boxes = datapoints.BoundingBoxes(
    torch.tensor([[10, 20, 100, 200]]),
    format=datapoints.BoundingBoxFormat.XYXY,
    canvas_size=image.shape[-2:]  # 关联图像尺寸
)
2.8 torchvision.transforms.v2

功能:增强版数据变换,支持多模态数据(如图像 + 边界框)。
优势

  • 自动处理不同数据类型(如同时变换图像和边界框)
  • 支持新的数据点类型(如 datapoints.BoundingBoxes

示例

from torchvision.transforms import v2

# 定义变换(同时处理图像和边界框)
transform = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224)),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToTensor(),
])

# 应用变换
output = transform({
    "image": image,          # datapoints.Image
    "boxes": boxes,          # datapoints.BoundingBoxes
    "masks": masks           # datapoints.Mask
})

3.IO操作对比

torchvision.io 确实提供了独立的数据读取功能,并且在特定场景下具有显著优势。

设计目标 性能特点
PIL/Pillow 通用图像处理,API 友好,支持丰富的图像格式和转换操作。 纯 Python 实现,速度较慢;适合小规模数据处理或简单预处理。
OpenCV (cv2) 计算机视觉专用库,提供高效的 C++ 底层实现,支持 GPU 加速(需额外配置)。 速度快,尤其适合大规模数据和实时处理;但 API 设计较复杂,默认返回 BGR 格式。
torchvision.io 专为 PyTorch 设计,直接返回张量(Tensor),无需额外格式转换;支持 GPU 解码和批量操作。 内存高效,减少 CPU-GPU 数据传输开销;适合深度学习训练时的高性能数据加载。
3.1 操作示例:
3.1.1 PIL/Pillow
from PIL import Image

# 读取图像为PIL对象
image = Image.open("example.jpg")  # 返回PIL.Image对象
tensor = transforms.ToTensor()(image)  # 需要手动转换为张量
  • 优点:API 简单直观,支持丰富的图像操作(如旋转、裁剪)。
  • 缺点:需手动转换为张量,涉及多次内存拷贝,效率较低。
3.1.2 OpenCV (cv2)
import cv2

# 读取图像为NumPy数组(BGR格式)
image = cv2.imread("example.jpg")  # 返回numpy.ndarray,形状:(H, W, 3)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 需手动转为RGB
tensor = torch.from_numpy(image).permute(2, 0, 1)  # 需手动调整维度和类型
  • 优点:速度快,支持视频处理和实时操作。
  • 缺点:默认返回 BGR 格式,需手动调整维度和类型,代码冗余。
3.1.3  torchvision.io
import torchvision.io as io

# 直接读取为张量(RGB格式,无需转换)
tensor = io.read_image("example.jpg")  # 返回torch.Tensor,形状:(3, H, W),dtype=uint8
  • 优点:直接返回 PyTorch 张量,无需格式转换,内存高效;支持 GPU 解码。
  • 缺点:功能较单一,仅支持基本读写,复杂处理需配合其他库。
Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐