以下是为图像分类(CV)和文本分类(NLP)任务设计的定制化代码示例,包含可直接运行的代码、关键参数说明和避坑指南:


一、计算机视觉(CV)图像分类实战

1. 完整训练流程(PyTorch)
import torch
from torch.utils.data import WeightedRandomSampler
from torchvision import transforms, models
import albumentations as A
from albumentations.pytorch import ToTensorV2

# ----------------------
# 1. 不均衡数据处理
# ----------------------
# 计算类别权重(假设class_counts是各类样本数)
class_counts = [1000, 200, 50]  # 假设3个类别
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
samples_weights = weights[train_labels]  # train_labels是每个样本的类别索引

# 创建加权采样器
sampler = WeightedRandomSampler(
    weights=samples_weights,
    num_samples=len(samples_weights),  # 通常设为多数类样本数的2-3倍
    replacement=True
)

# ----------------------
# 2. 高级数据增强
# ----------------------
train_transform = A.Compose([
    A.RandomResizedCrop(224, 224),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.CoarseDropout(max_holes=8, max_height=32, max_width=32, fill_value=0, p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# ----------------------
# 3. 模型与损失函数
# ----------------------
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features

# 修改分类头
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(0.5),
    torch.nn.Linear(num_ftrs, len(class_counts))
)

# Focal Loss实现(支持多分类)
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=None, gamma=2):
        super().__init__()
        self.alpha = alpha  # 可传入各类别的权重向量
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = torch.nn.functional.cross_entropy(
            inputs, targets, reduction='none', weight=self.alpha
        )
        pt = torch.exp(-ce_loss)
        loss = (1 - pt) ** self.gamma * ce_loss
        return loss.mean()

# 假设类别权重为倒数
focal_loss = FocalLoss(alpha=weights)

# ----------------------
# 4. 训练循环关键代码
# ----------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

for epoch in range(100):
    model.train()
    for images, labels in train_loader:
        outputs = model(images)
        loss = focal_loss(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 动态阈值调整(验证阶段)
    model.eval()
    with torch.no_grad():
        all_probs = []
        all_labels = []
        for images, labels in val_loader:
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            all_probs.append(probs)
            all_labels.append(labels)
        
        # 寻找最优阈值(最大化F1)
        from sklearn.metrics import f1_score
        probs = torch.cat(all_probs)
        labels = torch.cat(all_labels)
        best_threshold = 0.5
        best_f1 = 0
        for thresh in torch.arange(0.3, 0.7, 0.05):
            preds = (probs[:, 1] > thresh).long()  # 假设关注第二类(少数类)
            f1 = f1_score(labels == 1, preds)  # 二分类简化示例
            if f1 > best_f1:
                best_f1 = f1
                best_threshold = thresh
        print(f"Epoch {epoch}: Optimal threshold = {best_threshold:.2f}")
关键技巧说明
  1. 加权采样器:确保每个batch中少数类样本被更频繁选中
  2. CoarseDropout增强:模拟遮挡,提升模型鲁棒性
  3. 动态阈值调整:根据验证集表现自动优化分类阈值
  4. AdamW优化器:自带权重衰减,防止过拟合

二、自然语言处理(NLP)文本分类实战

1. BERT微调完整代码(HuggingFace Transformers)
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import Dataset, DataLoader
import nlpaug.augmenter.word as naw

# ----------------------
# 1. 不均衡数据增强
# ----------------------
class TextAugmenter:
    def __init__(self):
        self.synonym_aug = naw.SynonymAug(aug_src='wordnet')
        self.context_aug = naw.ContextualWordEmbsAug(
            model_path='bert-base-uncased', action="substitute"
        )
    
    def augment(self, text, num_aug=2):
        """生成增强样本"""
        augmented = [self.synonym_aug.augment(text), 
                    self.context_aug.augment(text)]
        return list(set(augmented))[:num_aug]  # 去重

# ----------------------
# 2. 带类别权重的BERT模型
# ----------------------
class WeightedBert(BertForSequenceClassification):
    def __init__(self, config, class_weights):
        super().__init__(config)
        self.class_weights = torch.tensor(class_weights)
    
    def forward(self, input_ids=None, attention_mask=None, labels=None):
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss(
                weight=self.class_weights.to(input_ids.device)
            )
            loss = loss_fct(outputs.logits.view(-1, self.num_labels), labels.view(-1))
            outputs.loss = loss
        return outputs

# ----------------------
# 3. 数据加载与训练
# ----------------------
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 假设数据示例
texts = ["This product is great!", "Terrible experience..."]  # 输入文本
labels = [1, 0]  # 假设0为少数类(负面)
class_counts = [100, 1000]  # 负面100条,正面1000条
class_weights = [1./(count/sum(class_counts)) for count in class_counts]  # 自动计算权重

# 创建带增强的数据集
class AugmentedDataset(Dataset):
    def __init__(self, texts, labels, augmenter=None, max_aug=2):
        self.texts = texts
        self.labels = labels
        self.augmenter = augmenter
        self.max_aug = max_aug
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # 只对少数类做增强
        if label == 0 and self.augmenter:  # 假设0是少数类
            aug_texts = self.augmenter.augment(text, self.max_aug)
            return {
                'text': aug_texts[0], 
                'label': label,
                'augmented': aug_texts[1:]  # 存储额外增强样本
            }
        else:
            return {'text': text, 'label': label}

# 初始化模型
model = WeightedBert.from_pretrained(
    'bert-base-uncased',
    num_labels=2,
    class_weights=class_weights
)

# 自定义collate_fn处理增强数据
def collate_fn(batch):
    main_texts = [item['text'] for item in batch]
    main_labels = [item['label'] for item in batch]
    
    # 添加增强样本
    for item in batch:
        if 'augmented' in item:
            main_texts.extend(item['augmented'])
            main_labels.extend([item['label']]*len(item['augmented']))
    
    # Tokenize
    tokens = tokenizer(
        main_texts, 
        padding=True, 
        truncation=True, 
        max_length=128,
        return_tensors='pt'
    )
    return {
        'input_ids': tokens['input_ids'],
        'attention_mask': tokens['attention_mask'],
        'labels': torch.tensor(main_labels)
    }

# 创建数据加载器
dataset = AugmentedDataset(texts, labels, augmenter=TextAugmenter())
dataloader = DataLoader(dataset, batch_size=16, collate_fn=collate_fn, shuffle=True)

# ----------------------
# 4. 带权重更新的训练循环
# ----------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
model.train()

for epoch in range(5):
    total_loss = 0
    for batch in dataloader:
        outputs = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
    print(f"Epoch {epoch}: Loss = {total_loss/len(dataloader):.4f}")

# ----------------------
# 5. 推理时动态阈值调整
# ----------------------
def predict_with_threshold(model, text, threshold=0.5):
    tokens = tokenizer(
        text, 
        return_tensors='pt', 
        padding=True, 
        truncation=True,
        max_length=128
    )
    with torch.no_grad():
        logits = model(**tokens).logits
    probs = torch.softmax(logits, dim=1)
    
    # 对少数类(类别0)使用更低阈值
    if probs[0][0] > threshold:  # 假设类别0是少数类
        return 0
    else:
        return torch.argmax(probs).item()
关键创新点
  1. 动态增强注入:仅在训练时对少数类样本进行实时增强
  2. 增强样本缓存:在batch内直接扩展增强样本,避免预生成占用存储
  3. 权重融合模型:将类别权重直接嵌入BERT模型结构
  4. 实时阈值决策:在推理API中支持动态调整分类边界

三、避坑指南

问题现象 解决方案
过采样后过拟合 添加MixUp/CutMix等样本混合增强,而非单纯复制
验证集指标虚高 确保增强只在训练集应用,验证集保持原始分布
小样本增强后语义失真 使用基于BERT的上下文增强(ContextualWordEmbsAug)保持语义连贯性
类别权重导致训练不稳定 对权重进行归一化:weights = weights / weights.max()
动态阈值搜索耗时 使用贝叶斯优化库(如Optuna)加速搜索过程

四、性能优化技巧

CV任务加速
# 使用混合精度训练(提速30%+)
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
with autocast():
    outputs = model(images)
    loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
NLP任务优化
# 使用梯度累积(模拟更大batch size)
accum_steps = 4
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
    loss = model(**batch).loss
    loss = loss / accum_steps
    loss.backward()
    if (i+1) % accum_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

五、模型解释性工具

CV特征可视化
# 使用Captum库可视化关键区域
from captum.attr import IntegratedGradients

ig = IntegratedGradients(model)
attributions = ig.attribute(
    images, 
    target=target_class,
    n_steps=50
)
heatmap = attributions.mean(dim=1).cpu().numpy()
NLP注意力分析
# 可视化BERT注意力
from bertviz import head_view

head_view(
    model=model,
    tokenizer=tokenizer,
    sentence="The movie was terrible but acting was good",
    layer=4,  # 观察第4层注意力
    heads=[3]  # 查看第3个注意力头
)

以上代码均可直接复制到Jupyter Notebook或Python脚本中运行,需要根据实际数据路径和任务需求调整数据加载部分。建议从小的子模块开始测试(如单独测试数据增强效果),再逐步整合到完整训练流程中。

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐