torchvison的模块解读
是 PyTorch 的核心视觉库,提供了图像和视频处理所需的工具,包括数据集、模型、预处理和底层操作。datasetsmodelstransformsutilsopsiodatapointsextensions:提供预实现的数据集类,简化数据加载。:提供预训练的深度学习模型,支持微调。:图像预处理和增强,支持链式操作。ResizeCenterCropRandomCropToTensorNormal
·
torchvision 是 PyTorch 的核心视觉库,提供了图像和视频处理所需的工具,包括数据集、模型、预处理和底层操作。以下是对其核心模块的详细解读:
1. 核心模块概览
| 模块 | 主要功能 |
|---|---|
datasets |
预实现的视觉数据集(如 MNIST、CIFAR、ImageNet) |
models |
预训练的深度学习模型(如 ResNet、VGG、Faster R-CNN) |
transforms |
图像预处理和增强(如裁剪、旋转、归一化) |
utils |
辅助工具(如网格生成、边界框绘制) |
ops |
计算机视觉底层操作(如 NMS、ROI 对齐) |
io |
图像和视频的输入输出(支持高效解码和 GPU 加速) |
datapoints |
类型化数据结构(如 Image、BoundingBoxes) |
extensions |
可选扩展功能(如 CUDA 优化) |
2. 各模块深度解析
2.1 torchvision.datasets
功能:提供预实现的数据集类,简化数据加载。
常用数据集:
- 分类数据集:MNIST、CIFAR10/100、ImageNet、FashionMNIST
- 检测 / 分割数据集:COCO、VOC、Cityscapes
- 其他:Caltech101/256、Stanford Cars、SVHN
示例:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义预处理操作
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量
transforms.Normalize( # 归一化
mean=[0.1307], # MNIST数据集的均值
std=[0.3081] # MNIST数据集的标准差
)
])
# 加载MNIST训练集,传入定义好的transform
train_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform # 注意这里用的是上面定义的transform变量
)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
2.2 torchvision.models
功能:提供预训练的深度学习模型,支持微调。
常用模型家族:
- 分类模型:ResNet、VGG、MobileNet、EfficientNet
- 检测模型:Faster R-CNN、Mask R-CNN、SSD
- 分割模型:DeepLabV3、FCN、U-Net
- 生成模型:DCGAN、StyleGAN
示例:
from torchvision import models
# 加载预训练的 ResNet18,34,50,101,152
resnet = models.resnet18(pretrained=True)
# 修改最后一层以适应新任务
num_classes = 10
resnet.fc = torch.nn.Linear(resnet.fc.in_features, num_classes)
2.3 torchvision.transforms
功能:图像预处理和增强,支持链式操作。
常用变换:
- 尺寸调整:
Resize、CenterCrop、RandomCrop - 数据增强:
RandomHorizontalFlip、ColorJitter、RandomRotation - 格式转换:
ToTensor、Normalize、PILToTensor
示例:
from torchvision import transforms
# 定义预处理管道
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
2.4 torchvision.utils
功能:辅助工具,简化可视化和调试。
常用函数:
- 图像网格:
make_grid(将多张图像拼接为网格) - 边界框绘制:
draw_bounding_boxes(可视化检测结果) - 分割掩码绘制:
draw_segmentation_masks(可视化分割结果)
示例:
from torchvision.utils import make_grid, save_image
from PIL import Image
import matplotlib.pyplot as plt
# 创建图像网格'
images = torch.rand(8, 3, 224, 224) # 8张随机RGB图像
grid = make_grid(images, nrow=8) # images: [B, C, H, W]
save_image(grid, 'grid.png')
# img = Image.open('grid.png')
# img.show()
2.5 torchvision.ops
功能:底层计算机视觉操作,优化性能。
常用操作:
- 非极大值抑制(NMS):
nms(去除重叠边界框) - ROI 对齐 / 池化:
roi_align、roi_pool(目标检测特征提取) - 边界框操作:
box_iou(计算 IoU)、box_convert(格式转换)
示例:
from torchvision.ops import nms, roi_align
# NMS
keep_indices = nms(boxes, scores, iou_threshold=0.5)
# ROI 对齐
roi_features = roi_align(
features, # 特征图 [B, C, H, W]
boxes, # 边界框 [N, 5](batch_idx, x1, y1, x2, y2)
output_size=(7, 7), # 输出尺寸
spatial_scale=0.25 # 缩放比例
)
2.6 torchvision.io
功能:高效读写图像和视频,支持 GPU 加速。
核心函数:
- 图像操作:
read_image、write_image - 视频操作:
read_video、write_video、VideoReader
示例:
from torchvision.io import read_image, write_video
# 读取图像为张量
image = read_image("image.jpg") # [C, H, W],dtype=torch.uint8
# 读取视频
frames, audio, info = read_video("video.mp4")
# frames: [T, H, W, C](帧数, 高, 宽, 通道)
2.7 torchvision.datapoints
功能:类型化数据结构,增强代码安全性。
核心类:
Image:表示图像张量(形状为[C, H, W])BoundingBoxes:表示边界框(支持多种格式,如 XYXY、CXCYWH)Mask:表示分割掩码
示例:
from torchvision import datapoints
# 创建图像数据点
image = datapoints.Image(torch.rand(3, 224, 224))
# 创建边界框
boxes = datapoints.BoundingBoxes(
torch.tensor([[10, 20, 100, 200]]),
format=datapoints.BoundingBoxFormat.XYXY,
canvas_size=image.shape[-2:] # 关联图像尺寸
)
2.8 torchvision.transforms.v2
功能:增强版数据变换,支持多模态数据(如图像 + 边界框)。
优势:
- 自动处理不同数据类型(如同时变换图像和边界框)
- 支持新的数据点类型(如
datapoints.BoundingBoxes)
示例:
from torchvision.transforms import v2
# 定义变换(同时处理图像和边界框)
transform = v2.Compose([
v2.RandomResizedCrop(size=(224, 224)),
v2.RandomHorizontalFlip(p=0.5),
v2.ToTensor(),
])
# 应用变换
output = transform({
"image": image, # datapoints.Image
"boxes": boxes, # datapoints.BoundingBoxes
"masks": masks # datapoints.Mask
})
3.IO操作对比
torchvision.io 确实提供了独立的数据读取功能,并且在特定场景下具有显著优势。
| 库 | 设计目标 | 性能特点 |
|---|---|---|
| PIL/Pillow | 通用图像处理,API 友好,支持丰富的图像格式和转换操作。 | 纯 Python 实现,速度较慢;适合小规模数据处理或简单预处理。 |
| OpenCV (cv2) | 计算机视觉专用库,提供高效的 C++ 底层实现,支持 GPU 加速(需额外配置)。 | 速度快,尤其适合大规模数据和实时处理;但 API 设计较复杂,默认返回 BGR 格式。 |
| torchvision.io | 专为 PyTorch 设计,直接返回张量(Tensor),无需额外格式转换;支持 GPU 解码和批量操作。 | 内存高效,减少 CPU-GPU 数据传输开销;适合深度学习训练时的高性能数据加载。 |
3.1 操作示例:
3.1.1 PIL/Pillow
from PIL import Image
# 读取图像为PIL对象
image = Image.open("example.jpg") # 返回PIL.Image对象
tensor = transforms.ToTensor()(image) # 需要手动转换为张量
- 优点:API 简单直观,支持丰富的图像操作(如旋转、裁剪)。
- 缺点:需手动转换为张量,涉及多次内存拷贝,效率较低。
3.1.2 OpenCV (cv2)
import cv2
# 读取图像为NumPy数组(BGR格式)
image = cv2.imread("example.jpg") # 返回numpy.ndarray,形状:(H, W, 3)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 需手动转为RGB
tensor = torch.from_numpy(image).permute(2, 0, 1) # 需手动调整维度和类型
- 优点:速度快,支持视频处理和实时操作。
- 缺点:默认返回 BGR 格式,需手动调整维度和类型,代码冗余。
3.1.3 torchvision.io
import torchvision.io as io
# 直接读取为张量(RGB格式,无需转换)
tensor = io.read_image("example.jpg") # 返回torch.Tensor,形状:(3, H, W),dtype=uint8
- 优点:直接返回 PyTorch 张量,无需格式转换,内存高效;支持 GPU 解码。
- 缺点:功能较单一,仅支持基本读写,复杂处理需配合其他库。
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)