本文覆盖:论文解读,代码解读以及实操细节。

背景介绍

        RT-DETR(Real-Time DEtection TRansformer)是一个实时的基于Transformers的目标检测模型,首先是

Transformers(论文:Attention Is All You Need),作为基础知识本文略。

DETR          (论文:End-to-End Object Detection with Transformers),见我上一篇博客。

RT-DETR    (论文:DETRsBeat YOLOsonReal-time Object Detection)。

Transformers是最基础的架构,DETR是将它用到了目标检测,实现了真正端到端的训练和推理,消除了YOLO对锚点、网格中心作为初始猜测以及NMS后处理的依赖。然而其计算成本过高限制了实际应用和无需NMS的优势。RT-DETR则是对其改进,使其性能在当时超越YOLO。5分钟快速理解YOLO-CSDN博客

DETR

       将Transformers用于目标检测的开山之作:DETR_目标检测transformer开源-CSDN博客

RT-DETR

      核心逻辑同DETR,主要是网络架构不同。能理解网络 每一层的设计理念和输入输出维度变换就能理解模型。

        下面的维度为了方便理解我用具体数字替代。

  1. 输入:3*640*640   解释:代表一张分辨率640*640的彩色图片。
  2. 通过Backbone输出:低层特征S3(512*80*80) 中层特征S4(1024*40*40) 高层特征S5(2048*20*20) 解释:Backbone基于ResNet-50,输出多尺度特征图。底层特征 空间分辨率大,通道数低,代表颜色纹理等浅层信息;高层特征 空间分辨率小,通道数高,代表物体部件类别关系等高层语义信息。
  3. S5通过AIFI模块输出F5:400*2048  解释:AIFI模块就是一个注意力机制模块,不过QKV等于展平的S5(400*2048),用QKV计算输出F5。对S5特征层执行尺度内交互。这样设计的原因在于,对语义概念更丰富的高层特征执行自注意力操作,能够捕捉概念实体之间的关联,从而为后续模块实现目标定位与识别提供便利。而低层特征因缺乏语义概念,且其尺度内交互存在与高层特征交互重复、混淆的风险,因此无需对低层特征进行尺度内交互。
  4. S3,S4,F5经过CCFF输出:256*80*80  解释:CCFF就是 基于 CNN 的跨尺度特征融合模块。内部经过一系列操作(上采样、卷积等)进行特征融合(详见论文)。
  5. CCFF的输出展成序列形式(6400*256)进入 不确定性目标选择模块 后输出:(300*256)  解释:这个模块对应DETR的目标查询,不过这里充分利用了图像特征。输出的300可以理解为300个预测框。(详见论文,其中核心是将不确定性融入损失函数,不确定性用于绑定定位和分类预测能力)
  6. 前面输出的高质量初始查询(300*256)进入 解码器与预测头模块 输出 (300,80)(300,4)解释:第一个输出代表每个查询的80类概率分布(coco有80类);第二个输出代表每个查询的边界框坐标(cx, cy, w, h)。先通过6层Transformer解码器(自注意力 + 交叉注意力)每层输出维度不变(300,256);最后分别通过分类头(300,80)和回归头(300,4)。

源码

lyuwenyu/RT-DETR: [CVPR 2024] Official RT-DETR (RTDETR paddle pytorch), Real-Time DEtection TRansformer, DETRs Beat YOLOs on Real-time Object Detection. 🔥 🔥 🔥

下载code,解压后得到,我们主要用rtdetr_pytorch;其中rtdetr_paddle是基于百度的架构Paddle构建的。

内含configs,包含数据集配置文件,和模型配置文件:告诉你种类数,训练数据放哪,模型架构信息等等。

内含src,模型源码。

内含tool,提供导出模型、推理、训练相关代码。

我们的项目很简单:用它的预训练模型在我们特定任务的数据集上再训练一下(微调)。大家可以根据自己的需要丰富该项目。

数据集

真实值(Ground Truth)需要一张图片,以及对应物体的框和种类信息。所以要找标注好的数据集。这里推荐网站:https://universe.roboflow.com/(可能需要科学上网,然后免费注册就可以了。注意下载COCO Json格式,因为代码里的yaml配置文件用的是Json。)

这里,我使用的是一个检测个人防护设备的数据集。大家嫌麻烦可以从网盘下载该数据集:

https://www.123865.com/s/68YFvd-B7ab?pwd=XXDZ#

下载并解压到configs/dataset;之后我们修改源码给的yaml配置文件:把类型改为14,remap=Flase,修改文件位置,帮图片(img)和标注文件(ann)放入对应位置。batch_size根据内存大小适当修改,num_workers使用Windows 系统则设为0。

至此我们的数据集准备完毕。

训练

下载依赖:注意版本,不一致版本可能引发各种不兼容问题。

如果src标红,我们将rtdetr_pytorch目录标记为源代码根目录。

打开train:我们不需要多进程和种子直接注释掉。

可以看到,代码已经给了我们微调接口。

接下来我们打开readme下载想要的预训练模型,然后放到指定文件夹。最后我们直接在终端输入:

python tools/train.py -c configs/rtdetr/rtdetr_r18vd_6x_coco.yml -t pre_model/rtdetr_r18vd_dec3_6x_coco_from_paddle.pth

可以看到模型就跑起来了:大概需要30多个小时才能跑完。

结语

如果我的博客对您有帮助,请点赞收藏。谢谢!

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐