BrushNet数据集构建:BrushData与BrushBench使用指南

【免费下载链接】BrushNet The official implementation of paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion" 【免费下载链接】BrushNet 项目地址: https://gitcode.com/GitHub_Trending/br/BrushNet

引言:图像修复的新范式

在图像生成与编辑领域,图像修复(Image Inpainting)一直是一个具有挑战性的任务。传统的修复方法往往难以处理复杂的缺失区域,特别是在需要保持语义一致性和视觉真实性的场景中。BrushNet作为ECCV 2024的最新研究成果,提出了一种即插即用的双分支扩散模型架构,为图像修复任务带来了革命性的突破。

本文将深入探讨BrushNet项目中的两个核心数据集:BrushData(训练数据集)和BrushBench(评估基准),为您提供完整的构建和使用指南。

BrushNet架构概览

在深入了解数据集之前,让我们先快速了解BrushNet的核心设计理念:

mermaid

BrushNet的关键创新在于将掩码图像特征和噪声潜在表示分离处理,显著降低了模型的学习负担,同时通过密集的像素级控制增强了预训练模型在图像修复任务中的适用性。

BrushData:训练数据集详解

数据集结构与格式

BrushData是专门为训练BrushNet模型设计的大规模数据集,其结构组织如下:

BrushData/
├── 00200.tar
├── 00201.tar
├── 00202.tar
└── ... (共200个tar文件)

每个tar文件包含多个样本,每个样本包含以下字段:

字段名 数据类型 描述
image bytes 原始图像数据(JPEG编码)
height string 图像高度
width string 图像宽度
caption string 图像描述文本
segmentation JSON 分割掩码信息(RLE编码)

数据预处理流程

BrushData的数据预处理遵循以下标准化流程:

def preprocess_brush_data_sample(example):
    # 解码图像数据
    image = cv2.imdecode(np.frombuffer(example["image"], np.uint8), cv2.IMREAD_COLOR)
    
    # 解析分割掩码
    segmentation = json.loads(example["segmentation"])
    if len(segmentation["mask"]) > 0:
        mask = rle2mask(random.choice(segmentation["mask"]), (height, width))
    else:
        mask = np.ones_like(image)[:,:,0:1]
    
    # 随机掩码生成(可选)
    if self.random_mask:
        mask = random_mask_gen(image.shape[0], image.shape[1])
    
    # 应用形态学操作增强
    if random.random() < 0.3:
        kernel = np.ones((8,8), np.uint8)
        mask_erosion = cv2.erode(mask, kernel, iterations=1)
        mask_dilation = cv2.dilate(mask_erosion, kernel, iterations=1)
        mask = 1 * (mask_dilation > 0)
    
    # 生成掩码图像
    masked_image = image * mask
    
    # 随机反转掩码(内外修复)
    if random.random() < 0.5:
        masked_image = image - masked_image
        mask = 1 - mask
    
    return {
        "pixel_values": normalized_image,
        "conditioning_pixel_values": normalized_masked_image,
        "masks": mask,
        "input_ids": tokenized_caption
    }

数据集统计特性

BrushData数据集具有以下关键统计特性:

统计指标 数值 说明
总样本数 ~2M 约200万个训练样本
图像分辨率 可变 原始分辨率,训练时统一resize到512x512
掩码类型 分割掩码 + 随机掩码 支持物体形状掩码和随机形状掩码
文本描述 英文 来自LAION数据集的高质量标注

自定义数据集构建

如果您希望构建自己的BrushNet训练数据集,可以参考以下格式:

# 自定义数据集示例
class CustomBrushDataset(Dataset):
    def __init__(self, image_dir, annotation_file, transform=None):
        self.image_dir = image_dir
        self.annotations = self.load_annotations(annotation_file)
        self.transform = transform
        self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    
    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        
        # 加载图像
        image_path = os.path.join(self.image_dir, annotation["image_path"])
        image = Image.open(image_path).convert("RGB")
        
        # 加载掩码(支持多种格式)
        if "mask_path" in annotation:
            mask = Image.open(annotation["mask_path"]).convert("L")
        elif "mask_rle" in annotation:
            mask = rle2mask(annotation["mask_rle"], image.size)
        else:
            mask = generate_random_mask(image.size)
        
        # 数据增强
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        # 生成掩码图像
        masked_image = image * mask
        
        # 文本编码
        caption = annotation.get("caption", "")
        text_inputs = self.tokenizer(
            caption, max_length=77, padding="max_length", 
            truncation=True, return_tensors="pt"
        )
        
        return {
            "pixel_values": image,
            "conditioning_pixel_values": masked_image,
            "masks": mask,
            "input_ids": text_inputs.input_ids.squeeze()
        }

BrushBench:评估基准详解

基准结构设计

BrushBench是一个专门为评估图像修复模型性能设计的标准化基准,其结构如下:

BrushBench/
├── images/
│   ├── 00001.jpg
│   ├── 00002.jpg
│   └── ... (测试图像)
└── mapping_file.json

评估指标体系

BrushBench提供了全面的评估指标体系,涵盖多个维度:

class MetricsCalculator:
    def __init__(self, device):
        # 感知质量指标
        self.clip_metric = CLIPScore(model_name="openai/clip-vit-large-patch14")
        self.lpips_metric = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
        
        # 美学评估
        self.aesthetic_model = self.load_aesthetic_model()
        
        # 像素级指标
        self.psnr = PeakSignalNoiseRatio()
        self.ssim = StructuralSimilarityIndexMeasure()
        self.mse = MeanSquaredError()
        
        # 人类偏好指标
        self.imagereward_model = RM.load("ImageReward-v1.0")
        self.hpsv2_model = hpsv2.load("hpsv2.1")

评估流程

完整的BrushBench评估流程如下:

mermaid

使用BrushBench进行评估

使用BrushBench进行评估的完整代码示例:

# 评估脚本示例
def evaluate_brushnet_on_brushbench():
    # 加载模型
    brushnet = BrushNetModel.from_pretrained("path/to/brushnet")
    pipe = StableDiffusionBrushNetPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", 
        brushnet=brushnet
    )
    
    # 加载BrushBench数据
    with open("data/BrushBench/mapping_file.json", "r") as f:
        test_cases = json.load(f)
    
    results = []
    for case_id, case_data in test_cases.items():
        # 准备输入
        image = load_image(case_data["image"])
        mask = rle2mask(case_data["inpainting_mask"], (512, 512))
        prompt = case_data["caption"]
        
        # 生成修复结果
        result = pipe(
            prompt=prompt,
            image=image,
            mask_image=mask,
            num_inference_steps=50
        ).images[0]
        
        # 计算指标
        metrics = calculate_metrics(result, image, mask, prompt)
        results.append({
            "case_id": case_id,
            "metrics": metrics,
            "result_image": result
        })
    
    # 生成评估报告
    generate_evaluation_report(results)

实战案例:从数据到模型

案例1:商品展示图像修复

场景描述:电商平台需要自动修复商品展示图中被遮挡或损坏的区域。

数据处理策略

  • 使用商品分割掩码作为先验信息
  • 针对商品类别定制文本提示模板
  • 采用高分辨率训练(1024x1024)

代码实现

def train_product_inpainting():
    # 数据准备
    dataset = ProductBrushDataset(
        product_image_dir="data/products/images",
        segmentation_dir="data/products/segmentation",
        prompt_templates=[
            "a high-quality photo of {product_name}",
            "professional product photography of {product_name}",
            "{product_name} on white background"
        ]
    )
    
    # 训练配置
    training_args = {
        "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
        "train_data_dir": dataset,
        "resolution": 1024,
        "learning_rate": 1e-5,
        "train_batch_size": 2,
        "gradient_accumulation_steps": 4,
        "max_train_steps": 10000,
        "checkpointing_steps": 1000
    }
    
    # 启动训练
    accelerate_launch("train_brushnet.py", training_args)

案例2:艺术创作辅助

场景描述:数字艺术创作中需要智能填充画布空白区域。

特色功能

  • 支持风格一致性保持
  • 多尺度修复能力
  • 创造性内容生成

评估指标

artistic_metrics = {
    "style_consistency": calculate_style_similarity,
    "creative_score": assess_creativity,
    "aesthetic_quality": aesthetic_assessment,
    "user_preference": collect_human_feedback
}

最佳实践与技巧

数据准备技巧

  1. 掩码质量至关重要

    • 使用精确的分割掩码而非粗糙的矩形掩码
    • 确保掩码边缘平滑,避免锯齿效应
  2. 文本提示优化

    def enhance_prompts(base_prompts):
        enhanced = []
        for prompt in base_prompts:
            # 添加细节描述
            enhanced.append(f"high detail, 4k, professional, {prompt}")
            # 添加风格描述
            enhanced.append(f"photorealistic, sharp focus, {prompt}")
            # 添加场景上下文
            enhanced.append(f"{prompt}, perfect lighting, studio quality")
        return enhanced
    
  3. 数据增强策略

    • 随机掩码形状和大小
    • 颜色空间变换
    • 几何变换(旋转、缩放、裁剪)

训练优化建议

  1. 学习率调度

    # 推荐的学习率配置
    lr_config = {
        "initial_learning_rate": 1e-5,
        "lr_scheduler": "cosine_with_warmup",
        "lr_warmup_steps": 500,
        "lr_num_cycles": 1
    }
    
  2. 梯度累积策略

    • 根据GPU内存调整batch size
    • 使用梯度累积模拟大批次训练
  3. 正则化技术

    • 权重衰减(weight decay)
    • 梯度裁剪(gradient clipping)
    • EMA(指数移动平均)

常见问题解答

Q1: BrushData与其他修复数据集有何不同?

A: BrushData专门为扩散模型设计,提供:

  • 高质量的文本-图像-掩码三元组
  • 真实世界的物体分割掩码
  • 大规模多样化训练样本
  • 针对修复任务的优化标注

Q2: 如何处理自定义数据格式?

A: 您可以实现自定义数据加载器:

class CustomDataLoader:
    def __init__(self, data_source):
        self.data_source = data_source
    
    def to_brushnet_format(self):
        # 实现格式转换逻辑
        return {
            "image": self.load_image(),
            "mask": self.generate_mask(),
            "caption": self.generate_caption()
        }

【免费下载链接】BrushNet The official implementation of paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion" 【免费下载链接】BrushNet 项目地址: https://gitcode.com/GitHub_Trending/br/BrushNet

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐