MedCLIP模型学习记录
贴一个经过注释的MedCLIP_demo文件,帮助理解medclip模型的搭建。
·
贴一个经过注释的MedCLIP_demo文件,帮助理解medclip模型的搭建
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
#dataset:pytorch中的自定义数据集需要继承torch.utils.data
from transformers import AutoTokenizer,AutoModel #自动加载预训练模型对应的tokenizer,例如clinicalbert
from PIL import Image #用于打开图像文件
import pandas as pd #用于读取csv文件
import os
from torchvision import transforms,models
#step1:数据集定义
class MIMICCXRDataset(Dataset):
"""
pytorch 的Dataloader会用这个类来批量取数据,做训练和测试
"""
def __init__(self, csv_file,img_root,tokenizer,max_len=128, transform=None):
self.df = pd.read_csv(csv_file) #csv文件路径,里面包含图像路径和文本报告路径
self.img_root = img_root #图像的根目录
self.tokenizer = tokenizer #文本的tokenizer 这里是clinicalbert
self.max_len = max_len #
self.transform = transform #图像预处理,例如缩放、归一化等
def __len__(self):
return len(self.df) #返回的是数据集的长度,pytorch需要知道数据集的大小才能做训练
def __getitem__(self, idx): #获取某一条数据,当你想要第idx个数据的时候,会调用这个函数
row = self.df.iloc[idx] #取csv表格的第idx行,对应了一张图片和一份报告
#图像处理
img_path = os.path.join(self.img_root,row["image"]) #拼接完整的图片路径
img = Image.open(img_path).convert('RGB') #打开图片文件,并把图片文件改成RGB格式
if self.transform:
img = self.transform(img) #如果你提供了图像处理方法,就执行
#图像处理的输出是一个数字矩阵
#文本处理
report = row["report"] #取csv表格中的文字报告
text_inputs = self.tokenizer(report, padding="max_length", #如果文字太短就补齐到最大长度
truncation=True,#如果文字太长,就截断到最大长度
max_length=self.max_len,
return_tensors="pt")#返回pytorch张量格式,模型可以直接使用
#文本处理的输出是一个字典,通常包括input_ids文字对应的数字序列和attention_mask:哪些位置是有效的,哪些位置是padding
return img,text_inputs #图像张量和文本张量
#step2:模型定义
class MedCLIP(nn.Module):
"""
把图像和文本都映射到同一个空间向量,方便后续计算相似度或做多模态任务
"""
def __init__(self, img_dim=512,txt_dim=768,embed_dim=512):
"""
:param img_dim: 图像编码器输出的特征维度
:param txt_dim: 文本编码器输出的特征维度
:param embed_dim: 文本和图像最后映射到的共同向量维度(统一维度,便于对比学习)
"""
super().__init__() #调用父类nn.Module的初始化方法
#========================
#图像编码器:用的是resnet50
self.img_encoder = models.resnet50(pretrained=True) #使用resnet50模型,经典的图像卷积神经网络
#pretrained=true表示使用在imagenet预训练的权重,已经学会基本的图像特征
self.img_encoder.fc = nn.Linear(self.img_encoder.fc.in_features,img_dim)
#resnet默认输出1000类,这里改成输出img_dim维,nn.Linear是全连接层,相当于把resnet的输出映射到我们想要的维度
#图像投影
self.img_proj = nn.Linear(img_dim,embed_dim)
#===========================
#文本编码器:用的是clinicalBERT
self.txt_encoder = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
#使用的是clinicalbert模型,它是专门处理医学文本的bert模型
self.text_proj = nn.Linear(txt_dim,embed_dim)
#文本编码器输出的是txt_dim768维,这里用全连接层映射到和图像相同的维度
#这样图像和向量可以在同一个向量空间中比较维度
#=========================
#向前传播方法
def forward(self, images, text_inputs):
"""
forward 是pytorch模型的核心函数,定义模型的计算流程
输入:
:param images: 图像张量[batch_size,3,H,W]
:param text_inputs: 文本张量(字典,包括input_ids,attention_mask)
"""
#图像特征
img_feat = self.img_encoder(images) #把图片送进resnet提取特征,输出[batch_size,img_dim]
img_feat = self.img_proj(img_feat) #映射到统一的向量空间
img_feat = nn.functional.normalize(img_feat,dim=-1) #向量归一化,把每个特征向量长度变成1,方便后面计算余弦相似度
#文本特征
outputs = self.txt_encoder(**text_inputs) #把文本送进clinicalbert,返回字典outputs,包含每个token的向量
txt_feat = outputs.last_hidden_state[:,0,:] #取每条文本的token向量(BERT)用它表示整段文本
txt_feat = self.text_proj(txt_feat) #映射到和图像相同的向量空间
txt_feat = nn.functional.normalize(txt_feat,dim=-1) #向量归一化,长度为0
return img_feat,txt_feat #返回文本向量:batch_size*embed_dim的矩阵 图像向量:batch_size*embed_dim的矩阵
#====================
#step3:对比损失学习
"""
计算图像特征和文本特征之间的对比损失,
目的是 对应的图像和文本特征更相似 不对应的图像和文本特征更不相似
"""
def contrastive_loss(img_feat,txt_feat,temperature=0.07):
#temperature:温度系数,用来缩放相似度,默认是0.07.这个值越小,相似度差异越明显
logits = img_feat @ txt_feat.T / temperature # @是矩阵相乘的符号
labels = torch.arange(len(logits)).to(img_feat.device)
#torch.arange(len(logits))生成一个整数序列[0,1.....B-1]
#to(img_feat.device)把上面那个序列放到和img_feat相同的设备CPU/GPU
loss_i = nn.CrossEntropyLoss()(logits,labels) #以图片为anchor
#nn.CrossEntropyLoss是pytorch的交叉熵损失函数
#logits是预测相似度分数,labels是真实对应关系
loss_t = nn.CrossEntropyLoss()(logits.T,labels) #以文本为anchor
#对称的训练图像和文本可以提高模型匹配效果
return (loss_i+loss_t)/2 #返回图片段loss和文本段loss的平均值,保证模型对两个方向都能学习对齐
#====================
#step4:训练循环
def train(model,dataloader,optimizer,device,epochs=10):
"""
:param model: 一个pytorch的nn.Module(我们的模型)
:param dataloader: torch.utils.data.DataLoader,每次迭代返回一个batch(我的训练数据)
:param optimizer: torch.optim.Adam()负责根据梯度更新模型参数
:param device: 设备 torch.device('cuda')
:param epochs: 训练轮数,默认是10
:return:
"""
def maybe_squeezedim1(tensor):
if tensor.ndim >= 2 and tensor.size(1) == 1:
return tensor.squeeze(1)
return tensor
model.to(device) #把模型的参数和缓冲(如batchnorm的running stats)移动到指定的设备
model.train()
#告诉pytorch进入训练模式,该模式下dropout会启用,batchnorm会使用batch的均值方差(而不是全局的)
#在验证/推理时要用model.eval
for epoch in range(epochs): #外层循环,按照epoch训练
total_loss = 0 #用来累加当前epoch中的每个batch中的loss,以便最后打印平均loss
for img, text_inputs in dataloader:
#内层循环:遍历dataloader,每个迭代返回一个batch,包含两个部分img和text_inputs
#常见的img的形状[batch,channels,height,weight]
#text_inputs是一个字典{'input_ids':tensor,'attention_mask':tensor}.
# 这些tensor的形状通常是[B,seq_len],但有时因为处理方式可能是[B,1,seq_len],所以下面有squeeze操作
img = img.to(device)
text_inputs = {k: maybe_squeezedim1(v).to(device) for k, v in text_inputs.items()}
#向前传播
img_feat,txt_feat = model(img,text_inputs)
"""
#把图像和文本送进模型做向前传播。
模型的forward应该返回两个东西:图像特征img_feat和文本特征txt_feat
"""
loss = contrastive_loss(img_feat,txt_feat)
"""
计算对比损失,把图像向量和文本向量做正负样本进行训练。常见实现的是InfoNCE
计算[B,B]的相似度矩阵(每个样本与batch内其它样本的相似度),对角线为正样本,然后用交叉熵损失
"""
#反向传播与优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
"""
三步标准训练流程:
1.optimizer.zero_grad():把模型参数上的梯度清零(默认会累加梯度,如果不清零会把前一次的梯度加到当前梯度上
2.loss.backward():反向传播,计算每个参数的梯度(基于当前的loss)
3.optimizer.step():用计算得到的梯度更新模型参数(例如adam会根据学习率更新参数
"""
total_loss += loss.item()
"""
把0维度的tensor转成python浮点数(detach+转成float)并加到total_loss中用于统计
"""
print(f"Epoch {epoch+1}, loss: {total_loss/len(dataloader):.4f}")
#=============================
#step5:运行示例
if __name__ == "__main__": #只有这个文件直接运行的时候,下面代码才会运行,如果这个文件被别人当做模块导入,下面内容不会直接运行
#图像处理
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
#把图像从PIL/numpy数组格式转换成tensor(pytorch能处理的数据格式)
#同时像素值会自动缩放到[0,1]的范围
transforms.Normalize([0.5],[0.5])
#把像素值进一步归一化映射到[-1,1]的范围,有助于加快训练,稳定梯度
]) #torchvision.transforms是专门用于处理图像的
#文本处理
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
#数据集加载
dataset = MIMICCXRDataset(csv_file="/mnt/data/dataset/train.csv",
img_root="/mnt/data/dataset/IMG",
tokenizer=tokenizer,
transform=transform)
dataloader = DataLoader(dataset,batch_size=32,shuffle=True,num_workers=4)
"""
dataloader是pytorch提供的工具,用来批量加载数据
batch_size=32:一次取32个样本(提高训练效率)
shuffle=True:每个epoch都随机打乱顺序(防止模型记忆顺序)
num_workers=4:开四个进程同时读数据,加快速度
"""
#模型和优化器
model = MedCLIP(img_dim=512,txt_dim=768,embed_dim=512)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
"""
Adam优化器:训练时更新模型参数的方法
lr=1e-4:学习率,决定参数更新的步子大小
"""
train(model,dataloader,optimizer,device="cuda",epochs = 5)
更多推荐
所有评论(0)