Deepseek 针对您的具体任务(如正在处理的图像分类或文本分类项目)提供定制化代码示例或者需要解释某些技术细节的实现原理
以上代码均可直接复制到Jupyter Notebook或Python脚本中运行,需要根据实际数据路径和任务需求调整数据加载部分。建议从小的子模块开始测试(如单独测试数据增强效果),再逐步整合到完整训练流程中。以下是为图像分类(CV)和文本分类(NLP)任务设计的。
·
以下是为图像分类(CV)和文本分类(NLP)任务设计的定制化代码示例,包含可直接运行的代码、关键参数说明和避坑指南:
一、计算机视觉(CV)图像分类实战
1. 完整训练流程(PyTorch)
import torch
from torch.utils.data import WeightedRandomSampler
from torchvision import transforms, models
import albumentations as A
from albumentations.pytorch import ToTensorV2
# ----------------------
# 1. 不均衡数据处理
# ----------------------
# 计算类别权重(假设class_counts是各类样本数)
class_counts = [1000, 200, 50] # 假设3个类别
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
samples_weights = weights[train_labels] # train_labels是每个样本的类别索引
# 创建加权采样器
sampler = WeightedRandomSampler(
weights=samples_weights,
num_samples=len(samples_weights), # 通常设为多数类样本数的2-3倍
replacement=True
)
# ----------------------
# 2. 高级数据增强
# ----------------------
train_transform = A.Compose([
A.RandomResizedCrop(224, 224),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, p=0.5),
A.RandomBrightnessContrast(p=0.3),
A.CoarseDropout(max_holes=8, max_height=32, max_width=32, fill_value=0, p=0.5),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
# ----------------------
# 3. 模型与损失函数
# ----------------------
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
# 修改分类头
model.fc = torch.nn.Sequential(
torch.nn.Dropout(0.5),
torch.nn.Linear(num_ftrs, len(class_counts))
)
# Focal Loss实现(支持多分类)
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=None, gamma=2):
super().__init__()
self.alpha = alpha # 可传入各类别的权重向量
self.gamma = gamma
def forward(self, inputs, targets):
ce_loss = torch.nn.functional.cross_entropy(
inputs, targets, reduction='none', weight=self.alpha
)
pt = torch.exp(-ce_loss)
loss = (1 - pt) ** self.gamma * ce_loss
return loss.mean()
# 假设类别权重为倒数
focal_loss = FocalLoss(alpha=weights)
# ----------------------
# 4. 训练循环关键代码
# ----------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
for epoch in range(100):
model.train()
for images, labels in train_loader:
outputs = model(images)
loss = focal_loss(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 动态阈值调整(验证阶段)
model.eval()
with torch.no_grad():
all_probs = []
all_labels = []
for images, labels in val_loader:
outputs = model(images)
probs = torch.softmax(outputs, dim=1)
all_probs.append(probs)
all_labels.append(labels)
# 寻找最优阈值(最大化F1)
from sklearn.metrics import f1_score
probs = torch.cat(all_probs)
labels = torch.cat(all_labels)
best_threshold = 0.5
best_f1 = 0
for thresh in torch.arange(0.3, 0.7, 0.05):
preds = (probs[:, 1] > thresh).long() # 假设关注第二类(少数类)
f1 = f1_score(labels == 1, preds) # 二分类简化示例
if f1 > best_f1:
best_f1 = f1
best_threshold = thresh
print(f"Epoch {epoch}: Optimal threshold = {best_threshold:.2f}")
关键技巧说明
- 加权采样器:确保每个batch中少数类样本被更频繁选中
- CoarseDropout增强:模拟遮挡,提升模型鲁棒性
- 动态阈值调整:根据验证集表现自动优化分类阈值
- AdamW优化器:自带权重衰减,防止过拟合
二、自然语言处理(NLP)文本分类实战
1. BERT微调完整代码(HuggingFace Transformers)
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import Dataset, DataLoader
import nlpaug.augmenter.word as naw
# ----------------------
# 1. 不均衡数据增强
# ----------------------
class TextAugmenter:
def __init__(self):
self.synonym_aug = naw.SynonymAug(aug_src='wordnet')
self.context_aug = naw.ContextualWordEmbsAug(
model_path='bert-base-uncased', action="substitute"
)
def augment(self, text, num_aug=2):
"""生成增强样本"""
augmented = [self.synonym_aug.augment(text),
self.context_aug.augment(text)]
return list(set(augmented))[:num_aug] # 去重
# ----------------------
# 2. 带类别权重的BERT模型
# ----------------------
class WeightedBert(BertForSequenceClassification):
def __init__(self, config, class_weights):
super().__init__(config)
self.class_weights = torch.tensor(class_weights)
def forward(self, input_ids=None, attention_mask=None, labels=None):
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
if labels is not None:
loss_fct = torch.nn.CrossEntropyLoss(
weight=self.class_weights.to(input_ids.device)
)
loss = loss_fct(outputs.logits.view(-1, self.num_labels), labels.view(-1))
outputs.loss = loss
return outputs
# ----------------------
# 3. 数据加载与训练
# ----------------------
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# 假设数据示例
texts = ["This product is great!", "Terrible experience..."] # 输入文本
labels = [1, 0] # 假设0为少数类(负面)
class_counts = [100, 1000] # 负面100条,正面1000条
class_weights = [1./(count/sum(class_counts)) for count in class_counts] # 自动计算权重
# 创建带增强的数据集
class AugmentedDataset(Dataset):
def __init__(self, texts, labels, augmenter=None, max_aug=2):
self.texts = texts
self.labels = labels
self.augmenter = augmenter
self.max_aug = max_aug
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
# 只对少数类做增强
if label == 0 and self.augmenter: # 假设0是少数类
aug_texts = self.augmenter.augment(text, self.max_aug)
return {
'text': aug_texts[0],
'label': label,
'augmented': aug_texts[1:] # 存储额外增强样本
}
else:
return {'text': text, 'label': label}
# 初始化模型
model = WeightedBert.from_pretrained(
'bert-base-uncased',
num_labels=2,
class_weights=class_weights
)
# 自定义collate_fn处理增强数据
def collate_fn(batch):
main_texts = [item['text'] for item in batch]
main_labels = [item['label'] for item in batch]
# 添加增强样本
for item in batch:
if 'augmented' in item:
main_texts.extend(item['augmented'])
main_labels.extend([item['label']]*len(item['augmented']))
# Tokenize
tokens = tokenizer(
main_texts,
padding=True,
truncation=True,
max_length=128,
return_tensors='pt'
)
return {
'input_ids': tokens['input_ids'],
'attention_mask': tokens['attention_mask'],
'labels': torch.tensor(main_labels)
}
# 创建数据加载器
dataset = AugmentedDataset(texts, labels, augmenter=TextAugmenter())
dataloader = DataLoader(dataset, batch_size=16, collate_fn=collate_fn, shuffle=True)
# ----------------------
# 4. 带权重更新的训练循环
# ----------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
model.train()
for epoch in range(5):
total_loss = 0
for batch in dataloader:
outputs = model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels']
)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
print(f"Epoch {epoch}: Loss = {total_loss/len(dataloader):.4f}")
# ----------------------
# 5. 推理时动态阈值调整
# ----------------------
def predict_with_threshold(model, text, threshold=0.5):
tokens = tokenizer(
text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=128
)
with torch.no_grad():
logits = model(**tokens).logits
probs = torch.softmax(logits, dim=1)
# 对少数类(类别0)使用更低阈值
if probs[0][0] > threshold: # 假设类别0是少数类
return 0
else:
return torch.argmax(probs).item()
关键创新点
- 动态增强注入:仅在训练时对少数类样本进行实时增强
- 增强样本缓存:在batch内直接扩展增强样本,避免预生成占用存储
- 权重融合模型:将类别权重直接嵌入BERT模型结构
- 实时阈值决策:在推理API中支持动态调整分类边界
三、避坑指南
| 问题现象 | 解决方案 |
|---|---|
| 过采样后过拟合 | 添加MixUp/CutMix等样本混合增强,而非单纯复制 |
| 验证集指标虚高 | 确保增强只在训练集应用,验证集保持原始分布 |
| 小样本增强后语义失真 | 使用基于BERT的上下文增强(ContextualWordEmbsAug)保持语义连贯性 |
| 类别权重导致训练不稳定 | 对权重进行归一化:weights = weights / weights.max() |
| 动态阈值搜索耗时 | 使用贝叶斯优化库(如Optuna)加速搜索过程 |
四、性能优化技巧
CV任务加速
# 使用混合精度训练(提速30%+)
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
outputs = model(images)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
NLP任务优化
# 使用梯度累积(模拟更大batch size)
accum_steps = 4
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
loss = model(**batch).loss
loss = loss / accum_steps
loss.backward()
if (i+1) % accum_steps == 0:
optimizer.step()
optimizer.zero_grad()
五、模型解释性工具
CV特征可视化
# 使用Captum库可视化关键区域
from captum.attr import IntegratedGradients
ig = IntegratedGradients(model)
attributions = ig.attribute(
images,
target=target_class,
n_steps=50
)
heatmap = attributions.mean(dim=1).cpu().numpy()
NLP注意力分析
# 可视化BERT注意力
from bertviz import head_view
head_view(
model=model,
tokenizer=tokenizer,
sentence="The movie was terrible but acting was good",
layer=4, # 观察第4层注意力
heads=[3] # 查看第3个注意力头
)
以上代码均可直接复制到Jupyter Notebook或Python脚本中运行,需要根据实际数据路径和任务需求调整数据加载部分。建议从小的子模块开始测试(如单独测试数据增强效果),再逐步整合到完整训练流程中。
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)