paddleocr 自己封装使用:
下载命名:详细代码:多进程版本:服务器托管版本:pip install gunicornpip install flask onnxruntime opencv-python numpy pyclipper
·
下载命名:
# 进入模型目录
mkdir -p ~/paddleocr_models && cd ~/paddleocr_models
# 下载检测模型
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
tar -xvf ch_PP-OCRv3_det_infer.tar
# 下载识别模型
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar
tar -xvf ch_PP-OCRv3_rec_infer.tar
# 下载方向分类模型
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar -xvf ch_ppocr_mobile_v2.0_cls_infer.tar
# 下载字典
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.7/ppocr/utils/ppocr_keys_v1.txt
paddleocr_onnx/
├── det/
│ ├── model.onnx
│ └── inference.pdiparams (原始paddle文件,可留作参考)
├── rec/
│ ├── model.onnx
│ └── inference.pdiparams
├── cls/
│ ├── model.onnx
│ └── inference.pdiparams
└── ppocr_keys_v1.txt
pip install paddle2onnx
# det 模型转换
paddle2onnx \
--model_dir ./ch_PP-OCRv3_det_infer \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--save_file ./det/model.onnx \
--opset_version 11 \
--enable_onnx_checker True
# rec 模型转换
paddle2onnx \
--model_dir ./ch_PP-OCRv3_rec_infer \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--save_file ./rec/model.onnx \
--opset_version 11 \
--enable_onnx_checker True
# cls 模型转换
paddle2onnx \
--model_dir ./ch_ppocr_mobile_v2.0_cls_infer \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--save_file ./cls/model.onnx \
--opset_version 11 \
--enable_onnx_checker True
使用 ONNXRuntime 时,关键参数有:
det_model_path: 检测模型路径 (paddleocr_onnx/det/model.onnx)
rec_model_path: 识别模型路径 (paddleocr_onnx/rec/model.onnx)
cls_model_path: 方向分类模型路径 (paddleocr_onnx/cls/model.onnx)
rec_char_dict_path: 字典路径 (paddleocr_onnx/ppocr_keys_v1.txt)
use_gpu: 是否使用 GPU(True / False,需要 CUDAProvider 支持)
providers: 指定推理后端,例如:
详细代码:
"""
app.py — Flask + ONNXRuntime 部署 PP-OCRv3(det/cls/rec)完整推理服务(线程安全版)
功能:
1) 直接加载本地 ONNX 模型(不从网络下载)。
2) 提供 /ocr 接口:输入图片,输出文本、置信度与检测框(四点坐标)。
3) det -> crop & rectify -> cls -> rec 全流程;包含 DB 检测后处理(阈值、膨胀、排序)。
4) 使用全局 ONNX Runtime Session(线程安全);并提供额外锁作保守保护。
5) 详细中文注释,便于二次开发。
依赖:
pip install flask onnxruntime opencv-python numpy pyclipper
目录建议:
project_root/
app.py # 本文件
paddleocr_onnx/
det/model.onnx
rec/model.onnx
cls/model.onnx
ppocr_keys_v1.txt
启动:
python app.py --host 0.0.0.0 --port 5000
生产部署(示例):
# 使用 gunicorn 多进程(避免 GIL 影响),每个 worker 内部 Session 复用
gunicorn -w 2 -b 0.0.0.0:5000 'app:create_app()' --timeout 120 --threads 1
说明:
- PP-OCRv3 rec 默认输入尺寸为 (3, 48, 320)。参考官方说明。
- DB 检测常用后处理参数:box_thresh=0.6, unclip_ratio=1.5(可按需调整)。
"""
import argparse
import io
import os
import math
import threading
from typing import List, Tuple, Dict, Any
import cv2
import numpy as np
import onnxruntime as ort
import pyclipper
from flask import Flask, request, jsonify
# ------------------------------
# 全局配置(可按需修改)
# ------------------------------
MODEL_DIR = os.getenv("PPOCR_ONNX_DIR", os.path.join(os.path.dirname(__file__), "paddleocr_onnx"))
DET_MODEL_PATH = os.path.join(MODEL_DIR, "det", "model.onnx")
REC_MODEL_PATH = os.path.join(MODEL_DIR, "rec", "model.onnx")
CLS_MODEL_PATH = os.path.join(MODEL_DIR, "cls", "model.onnx")
CHAR_DICT_PATH = os.path.join(MODEL_DIR, "ppocr_keys_v1.txt")
# 检测前处理尺寸控制:将长边/短边限制到指定范围,并调整到 32 的倍数(DB 网络下采样的需要)
DET_LIMIT_SIDE_LEN = 960 # 最长边或短边限制值(保持比例缩放)
DET_LIMIT_TYPE = "max" # "max" 或 "min",常用 "max"
# DB 后处理参数(常用默认,可通过环境变量覆盖)
DB_THRESH = float(os.getenv("DB_THRESH", 0.3)) # 二值化阈值(概率图)
DB_BOX_THRESH = float(os.getenv("DB_BOX_THRESH", 0.6)) # 盒子阈值(平均得分)
DB_UNCLIP_RATIO = float(os.getenv("DB_UNCLIP_RATIO", 1.5)) # 盒子膨胀比例
DB_MAX_CANDIDATES = int(os.getenv("DB_MAX_CANDIDATES", 1000))
# 识别输入尺寸(PP-OCRv3 默认 3x48x320)
REC_IMAGE_SHAPE = (3, 48, 320)
# 方向分类输入宽度(PP-OCR 系列常用 3x48x192)
CLS_IMAGE_SHAPE = (3, 48, 192)
CLS_THRESH = 0.9 # 当预测为180度的概率>该阈值时,做180°旋转
# ONNXRuntime Provider(CPU 默认;若安装了 GPU,可改为 CUDAExecutionProvider)
ONNX_PROVIDERS = os.getenv("ONNX_PROVIDERS", "CPUExecutionProvider").split(",")
# 为求稳妥,提供可选的会话级锁(ONNX Runtime 的 Run 是线程安全的;此锁是“保险”)
USE_SESSION_LOCK = os.getenv("USE_SESSION_LOCK", "1") == "1"
# ------------------------------
# 工具函数
# ------------------------------
def create_ort_session(model_path: str) -> Tuple[ort.InferenceSession, threading.Lock]:
"""创建 ONNXRuntime Session,并返回(session, lock)。
Session.run 是线程安全的;此处返回的锁对某些第三方库的潜在线程问题提供“保险”。
"""
so = ort.SessionOptions()
# 可按需调优线程:
# so.intra_op_num_threads = 1
# so.inter_op_num_threads = 1
# so.execution_mode = ort.ExecutionMode.ORT_PARALLEL
sess = ort.InferenceSession(model_path, sess_options=so, providers=ONNX_PROVIDERS)
lock = threading.Lock() if USE_SESSION_LOCK else None
return sess, lock
def read_charset(dict_path: str) -> List[str]:
"""读取字典文件(每行一个字符)。
PaddleOCR 的 CTC 解码默认 blank index 为 len(charset)。
"""
with open(dict_path, "r", encoding="utf-8") as f:
charset = [line.strip("\n\r") for line in f]
return charset
def resize_det(img: np.ndarray, limit_side_len: int, limit_type: str = "max") -> Tuple[np.ndarray, float, float]:
"""根据最长边/最短边限制,保持比例缩放到不超过限制,并将 H/W 调整为 32 的倍数。
返回:resized_img, ratio_h, ratio_w(原图到缩放图的比例)。
"""
h, w = img.shape[:2]
if limit_type == "max":
if max(h, w) > limit_side_len:
ratio = float(limit_side_len) / max(h, w)
else:
ratio = 1.0
else: # "min"
if min(h, w) < limit_side_len:
ratio = float(limit_side_len) / min(h, w)
else:
ratio = 1.0
new_h = int(h * ratio)
new_w = int(w * ratio)
# 调整到 32 的倍数(DB 网络下采样步长为 32)
new_h = max(32, int(round(new_h / 32) * 32))
new_w = max(32, int(round(new_w / 32) * 32))
resized = cv2.resize(img, (new_w, new_h))
ratio_h = new_h / float(h)
ratio_w = new_w / float(w)
return resized, ratio_h, ratio_w
def normalize_img(img: np.ndarray, mean: List[float], std: List[float], is_scale: bool = True) -> np.ndarray:
"""通用归一化:先缩放至[0,1],再 (x - mean) / std。
注意:PaddleOCR 的 det/rec 采用不同均值方差;
- det 常用 mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]
- rec 常用 mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]
"""
img = img.astype("float32")
if is_scale:
img = img / 255.0
img = (img - mean) / std
return img
def chw(img: np.ndarray) -> np.ndarray:
"""HWC -> CHW"""
return np.transpose(img, (2, 0, 1))
def order_points_clockwise(pts: np.ndarray) -> np.ndarray:
"""将四边形点按顺时针排序:tl, tr, br, bl。输入形状 (4,2)。"""
rect = np.zeros((4, 2), dtype=np.float32)
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)] # top-left
rect[2] = pts[np.argmax(s)] # bottom-right
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)] # top-right
rect[3] = pts[np.argmax(diff)] # bottom-left
return rect
def box_score_fast(prob_map: np.ndarray, box: np.ndarray) -> float:
"""计算 box 内概率均值作为得分。prob_map 范围 [0,1],shape (H,W)。"""
h, w = prob_map.shape
mask = np.zeros((h, w), dtype=np.uint8)
box_int = box.astype(np.int32)
cv2.fillPoly(mask, [box_int], 1)
if mask.sum() == 0:
return 0.0
return float(prob_map[mask == 1].mean())
def unclip(box: np.ndarray, unclip_ratio: float) -> np.ndarray:
"""使用 pyclipper 对多边形做外扩(膨胀)。"""
poly = box.reshape(-1, 2)
distance = (cv2.contourArea(poly) * unclip_ratio) / (cv2.arcLength(poly, True) + 1e-6)
offset = pyclipper.PyclipperOffset()
offset.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = offset.Execute(distance)
if len(expanded) == 0:
return box
expanded = np.array(expanded[0])
rect = cv2.minAreaRect(expanded.astype(np.float32))
return cv2.boxPoints(rect)
def clip_box(box: np.ndarray, img_h: int, img_w: int) -> np.ndarray:
box[:, 0] = np.clip(box[:, 0], 0, img_w - 1)
box[:, 1] = np.clip(box[:, 1], 0, img_h - 1)
return box
def sort_boxes(boxes: List[np.ndarray]) -> List[np.ndarray]:
"""按从上到下、从左到右排序检测框。"""
boxes = sorted(boxes, key=lambda b: (np.mean(b[:, 1]), np.mean(b[:, 0])))
return boxes
def get_rotate_crop_image(img: np.ndarray, box: np.ndarray) -> np.ndarray:
"""根据四点透视变换裁剪文字区域为直立矩形。返回裁剪图。"""
box = order_points_clockwise(box.astype(np.float32))
# 宽高根据两条边长度估计
w = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3])))
h = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2])))
dst = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]], dtype=np.float32)
M = cv2.getPerspectiveTransform(box, dst)
warped = cv2.warpPerspective(img, M, (w, h), flags=cv2.INTER_CUBIC)
# 如果宽高不合理,做一次翻转(保证宽 >= 高,有利于识别)
if h > 1.5 * w:
warped = np.rot90(warped)
return warped
# ------------------------------
# 模型封装类:Det / Cls / Rec
# ------------------------------
class DBDetector:
def __init__(self, model_path: str):
self.session, self.lock = create_ort_session(model_path)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
def preprocess(self, img: np.ndarray) -> Tuple[np.ndarray, float, float]:
resized, ratio_h, ratio_w = resize_det(img, DET_LIMIT_SIDE_LEN, DET_LIMIT_TYPE)
# DB 常用 mean/std
img_norm = normalize_img(resized[:, :, ::-1], # BGR->RGB
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True)
img_chw = chw(img_norm)
return img_chw[np.newaxis, :].astype("float32"), ratio_h, ratio_w
def postprocess(self, pred: np.ndarray, ratio_h: float, ratio_w: float, ori_h: int, ori_w: int) -> List[np.ndarray]:
"""DB 后处理:取概率图 -> 二值化 -> 找轮廓 -> 得分过滤 -> 膨胀 -> 映射回原图。"""
if pred.ndim == 4:
# 可能是 NHWC 或 NCHW
if pred.shape[1] == 1: # N,1,H,W
prob_map = pred[0, 0]
elif pred.shape[-1] == 1: # N,H,W,1
prob_map = pred[0, :, :, 0]
else:
# 某些导出会直接输出概率图 HxW
prob_map = pred[0]
else:
prob_map = pred
# 某些导出模型可能未带 sigmoid,这里做一次保守处理
if prob_map.max() > 1.0 or prob_map.min() < 0.0:
prob_map = 1.0 / (1.0 + np.exp(-prob_map))
# 二值化
_, binary = cv2.threshold((prob_map * 255).astype(np.uint8), int(DB_THRESH * 255), 255, cv2.THRESH_BINARY)
# 找轮廓
contours, _ = cv2.findContours(binary, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
boxes = []
H, W = prob_map.shape
for cnt in contours[:DB_MAX_CANDIDATES]:
if cv2.contourArea(cnt) < 10:
continue
rect = cv2.minAreaRect(cnt)
box = cv2.boxPoints(rect)
box = order_points_clockwise(box)
score = box_score_fast(prob_map, box)
if score < DB_BOX_THRESH:
continue
# 膨胀扩大
box = unclip(box, DB_UNCLIP_RATIO)
box = order_points_clockwise(box)
# 映射回原图坐标
box[:, 0] = box[:, 0] / ratio_w
box[:, 1] = box[:, 1] / ratio_h
box = clip_box(box, ori_h, ori_w)
boxes.append(box)
return sort_boxes(boxes)
def __call__(self, img: np.ndarray) -> List[np.ndarray]:
inp, r_h, r_w = self.preprocess(img)
if self.lock:
with self.lock:
pred = self.session.run([self.output_name], {self.input_name: inp})[0]
else:
pred = self.session.run([self.output_name], {self.input_name: inp})[0]
boxes = self.postprocess(pred, r_h, r_w, img.shape[0], img.shape[1])
return boxes
class ClsClassifier:
def __init__(self, model_path: str):
self.session, self.lock = create_ort_session(model_path)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
def preprocess(self, img: np.ndarray) -> np.ndarray:
# 统一缩放到 CLS_IMAGE_SHAPE 的 HxW,保持比例,短边填充
_, H, W = CLS_IMAGE_SHAPE
h, w = img.shape[:2]
ratio = H / float(h)
new_w = int(w * ratio)
if new_w <= 0:
new_w = 1
new_w = min(W, new_w)
resized = cv2.resize(img, (new_w, H))
pad = np.zeros((H, W, 3), dtype=np.uint8)
pad[:, :new_w, :] = resized
# 归一化(与 rec 一致,0.5/0.5)
pad = normalize_img(pad[:, :, ::-1], mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], is_scale=True)
x = chw(pad)[np.newaxis, :].astype("float32")
return x
def __call__(self, img: np.ndarray) -> Tuple[int, float]:
x = self.preprocess(img)
if self.lock:
with self.lock:
out = self.session.run([self.output_name], {self.input_name: x})[0]
else:
out = self.session.run([self.output_name], {self.input_name: x})[0]
# 假设输出为 [N, 2],分别对应 0°/180° 概率
prob = softmax(out[0])
label = int(np.argmax(prob))
conf = float(prob[label])
return label, conf
class TextRecognizer:
def __init__(self, model_path: str, charset_path: str):
self.session, self.lock = create_ort_session(model_path)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
self.charset = read_charset(charset_path)
self.blank_idx = len(self.charset)
def preprocess(self, img: np.ndarray) -> np.ndarray:
# 等高缩放到 H=REC_IMAGE_SHAPE[1],宽度不超过 REC_IMAGE_SHAPE[2],右侧零填充
_, H, W = REC_IMAGE_SHAPE
h, w = img.shape[:2]
ratio = H / float(h)
new_w = int(round(w * ratio))
new_w = max(1, min(W, new_w))
resized = cv2.resize(img, (new_w, H))
pad = np.zeros((H, W, 3), dtype=np.uint8)
pad[:, :new_w, :] = resized
# 归一化(常用 rec:0.5/0.5)
pad = normalize_img(pad[:, :, ::-1], mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], is_scale=True)
x = chw(pad)[np.newaxis, :].astype("float32")
return x
def __call__(self, img: np.ndarray) -> Tuple[str, float]:
x = self.preprocess(img)
if self.lock:
with self.lock:
out = self.session.run([self.output_name], {self.input_name: x})[0]
else:
out = self.session.run([self.output_name], {self.input_name: x})[0]
# 兼容不同导出:期望 [N, seq_len, num_classes] 或 [N, num_classes, seq_len]
logits = out
if logits.ndim == 3:
if logits.shape[1] < logits.shape[2]:
# [N, seq, C]
seq_len = logits.shape[1]
num_classes = logits.shape[2]
logits = logits[0]
else:
# [N, C, seq]
num_classes = logits.shape[1]
seq_len = logits.shape[2]
logits = np.transpose(logits[0], (1, 0)) # -> [seq, C]
else:
raise ValueError("Unexpected rec output shape: {}".format(logits.shape))
probs = softmax(logits, axis=1)
pred_idx = probs.argmax(axis=1)
pred_prob = probs.max(axis=1)
# CTC 去重、去 blank
text, confs = ctc_greedy_decode(pred_idx, pred_prob, blank=self.blank_idx)
return text, float(np.mean(confs) if confs else 0.0)
# ------------------------------
# 解码/数学函数
# ------------------------------
def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
x = x - np.max(x, axis=axis, keepdims=True)
e = np.exp(x)
return e / np.sum(e, axis=axis, keepdims=True)
def ctc_greedy_decode(indices: np.ndarray, probs: np.ndarray, blank: int) -> Tuple[str, List[float]]:
"""CTC 贪心解码:移除重复与 blank。
indices: [seq_len] 的类别索引
probs: [seq_len] 对应的最大概率
"""
res_idx = []
res_prob = []
prev = -1
for i, cls in enumerate(indices):
if cls == blank:
prev = -1
continue
if cls == prev:
continue
res_idx.append(int(cls))
res_prob.append(float(probs[i]))
prev = cls
# 将 idx 映射到字符
# 注意:PaddleOCR 的字典不包含 blank;blank = len(charset)
chars = []
for i in res_idx:
if 0 <= i < len(ocr_pipeline.recognizer.charset): # 保险判断
chars.append(ocr_pipeline.recognizer.charset[i])
return "".join(chars), res_prob
# ------------------------------
# OCR 流水线封装
# ------------------------------
class OCRPipeline:
def __init__(self, det_path: str, rec_path: str, cls_path: str, dict_path: str):
self.detector = DBDetector(det_path)
self.classifier = ClsClassifier(cls_path)
self.recognizer = TextRecognizer(rec_path, dict_path)
def __call__(self, img: np.ndarray) -> Dict[str, Any]:
boxes = self.detector(img)
results = []
for box in boxes:
crop = get_rotate_crop_image(img, box)
# 方向分类:预测是否需要 180° 翻转
label, conf = self.classifier(crop)
if label == 1 and conf > CLS_THRESH:
crop = cv2.rotate(crop, cv2.ROTATE_180)
text, score = self.recognizer(crop)
results.append({
"text": text,
"score": float(score),
"box": box.astype(float).tolist()
})
return {"results": results}
# ------------------------------
# Flask 应用
# ------------------------------
def create_app() -> Flask:
app = Flask(__name__)
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "ok"})
@app.route("/ocr", methods=["POST"])
def ocr_endpoint():
"""表单上传字段名:image
返回:{
"results": [
{"text": str, "score": float, "box": [[x,y],...]} , ...
]
}
"""
if "image" not in request.files:
return jsonify({"error": "no image"}), 400
file = request.files["image"]
data = file.read()
img_bytes = np.frombuffer(data, np.uint8)
img = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR)
if img is None:
return jsonify({"error": "invalid image"}), 400
output = ocr_pipeline(img)
return jsonify(output)
return app
# ------------------------------
# 主入口
# ------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=5000)
args = parser.parse_args()
# 初始化全局流水线(只初始化一次,供进程内所有请求复用)
ocr_pipeline = OCRPipeline(DET_MODEL_PATH, REC_MODEL_PATH, CLS_MODEL_PATH, CHAR_DICT_PATH)
app = create_app()
# 为了线程安全,可以关闭 Flask 的多线程(生产建议用 gunicorn 多进程 worker)
app.run(host=args.host, port=args.port, threaded=False)
else:
# gunicorn 模式下会走这里:
ocr_pipeline = OCRPipeline(DET_MODEL_PATH, REC_MODEL_PATH, CLS_MODEL_PATH, CHAR_DICT_PATH)
app = create_app()
多进程版本:
import os
import sys
import argparse
import multiprocessing
from flask import Flask, request, jsonify
from pipeline import OCRPipeline
# 模型路径
DET_MODEL_PATH = './models/ch_PP-OCRv4_det_infer'
REC_MODEL_PATH = './models/ch_PP-OCRv4_rec_infer'
CLS_MODEL_PATH = './models/ch_ppocr_mobile_v2.0_cls_infer'
CHAR_DICT_PATH = './models/ppocr_keys_v1.txt'
ocr_pipeline = None # 每个进程各自初始化
def create_app():
app = Flask(__name__)
@app.route("/ocr", methods=["POST"])
def ocr():
if "file" not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files["file"]
img_bytes = file.read()
results = ocr_pipeline(img_bytes)
return jsonify({"results": results})
return app
def run_server(host, port):
global ocr_pipeline
ocr_pipeline = OCRPipeline(DET_MODEL_PATH, REC_MODEL_PATH, CLS_MODEL_PATH, CHAR_DICT_PATH)
app = create_app()
app.run(host=host, port=port, threaded=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--workers", type=int, default=0, help="进程数 (0 表示自动选择)")
args = parser.parse_args()
# 自动检测CPU核心数
cpu_count = multiprocessing.cpu_count()
if args.workers <= 0:
workers = min(4, cpu_count)
else:
workers = args.workers
if workers > 1:
processes = []
for _ in range(workers):
p = multiprocessing.Process(target=run_server, args=(args.host, args.port))
p.start()
processes.append(p)
for p in processes:
p.join()
else:
run_server(args.host, args.port)
服务器托管版本:
gunicorn -w 2 -b 0.0.0.0:5000 'app:create_app()' --timeout 120 --threads 1
"""
app.py — Flask + ONNXRuntime 部署 PP-OCRv3(det/cls/rec)完整推理服务(线程安全版)
功能:
1) 直接加载本地 ONNX 模型(不从网络下载)。
2) 提供 /ocr 接口:输入图片,输出文本、置信度与检测框(四点坐标)。
3) det -> crop & rectify -> cls -> rec 全流程;包含 DB 检测后处理(阈值、膨胀、排序)。
4) 使用全局 ONNX Runtime Session(线程安全);并提供额外锁作保守保护。
5) 详细中文注释,便于二次开发。
依赖:
pip install flask onnxruntime opencv-python numpy pyclipper
目录建议:
project_root/
app.py # 本文件
paddleocr_onnx/
det/model.onnx
rec/model.onnx
cls/model.onnx
ppocr_keys_v1.txt
启动:
python app.py --host 0.0.0.0 --port 5000
生产部署(示例):
# 使用 gunicorn 多进程(避免 GIL 影响),每个 worker 内部 Session 复用
gunicorn -w 2 -b 0.0.0.0:5000 'app:create_app()' --timeout 120 --threads 1
说明:
- PP-OCRv3 rec 默认输入尺寸为 (3, 48, 320)。参考官方说明。
- DB 检测常用后处理参数:box_thresh=0.6, unclip_ratio=1.5(可按需调整)。
"""
import argparse
import io
import os
import math
import threading
from typing import List, Tuple, Dict, Any
import cv2
import numpy as np
import onnxruntime as ort
import pyclipper
from flask import Flask, request, jsonify
# ------------------------------
# 全局配置(可按需修改)
# ------------------------------
MODEL_DIR = os.getenv("PPOCR_ONNX_DIR", os.path.join(os.path.dirname(__file__), "paddleocr_onnx"))
DET_MODEL_PATH = os.path.join(MODEL_DIR, "det", "model.onnx")
REC_MODEL_PATH = os.path.join(MODEL_DIR, "rec", "model.onnx")
CLS_MODEL_PATH = os.path.join(MODEL_DIR, "cls", "model.onnx")
CHAR_DICT_PATH = os.path.join(MODEL_DIR, "ppocr_keys_v1.txt")
# 检测前处理尺寸控制:将长边/短边限制到指定范围,并调整到 32 的倍数(DB 网络下采样的需要)
DET_LIMIT_SIDE_LEN = 960 # 最长边或短边限制值(保持比例缩放)
DET_LIMIT_TYPE = "max" # "max" 或 "min",常用 "max"
# DB 后处理参数(常用默认,可通过环境变量覆盖)
DB_THRESH = float(os.getenv("DB_THRESH", 0.3)) # 二值化阈值(概率图)
DB_BOX_THRESH = float(os.getenv("DB_BOX_THRESH", 0.6)) # 盒子阈值(平均得分)
DB_UNCLIP_RATIO = float(os.getenv("DB_UNCLIP_RATIO", 1.5)) # 盒子膨胀比例
DB_MAX_CANDIDATES = int(os.getenv("DB_MAX_CANDIDATES", 1000))
# 识别输入尺寸(PP-OCRv3 默认 3x48x320)
REC_IMAGE_SHAPE = (3, 48, 320)
# 方向分类输入宽度(PP-OCR 系列常用 3x48x192)
CLS_IMAGE_SHAPE = (3, 48, 192)
CLS_THRESH = 0.9 # 当预测为180度的概率>该阈值时,做180°旋转
# ONNXRuntime Provider(CPU 默认;若安装了 GPU,可改为 CUDAExecutionProvider)
ONNX_PROVIDERS = os.getenv("ONNX_PROVIDERS", "CPUExecutionProvider").split(",")
# 为求稳妥,提供可选的会话级锁(ONNX Runtime 的 Run 是线程安全的;此锁是“保险”)
USE_SESSION_LOCK = os.getenv("USE_SESSION_LOCK", "1") == "1"
# ------------------------------
# 工具函数
# ------------------------------
def create_ort_session(model_path: str) -> Tuple[ort.InferenceSession, threading.Lock]:
"""创建 ONNXRuntime Session,并返回(session, lock)。
Session.run 是线程安全的;此处返回的锁对某些第三方库的潜在线程问题提供“保险”。
"""
so = ort.SessionOptions()
# 可按需调优线程:
# so.intra_op_num_threads = 1
# so.inter_op_num_threads = 1
# so.execution_mode = ort.ExecutionMode.ORT_PARALLEL
sess = ort.InferenceSession(model_path, sess_options=so, providers=ONNX_PROVIDERS)
lock = threading.Lock() if USE_SESSION_LOCK else None
return sess, lock
def read_charset(dict_path: str) -> List[str]:
"""读取字典文件(每行一个字符)。
PaddleOCR 的 CTC 解码默认 blank index 为 len(charset)。
"""
with open(dict_path, "r", encoding="utf-8") as f:
charset = [line.strip("\n\r") for line in f]
return charset
def resize_det(img: np.ndarray, limit_side_len: int, limit_type: str = "max") -> Tuple[np.ndarray, float, float]:
"""根据最长边/最短边限制,保持比例缩放到不超过限制,并将 H/W 调整为 32 的倍数。
返回:resized_img, ratio_h, ratio_w(原图到缩放图的比例)。
"""
h, w = img.shape[:2]
if limit_type == "max":
if max(h, w) > limit_side_len:
ratio = float(limit_side_len) / max(h, w)
else:
ratio = 1.0
else: # "min"
if min(h, w) < limit_side_len:
ratio = float(limit_side_len) / min(h, w)
else:
ratio = 1.0
new_h = int(h * ratio)
new_w = int(w * ratio)
# 调整到 32 的倍数(DB 网络下采样步长为 32)
new_h = max(32, int(round(new_h / 32) * 32))
new_w = max(32, int(round(new_w / 32) * 32))
resized = cv2.resize(img, (new_w, new_h))
ratio_h = new_h / float(h)
ratio_w = new_w / float(w)
return resized, ratio_h, ratio_w
def normalize_img(img: np.ndarray, mean: List[float], std: List[float], is_scale: bool = True) -> np.ndarray:
"""通用归一化:先缩放至[0,1],再 (x - mean) / std。
注意:PaddleOCR 的 det/rec 采用不同均值方差;
- det 常用 mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]
- rec 常用 mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]
"""
img = img.astype("float32")
if is_scale:
img = img / 255.0
img = (img - mean) / std
return img
def chw(img: np.ndarray) -> np.ndarray:
"""HWC -> CHW"""
return np.transpose(img, (2, 0, 1))
def order_points_clockwise(pts: np.ndarray) -> np.ndarray:
"""将四边形点按顺时针排序:tl, tr, br, bl。输入形状 (4,2)。"""
rect = np.zeros((4, 2), dtype=np.float32)
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)] # top-left
rect[2] = pts[np.argmax(s)] # bottom-right
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)] # top-right
rect[3] = pts[np.argmax(diff)] # bottom-left
return rect
def box_score_fast(prob_map: np.ndarray, box: np.ndarray) -> float:
"""计算 box 内概率均值作为得分。prob_map 范围 [0,1],shape (H,W)。"""
h, w = prob_map.shape
mask = np.zeros((h, w), dtype=np.uint8)
box_int = box.astype(np.int32)
cv2.fillPoly(mask, [box_int], 1)
if mask.sum() == 0:
return 0.0
return float(prob_map[mask == 1].mean())
def unclip(box: np.ndarray, unclip_ratio: float) -> np.ndarray:
"""使用 pyclipper 对多边形做外扩(膨胀)。"""
poly = box.reshape(-1, 2)
distance = (cv2.contourArea(poly) * unclip_ratio) / (cv2.arcLength(poly, True) + 1e-6)
offset = pyclipper.PyclipperOffset()
offset.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = offset.Execute(distance)
if len(expanded) == 0:
return box
expanded = np.array(expanded[0])
rect = cv2.minAreaRect(expanded.astype(np.float32))
return cv2.boxPoints(rect)
def clip_box(box: np.ndarray, img_h: int, img_w: int) -> np.ndarray:
box[:, 0] = np.clip(box[:, 0], 0, img_w - 1)
box[:, 1] = np.clip(box[:, 1], 0, img_h - 1)
return box
def sort_boxes(boxes: List[np.ndarray]) -> List[np.ndarray]:
"""按从上到下、从左到右排序检测框。"""
boxes = sorted(boxes, key=lambda b: (np.mean(b[:, 1]), np.mean(b[:, 0])))
return boxes
def get_rotate_crop_image(img: np.ndarray, box: np.ndarray) -> np.ndarray:
"""根据四点透视变换裁剪文字区域为直立矩形。返回裁剪图。"""
box = order_points_clockwise(box.astype(np.float32))
# 宽高根据两条边长度估计
w = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3])))
h = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2])))
dst = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]], dtype=np.float32)
M = cv2.getPerspectiveTransform(box, dst)
warped = cv2.warpPerspective(img, M, (w, h), flags=cv2.INTER_CUBIC)
# 如果宽高不合理,做一次翻转(保证宽 >= 高,有利于识别)
if h > 1.5 * w:
warped = np.rot90(warped)
return warped
# ------------------------------
# 模型封装类:Det / Cls / Rec
# ------------------------------
class DBDetector:
def __init__(self, model_path: str):
self.session, self.lock = create_ort_session(model_path)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
def preprocess(self, img: np.ndarray) -> Tuple[np.ndarray, float, float]:
resized, ratio_h, ratio_w = resize_det(img, DET_LIMIT_SIDE_LEN, DET_LIMIT_TYPE)
# DB 常用 mean/std
img_norm = normalize_img(resized[:, :, ::-1], # BGR->RGB
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True)
img_chw = chw(img_norm)
return img_chw[np.newaxis, :].astype("float32"), ratio_h, ratio_w
def postprocess(self, pred: np.ndarray, ratio_h: float, ratio_w: float, ori_h: int, ori_w: int) -> List[np.ndarray]:
"""DB 后处理:取概率图 -> 二值化 -> 找轮廓 -> 得分过滤 -> 膨胀 -> 映射回原图。"""
if pred.ndim == 4:
# 可能是 NHWC 或 NCHW
if pred.shape[1] == 1: # N,1,H,W
prob_map = pred[0, 0]
elif pred.shape[-1] == 1: # N,H,W,1
prob_map = pred[0, :, :, 0]
else:
# 某些导出会直接输出概率图 HxW
prob_map = pred[0]
else:
prob_map = pred
# 某些导出模型可能未带 sigmoid,这里做一次保守处理
if prob_map.max() > 1.0 or prob_map.min() < 0.0:
prob_map = 1.0 / (1.0 + np.exp(-prob_map))
# 二值化
_, binary = cv2.threshold((prob_map * 255).astype(np.uint8), int(DB_THRESH * 255), 255, cv2.THRESH_BINARY)
# 找轮廓
contours, _ = cv2.findContours(binary, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
boxes = []
H, W = prob_map.shape
for cnt in contours[:DB_MAX_CANDIDATES]:
if cv2.contourArea(cnt) < 10:
continue
rect = cv2.minAreaRect(cnt)
box = cv2.boxPoints(rect)
box = order_points_clockwise(box)
score = box_score_fast(prob_map, box)
if score < DB_BOX_THRESH:
continue
# 膨胀扩大
box = unclip(box, DB_UNCLIP_RATIO)
box = order_points_clockwise(box)
# 映射回原图坐标
box[:, 0] = box[:, 0] / ratio_w
box[:, 1] = box[:, 1] / ratio_h
box = clip_box(box, ori_h, ori_w)
boxes.append(box)
return sort_boxes(boxes)
def __call__(self, img: np.ndarray) -> List[np.ndarray]:
inp, r_h, r_w = self.preprocess(img)
if self.lock:
with self.lock:
pred = self.session.run([self.output_name], {self.input_name: inp})[0]
else:
pred = self.session.run([self.output_name], {self.input_name: inp})[0]
boxes = self.postprocess(pred, r_h, r_w, img.shape[0], img.shape[1])
return boxes
class ClsClassifier:
def __init__(self, model_path: str):
self.session, self.lock = create_ort_session(model_path)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
def preprocess(self, img: np.ndarray) -> np.ndarray:
# 统一缩放到 CLS_IMAGE_SHAPE 的 HxW,保持比例,短边填充
_, H, W = CLS_IMAGE_SHAPE
h, w = img.shape[:2]
ratio = H / float(h)
new_w = int(w * ratio)
if new_w <= 0:
new_w = 1
new_w = min(W, new_w)
resized = cv2.resize(img, (new_w, H))
pad = np.zeros((H, W, 3), dtype=np.uint8)
pad[:, :new_w, :] = resized
# 归一化(与 rec 一致,0.5/0.5)
pad = normalize_img(pad[:, :, ::-1], mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], is_scale=True)
x = chw(pad)[np.newaxis, :].astype("float32")
return x
def __call__(self, img: np.ndarray) -> Tuple[int, float]:
x = self.preprocess(img)
if self.lock:
with self.lock:
out = self.session.run([self.output_name], {self.input_name: x})[0]
else:
out = self.session.run([self.output_name], {self.input_name: x})[0]
# 假设输出为 [N, 2],分别对应 0°/180° 概率
prob = softmax(out[0])
label = int(np.argmax(prob))
conf = float(prob[label])
return label, conf
class TextRecognizer:
def __init__(self, model_path: str, charset_path: str):
self.session, self.lock = create_ort_session(model_path)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
self.charset = read_charset(charset_path)
self.blank_idx = len(self.charset)
def preprocess(self, img: np.ndarray) -> np.ndarray:
# 等高缩放到 H=REC_IMAGE_SHAPE[1],宽度不超过 REC_IMAGE_SHAPE[2],右侧零填充
_, H, W = REC_IMAGE_SHAPE
h, w = img.shape[:2]
ratio = H / float(h)
new_w = int(round(w * ratio))
new_w = max(1, min(W, new_w))
resized = cv2.resize(img, (new_w, H))
pad = np.zeros((H, W, 3), dtype=np.uint8)
pad[:, :new_w, :] = resized
# 归一化(常用 rec:0.5/0.5)
pad = normalize_img(pad[:, :, ::-1], mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], is_scale=True)
x = chw(pad)[np.newaxis, :].astype("float32")
return x
def __call__(self, img: np.ndarray) -> Tuple[str, float]:
x = self.preprocess(img)
if self.lock:
with self.lock:
out = self.session.run([self.output_name], {self.input_name: x})[0]
else:
out = self.session.run([self.output_name], {self.input_name: x})[0]
# 兼容不同导出:期望 [N, seq_len, num_classes] 或 [N, num_classes, seq_len]
logits = out
if logits.ndim == 3:
if logits.shape[1] < logits.shape[2]:
# [N, seq, C]
seq_len = logits.shape[1]
num_classes = logits.shape[2]
logits = logits[0]
else:
# [N, C, seq]
num_classes = logits.shape[1]
seq_len = logits.shape[2]
logits = np.transpose(logits[0], (1, 0)) # -> [seq, C]
else:
raise ValueError("Unexpected rec output shape: {}".format(logits.shape))
probs = softmax(logits, axis=1)
pred_idx = probs.argmax(axis=1)
pred_prob = probs.max(axis=1)
# CTC 去重、去 blank
text, confs = ctc_greedy_decode(pred_idx, pred_prob, blank=self.blank_idx)
return text, float(np.mean(confs) if confs else 0.0)
# ------------------------------
# 解码/数学函数
# ------------------------------
def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
x = x - np.max(x, axis=axis, keepdims=True)
e = np.exp(x)
return e / np.sum(e, axis=axis, keepdims=True)
def ctc_greedy_decode(indices: np.ndarray, probs: np.ndarray, blank: int) -> Tuple[str, List[float]]:
"""CTC 贪心解码:移除重复与 blank。
indices: [seq_len] 的类别索引
probs: [seq_len] 对应的最大概率
"""
res_idx = []
res_prob = []
prev = -1
for i, cls in enumerate(indices):
if cls == blank:
prev = -1
continue
if cls == prev:
continue
res_idx.append(int(cls))
res_prob.append(float(probs[i]))
prev = cls
# 将 idx 映射到字符
# 注意:PaddleOCR 的字典不包含 blank;blank = len(charset)
chars = []
for i in res_idx:
if 0 <= i < len(ocr_pipeline.recognizer.charset): # 保险判断
chars.append(ocr_pipeline.recognizer.charset[i])
return "".join(chars), res_prob
# ------------------------------
# OCR 流水线封装
# ------------------------------
class OCRPipeline:
def __init__(self, det_path: str, rec_path: str, cls_path: str, dict_path: str):
self.detector = DBDetector(det_path)
self.classifier = ClsClassifier(cls_path)
self.recognizer = TextRecognizer(rec_path, dict_path)
def __call__(self, img: np.ndarray) -> Dict[str, Any]:
boxes = self.detector(img)
results = []
for box in boxes:
crop = get_rotate_crop_image(img, box)
# 方向分类:预测是否需要 180° 翻转
label, conf = self.classifier(crop)
if label == 1 and conf > CLS_THRESH:
crop = cv2.rotate(crop, cv2.ROTATE_180)
text, score = self.recognizer(crop)
results.append({
"text": text,
"score": float(score),
"box": box.astype(float).tolist()
})
return {"results": results}
# ------------------------------
# Flask 应用
# ------------------------------
def create_app() -> Flask:
app = Flask(__name__)
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "ok"})
@app.route("/ocr", methods=["POST"])
def ocr_endpoint():
"""表单上传字段名:image
返回:{
"results": [
{"text": str, "score": float, "box": [[x,y],...]} , ...
]
}
"""
if "image" not in request.files:
return jsonify({"error": "no image"}), 400
file = request.files["image"]
data = file.read()
img_bytes = np.frombuffer(data, np.uint8)
img = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR)
if img is None:
return jsonify({"error": "invalid image"}), 400
output = ocr_pipeline(img)
return jsonify(output)
return app
# ------------------------------
# 主入口
# ------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=5000)
args = parser.parse_args()
# 初始化全局流水线(只初始化一次,供进程内所有请求复用)
ocr_pipeline = OCRPipeline(DET_MODEL_PATH, REC_MODEL_PATH, CLS_MODEL_PATH, CHAR_DICT_PATH)
app = create_app()
# 为了线程安全,可以关闭 Flask 的多线程(生产建议用 gunicorn 多进程 worker)
app.run(host=args.host, port=args.port, threaded=False)
else:
# gunicorn 模式下会走这里:
ocr_pipeline = OCRPipeline(DET_MODEL_PATH, REC_MODEL_PATH, CLS_MODEL_PATH, CHAR_DICT_PATH)
app = create_app()
pip install gunicorn
pip install flask onnxruntime opencv-python numpy pyclipper
更多推荐
所有评论(0)