ragflow项目源码解读之文本两阶段识别:ocr.py
本文介绍了OCR系统的两阶段架构设计及核心实现细节。系统采用检测-识别分离架构,TextDetector负责定位文本区域,TextRecognizer负责识别文本内容。核心类OCR协调整个流程,支持多GPU并行处理,包含智能排序、旋转识别等优化方法。TextDetector通过预处理、归一化和后处理实现高效文本检测。系统通过置信度过滤、阅读顺序恢复和竖排文字处理等技术创新,显著提升了OCR的准确率
整体架构设计
OCR 的两阶段流程
输入图像
↓
【第一阶段:文本检测 TextDetector】
检测文本在哪里(输出:文本框坐标)
↓
【第二阶段:文本识别 TextRecognizer】
识别每个文本框的内容(输出:文本 + 置信度)
↓
输出:[(坐标, (文字, 置信度)), ...]
面试问题:为什么要**分两阶段ocr?**
- 检测:找到文本区域(Where)
- 识别:读懂文本内容(What)
- 分离关注点,每个模型专注一个任务,效果更好
核心类详解
1. OCR 类(主入口,第 536-752 行)
这是用户直接调用的类,负责协调检测和识别两个模块。
初始化(__init__,537-582 行)
class OCR:
def __init__(self, model_dir=None):
# 支持多 GPU 并行
if settings.PARALLEL_DEVICES > 0:
self.text_detector = [] # 多个检测器
self.text_recognizer = [] # 多个识别器
for device_id in range(settings.PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
设计亮点:多 GPU 支持
- 每个 GPU 有独立的检测器和识别器实例
- 可以并行处理多张图像,提升吞吐量
模型下载逻辑:
try:
# 先尝试从本地加载
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
self.text_detector = [TextDetector(model_dir)]
except Exception:
# 本地没有,从 HuggingFace 下载
model_dir = snapshot_download(
repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc")
)
核心方法:__call__(708-751 行)
这是最重要的方法,实现完整的 OCR 流程:
def __call__(self, img, device_id=0, cls=True):
# 1️⃣ 文本检测:找到所有文本框
dt_boxes, elapse = self.text_detector[device_id](img)
# 2️⃣ 排序文本框(从上到下,从左到右)
dt_boxes = self.sorted_boxes(dt_boxes)
# 3️⃣ 裁剪出每个文本区域
img_crop_list = []
for bno in range(len(dt_boxes)):
img_crop = self.get_rotate_crop_image(ori_im, dt_boxes[bno])
img_crop_list.append(img_crop)
# 4️⃣ 批量识别所有文本(提升效率)
rec_res, elapse = self.text_recognizer[device_id](img_crop_list)
# 5️⃣ 过滤低置信度结果
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
if score >= self.drop_score: # 默认 0.5
filter_boxes.append(box)
return list(zip(filter_boxes, filter_rec_res))
输出格式示例:
[
([[10, 20], [100, 20], [100, 50], [10, 50]], ("Hello World", 0.95)),
([[10, 60], [150, 60], [150, 90], [10, 90]], ("RAGFlow", 0.88)),
...
]
关键辅助方法
1. sorted_boxes(640-661 行):智能排序
def sorted_boxes(self, dt_boxes):
# 先按 Y 坐标排序(从上到下)
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
# 如果两个框在同一行(Y 差距 < 10),按 X 排序
for i in range(num_boxes - 1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10:
# 交换位置,确保左边的在前
面试要点:这是典型的**阅读顺序恢复算法**,模拟人类先从从上到下、再左到右的阅读习惯。
2. get_rotate_crop_image(584-638 行):智能旋转识别
def get_rotate_crop_image(self, img, points):
# 1. 透视变换,将倾斜文本框变为矩形
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(img, M, (img_crop_width, img_crop_height))
# 2. 如果文本框是竖向的(高/宽 >= 1.5)
if dst_img_height * 1.0 / dst_img_width >= 1.5:
# 尝试原始方向
rec_result = self.text_recognizer[0]([dst_img])
best_score = rec_result[0][0][1]
# 尝试顺时针旋转 90°
rotated_cw = np.rot90(dst_img, k=3)
rotated_cw_score = self.text_recognizer[0]([rotated_cw])[0][0][1]
if rotated_cw_score > best_score:
best_img = rotated_cw
# 尝试逆时针旋转 90°
rotated_ccw = np.rot90(dst_img, k=1)
# ... 选择置信度最高的方向
设计亮点:
- 自动处理竖排文字(如中文古籍)
- 通过置信度判断最佳旋转角度
- 提升识别准确率
2. TextDetector 类(文本检测器,414-533 行)
负责找到图像中的文本区域。
初始化(415-451 行)
class TextDetector:
def __init__(self, model_dir, device_id: int | None = None):
# 预处理流程配置
pre_process_list = [
{'DetResizeForTest': {
'limit_side_len': 960, # 图像最长边限制
'limit_type': "max",
}},
{'NormalizeImage': {
'mean': [0.485, 0.456, 0.406], # ImageNet 均值
'std': [0.229, 0.224, 0.225], # ImageNet 标准差
}},
{'ToCHWImage': None}, # HWC -> CHW (高宽通道 -> 通道高宽)
]
# 后处理:从概率图恢复文本框
postprocess_params = {
"name": "DBPostProcess",
"thresh": 0.3, # 二值化阈值
"box_thresh": 0.5, # 文本框置信度阈值
"unclip_ratio": 1.5, # 文本框扩展比例
}
# 加载 ONNX 模型
self.predictor, self.run_options = load_model(model_dir, 'det', device_id)
**面试问题:为什么要 normalize?**图像归一化
- 神经网络训练时使用 ImageNet 数据集的统计特性
- 归一化让模型输入分布一致,提升效果
NormalizeImage 的作用
✅ 将像素值从 [0, 255] 标准化到约 [-2.5, 2.5]
✅ 匹配模型训练时的数据分布 (ImageNet)
✅ 加速训练收敛,提升模型效果
✅ 消除不同通道之间的量纲差异
ToCHWImage 的作用
✅ 将 OpenCV/PIL 的 HWC 格式转换为 PyTorch/ONNX 的 CHW 格式
✅ 符合深度学习框架的输入要求
✅ 便于批处理(batch 维度在最前)
核心方法:__call__(503-530 行)
def __call__(self, img):
# 1. 预处理
data = transform(data, self.preprocess_op) # 缩放、归一化
# 2. 模型推理(带重试机制)
for i in range(100000):
try:
outputs = self.predictor.run(None, input_dict, self.run_options)
break
except Exception as e:
if i >= 3:
raise e
time.sleep(5) # 重试前等待
# 3. 后处理:从概率图恢复文本框
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
dt_boxes = post_result[0]['points'] # 四边形坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
# 4. 过滤无效框(太小的框)
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
return dt_boxes
设计亮点:重试机制
- 处理 ONNX Runtime 偶发的并发问题
- 最多重试 3 次,提升稳定性
文本框过滤(470-484 行)
def filter_tag_det_res(self, dt_boxes, image_shape):
dt_boxes_new = []
for box in dt_boxes:
# 1. 按顺时针排序四个顶点
box = self.order_points_clockwise(box)
# 2. 裁剪到图像边界内
box = self.clip_det_res(box, img_height, img_width)
# 3. 计算宽度和高度
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
# 4. 过滤太小的框(可能是噪声)
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
return np.array(dt_boxes_new)
3. TextRecognizer 类(文本识别器,133-411 行)
负责识别裁剪出的文本图像的具体内容。
初始化(134-144 行)
class TextRecognizer:
def __init__(self, model_dir, device_id: int | None = None):
self.rec_image_shape = [3, 48, 320] # 通道、高、宽
self.rec_batch_num = 16 # 批处理大小
# 后处理:CTC 解码器
postprocess_params = {
'name': 'CTCLabelDecode',
'character_dict_path': os.path.join(model_dir, "ocr.res"), # 字符字典
}
self.postprocess_op = build_post_process(postprocess_params)
# 加载识别模型
self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)
关键概念:CTC(Connectionist Temporal Classification)
- 一种序列标注算法,将图像转换为文本序列
- 不需要字符级别的标注,只需要整个文本的标签
核心方法:__call__(363-408 行)
def __call__(self, img_list):
# 1. 计算所有图像的宽高比
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# 2. 按宽高比排序(相似比例的图像一起处理)
indices = np.argsort(np.array(width_list))
# 3. 批量处理
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
# 4. 找到这一批的最大宽高比
max_wh_ratio = imgW / imgH
for ino in range(beg_img_no, end_img_no):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
# 5. 统一缩放到相同高度,宽度按比例
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
norm_img_batch.append(norm_img)
# 6. 模型推理
outputs = self.predictor.run(None, input_dict, self.run_options)
# 7. CTC 解码
rec_result = self.postprocess_op(preds)
return rec_res # [(文本, 置信度), ...]
优化技巧:按宽高比排序
- 相似形状的图像一起处理,减少 padding
- 提升批处理效率,降低内存浪费
图像预处理(146-170 行)
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape # [3, 48, 320]
# 1. 计算缩放后的宽度(保持宽高比)
imgW = int((imgH * max_wh_ratio))
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW # 宽度超限,裁剪
else:
resized_w = int(math.ceil(imgH * ratio)) # 按比例缩放
# 2. 缩放图像
resized_image = cv2.resize(img, (resized_w, imgH))
# 3. 归一化到 [-1, 1]
resized_image = resized_image.transpose((2, 0, 1)) / 255 # [0, 1]
resized_image -= 0.5
resized_image /= 0.5 # [-1, 1]
# 4. 右侧填充 0
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
设计原理:
- 固定高度 48 像素(模型输入要求)
- 宽度动态调整(支持不同长度的文本)
- 右侧填充 0(批处理时对齐)
辅助函数
1. load_model(71-130 行):模型加载器
def load_model(model_dir, nm, device_id: int | None = None):
model_file_path = os.path.join(model_dir, nm + ".onnx")
# 模型缓存机制
global loaded_models
if model_cached_tag in loaded_models:
return loaded_models[model_cached_tag] # 复用已加载的模型
# 检查 CUDA 是否可用
if cuda_is_available():
cuda_provider_options = {
"device_id": device_id,
"gpu_mem_limit": max(gpu_mem_limit_mb, 0) * 1024 * 1024,
"arena_extend_strategy": arena_strategy,
}
sess = ort.InferenceSession(
model_file_path,
providers=['CUDAExecutionProvider'],
provider_options=[cuda_provider_options]
)
else:
sess = ort.InferenceSession(
model_file_path,
providers=['CPUExecutionProvider']
)
# 缓存模型
loaded_models[model_cached_tag] = (sess, run_options)
return sess, run_options
设计亮点:
- 全局缓存:避免重复加载模型
- GPU 内存限制:防止 OOM(默认 2GB)
- 多设备支持:每个 GPU 独立模型实例
2. create_operators(49-68 行):动态创建操作符
def create_operators(op_param_list, global_config=None):
ops = []
for operator in op_param_list:
op_name = list(operator)[0] # 如 'DetResizeForTest'
param = operator[op_name]
# 动态获取类并实例化
op = getattr(operators, op_name)(**param)
ops.append(op)
return ops
配置示例:
[
{'DetResizeForTest': {'limit_side_len': 960}},
{'NormalizeImage': {'mean': [0.485, 0.456, 0.406]}},
{'ToCHWImage': None},
]
设计模式:责任链模式 + 工厂模式
- 配置驱动,灵活组合预处理流程
- 便于添加新的操作符
面试高频问题
Q1:为什么检测和识别要分开?
A:
- 检测是定位问题(回归),识别是分类问题(序列标注)
- 分开训练,各自优化,效果更好
- 检测可以处理任意角度的文本,识别只需要处理水平文本
Q2:批处理为什么按宽高比排序?
A:
- 相似形状的图像填充浪费更少
- 批内宽度统一为最大宽高比,避免过度填充
- 提升 GPU 利用率
Q3:为什么要多次尝试旋转角度?
A:
- 处理竖排文字、倾斜文档
- 通过置信度自动选择最佳角度
- 提升泛化能力
Q4:重试机制的作用?
A:
- ONNX Runtime 在高并发下可能失败
- 重试 + 延迟避免资源竞争
- 提升生产环境稳定性
五、性能优化总结
| 优化点 | 实现方式 | 效果 |
|---|---|---|
| 模型缓存 | 全局字典 loaded_models |
避免重复加载,节省启动时间 |
| 批处理 | rec_batch_num=16 |
GPU 吞吐量提升 10x |
| 多 GPU | PARALLEL_DEVICES |
并行处理,线性扩展 |
| 内存限制 | gpu_mem_limit |
防止 OOM,稳定运行 |
| 智能排序 | 按宽高比分组 | 减少 padding,提升 20% 效率 |
六、在 RAGFlow 中的位置
用户上传 PDF
↓
PDF 转图像(每页一张)
↓
【ocr.py - OCR 类】
├─ TextDetector: 找到文本框
└─ TextRecognizer: 识别文本
↓
输出:[(坐标, 文本), ...]
↓
【pdf_parser.py】
使用 OCR 结果 + 布局识别 + 表格识别
↓
【分块 + 向量化 + 检索】
边框检测
问题 1:边框检测使用模型(det.onnx)
答案:DB(Differentiable Binarization)模型
从代码中可以明确看到:
postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000,
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
DB 模型全称:Differentiable Binarization(可微分二值化)
DB 模型的工作原理
核心思想:端到端的文本检测
输入图像
↓
【卷积神经网络(Backbone + FPN)】
↓
输出概率图(Probability Map)
- Probability Map (P): 每个像素是文本的概率 [0, 1]
- Threshold Map (T): 每个像素的自适应阈值
↓
【可微分二值化】
Binary = 1 / (1 + e^(-k(P - T)))
↓
二值图(0 或 1)
↓
【后处理】
轮廓检测 → 文本框
为什么叫"可微分"二值化?
传统方法的问题:
# 传统的硬二值化(不可微分)
binary = P > threshold # 梯度为 0,无法反向传播
DB 的创新:
# 可微分的软二值化(smooth approximation)
binary = 1 / (1 + exp(-k * (P - T))) # 可微分,可以端到端训练
面试要点:
- 传统方法:先训练模型输出概率,再用固定阈值二值化(两阶段)
- DB 方法:二值化过程也参与训练(端到端),自适应学习阈值
模型输出详解
# TextDetector 的推理输出
outputs = self.predictor.run(None, input_dict, self.run_options)
# outputs[0].shape = (1, 1, H, W)
# 值范围:[0, 1],表示每个像素是文本的概率
可视化示例:

问题 2:从概率图恢复文本框的原理
**核心流程在 <font style="color:#DF2A3F;">DBPostProcess</font> **类(postprocess.py)中实现
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
pred = pred[:, 0, :, :]
# 1️⃣ 二值化:将概率图转为二值图
segmentation = pred > self.thresh # thresh=0.3
# 2️⃣ 可选的形态学膨胀(扩大文本区域)
if self.dilation_kernel is not None:
mask = cv2.dilate(np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel)
# 3️⃣ 轮廓检测:从二值图提取文本框
if self.box_type == 'quad': # 四边形
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h)
return boxes_batch
详细步骤拆解
Step 1: 二值化(Binarization)
segmentation = pred > self.thresh # thresh = 0.3
作用:将连续的概率值转为离散的二值掩码
概率图 (0-1 范围): 二值图 (0或1):
┌──────────────┐ ┌──────────────┐
│ 0.1 0.2 0.8 │ │ 0 0 1 │
│ 0.9 0.9 0.7 │ --> │ 1 1 1 │
│ 0.2 0.3 0.1 │ │ 0 0 0 │
└──────────────┘ └──────────────┘
Step 2: 轮廓检测(Contour Detection)
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
OpenCV 的 findContours 函数:
- 在二值图中找到所有连通区域的边界
- 输出轮廓点的坐标序列
示例:
二值图: 检测到的轮廓:
┌──────────────┐ ┌──────────────┐
│ 0 0 0 0 │ │ │
│ 0 1 1 0 │ │ ┌────┐ │
│ 0 1 1 0 │ --> │ │ │ │
│ 0 0 0 0 │ │ └────┘ │
└──────────────┘ └──────────────┘
Step 3: 计算置信度分数
def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
# 1. 计算文本框的边界
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
# 2. 创建掩码
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
# 3. 计算文本框区域内的平均概率
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
目的:过滤掉低质量的检测框(可能是噪声)
过滤条件:
if self.box_thresh > score: # box_thresh=0.5
continue # 跳过这个框
置信度 = 文本框区域内概率值的平均值
= Σ(概率图中文本框像素值) / 像素数量
Fast 方法(矩形掩码)
# 优点:计算快(只需找边界矩形)
# 缺点:对于倾斜文本,可能包含更多背景
Bounding Box (矩形):
┌─────────┐
│█████████│ ← 包含文本框的最小矩形
│█████████│
│█████████│
└─────────┘
Slow 方法(精确多边形掩码)
# 优点:精确(完全贴合文本轮廓)
# 缺点:计算稍慢(需要多边形填充)
Precise Polygon:
┌───┐
/ \ ← 精确的文本轮廓
/ \
└─────────┘
| 参数 | 默认值 | 作用 | 影响 |
|---|---|---|---|
| box_thresh | 0.5 | 置信度阈值 | 越高越严格,误检少但可能漏检 |
| thresh | 0.3 | 二值化阈值 | 越高文本区域越小,越保守 |
| score_mode | “fast” | 计算方法 | fast 快但不精确,slow 慢但精确 |
Step 4: 文本框扩展(Unclip)
这是 DB 模型的关键创新!
def unclip(self, box, unclip_ratio):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length # 扩展距离
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
为什么要扩展?
问题:神经网络倾向于收缩文本区域(保守预测)
检测到的框(太小): 扩展后的框(正好):
┌──────────────┐ ┌──────────────┐
│ Hello │ │ Hello │
│ ┌────┐ │ │ ┌──────┐ │
│ │ell │ │ --> │ │Hello │ │
│ └────┘ │ │ └──────┘ │
└──────────────┘ └──────────────┘
扩展算法:Vatti Clipping(多边形偏移)
<font style="color:#DF2A3F;">unclip_ratio=1.5</font>:扩展系数(默认值)distance = 面积 × ratio / 周长:计算扩展距离- 保持四边形形状的同时向外扩展
Step 5: 最小外接矩形
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour) # 最小外接矩形
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
# 对四个顶点排序(左上、右上、右下、左下)
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
作用:将不规则轮廓转为标准的旋转矩形
不规则轮廓: 最小外接矩形:
___ ┌────┐
/ \ │ │
/ \ --> │ │
/ \ │ │
───────── └────┘
Step 6: 坐标缩放回原图
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
原因:模型输入图像被缩放过(如 960x640),需要映射回原始尺寸
模型输入尺寸: 960x640 原图尺寸: 1920x1280
检测框: [100, 50] --> [200, 100]
x' = x * 1920/960 = x * 2
┌─────────────────────────────────────────────────────┐
│ 原始图像: 任意尺寸 │
│ 例如: 4000x3000, 800x600, 2400x1600... │
└─────────────────┬───────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────┐
│ Step 1: 判断是否需要缩放 │
│ if max(h, w) > 960: │
│ ratio = 960 / max(h, w) │
│ else: │
│ ratio = 1.0 (不缩放) │
└─────────────────┬───────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────┐
│ Step 2: 计算初步尺寸 │
│ new_h = h * ratio │
│ new_w = w * ratio │
└─────────────────┬───────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────┐
│ Step 3: 向上取整到 32 的倍数 ⭐ │
│ new_h = round(new_h / 32) * 32 │
│ new_w = round(new_w / 32) * 32 │
│ (最小为 32) │
└─────────────────┬───────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────┐
│ Step 4: 执行缩放 │
│ img = cv2.resize(img, (new_w, new_h)) │
└─────────────────┬───────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────┐
│ 输出: 符合要求的图像 │
│ - 最长边 ≤ 960 │
│ - 宽度和高度都是 32 的倍数 │
│ - 保持原始宽高比 │
└─────────────────────────────────────────────────────┘
预处理配置(ocr.py)
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': 960,
'limit_type': "max",
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
关键配置解读:
- limit_side_len’: 960:将图像最长边缩放到 960 像素
- limit_type’: “max”:保持宽高比,只限制最长边
需要缩放的5 个核心原因
| 原因 | 说明 | 解决方案 |
|---|---|---|
| 1. 计算效率 | 大图像计算量巨大 (4000x3000 = 1200万像素) | 限制最长边为 960 |
| 2. 内存限制 | GPU 内存有限(通常 8-16GB) | 统一尺寸到可控范围 |
| 3. 模型架构 | 下采样 5 次 (2^5=32),需要能整除 | 向上取整到 32 的倍数 |
| 4. 训练一致性 | 推理时的尺寸应与训练时接近 | 保持在 640-1024 范围 |
| 5. 特征尺度 | 文本大小需要在模型感受野范围内 | 保持合适的像素尺寸 |
缩放策略
# 缩放策略
✓ 最长边限制为 960 (可配置)
✓ 保持宽高比
✓ 对齐到 32 的倍数 (神经网络架构要求)
✓ 小图不放大 (避免插值噪声)
✓ 返回缩放比例 (用于坐标映射回原图)
# 计算公式
ratio = min(960 / max(h, w), 1.0)
new_h = round(h * ratio / 32) * 32
new_w = round(w * ratio / 32) * 32
为什么要从概率图恢复文本框?
| 目的 | 说明 |
|---|---|
| 1. 从像素到结构 | 概率图是像素级预测,需要转换为结构化的文本框(坐标) |
| 2. 过滤噪声 | 通过置信度阈值过滤误检测(小块、低概率区域) |
| 3. 精确定位 | Unclip 操作补偿模型的收缩倾向,确保文本完整 |
| 4. 标准化输出 | 输出标准的四边形坐标,方便后续裁剪和识别 |
| 5. 多尺度适配 | 将模型输出的小尺寸坐标映射回原始图像 |
完整流程可视化
【模型推理】
输入: RGB 图像 (1920x1280x3)
↓ 缩放到 960x640
【DB 模型】
↓
概率图 (960x640, 值 0-1)
↓
【后处理流程】
1️⃣ 二值化 (thresh=0.3)
[0.8, 0.9, 0.2] --> [1, 1, 0]
2️⃣ 轮廓检测 (cv2.findContours)
二值图 --> [(x1,y1), (x2,y2), ...]
3️⃣ 计算置信度 (box_score_fast)
平均概率: 0.87 > 0.5 ✓ 保留
4️⃣ 文本框扩展 (unclip)
面积=1000, 周长=120
扩展距离 = 1000*1.5/120 = 12.5 像素
5️⃣ 最小外接矩形 (minAreaRect)
不规则形状 --> 标准四边形
6️⃣ 坐标映射
[100, 50] @ 960x640 --> [200, 100] @ 1920x1280
输出: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
面试高频追问
Q1: 为什么不直接用分类器?
A: 文本检测是密集预测问题,需要对每个像素做判断。分类器只能判断"整张图是否有文本",无法定位。
Q2: Unclip 的 ratio 如何选择?
A:
- 太小(如 1.0):文字边缘被裁剪,影响识别
- 太大(如 2.0):包含背景噪声
- 经验值 1.5:平衡完整性和准确性
Q3: 为什么用概率图而不是直接输出坐标?
A:
- 概率图:像素级监督,训练信号丰富
- 直接回归坐标:监督信号稀疏,难以训练
- 类似于"语义分割→物体检测"的思路
文字识别
为什么按宽高比排序?
批处理的挑战:
- 需要统一尺寸 (H, W)
- 不同文本长度不一
- 必须 padding 到最大宽度
优化策略:
1. 按宽高比排序
2. 相似长度的文本放在同一批次
3. 每批次的"最大宽度"更接近"平均宽度"
4. 减少无效 padding
结果:
✓ 更高效的 GPU 利用率
✓ 更快的推理速度
✓ 更低的内存占用
| 原因 | 说明 | 收益 |
|---|---|---|
| 减少 Padding | 相似长度的图像放在一起 | padding 从 45% → 15% |
| 节省计算 | GPU 不浪费算力在填充区域 | 计算量减少 30% |
| 降低内存 | batch 的统一宽度更小 | 内存占用减少 40% |
| 提升速度 | 整体推理更快 | 速度提升 25-33% |
| 无损优化 | 不影响识别准确率 | 零副作用 |
问题 1:[3, 48, 320] 是什么?
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
含义:文本识别模型的输入尺寸
[3, 48, 320]
│ │ │
│ │ └─ 宽度 (Width): 320 像素
│ └───── 高度 (Height): 48 像素
└───────── 通道数 (Channel): 3 (RGB)
格式: [C, H, W] (Channel, Height, Width)
- 高度固定为 48 像素
原因:
✓ 文本识别是"序列到序列"任务
✓ 高度固定,便于特征提取的稳定性
✓ 48 像素足够表示大多数字符的细节
例子:
原始文本图像: 100 x 400 (高 x 宽)
缩放后: 48 x 192 (保持宽高比)
可视化:
原始裁剪的文本框:
┌────────────────────────────┐
│ Hello World 你好世界 │ ← 高度可能是 50-200 像素
└────────────────────────────┘
缩放到固定高度:
┌────────────────────────────┐
│ Hello World 你好世界 │ ← 固定高度 48 像素
└────────────────────────────┘
宽度动态变化 (根据文本长度)
宽度最大为 320 像素(动态的)
实际代码逻辑:
imgW = int((imgH * max_wh_ratio)) # 根据宽高比计算
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW # 超过最大宽度,裁剪
else:
resized_w = int(math.ceil(imgH * ratio)) # 保持宽高比
通道数为 3 (RGB)
RGB 三通道彩色图像
- R (Red): 红色通道
- G (Green): 绿色通道
- B (Blue): 蓝色通道
保留颜色信息有助于:
✓ 区分文字和背景
✓ 处理彩色文本
✓ 提高识别准确率
流程可视化:
输入: 裁剪的文本框图像 (60 x 240 x 3)
↓
Step 1: 缩放到固定高度 48
计算宽高比: 240/60 = 4.0
新宽度: 48 * 4.0 = 192
结果: 48 x 192 x 3
↓
Step 2: 转换为 CHW 格式
(48, 192, 3) → (3, 48, 192)
↓
Step 3: 归一化到 [-1, 1]
/255 → [0, 1]
-0.5 → [-0.5, 0.5]
/0.5 → [-1, 1]
↓
Step 4: 右侧填充 0 到固定宽度
(3, 48, 192) → (3, 48, 320)
填充区域: [:, :, 192:320] = 0
↓
输出: (3, 48, 320) ← 模型输入
问题 2:CTC 解码器是什么?
CTC 全称:Connectionist Temporal Classification(连接时序分类)
核心问题:序列到序列的对齐
问题场景:
输入: 文本图像
输出: 文本序列
挑战:
- 不知道每个字符在图像中的精确位置
- 字符宽度不一致('i' vs 'W')
- 需要模型自己学习对齐关系
CTC 的工作原理
模型输出:每个时间步的概率分布
模型结构:
图像 (3, 48, 320)
↓ CNN 特征提取
特征图 (512, 1, 80) ← 宽度 320 → 80 个时间步
↓ RNN/Transformer
序列特征 (80, 512)
↓ 全连接层
概率分布 (80, 6625) ← 80 个时间步,每步预测 6625 个字符的概率
示例:识别 “Hello”
时间步: 0 1 2 3 4 5 6 7 8 9 ... 79
预测: blank H H H e e l l l o o blank ...
概率: 0.8 0.9 0.7 0.6 0.85 0.8 0.9 0.85 0.8 0.9 0.7 0.8
CTC 解码:去重和去空白
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if not isinstance(preds, np.ndarray):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2) # 取每个时间步概率最大的字符
preds_prob = preds.max(axis=2) # 取最大概率值
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
** CTC 解码详细示例**
# 模型输出 (简化版)
模型预测序列 (80 个时间步):
[blank, H, H, H, e, e, blank, l, l, o, blank, blank, ...]
↓
Step 1: 去除连续重复
[blank, H, e, blank, l, o, blank, ...]
↓
Step 2: 去除 blank (空白符)
[H, e, l, l, o]
↓
Step 3: 拼接成文本
"Hello"
更复杂的例子:
原始输出:
[b, H, H, e, e, e, l, l, l, l, o, o, b, b, W, W, o, o, r, r, l, l, d, d, b]
b = blank
解码过程:
1. 去连续重复:
[b, H, e, l, o, b, W, o, r, l, d, b]
2. 去 blank:
[H, e, l, o, W, o, r, l, d]
3. 拼接:
"HelloWorld"
****字符字典(ocr.res)
postprocess_params = {
'name': 'CTCLabelDecode',
"character_dict_path": os.path.join(model_dir, "ocr.res"),
"use_space_char": True
}
字典内容示例:
blank ← 索引 0(特殊符号)
0 ← 索引 1
1 ← 索引 2
2 ← 索引 3
...
a ← 索引 37
b ← 索引 38
...
A ← 索引 63
B ← 索引 64
...
的 ← 索引 100
一 ← 索引 101
...
(空格) ← 索引 6624
解码过程:
模型输出索引: [0, 38, 101, 37, 0]
查字典: [blank, 'b', '一', 'a', blank]
去blank: ['b', '一', 'a']
拼接: "b一a"
字符字典和识别模型的关系
- 字符字典和识别模型是强绑定的
✓ 模型输出维度 = 字典大小
✓ 索引映射关系固定
✓ 顺序必须一致 - 不能随意替换
✗ 单独换字典 → 识别不了新字符
✗ 单独换模型 → 可能崩溃
✓ 必须同步替换 - 扩展字典需要重新训练
方案1: 从头训练(耗时长)
方案2: 微调(推荐)
方案3: 使用现成的大字典模型 - 版本管理很重要
建议: 模型和字典放在同一个文件夹
文件名包含版本号
例如: rec_v1.onnx + ocr_v1.res
问题 3:识别模型(**rec.onnx**)
模型架构:CRNN (CNN + RNN)
self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)
模型文件:rec.onnx
****CRNN 架构详解
输入图像 (3, 48, 320)
↓
┌─────────────────────────────────────────────┐
│ Part 1: CNN (卷积神经网络) │
│ 作用: 提取图像特征 │
├─────────────────────────────────────────────┤
│ Conv2D (3→64) + ReLU + MaxPool │
│ Conv2D (64→128) + ReLU + MaxPool │
│ Conv2D (128→256) + ReLU │
│ Conv2D (256→256) + ReLU + MaxPool │
│ Conv2D (256→512) + ReLU │
│ Conv2D (512→512) + ReLU + MaxPool │
│ Conv2D (512→512) + BN │
└─────────────┬───────────────────────────────┘
↓
特征图 (512, 1, 80)
↓ reshape
序列特征 (80, 512)
↓ 80 个时间步,每步 512 维特征
┌─────────────────────────────────────────────┐
│ Part 2: RNN (循环神经网络) │
│ 作用: 建模序列依赖关系 │
├─────────────────────────────────────────────┤
│ BiLSTM (512→256→256) │
│ BiLSTM (256→256→256) │
└─────────────┬───────────────────────────────┘
↓
序列特征 (80, 512) ← 双向拼接
↓
┌─────────────────────────────────────────────┐
│ Part 3: Transcription (转录层) │
│ 作用: 输出每个时间步的字符概率 │
├─────────────────────────────────────────────┤
│ 全连接层 (512→6625) │
│ Softmax │
└─────────────┬───────────────────────────────┘
↓
概率分布 (80, 6625)
↓ 每个时间步预测 6625 个字符的概率
┌─────────────────────────────────────────────┐
│ Part 4: CTC Decoder (解码器) │
│ 作用: 将概率序列转换为文本 │
├─────────────────────────────────────────────┤
│ Greedy Decoding / Beam Search │
│ 去重 + 去 blank │
└─────────────┬───────────────────────────────┘
↓
输出文本: "Hello World"
****各部分的作用
CNN 部分:视觉特征提取
作用: 将图像转换为特征序列
输入: (3, 48, 320)
↓ 多层卷积 + 池化
输出: (512, 1, 80)
↓ reshape
(80, 512) ← 80 个时间步,每步 512 维特征
可视化:
图像: [ H e l l o ]
↓ CNN
特征: [f1][f2][f3][f4][f5]...[f80]
每个 f 是 512 维向量
为什么宽度 320 → 80?
320 / 4 = 80 (经过 4 次步长为 2 的池化)
池化过程:
320 → 160 → 80 → 40 → ...
最终在宽度方向上得到 80 个特征列
RNN 部分:序列建模
作用: 利用上下文信息增强特征
为什么需要 RNN?
- 字符识别依赖上下文
- "1" vs "l" vs "I" 需要周围字符帮助判断
- RNN 可以"看到"左右的信息
BiLSTM (双向 LSTM):
前向: f1 → f2 → f3 → ... → f80
后向: f80 ← f79 ← f78 ← ... ← f1
拼接: [前向特征, 后向特征]
示例:
文本: "Hello"
前向 LSTM:
H → He → Hel → Hell → Hello
每步都知道前面的字符
后向 LSTM:
o ← lo ← llo ← ello ← Hello
每步都知道后面的字符
拼接后:
每个位置都有前后文信息
→ 提高识别准确率
转录层:输出概率
作用: 将特征映射到字符空间
全连接层: (512) → (6625)
6625 = 字符字典大小
= 数字 + 字母 + 汉字 + 符号 + blank
输出:
时间步 0: [P(blank)=0.8, P(H)=0.15, P(e)=0.02, ...]
时间步 1: [P(blank)=0.1, P(H)=0.85, P(e)=0.03, ...]
时间步 2: [P(blank)=0.2, P(H)=0.7, P(e)=0.05, ...]
...
****** 完整的识别流程示例**
# 输入: 裁剪的文本框图像
原图: "Hello World" 的图像
Step 1: 预处理
(60, 240, 3) → resize → (48, 192, 3)
→ normalize → [-1, 1]
→ to CHW → (3, 48, 192)
→ padding → (3, 48, 320)
Step 2: 模型推理 (rec.onnx)
(3, 48, 320) → CRNN → (80, 6625)
Step 3: CTC 解码
概率分布 (80, 6625)
↓ argmax
索引序列: [0, 38, 38, 101, 101, 0, 124, 124, ...]
↓ 查字典(ocr.res)
字符序列: [blank, 'H', 'H', 'e', 'e', blank, 'l', 'l', ...]
↓ 去重 + 去 blank
结果: "Hello World"
↓ 置信度
score: 0.95
三个问题的关系图
┌──────────────────────────────────────────────┐
│ 问题 1: [3, 48, 320] │
│ 输入尺寸规范 │
└───────────────┬──────────────────────────────┘
↓
(3, 48, 320)
↓
┌──────────────────────────────────────────────┐
│ 问题 3: 识别模型 (rec.onnx) │
│ CRNN: CNN + BiLSTM + FC │
├──────────────────────────────────────────────┤
│ CNN: 图像 → 特征序列 │
│ RNN: 序列建模 │
│ FC: 特征 → 字符概率 │
└───────────────┬──────────────────────────────┘
↓
(80, 6625) 概率分布
↓
┌──────────────────────────────────────────────┐
│ 问题 2: CTC 解码器 │
│ 概率序列 → 文本 │
├──────────────────────────────────────────────┤
│ 1. Argmax: 取最大概率 │
│ 2. 去重: 连续重复 → 单个字符 │
│ 3. 去 blank: 移除空白符 │
│ 4. 拼接: 字符序列 → 文本字符串 │
└───────────────┬──────────────────────────────┘
↓
"Hello World"
总结
| 问题 | 答案 | 关键点 |
|---|---|---|
| [3, 48, 320] | 模型输入尺寸 | • 高度固定 48 • 宽度动态,最大 320 • RGB 三通道 |
| CTC 解码器 | 序列解码算法 | • 处理对齐问题 • 去重 + 去 blank • 无需字符级标注 |
| 识别模型 | CRNN 架构 | • CNN 提取特征 • BiLSTM 建模序列 • CTC 损失训练 |
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐

所有评论(0)