整体架构设计

OCR 的两阶段流程

输入图像
   ↓
【第一阶段:文本检测 TextDetector】
   检测文本在哪里(输出:文本框坐标)
   ↓
【第二阶段:文本识别 TextRecognizer】
   识别每个文本框的内容(输出:文本 + 置信度)
   ↓
输出:[(坐标, (文字, 置信度)), ...]

面试问题:为什么要**分两阶段ocr?**

  • 检测:找到文本区域(Where)
  • 识别:读懂文本内容(What)
  • 分离关注点,每个模型专注一个任务,效果更好

核心类详解

1. OCR 类(主入口,第 536-752 行)

这是用户直接调用的类,负责协调检测和识别两个模块

初始化(__init__,537-582 行)
class OCR:
    def __init__(self, model_dir=None):
        # 支持多 GPU 并行
        if settings.PARALLEL_DEVICES > 0:
            self.text_detector = []      # 多个检测器
            self.text_recognizer = []    # 多个识别器
            for device_id in range(settings.PARALLEL_DEVICES):
                self.text_detector.append(TextDetector(model_dir, device_id))
                self.text_recognizer.append(TextRecognizer(model_dir, device_id))

设计亮点:多 GPU 支持

  • 每个 GPU 有独立的检测器和识别器实例
  • 可以并行处理多张图像,提升吞吐量

模型下载逻辑

try:
    # 先尝试从本地加载
    model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
    self.text_detector = [TextDetector(model_dir)]
except Exception:
    # 本地没有,从 HuggingFace 下载
    model_dir = snapshot_download(
        repo_id="InfiniFlow/deepdoc",
        local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc")
    )

核心方法:__call__(708-751 行)

这是最重要的方法,实现完整的 OCR 流程:

def __call__(self, img, device_id=0, cls=True):
    # 1️⃣ 文本检测:找到所有文本框
    dt_boxes, elapse = self.text_detector[device_id](img)
    
    # 2️⃣ 排序文本框(从上到下,从左到右)
    dt_boxes = self.sorted_boxes(dt_boxes)
    
    # 3️⃣ 裁剪出每个文本区域
    img_crop_list = []
    for bno in range(len(dt_boxes)):
        img_crop = self.get_rotate_crop_image(ori_im, dt_boxes[bno])
        img_crop_list.append(img_crop)
    
    # 4️⃣ 批量识别所有文本(提升效率)
    rec_res, elapse = self.text_recognizer[device_id](img_crop_list)
    
    # 5️⃣ 过滤低置信度结果
    for box, rec_result in zip(dt_boxes, rec_res):
        text, score = rec_result
        if score >= self.drop_score:  # 默认 0.5
            filter_boxes.append(box)
    
    return list(zip(filter_boxes, filter_rec_res))

输出格式示例

[
    ([[10, 20], [100, 20], [100, 50], [10, 50]], ("Hello World", 0.95)),
    ([[10, 60], [150, 60], [150, 90], [10, 90]], ("RAGFlow", 0.88)),
    ...
]
关键辅助方法

1. sorted_boxes(640-661 行):智能排序

def sorted_boxes(self, dt_boxes):
    # 先按 Y 坐标排序(从上到下)
    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
    
    # 如果两个框在同一行(Y 差距 < 10),按 X 排序
    for i in range(num_boxes - 1):
        if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10:
            # 交换位置,确保左边的在前

面试要点:这是典型的**阅读顺序恢复算法**,模拟人类先从从上到下、再左到右的阅读习惯。


2. get_rotate_crop_image(584-638 行):智能旋转识别

def get_rotate_crop_image(self, img, points):
    # 1. 透视变换,将倾斜文本框变为矩形
    M = cv2.getPerspectiveTransform(points, pts_std)
    dst_img = cv2.warpPerspective(img, M, (img_crop_width, img_crop_height))
    
    # 2. 如果文本框是竖向的(高/宽 >= 1.5)
    if dst_img_height * 1.0 / dst_img_width >= 1.5:
        # 尝试原始方向
        rec_result = self.text_recognizer[0]([dst_img])
        best_score = rec_result[0][0][1]
        
        # 尝试顺时针旋转 90°
        rotated_cw = np.rot90(dst_img, k=3)
        rotated_cw_score = self.text_recognizer[0]([rotated_cw])[0][0][1]
        if rotated_cw_score > best_score:
            best_img = rotated_cw
        
        # 尝试逆时针旋转 90°
        rotated_ccw = np.rot90(dst_img, k=1)
        # ... 选择置信度最高的方向

设计亮点

  • 自动处理竖排文字(如中文古籍)
  • 通过置信度判断最佳旋转角度
  • 提升识别准确率

2. TextDetector 类(文本检测器,414-533 行)

负责找到图像中的文本区域。

初始化(415-451 行)
class TextDetector:
    def __init__(self, model_dir, device_id: int | None = None):
        # 预处理流程配置
        pre_process_list = [
            {'DetResizeForTest': {
                'limit_side_len': 960,  # 图像最长边限制
                'limit_type': "max",
            }},
            {'NormalizeImage': {
                'mean': [0.485, 0.456, 0.406],  # ImageNet 均值
                'std': [0.229, 0.224, 0.225],   # ImageNet 标准差
            }},
            {'ToCHWImage': None},  # HWC -> CHW (高宽通道 -> 通道高宽)
        ]
        
        # 后处理:从概率图恢复文本框
        postprocess_params = {
            "name": "DBPostProcess",
            "thresh": 0.3,          # 二值化阈值
            "box_thresh": 0.5,      # 文本框置信度阈值
            "unclip_ratio": 1.5,    # 文本框扩展比例
        }
        
        # 加载 ONNX 模型
        self.predictor, self.run_options = load_model(model_dir, 'det', device_id)

**面试问题:为什么要 normalize?**图像归一化

  • 神经网络训练时使用 ImageNet 数据集的统计特性
  • 归一化让模型输入分布一致,提升效果
NormalizeImage 的作用
    ✅ 将像素值从 [0, 255] 标准化到约 [-2.5, 2.5]
    ✅ 匹配模型训练时的数据分布 (ImageNet)
    ✅ 加速训练收敛,提升模型效果
    ✅ 消除不同通道之间的量纲差异

ToCHWImage 的作用
    ✅ 将 OpenCV/PIL 的 HWC 格式转换为 PyTorch/ONNX 的 CHW 格式
    ✅ 符合深度学习框架的输入要求
    ✅ 便于批处理(batch 维度在最前)

核心方法:__call__(503-530 行)
def __call__(self, img):
    # 1. 预处理
    data = transform(data, self.preprocess_op)  # 缩放、归一化
    
    # 2. 模型推理(带重试机制)
    for i in range(100000):
        try:
            outputs = self.predictor.run(None, input_dict, self.run_options)
            break
        except Exception as e:
            if i >= 3:
                raise e
            time.sleep(5)  # 重试前等待
    
    # 3. 后处理:从概率图恢复文本框
    post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
    dt_boxes = post_result[0]['points']  # 四边形坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
    
    # 4. 过滤无效框(太小的框)
    dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
    
    return dt_boxes

设计亮点:重试机制

  • 处理 ONNX Runtime 偶发的并发问题
  • 最多重试 3 次,提升稳定性

文本框过滤(470-484 行)
def filter_tag_det_res(self, dt_boxes, image_shape):
    dt_boxes_new = []
    for box in dt_boxes:
        # 1. 按顺时针排序四个顶点
        box = self.order_points_clockwise(box)
        
        # 2. 裁剪到图像边界内
        box = self.clip_det_res(box, img_height, img_width)
        
        # 3. 计算宽度和高度
        rect_width = int(np.linalg.norm(box[0] - box[1]))
        rect_height = int(np.linalg.norm(box[0] - box[3]))
        
        # 4. 过滤太小的框(可能是噪声)
        if rect_width <= 3 or rect_height <= 3:
            continue
        
        dt_boxes_new.append(box)
    return np.array(dt_boxes_new)

3. TextRecognizer 类(文本识别器,133-411 行)

负责识别裁剪出的文本图像的具体内容。

初始化(134-144 行)
class TextRecognizer:
    def __init__(self, model_dir, device_id: int | None = None):
        self.rec_image_shape = [3, 48, 320]  # 通道、高、宽
        self.rec_batch_num = 16              # 批处理大小
        
        # 后处理:CTC 解码器
        postprocess_params = {
            'name': 'CTCLabelDecode',
            'character_dict_path': os.path.join(model_dir, "ocr.res"),  # 字符字典
        }
        self.postprocess_op = build_post_process(postprocess_params)
        
        # 加载识别模型
        self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)

关键概念:CTC(Connectionist Temporal Classification)

  • 一种序列标注算法,将图像转换为文本序列
  • 不需要字符级别的标注,只需要整个文本的标签

核心方法:__call__(363-408 行)
def __call__(self, img_list):
    # 1. 计算所有图像的宽高比
    width_list = []
    for img in img_list:
        width_list.append(img.shape[1] / float(img.shape[0]))
    
    # 2. 按宽高比排序(相似比例的图像一起处理)
    indices = np.argsort(np.array(width_list))
    
    # 3. 批量处理
    for beg_img_no in range(0, img_num, batch_num):
        end_img_no = min(img_num, beg_img_no + batch_num)
        
        # 4. 找到这一批的最大宽高比
        max_wh_ratio = imgW / imgH
        for ino in range(beg_img_no, end_img_no):
            wh_ratio = w * 1.0 / h
            max_wh_ratio = max(max_wh_ratio, wh_ratio)
        
        # 5. 统一缩放到相同高度,宽度按比例
        for ino in range(beg_img_no, end_img_no):
            norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
            norm_img_batch.append(norm_img)
        
        # 6. 模型推理
        outputs = self.predictor.run(None, input_dict, self.run_options)
        
        # 7. CTC 解码
        rec_result = self.postprocess_op(preds)
    
    return rec_res  # [(文本, 置信度), ...]

优化技巧:按宽高比排序

  • 相似形状的图像一起处理,减少 padding
  • 提升批处理效率,降低内存浪费

图像预处理(146-170 行)
def resize_norm_img(self, img, max_wh_ratio):
    imgC, imgH, imgW = self.rec_image_shape  # [3, 48, 320]
    
    # 1. 计算缩放后的宽度(保持宽高比)
    imgW = int((imgH * max_wh_ratio))
    h, w = img.shape[:2]
    ratio = w / float(h)
    
    if math.ceil(imgH * ratio) > imgW:
        resized_w = imgW  # 宽度超限,裁剪
    else:
        resized_w = int(math.ceil(imgH * ratio))  # 按比例缩放
    
    # 2. 缩放图像
    resized_image = cv2.resize(img, (resized_w, imgH))
    
    # 3. 归一化到 [-1, 1]
    resized_image = resized_image.transpose((2, 0, 1)) / 255  # [0, 1]
    resized_image -= 0.5
    resized_image /= 0.5  # [-1, 1]
    
    # 4. 右侧填充 0
    padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
    padding_im[:, :, 0:resized_w] = resized_image
    
    return padding_im

设计原理

  • 固定高度 48 像素(模型输入要求)
  • 宽度动态调整(支持不同长度的文本)
  • 右侧填充 0(批处理时对齐)

辅助函数

1. load_model(71-130 行):模型加载器

def load_model(model_dir, nm, device_id: int | None = None):
    model_file_path = os.path.join(model_dir, nm + ".onnx")
    
    # 模型缓存机制
    global loaded_models
    if model_cached_tag in loaded_models:
        return loaded_models[model_cached_tag]  # 复用已加载的模型
    
    # 检查 CUDA 是否可用
    if cuda_is_available():
        cuda_provider_options = {
            "device_id": device_id,
            "gpu_mem_limit": max(gpu_mem_limit_mb, 0) * 1024 * 1024,
            "arena_extend_strategy": arena_strategy,
        }
        sess = ort.InferenceSession(
            model_file_path,
            providers=['CUDAExecutionProvider'],
            provider_options=[cuda_provider_options]
        )
    else:
        sess = ort.InferenceSession(
            model_file_path,
            providers=['CPUExecutionProvider']
        )
    
    # 缓存模型
    loaded_models[model_cached_tag] = (sess, run_options)
    return sess, run_options

设计亮点

  • 全局缓存:避免重复加载模型
  • GPU 内存限制:防止 OOM(默认 2GB)
  • 多设备支持:每个 GPU 独立模型实例

2. create_operators(49-68 行):动态创建操作符

def create_operators(op_param_list, global_config=None):
    ops = []
    for operator in op_param_list:
        op_name = list(operator)[0]  # 如 'DetResizeForTest'
        param = operator[op_name]
        
        # 动态获取类并实例化
        op = getattr(operators, op_name)(**param)
        ops.append(op)
    
    return ops

配置示例

[
    {'DetResizeForTest': {'limit_side_len': 960}},
    {'NormalizeImage': {'mean': [0.485, 0.456, 0.406]}},
    {'ToCHWImage': None},
]

设计模式责任链模式 + 工厂模式

  • 配置驱动,灵活组合预处理流程
  • 便于添加新的操作符

面试高频问题

Q1:为什么检测和识别要分开?

A:

  • 检测是定位问题(回归),识别是分类问题(序列标注)
  • 分开训练,各自优化,效果更好
  • 检测可以处理任意角度的文本,识别只需要处理水平文本

Q2:批处理为什么按宽高比排序?

A:

  • 相似形状的图像填充浪费更少
  • 批内宽度统一为最大宽高比,避免过度填充
  • 提升 GPU 利用率

Q3:为什么要多次尝试旋转角度?

A:

  • 处理竖排文字、倾斜文档
  • 通过置信度自动选择最佳角度
  • 提升泛化能力

Q4:重试机制的作用?

A:

  • ONNX Runtime 在高并发下可能失败
  • 重试 + 延迟避免资源竞争
  • 提升生产环境稳定性

五、性能优化总结

优化点 实现方式 效果
模型缓存 全局字典 loaded_models 避免重复加载,节省启动时间
批处理 rec_batch_num=16 GPU 吞吐量提升 10x
多 GPU PARALLEL_DEVICES 并行处理,线性扩展
内存限制 gpu_mem_limit 防止 OOM,稳定运行
智能排序 按宽高比分组 减少 padding,提升 20% 效率

六、在 RAGFlow 中的位置

用户上传 PDF
   ↓
PDF 转图像(每页一张)
   ↓
【ocr.py - OCR 类】
   ├─ TextDetector: 找到文本框
   └─ TextRecognizer: 识别文本
   ↓
输出:[(坐标, 文本), ...]
   ↓
【pdf_parser.py】
   使用 OCR 结果 + 布局识别 + 表格识别
   ↓
【分块 + 向量化 + 检索】

边框检测

问题 1:边框检测使用模型(det.onnx)

答案:DB(Differentiable Binarization)模型

从代码中可以明确看到:

postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000,
                      "unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}

DB 模型全称Differentiable Binarization(可微分二值化)


DB 模型的工作原理

核心思想:端到端的文本检测
输入图像
   ↓
【卷积神经网络(Backbone + FPN)】
   ↓
输出概率图(Probability Map)
   - Probability Map (P): 每个像素是文本的概率 [0, 1]
   - Threshold Map (T): 每个像素的自适应阈值
   ↓
【可微分二值化】
   Binary = 1 / (1 + e^(-k(P - T)))
   ↓
二值图(0 或 1)
   ↓
【后处理】
   轮廓检测 → 文本框

为什么叫"可微分"二值化?

传统方法的问题

# 传统的硬二值化(不可微分)
binary = P > threshold  # 梯度为 0,无法反向传播

DB 的创新

# 可微分的软二值化(smooth approximation)
binary = 1 / (1 + exp(-k * (P - T)))  # 可微分,可以端到端训练

面试要点

  • 传统方法:先训练模型输出概率,再用固定阈值二值化(两阶段)
  • DB 方法:二值化过程也参与训练(端到端),自适应学习阈值

模型输出详解

# TextDetector 的推理输出
outputs = self.predictor.run(None, input_dict, self.run_options)
# outputs[0].shape = (1, 1, H, W)
# 值范围:[0, 1],表示每个像素是文本的概率

可视化示例


问题 2:从概率图恢复文本框的原理

**核心流程在 <font style="color:#DF2A3F;">DBPostProcess</font> **类(postprocess.py)中实现

def __call__(self, outs_dict, shape_list):
    pred = outs_dict['maps']
    pred = pred[:, 0, :, :]
    
    # 1️⃣ 二值化:将概率图转为二值图
    segmentation = pred > self.thresh  # thresh=0.3
    
    # 2️⃣ 可选的形态学膨胀(扩大文本区域)
    if self.dilation_kernel is not None:
        mask = cv2.dilate(np.array(segmentation[batch_index]).astype(np.uint8),
                         self.dilation_kernel)
    
    # 3️⃣ 轮廓检测:从二值图提取文本框
    if self.box_type == 'quad':  # 四边形
        boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h)
    
    return boxes_batch

详细步骤拆解

Step 1: 二值化(Binarization)
segmentation = pred > self.thresh  # thresh = 0.3

作用:将连续的概率值转为离散的二值掩码

概率图 (0-1 范围):        二值图 (0或1):
┌──────────────┐         ┌──────────────┐
│ 0.1 0.2 0.8  │         │ 0  0  1      │
│ 0.9 0.9 0.7  │  -->    │ 1  1  1      │
│ 0.2 0.3 0.1  │         │ 0  0  0      │
└──────────────┘         └──────────────┘

Step 2: 轮廓检测(Contour Detection)
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), 
                               cv2.RETR_LIST,
                               cv2.CHAIN_APPROX_SIMPLE)

OpenCV 的 findContours 函数

  • 在二值图中找到所有连通区域的边界
  • 输出轮廓点的坐标序列

示例

二值图:                    检测到的轮廓:
┌──────────────┐         ┌──────────────┐
│ 0  0  0  0   │         │              │
│ 0  1  1  0   │         │  ┌────┐      │
│ 0  1  1  0   │  -->    │  │    │      │
│ 0  0  0  0   │         │  └────┘      │
└──────────────┘         └──────────────┘

Step 3: 计算置信度分数
def box_score_fast(self, bitmap, _box):
    '''
    box_score_fast: use bbox mean score as the mean score
    '''
    # 1. 计算文本框的边界
    xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
    xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
    ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
    ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)

    # 2. 创建掩码
    mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
    cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)

    # 3. 计算文本框区域内的平均概率
    return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]

目的:过滤掉低质量的检测框(可能是噪声)

过滤条件

if self.box_thresh > score:  # box_thresh=0.5
    continue  # 跳过这个框


置信度 = 文本框区域内概率值的平均值

   = Σ(概率图中文本框像素值) / 像素数量

Fast 方法(矩形掩码)

# 优点:计算快(只需找边界矩形)
# 缺点:对于倾斜文本,可能包含更多背景

Bounding Box (矩形):
┌─────────┐
│█████████│ ← 包含文本框的最小矩形
│█████████│
│█████████│
└─────────┘

Slow 方法(精确多边形掩码)

# 优点:精确(完全贴合文本轮廓)
# 缺点:计算稍慢(需要多边形填充)

Precise Polygon:
   ┌───┐
  /     \   ← 精确的文本轮廓
 /       \
└─────────┘
参数 默认值 作用 影响
box_thresh 0.5 置信度阈值 越高越严格,误检少但可能漏检
thresh 0.3 二值化阈值 越高文本区域越小,越保守
score_mode “fast” 计算方法 fast 快但不精确,slow 慢但精确

Step 4: 文本框扩展(Unclip)

这是 DB 模型的关键创新

def unclip(self, box, unclip_ratio):
    poly = Polygon(box)
    distance = poly.area * unclip_ratio / poly.length  # 扩展距离
    offset = pyclipper.PyclipperOffset()
    offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
    expanded = np.array(offset.Execute(distance))
    return expanded

为什么要扩展?

问题:神经网络倾向于收缩文本区域(保守预测)

检测到的框(太小):        扩展后的框(正好):
┌──────────────┐         ┌──────────────┐
│   Hello      │         │  Hello       │
│  ┌────┐      │         │ ┌──────┐     │
│  │ell │      │  -->    │ │Hello │     │
│  └────┘      │         │ └──────┘     │
└──────────────┘         └──────────────┘

扩展算法:Vatti Clipping(多边形偏移)

  • <font style="color:#DF2A3F;">unclip_ratio=1.5</font>:扩展系数(默认值)
  • distance = 面积 × ratio / 周长:计算扩展距离
  • 保持四边形形状的同时向外扩展

Step 5: 最小外接矩形
def get_mini_boxes(self, contour):
    bounding_box = cv2.minAreaRect(contour)  # 最小外接矩形
    points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
    
    # 对四个顶点排序(左上、右上、右下、左下)
    box = [points[index_1], points[index_2], points[index_3], points[index_4]]
    return box, min(bounding_box[1])

作用:将不规则轮廓转为标准的旋转矩形

不规则轮廓:              最小外接矩形:
    ___                    ┌────┐
   /   \                   │    │
  /     \       -->        │    │
 /       \                 │    │
─────────                  └────┘

Step 6: 坐标缩放回原图
box[:, 0] = np.clip(
    np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
    np.round(box[:, 1] / height * dest_height), 0, dest_height)

原因:模型输入图像被缩放过(如 960x640),需要映射回原始尺寸

模型输入尺寸: 960x640        原图尺寸: 1920x1280
检测框: [100, 50]      -->   [200, 100]
        x' = x * 1920/960 = x * 2

┌─────────────────────────────────────────────────────┐
│  原始图像: 任意尺寸                                 │
│  例如: 4000x3000, 800x600, 2400x1600...           │
└─────────────────┬───────────────────────────────────┘
                  ↓
┌─────────────────────────────────────────────────────┐
│  Step 1: 判断是否需要缩放                           │
│  if max(h, w) > 960:                                │
│      ratio = 960 / max(h, w)                        │
│  else:                                              │
│      ratio = 1.0 (不缩放)                           │
└─────────────────┬───────────────────────────────────┘
                  ↓
┌─────────────────────────────────────────────────────┐
│  Step 2: 计算初步尺寸                               │
│  new_h = h * ratio                                  │
│  new_w = w * ratio                                  │
└─────────────────┬───────────────────────────────────┘
                  ↓
┌─────────────────────────────────────────────────────┐
│  Step 3: 向上取整到 32 的倍数 ⭐                    │
│  new_h = round(new_h / 32) * 32                     │
│  new_w = round(new_w / 32) * 32                     │
│  (最小为 32)                                        │
└─────────────────┬───────────────────────────────────┘
                  ↓
┌─────────────────────────────────────────────────────┐
│  Step 4: 执行缩放                                   │
│  img = cv2.resize(img, (new_w, new_h))              │
└─────────────────┬───────────────────────────────────┘
                  ↓
┌─────────────────────────────────────────────────────┐
│  输出: 符合要求的图像                               │
│  - 最长边 ≤ 960                                     │
│  - 宽度和高度都是 32 的倍数                         │
│  - 保持原始宽高比                                   │
└─────────────────────────────────────────────────────┘
预处理配置(ocr.py)

pre_process_list = [{
            'DetResizeForTest': {
                'limit_side_len': 960,
                'limit_type': "max",
            }
        }, {
            'NormalizeImage': {
                'std': [0.229, 0.224, 0.225],
                'mean': [0.485, 0.456, 0.406],
                'scale': '1./255.',
                'order': 'hwc'
            }
        }, {
            'ToCHWImage': None
        }, {
            'KeepKeys': {
                'keep_keys': ['image', 'shape']
            }
        }]

关键配置解读:

  • limit_side_len’: 960:将图像最长边缩放到 960 像素
  • limit_type’: “max”:保持宽高比,只限制最长边
需要缩放的5 个核心原因
原因 说明 解决方案
1. 计算效率 大图像计算量巨大 (4000x3000 = 1200万像素) 限制最长边为 960
2. 内存限制 GPU 内存有限(通常 8-16GB) 统一尺寸到可控范围
3. 模型架构 下采样 5 次 (2^5=32),需要能整除 向上取整到 32 的倍数
4. 训练一致性 推理时的尺寸应与训练时接近 保持在 640-1024 范围
5. 特征尺度 文本大小需要在模型感受野范围内 保持合适的像素尺寸

缩放策略
# 缩放策略
✓ 最长边限制为 960 (可配置)
✓ 保持宽高比
✓ 对齐到 32 的倍数 (神经网络架构要求)
✓ 小图不放大 (避免插值噪声)
✓ 返回缩放比例 (用于坐标映射回原图)

# 计算公式
ratio = min(960 / max(h, w), 1.0)
new_h = round(h * ratio / 32) * 32
new_w = round(w * ratio / 32) * 32

为什么要从概率图恢复文本框?

目的 说明
1. 从像素到结构 概率图是像素级预测,需要转换为结构化的文本框(坐标)
2. 过滤噪声 通过置信度阈值过滤误检测(小块、低概率区域)
3. 精确定位 Unclip 操作补偿模型的收缩倾向,确保文本完整
4. 标准化输出 输出标准的四边形坐标,方便后续裁剪和识别
5. 多尺度适配 将模型输出的小尺寸坐标映射回原始图像

完整流程可视化

【模型推理】
输入: RGB 图像 (1920x1280x3)
   ↓ 缩放到 960x640
【DB 模型】
   ↓
概率图 (960x640, 值 0-1)
   ↓
【后处理流程】

1️⃣ 二值化 (thresh=0.3)
   [0.8, 0.9, 0.2] --> [1, 1, 0]

2️⃣ 轮廓检测 (cv2.findContours)
   二值图 --> [(x1,y1), (x2,y2), ...]

3️⃣ 计算置信度 (box_score_fast)
   平均概率: 0.87 > 0.5 ✓ 保留

4️⃣ 文本框扩展 (unclip)
   面积=1000, 周长=120
   扩展距离 = 1000*1.5/120 = 12.5 像素

5️⃣ 最小外接矩形 (minAreaRect)
   不规则形状 --> 标准四边形

6️⃣ 坐标映射
   [100, 50] @ 960x640 --> [200, 100] @ 1920x1280

输出: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]

面试高频追问

Q1: 为什么不直接用分类器?

A: 文本检测是密集预测问题,需要对每个像素做判断。分类器只能判断"整张图是否有文本",无法定位。

Q2: Unclip 的 ratio 如何选择?

A:

  • 太小(如 1.0):文字边缘被裁剪,影响识别
  • 太大(如 2.0):包含背景噪声
  • 经验值 1.5:平衡完整性和准确性

Q3: 为什么用概率图而不是直接输出坐标?

A:

  • 概率图:像素级监督,训练信号丰富
  • 直接回归坐标:监督信号稀疏,难以训练
  • 类似于"语义分割→物体检测"的思路

文字识别

为什么按宽高比排序?

批处理的挑战:
- 需要统一尺寸 (H, W)
- 不同文本长度不一
- 必须 padding 到最大宽度

优化策略:
1. 按宽高比排序
2. 相似长度的文本放在同一批次
3. 每批次的"最大宽度"更接近"平均宽度"
4. 减少无效 padding

结果:
✓ 更高效的 GPU 利用率
✓ 更快的推理速度
✓ 更低的内存占用
原因 说明 收益
减少 Padding 相似长度的图像放在一起 padding 从 45% → 15%
节省计算 GPU 不浪费算力在填充区域 计算量减少 30%
降低内存 batch 的统一宽度更小 内存占用减少 40%
提升速度 整体推理更快 速度提升 25-33%
无损优化 不影响识别准确率 零副作用

问题 1:[3, 48, 320] 是什么?

self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]

含义:文本识别模型的输入尺寸

[3, 48, 320]
 │   │   │
 │   │   └─ 宽度 (Width): 320 像素
 │   └───── 高度 (Height): 48 像素
 └───────── 通道数 (Channel): 3 (RGB)

格式: [C, H, W] (Channel, Height, Width)

  1. 高度固定为 48 像素
原因:
✓ 文本识别是"序列到序列"任务
✓ 高度固定,便于特征提取的稳定性
✓ 48 像素足够表示大多数字符的细节

例子:
原始文本图像: 100 x 400 (高 x 宽)
缩放后:       48  x 192 (保持宽高比)

可视化

原始裁剪的文本框:
┌────────────────────────────┐
│  Hello World 你好世界      │ ← 高度可能是 50-200 像素
└────────────────────────────┘

缩放到固定高度:
┌────────────────────────────┐
│  Hello World 你好世界      │ ← 固定高度 48 像素
└────────────────────────────┘
   宽度动态变化 (根据文本长度)

宽度最大为 320 像素(动态的)

实际代码逻辑:
imgW = int((imgH * max_wh_ratio))  # 根据宽高比计算
if math.ceil(imgH * ratio) > imgW:
    resized_w = imgW  # 超过最大宽度,裁剪
else:
    resized_w = int(math.ceil(imgH * ratio))  # 保持宽高比

通道数为 3 (RGB)

RGB 三通道彩色图像
- R (Red): 红色通道
- G (Green): 绿色通道  
- B (Blue): 蓝色通道

保留颜色信息有助于:
✓ 区分文字和背景
✓ 处理彩色文本
✓ 提高识别准确率

流程可视化

输入: 裁剪的文本框图像 (60 x 240 x 3)
   ↓
Step 1: 缩放到固定高度 48
   计算宽高比: 240/60 = 4.0
   新宽度: 48 * 4.0 = 192
   结果: 48 x 192 x 3
   ↓
Step 2: 转换为 CHW 格式
   (48, 192, 3)(3, 48, 192)
   ↓
Step 3: 归一化到 [-1, 1]
   /255[0, 1]
   -0.5[-0.5, 0.5]
   /0.5[-1, 1]
   ↓
Step 4: 右侧填充 0 到固定宽度
   (3, 48, 192)(3, 48, 320)
   填充区域: [:, :, 192:320] = 0
   ↓
输出: (3, 48, 320) ← 模型输入

问题 2:CTC 解码器是什么?

CTC 全称:Connectionist Temporal Classification(连接时序分类)

核心问题:序列到序列的对齐

问题场景:
输入: 文本图像
输出: 文本序列

挑战:
- 不知道每个字符在图像中的精确位置
- 字符宽度不一致('i' vs 'W')
- 需要模型自己学习对齐关系

CTC 的工作原理

模型输出:每个时间步的概率分布

模型结构:
图像 (3, 48, 320)
   ↓ CNN 特征提取
特征图 (512, 1, 80)  ← 宽度 32080 个时间步
   ↓ RNN/Transformer
序列特征 (80, 512)
   ↓ 全连接层
概率分布 (80, 6625)80 个时间步,每步预测 6625 个字符的概率

示例:识别 “Hello”

时间步:  0    1    2    3    4    5    6    7    8    9   ...  79
预测:  blank  H   H   H   e   e   l   l   l   o   o  blank ...
概率:  0.8  0.9  0.7  0.6  0.85 0.8  0.9  0.85 0.8  0.9  0.7  0.8

CTC 解码:去重和去空白

def __call__(self, preds, label=None, *args, **kwargs):
    if isinstance(preds, tuple) or isinstance(preds, list):
        preds = preds[-1]
    if not isinstance(preds, np.ndarray):
        preds = preds.numpy()
    preds_idx = preds.argmax(axis=2)  # 取每个时间步概率最大的字符
    preds_prob = preds.max(axis=2)    # 取最大概率值
    text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
    if label is None:
        return text
    label = self.decode(label)
    return text, label

** CTC 解码详细示例**

# 模型输出 (简化版)
模型预测序列 (80 个时间步):
[blank, H, H, H, e, e, blank, l, l, o, blank, blank, ...]
   ↓
Step 1: 去除连续重复
[blank, H, e, blank, l, o, blank, ...]
   ↓
Step 2: 去除 blank (空白符)
[H, e, l, l, o]
   ↓
Step 3: 拼接成文本
"Hello"

更复杂的例子

原始输出:
[b, H, H, e, e, e, l, l, l, l, o, o, b, b, W, W, o, o, r, r, l, l, d, d, b]
b = blank

解码过程:
1. 去连续重复:
   [b, H, e, l, o, b, W, o, r, l, d, b]

2. 去 blank:
   [H, e, l, o, W, o, r, l, d]

3. 拼接:
   "HelloWorld"

****字符字典(ocr.res)

postprocess_params = {
    'name': 'CTCLabelDecode',
    "character_dict_path": os.path.join(model_dir, "ocr.res"),
    "use_space_char": True
}

字典内容示例

blank          ← 索引 0(特殊符号)
0              ← 索引 1
1              ← 索引 2
2              ← 索引 3
...
a              ← 索引 37
b              ← 索引 38
...
A              ← 索引 63
B              ← 索引 64
...
的             ← 索引 100
一             ← 索引 101
...
(空格)         ← 索引 6624

解码过程

模型输出索引: [0, 38, 101, 37, 0]
查字典:      [blank, 'b', '一', 'a', blank]
去blank:     ['b', '一', 'a']
拼接:        "b一a"
字符字典和识别模型的关系
  1. 字符字典和识别模型是强绑定的
    ✓ 模型输出维度 = 字典大小
    ✓ 索引映射关系固定
    ✓ 顺序必须一致
  2. 不能随意替换
    ✗ 单独换字典 → 识别不了新字符
    ✗ 单独换模型 → 可能崩溃
    ✓ 必须同步替换
  3. 扩展字典需要重新训练
    方案1: 从头训练(耗时长)
    方案2: 微调(推荐)
    方案3: 使用现成的大字典模型
  4. 版本管理很重要
    建议: 模型和字典放在同一个文件夹
    文件名包含版本号
    例如: rec_v1.onnx + ocr_v1.res

问题 3:识别模型(**rec.onnx**

模型架构:CRNN (CNN + RNN)

self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)

模型文件:rec.onnx


****CRNN 架构详解

输入图像 (3, 48, 320)
   ↓
┌─────────────────────────────────────────────┐
│  Part 1: CNN (卷积神经网络)                  │
│  作用: 提取图像特征                          │
├─────────────────────────────────────────────┤
│  Conv2D (3→64) + ReLU + MaxPool             │
│  Conv2D (64→128) + ReLU + MaxPool           │
│  Conv2D (128→256) + ReLU                    │
│  Conv2D (256→256) + ReLU + MaxPool          │
│  Conv2D (256→512) + ReLU                    │
│  Conv2D (512→512) + ReLU + MaxPool          │
│  Conv2D (512→512) + BN                      │
└─────────────┬───────────────────────────────┘
              ↓
    特征图 (512, 1, 80)
    ↓ reshape
    序列特征 (80, 512)
    ↓ 80 个时间步,每步 512 维特征
┌─────────────────────────────────────────────┐
│  Part 2: RNN (循环神经网络)                  │
│  作用: 建模序列依赖关系                      │
├─────────────────────────────────────────────┤
│  BiLSTM (512→256→256)                       │
│  BiLSTM (256→256→256)                       │
└─────────────┬───────────────────────────────┘
              ↓
    序列特征 (80, 512) ← 双向拼接
    ↓
┌─────────────────────────────────────────────┐
│  Part 3: Transcription (转录层)              │
│  作用: 输出每个时间步的字符概率              │
├─────────────────────────────────────────────┤
│  全连接层 (512→6625)                         │
│  Softmax                                    │
└─────────────┬───────────────────────────────┘
              ↓
    概率分布 (80, 6625)
    ↓ 每个时间步预测 6625 个字符的概率
┌─────────────────────────────────────────────┐
│  Part 4: CTC Decoder (解码器)                │
│  作用: 将概率序列转换为文本                  │
├─────────────────────────────────────────────┤
│  Greedy Decoding / Beam Search              │
│  去重 + 去 blank                             │
└─────────────┬───────────────────────────────┘
              ↓
    输出文本: "Hello World"

****各部分的作用

CNN 部分:视觉特征提取

作用: 将图像转换为特征序列



输入: (3, 48, 320)
     ↓ 多层卷积 + 池化
输出: (512, 1, 80)
     ↓ reshape
     (80, 512)80 个时间步,每步 512 维特征



可视化:
图像: [  H  e  l  l  o  ]
     ↓ CNN
特征: [f1][f2][f3][f4][f5]...[f80]
     每个 f 是 512 维向量

为什么宽度 320 → 80?

320 / 4 = 80  (经过 4 次步长为 2 的池化)

池化过程:
320 → 160 → 80 → 40 → ... 
最终在宽度方向上得到 80 个特征列

RNN 部分:序列建模

作用: 利用上下文信息增强特征


为什么需要 RNN?
- 字符识别依赖上下文
- "1" vs "l" vs "I" 需要周围字符帮助判断
- RNN 可以"看到"左右的信息


BiLSTM (双向 LSTM):
前向:  f1 → f2 → f3 → ... → f80
后向:  f80 ← f79 ← f78 ← ... ← f1
拼接: [前向特征, 后向特征]

示例

文本: "Hello"

前向 LSTM:
H → He → Hel → Hell → Hello
每步都知道前面的字符

后向 LSTM:
o ← lo ← llo ← ello ← Hello
每步都知道后面的字符

拼接后:
每个位置都有前后文信息
→ 提高识别准确率

转录层:输出概率

作用: 将特征映射到字符空间

全连接层: (512)(6625)


6625 = 字符字典大小
     = 数字 + 字母 + 汉字 + 符号 + blank


输出:
时间步 0: [P(blank)=0.8, P(H)=0.15, P(e)=0.02, ...]
时间步 1: [P(blank)=0.1, P(H)=0.85, P(e)=0.03, ...]
时间步 2: [P(blank)=0.2, P(H)=0.7, P(e)=0.05, ...]
...

****** 完整的识别流程示例**

# 输入: 裁剪的文本框图像
原图: "Hello World" 的图像

Step 1: 预处理
(60, 240, 3) → resize → (48, 192, 3)
            → normalize → [-1, 1]
            → to CHW → (3, 48, 192)
            → padding → (3, 48, 320)


Step 2: 模型推理 (rec.onnx)
(3, 48, 320) → CRNN → (80, 6625)


Step 3: CTC 解码
概率分布 (80, 6625)
   ↓ argmax
索引序列: [0, 38, 38, 101, 101, 0, 124, 124, ...]
   ↓ 查字典(ocr.res)
字符序列: [blank, 'H', 'H', 'e', 'e', blank, 'l', 'l', ...]
   ↓ 去重 + 去 blank
结果: "Hello World"
   ↓ 置信度
score: 0.95

三个问题的关系图

┌──────────────────────────────────────────────┐
│  问题 1: [3, 48, 320]                         │
│  输入尺寸规范                                 │
└───────────────┬──────────────────────────────┘
                ↓
        (3, 48, 320)
                ↓
┌──────────────────────────────────────────────┐
│  问题 3: 识别模型 (rec.onnx)                  │
│  CRNN: CNN + BiLSTM + FC                     │
├──────────────────────────────────────────────┤
│  CNN: 图像 → 特征序列                         │
│  RNN: 序列建模                                │
│  FC:  特征 → 字符概率                         │
└───────────────┬──────────────────────────────┘
                ↓
        (80, 6625) 概率分布
                ↓
┌──────────────────────────────────────────────┐
│  问题 2: CTC 解码器                           │
│  概率序列 → 文本                              │
├──────────────────────────────────────────────┤
│  1. Argmax: 取最大概率                        │
│  2. 去重: 连续重复 → 单个字符                 │
│  3. 去 blank: 移除空白符                      │
│  4. 拼接: 字符序列 → 文本字符串              │
└───────────────┬──────────────────────────────┘
                ↓
        "Hello World"

总结

问题 答案 关键点
[3, 48, 320] 模型输入尺寸 • 高度固定 48 • 宽度动态,最大 320 • RGB 三通道
CTC 解码器 序列解码算法 • 处理对齐问题 • 去重 + 去 blank • 无需字符级标注
识别模型 CRNN 架构 • CNN 提取特征 • BiLSTM 建模序列 • CTC 损失训练
Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐