贴一个经过注释的MedCLIP_demo文件,帮助理解medclip模型的搭建


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
#dataset:pytorch中的自定义数据集需要继承torch.utils.data
from transformers import AutoTokenizer,AutoModel  #自动加载预训练模型对应的tokenizer,例如clinicalbert
from PIL import Image   #用于打开图像文件
import pandas as pd    #用于读取csv文件
import os

from torchvision import transforms,models


#step1:数据集定义
class MIMICCXRDataset(Dataset):
    """
    pytorch 的Dataloader会用这个类来批量取数据,做训练和测试

    """
    def __init__(self, csv_file,img_root,tokenizer,max_len=128, transform=None):
        self.df = pd.read_csv(csv_file)  #csv文件路径,里面包含图像路径和文本报告路径
        self.img_root = img_root   #图像的根目录
        self.tokenizer = tokenizer   #文本的tokenizer 这里是clinicalbert
        self.max_len = max_len    #
        self.transform = transform    #图像预处理,例如缩放、归一化等


    def __len__(self):
        return len(self.df)   #返回的是数据集的长度,pytorch需要知道数据集的大小才能做训练

    def __getitem__(self, idx):   #获取某一条数据,当你想要第idx个数据的时候,会调用这个函数
        row = self.df.iloc[idx]    #取csv表格的第idx行,对应了一张图片和一份报告

        #图像处理
        img_path = os.path.join(self.img_root,row["image"])  #拼接完整的图片路径
        img = Image.open(img_path).convert('RGB')   #打开图片文件,并把图片文件改成RGB格式
        if self.transform:
            img = self.transform(img)   #如果你提供了图像处理方法,就执行
        #图像处理的输出是一个数字矩阵

        #文本处理
        report = row["report"]    #取csv表格中的文字报告
        text_inputs = self.tokenizer(report, padding="max_length", #如果文字太短就补齐到最大长度
                                     truncation=True,#如果文字太长,就截断到最大长度
                                     max_length=self.max_len,
                                     return_tensors="pt")#返回pytorch张量格式,模型可以直接使用
        #文本处理的输出是一个字典,通常包括input_ids文字对应的数字序列和attention_mask:哪些位置是有效的,哪些位置是padding

        return img,text_inputs  #图像张量和文本张量

#step2:模型定义
class MedCLIP(nn.Module):
    """
    把图像和文本都映射到同一个空间向量,方便后续计算相似度或做多模态任务
    """

    def __init__(self, img_dim=512,txt_dim=768,embed_dim=512):
        """
        :param img_dim: 图像编码器输出的特征维度
        :param txt_dim:  文本编码器输出的特征维度
        :param embed_dim:  文本和图像最后映射到的共同向量维度(统一维度,便于对比学习)
        """
        super().__init__()  #调用父类nn.Module的初始化方法
        #========================
        #图像编码器:用的是resnet50
        self.img_encoder = models.resnet50(pretrained=True)  #使用resnet50模型,经典的图像卷积神经网络
        #pretrained=true表示使用在imagenet预训练的权重,已经学会基本的图像特征
        self.img_encoder.fc = nn.Linear(self.img_encoder.fc.in_features,img_dim)
        #resnet默认输出1000类,这里改成输出img_dim维,nn.Linear是全连接层,相当于把resnet的输出映射到我们想要的维度
        #图像投影
        self.img_proj = nn.Linear(img_dim,embed_dim)


        #===========================
        #文本编码器:用的是clinicalBERT
        self.txt_encoder = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        #使用的是clinicalbert模型,它是专门处理医学文本的bert模型
        self.text_proj = nn.Linear(txt_dim,embed_dim)
        #文本编码器输出的是txt_dim768维,这里用全连接层映射到和图像相同的维度
        #这样图像和向量可以在同一个向量空间中比较维度


    #=========================
    #向前传播方法
    def forward(self, images, text_inputs):
        """
        forward 是pytorch模型的核心函数,定义模型的计算流程
        输入:
        :param images: 图像张量[batch_size,3,H,W]
        :param text_inputs: 文本张量(字典,包括input_ids,attention_mask)

        """
        #图像特征
        img_feat = self.img_encoder(images)  #把图片送进resnet提取特征,输出[batch_size,img_dim]
        img_feat = self.img_proj(img_feat)   #映射到统一的向量空间
        img_feat = nn.functional.normalize(img_feat,dim=-1)  #向量归一化,把每个特征向量长度变成1,方便后面计算余弦相似度

        #文本特征
        outputs = self.txt_encoder(**text_inputs)  #把文本送进clinicalbert,返回字典outputs,包含每个token的向量
        txt_feat = outputs.last_hidden_state[:,0,:]  #取每条文本的token向量(BERT)用它表示整段文本
        txt_feat = self.text_proj(txt_feat)  #映射到和图像相同的向量空间
        txt_feat = nn.functional.normalize(txt_feat,dim=-1)  #向量归一化,长度为0

        return img_feat,txt_feat  #返回文本向量:batch_size*embed_dim的矩阵  图像向量:batch_size*embed_dim的矩阵

#====================
#step3:对比损失学习
"""
计算图像特征和文本特征之间的对比损失,
目的是   对应的图像和文本特征更相似   不对应的图像和文本特征更不相似


"""
def contrastive_loss(img_feat,txt_feat,temperature=0.07):
    #temperature:温度系数,用来缩放相似度,默认是0.07.这个值越小,相似度差异越明显
    logits = img_feat @ txt_feat.T / temperature  # @是矩阵相乘的符号
    labels = torch.arange(len(logits)).to(img_feat.device)
    #torch.arange(len(logits))生成一个整数序列[0,1.....B-1]
    #to(img_feat.device)把上面那个序列放到和img_feat相同的设备CPU/GPU
    loss_i = nn.CrossEntropyLoss()(logits,labels)  #以图片为anchor
    #nn.CrossEntropyLoss是pytorch的交叉熵损失函数
    #logits是预测相似度分数,labels是真实对应关系
    loss_t = nn.CrossEntropyLoss()(logits.T,labels)  #以文本为anchor
    #对称的训练图像和文本可以提高模型匹配效果

    return (loss_i+loss_t)/2  #返回图片段loss和文本段loss的平均值,保证模型对两个方向都能学习对齐





#====================
#step4:训练循环
def train(model,dataloader,optimizer,device,epochs=10):
    """

    :param model:  一个pytorch的nn.Module(我们的模型)
    :param dataloader: torch.utils.data.DataLoader,每次迭代返回一个batch(我的训练数据)
    :param optimizer: torch.optim.Adam()负责根据梯度更新模型参数
    :param device: 设备 torch.device('cuda')
    :param epochs: 训练轮数,默认是10
    :return:
    """

    def maybe_squeezedim1(tensor):
        if tensor.ndim >= 2 and tensor.size(1) == 1:
            return tensor.squeeze(1)
        return tensor

    model.to(device)  #把模型的参数和缓冲(如batchnorm的running stats)移动到指定的设备
    model.train()
    #告诉pytorch进入训练模式,该模式下dropout会启用,batchnorm会使用batch的均值方差(而不是全局的)
    #在验证/推理时要用model.eval
    for epoch in range(epochs): #外层循环,按照epoch训练
        total_loss = 0 #用来累加当前epoch中的每个batch中的loss,以便最后打印平均loss
        for img, text_inputs in dataloader:
            #内层循环:遍历dataloader,每个迭代返回一个batch,包含两个部分img和text_inputs
            #常见的img的形状[batch,channels,height,weight]
            #text_inputs是一个字典{'input_ids':tensor,'attention_mask':tensor}.
            # 这些tensor的形状通常是[B,seq_len],但有时因为处理方式可能是[B,1,seq_len],所以下面有squeeze操作

            img = img.to(device)


            text_inputs = {k: maybe_squeezedim1(v).to(device) for k, v in text_inputs.items()}
             #向前传播
            img_feat,txt_feat = model(img,text_inputs)
            """
            #把图像和文本送进模型做向前传播。
            模型的forward应该返回两个东西:图像特征img_feat和文本特征txt_feat
            """
            loss = contrastive_loss(img_feat,txt_feat)
            """
            计算对比损失,把图像向量和文本向量做正负样本进行训练。常见实现的是InfoNCE
            计算[B,B]的相似度矩阵(每个样本与batch内其它样本的相似度),对角线为正样本,然后用交叉熵损失
            """
            #反向传播与优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            """
            三步标准训练流程:
            1.optimizer.zero_grad():把模型参数上的梯度清零(默认会累加梯度,如果不清零会把前一次的梯度加到当前梯度上
            2.loss.backward():反向传播,计算每个参数的梯度(基于当前的loss)
            3.optimizer.step():用计算得到的梯度更新模型参数(例如adam会根据学习率更新参数
            """

            total_loss += loss.item()
            """
            把0维度的tensor转成python浮点数(detach+转成float)并加到total_loss中用于统计
            """
        print(f"Epoch {epoch+1}, loss: {total_loss/len(dataloader):.4f}")

#=============================
#step5:运行示例
if __name__ == "__main__":  #只有这个文件直接运行的时候,下面代码才会运行,如果这个文件被别人当做模块导入,下面内容不会直接运行
    #图像处理
    transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        #把图像从PIL/numpy数组格式转换成tensor(pytorch能处理的数据格式)
        #同时像素值会自动缩放到[0,1]的范围
        transforms.Normalize([0.5],[0.5])
        #把像素值进一步归一化映射到[-1,1]的范围,有助于加快训练,稳定梯度
    ])  #torchvision.transforms是专门用于处理图像的

    #文本处理
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

    #数据集加载
    dataset = MIMICCXRDataset(csv_file="/mnt/data/dataset/train.csv",
                              img_root="/mnt/data/dataset/IMG",
                              tokenizer=tokenizer,
                              transform=transform)


    dataloader = DataLoader(dataset,batch_size=32,shuffle=True,num_workers=4)
    """
    dataloader是pytorch提供的工具,用来批量加载数据
    batch_size=32:一次取32个样本(提高训练效率)
    shuffle=True:每个epoch都随机打乱顺序(防止模型记忆顺序)
    num_workers=4:开四个进程同时读数据,加快速度
    """


    #模型和优化器
    model = MedCLIP(img_dim=512,txt_dim=768,embed_dim=512)

    optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
    """
    Adam优化器:训练时更新模型参数的方法
    lr=1e-4:学习率,决定参数更新的步子大小
    """
    train(model,dataloader,optimizer,device="cuda",epochs = 5)
Logo

中国智能体开发者社区,聚焦智能体与大模型开发,提供前沿资讯、实用工具链、开源项目及行业案例。通过技术沙龙、开发者大赛等活动,促进经验交流与协作,助力开发者快速构建创新智能应用。

更多推荐