深度学习 | 数据增强库——albumenations的使用
你可以通过修改transform管道中的参数来调整增强策略albumentations官方文档。
1 什么是数据增强?
定义:通过人工生成多样化训练样本的技术,在不实际收集新数据的前提下,对现有数据进行变换扩展。
通俗理解:像给照片加「滤镜」和「特效」,但目的是让AI模型学习更强大的特征。
示例场景(以安全帽检测为例):
-
原始图片 → 建筑工人正脸佩戴安全帽
-
增强后可能包含:
-
旋转30度的工人侧脸
-
模拟阴天环境的低亮度版本
-
添加雨雾效果的施工场景
-
镜像翻转的左右对称版本
-
2 为什么需要数据增强?
| 场景 | 问题 | 增强解决方案 |
|---|---|---|
| 数据量不足 | 只收集到200张现场照片 | 生成2000+变异样本 |
| 模型过拟合 | 只在晴天数据表现好 | 添加阴雨雾雪模拟 |
| 现实复杂性 | 无法涵盖所有拍摄角度 | 随机旋转/透视变换 |
| 硬件限制 | 高分辨率训练困难 | 随机裁剪缩小尺寸 |
关键作用:
-
提升模型泛化能力(识别不同光照/角度的安全帽)
-
防止过拟合(避免模型死记硬背训练样本)
-
增强鲁棒性(应对实际场景中的噪声干扰)
3 具体实现
直接先上代码,然后再讲注意事项
3.1 数据增强脚本
以下是一个使用albumentations库实现的完整数据增强脚本,支持对指定文件夹内的所有图片进行批量增强处理。命名为:my_script.py(可根据自己的喜欢命名)
import os
import cv2
import albumentations as A
from tqdm import tqdm
def augment_folder(input_dir, output_dir, num_augments=3):
"""
对输入文件夹中的图片进行数据增强
:param input_dir: 输入图片文件夹路径
:param output_dir: 输出结果文件夹路径
:param num_augments: 每张图片生成的增强版本数量
"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 定义增强管道
transform = A.Compose([
A.RandomRotate90(p=0.5),
A.Flip(p=0.5),
A.RandomResizedCrop(height=256, width=256, scale=(0.8, 1.0), p=0.5),
A.RandomBrightnessContrast(p=0.3),
A.GaussianBlur(blur_limit=(3, 7), p=0.2),
A.HueSaturationValue(p=0.3),
A.CLAHE(p=0.3),
A.RandomGamma(p=0.2)
])
# 获取图片文件列表
image_files = [f for f in os.listdir(input_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
# 处理进度条
pbar = tqdm(image_files, desc="Processing images")
for filename in pbar:
# 读取原始图片
image_path = os.path.join(input_dir, filename)
image = cv2.imread(image_path)
if image is None:
continue
# 生成多个增强版本
for i in range(num_augments):
# 应用数据增强
transformed = transform(image=image)
transformed_image = transformed["image"]
# 生成新文件名
name, ext = os.path.splitext(filename)
new_filename = f"{name}_aug{i}{ext}"
output_path = os.path.join(output_dir, new_filename)
# 保存增强后的图片
cv2.imwrite(output_path, transformed_image)
if __name__ == "__main__":
# 配置参数
input_directory = "input_images" # 输入图片文件夹
output_directory = "augmented_images" # 输出结果文件夹
augmentations_per_image = 5 # 每张图片生成5个增强版本
# 执行数据增强
augment_folder(input_directory,
output_directory,
num_augments=augmentations_per_image)
print(f"数据增强完成!增强后的图片已保存至:{output_directory}")
3.2 代码说明
3.2.1 支持功能
- 支持常见图片格式:jpg、png、jpeg、bmp
- 自动跳过无法读取的图片文件
- 原始文件名会保留并添加_aug0、_aug1等后缀
- 使用tqdm显示处理进度条
- 自动创建输出目录(如果不存在)
3.2.2 支持的增强操作
- 随机90度旋转
- 水平/垂直翻转
- 随机裁剪和缩放
- 亮度对比度调整
- 高斯模糊
- 色调饱和度调整
- CLAHE直方图均衡
- Gamma校正
3.3 使用说明
TIPS:建议在anaconda下,新建一个虚拟环境,来实现数据增强,避免污染其他环境
3.3.1 创建环境并激活
conda create -n data_process python=3.8
conda activate data_process
- data_process:环境名。可替换为自己想要的命名
3.3.2 安装依赖库
pip install albumentations opencv-python tqdm
3.3.3 脚本修改
1、修改输入输出路径(必选)
修改脚本第58、59行的,input_directory和output_directory变量名对应的内容,应为输入图片所在文件夹和输出文件所在文件夹。
PS:尽量使用绝对路径,避免出错
2、调整增强强度(可选)
修改num_augments参数控制每张图片生成的增强版本数量
3、自定义增强管道(可选)
在transform中调整或添加Albumentations的增强操作
4、增强效果预览(可选)
如果需要预览增强效果,可以将以下代码添加到循环中
# 在cv2.imwrite之前添加
preview = cv2.resize(transformed_image, (512, 512))
cv2.imshow("Preview", preview)
cv2.waitKey(100) # 显示100ms
3.3.4 运行
激活虚拟环境后,在终端中,输入以下命令
python my_script.py
4 最后
你可以通过修改transform管道中的参数来调整增强策略
albumentations官方文档提供了完整的增强操作列表和参数说明:albumentations官方文档
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)