手写文字识别技术

OpenCV实现与深度学习模型训练

OpenCV

计算机视觉库,提供图像预处理和特征提取功能

深度学习

CNN和RNN模型用于字符识别和序列建模

PySide6

现代GUI框架,构建用户友好的识别界面

图像预处理技术详解

图像去噪算法

在手写文字识别中,图像去噪是关键的预处理步骤。OpenCV提供了多种去噪算法,包括高斯滤波、双边滤波和非局部均值去噪。对于手写文字图像,我们通常采用自适应去噪策略。

  • 高斯滤波:适用于去除高频噪声
  • 双边滤波:保持边缘的同时平滑噪声
  • 形态学操作:去除椒盐噪声
  • 中值滤波:处理脉冲噪声

二值化处理

二值化是将灰度图像转换为黑白图像的过程,这对后续的字符分割和识别至关重要。我们需要根据图像的特点选择合适的二值化方法。

  • 全局阈值:适用于光照均匀的图像
  • 自适应阈值:处理光照不均的情况
  • OTSU方法:自动确定最优阈值
  • Niblack算法:局部自适应二值化

图像倾斜校正算法

手写文字往往存在倾斜问题,需要进行角度校正。我们采用基于霍夫变换的倾斜检测算法,通过检测文本行的主要方向来确定倾斜角度。算法流程包括:边缘检测、霍夫直线变换、角度统计分析、仿射变换校正。

首先使用Canny边缘检测算法提取图像中的边缘信息,然后应用霍夫变换检测直线,通过统计分析确定文本的主要倾斜角度。最后使用仿射变换矩阵对图像进行旋转校正,确保文本行水平对齐。

字符分割与连通域分析

连通域分析原理

连通域分析是字符分割的核心技术,通过识别图像中相互连接的像素区域来定位单个字符。在OpenCV中,我们使用cv2.connectedComponents函数进行连通域标记,结合形态学操作来优化分割效果。

四连通与八连通

连通性定义了相邻像素的判断标准。四连通只考虑上下左右四个方向的像素,而八连通还包括对角线方向的像素。对于手写字符,八连通通常能提供更好的字符完整性。

在处理笔画较细或有断笔的手写字符时,八连通能够更好地保持字符的连续性。

噪声过滤策略

通过面积阈值过滤可以去除小的噪声连通域,通过长宽比分析可以过滤不符合字符特征的区域。同时使用密度分析可以识别空洞字符如"O"、"A"等。

设置合理的面积阈值范围(通常为总像素的0.1%-10%)能有效平衡噪声去除和字符保留。

 

特征提取与描述子

传统特征提取方法

方向梯度直方图 (HOG)

HOG特征描述符通过计算图像局部区域的梯度方向分布来描述物体的形状特征。对于手写字符识别,HOG能够有效捕捉笔画的方向信息和局部形状特征。

算法步骤:
  1. 计算图像梯度(使用Sobel算子)
  2. 计算梯度方向和幅值
  3. 构建方向直方图(通常9个bin)
  4. 在块内进行归一化处理
  5. 连接所有块的特征向量
参数配置:
  • 细胞大小:8×8像素
  • 块大小:2×2细胞
  • 方向bins:9个
  • 归一化方法:L2-Hys
局部二值模式 (LBP)

LBP是一种简单而有效的纹理描述算子,通过比较中心像素与其邻域像素的灰度值来生成二进制编码。在手写字符识别中,LBP能够捕捉笔画的局部纹理特征。

改进的均匀LBP(Uniform LBP)只考虑二进制模式中最多有两次跳变的模式,这些模式在自然纹理中占主导地位。对于手写字符,均匀LBP能够更好地描述笔画的平滑过渡区域。

旋转不变性: 通过选择LBP值的最小旋转形式,可以实现对字符旋转的鲁棒性,这对于手写字符的多样性特别重要。

深度学习特征提取

现代深度学习方法通过卷积神经网络自动学习层次化特征表示。CNN能够从低级的边缘和纹理特征逐步构建高级的语义特征,这种端到端的学习方式在手写字符识别中表现出色。

卷积层

提取局部特征和模式

池化层

降维和特征选择

全连接层

高级语义特征融合

CNN模型架构设计

LeNet-5改进架构

LeNet-5是最早成功应用于手写数字识别的CNN架构。我们基于LeNet-5进行改进,增加了批量归一化、Dropout正则化和残差连接,以提高模型的收敛速度和泛化能力。

网络层次结构

1

输入层

32×32单通道灰度图像

2

卷积层1 + 池化层1

6个5×5卷积核,ReLU激活,2×2最大池化

3

卷积层2 + 池化层2

16个5×5卷积核,ReLU激活,2×2最大池化

4

全连接层

120个神经元,ReLU激活,Dropout(0.5)

5

输出层

类别数个神经元,Softmax激活

ResNet残差网络应用

残差网络通过引入跳跃连接解决了深度网络的梯度消失问题。在手写字符识别中,我们使用轻量级的ResNet架构,在保证识别精度的同时控制计算复杂度。

残差块设计

每个残差块包含两个3×3卷积层,使用批量归一化和ReLU激活函数。跳跃连接直接将输入添加到块的输出,形成残差映射。

当输入和输出维度不匹配时,使用1×1卷积进行维度调整。

网络深度选择

对于手写字符识别,我们通常使用18层或34层的ResNet架构。更深的网络能够学习更复杂的特征表示,但也需要更多的训练数据和计算资源。

实验表明,在中等规模数据集上,18层ResNet已能达到良好的识别效果。

注意力机制集成

注意力机制能够让模型自动关注输入中的重要区域。在手写字符识别中,我们在CNN的特征映射上应用空间注意力和通道注意力,提高模型对关键笔画区域的敏感性。

空间注意力

通过学习空间权重图,突出显示字符的关键区域,如笔画交叉点、拐角等特征位置。

通道注意力

学习不同特征通道的重要性权重,自适应地强调对当前字符识别最有用的特征。

RNN序列建模与CTC解码

LSTM网络架构

长短期记忆网络(LSTM)通过门控机制解决了传统RNN的梯度消失问题,能够有效建模长序列依赖关系。在手写文字识别中,LSTM用于建模字符序列的上下文信息。

LSTM单元结构
门控机制:
  • 忘记门:决定从细胞状态中丢弃什么信息
  • 输入门:决定存储什么新信息到细胞状态
  • 输出门:控制细胞状态的哪些部分输出
状态更新:
  • 细胞状态:长期记忆的载体
  • 隐藏状态:短期记忆和输出
  • 候选值:待添加的新信息

双向LSTM: 通过同时考虑前向和后向的序列信息,双向LSTM能够更好地理解字符的上下文关系,提高识别准确性。

CTC损失函数

连接主义时序分类(CTC)允许模型在不需要精确字符级对齐的情况下进行训练。这对于手写文字识别特别重要,因为很难准确标注每个字符的边界位置。

CTC对齐机制

CTC引入空白符号(blank)来处理重复字符和字符间隔。对于输入序列,CTC考虑所有可能的对齐路径,并计算目标序列的总概率。

对齐规则:
  • • 连续相同字符间必须有空白符分隔
  • • 空白符可以出现在任何位置
  • • 移除空白符和重复字符得到最终输出
概率计算:
  • • 使用前向-后向算法
  • • 动态规划计算所有路径概率和
  • • 梯度可通过路径概率计算
解码策略

CTC解码可以使用贪心搜索或束搜索。贪心搜索简单快速,但可能不是全局最优。束搜索考虑多个候选路径,能够获得更好的解码结果。

贪心解码

在每个时间步选择概率最高的字符,速度快但可能局部最优

束搜索解码

维护多个候选路径,选择整体概率最高的序列,准确性更高

数据增强与正则化策略

几何变换增强

几何变换是最常用的数据增强方法,通过对图像进行旋转、缩放、平移和剪切来增加数据的多样性。这些变换模拟了现实场景中手写文字的自然变化。

🔄

旋转变换

±15°随机旋转,模拟书写角度变化

📏

缩放变换

0.8-1.2倍缩放,模拟字符大小变化

↔️

平移变换

±10%图像大小的随机平移

📐

剪切变换

±0.2弧度剪切,模拟倾斜书写

弹性变形

弹性变形通过在图像上应用随机位移场来模拟手写字符的自然变形。这种方法特别适用于手写体,因为每个人的书写风格都有细微的差异。

实现时,我们在图像上生成平滑的随机位移场,然后使用双线性插值对像素进行重新采样。位移幅度通常控制在图像大小的5-10%范围内。

像素级增强技术

亮度和对比度调整

模拟不同的光照条件和图像质量。亮度调整范围通常为±20%,对比度调整范围为0.8-1.2倍。这些变换能够提高模型对光照变化的鲁棒性。

亮度调整

I_new = I + α,其中α∈[-50, 50]

对比度调整

I_new = α × I,其中α∈[0.8, 1.2]

噪声添加

添加不同类型的噪声可以提高模型的抗噪能力。常用的噪声类型包括高斯噪声、椒盐噪声和模糊噪声。

高斯噪声

σ=5-15的正态分布噪声

椒盐噪声

0.5-2%像素的随机噪声点

模糊噪声

3×3或5×5高斯核模糊

正则化技术

除了数据增强,我们还采用多种正则化技术来防止过拟合,包括Dropout、批量归一化、权重衰减和早停策略。

网络正则化
  • • Dropout: 0.2-0.5的随机失活率
  • • 批量归一化: 加速收敛和稳定训练
  • • 权重衰减: L2正则化系数1e-4
训练策略
  • • 早停: 验证损失不下降时停止
  • • 学习率衰减: 指数或余弦衰减
  • • 梯度裁剪: 防止梯度爆炸

模型训练与优化策略

损失函数设计

合适的损失函数对模型性能至关重要。在手写字符识别中,我们通常使用交叉熵损失进行字符分类,CTC损失处理序列对齐问题,焦点损失解决类别不平衡。

多任务损失函数

在实际应用中,我们往往需要同时进行字符分类和序列识别。多任务学习通过联合优化多个相关任务来提高整体性能。

L_total = α × L_classification + β × L_ctc + γ × L_attention

分类损失

字符级别的交叉熵损失

CTC损失

序列级别的对齐损失

注意力损失

注意力权重的正则化

焦点损失应用

手写字符数据集往往存在类别不平衡问题,某些字符出现频率远高于其他字符。焦点损失通过降低易分类样本的权重,让模型更关注困难样本。

FL(p_t) = -α_t × (1-p_t)^γ × log(p_t)

其中α_t是类别权重,γ是聚焦参数(通常取2),p_t是模型预测的目标类别概率

优化器选择与调参

Adam优化器配置

Adam优化器结合了动量和自适应学习率的优点,在深度学习中表现出色。对于手写字符识别,我们通常使用以下配置参数。

基本参数:
  • • 学习率: 1e-3 (初始值)
  • • β1: 0.9 (一阶矩估计的衰减率)
  • • β2: 0.999 (二阶矩估计的衰减率)
  • • ε: 1e-8 (数值稳定性参数)
调整策略:
  • • 权重衰减: 1e-4
  • • 预热策略: 前1000步线性增长
  • • 学习率调度: 余弦退火
  • • 梯度裁剪: 最大范数1.0
学习率调度策略

动态调整学习率对模型收敛至关重要。我们采用多种学习率调度策略来平衡收敛速度和最终性能。

指数衰减

lr = lr_0 × γ^epoch

衰减因子γ=0.95,每个epoch衰减

余弦退火

平滑的余弦函数衰减

能够跳出局部最优,获得更好解

阶梯衰减

每隔固定epoch降低学习率

简单有效,易于控制

分布式训练策略

对于大规模手写字符数据集,单GPU训练往往效率较低。我们采用数据并行和模型并行相结合的分布式训练策略来加速训练过程。

数据并行

将批次数据分配到多个GPU上并行计算,通过AllReduce操作同步梯度。

  • 线性加速比随GPU数量增长
  • 通信开销相对较小
梯度累积

在显存有限的情况下,通过累积多个小批次的梯度来模拟大批次训练。

  • 有效批次大小可达数千
  • 保持训练稳定性

性能评估与模型优化

评估指标体系

全面的评估指标体系是衡量模型性能的基础。在手写字符识别中,我们需要从字符级、单词级和句子级多个层面评估模型性能。

字符级指标
  • • 准确率: 正确识别的字符数/总字符数
  • • Top-k准确率: 前k个预测中包含正确答案
  • • 混淆矩阵: 分析具体的错误模式
序列级指标
  • • 编辑距离: 最小编辑操作数
  • • BLEU分数: n-gram匹配度量
  • • 完全匹配率: 整个序列完全正确的比例
性能指标
  • • 推理速度: 每秒处理的图像数
  • • 模型大小: 参数量和存储空间
  • • 能耗分析: 移动设备功耗评估
评估数据集构建

构建高质量的评估数据集对模型性能评估至关重要。我们需要确保测试集的代表性和多样性,包含不同书写风格、质量水平和场景的样本。

数据分层:
  • • 按书写者年龄分组
  • • 按书写质量分级
  • • 按字符频率平衡
  • • 按图像质量分类
交叉验证:
  • • k折交叉验证(k=5)
  • • 时间序列分割验证
  • • 书写者无关验证
  • • 领域适应评估

模型压缩技术

知识蒸馏

知识蒸馏通过让小模型(学生)学习大模型(教师)的知识来提高小模型的性能。在手写字符识别中,我们可以用大型CNN作为教师模型,轻量级MobileNet作为学生模型。

软目标学习

学生模型不仅学习硬标签,还学习教师模型输出的概率分布,获得更丰富的信息。

特征对齐

让学生模型的中间特征向教师模型的对应特征对齐,学习更深层的表示。

模型剪枝

通过移除不重要的权重或神经元来减小模型大小。结构化剪枝移除整个通道或层,非结构化剪枝移除个别权重。

幅值剪枝

移除绝对值小的权重

结构化剪枝

移除整个通道或层

梯度剪枝

基于梯度信息剪枝

量化技术

将模型权重从32位浮点数量化为8位或16位整数,显著减小模型大小和计算量。现代量化技术能够在保持精度的同时大幅提升推理速度。

训练后量化:
  • • 动态量化:运行时量化激活
  • • 静态量化:预计算量化参数
  • • QAT:量化感知训练
量化策略:
  • • INT8量化:8倍压缩比
  • • 混合精度:关键层保持FP16
  • • 校准集优化:提高量化精度

实际应用案例与部署

移动端部署优化

移动端部署面临计算资源有限、内存约束、功耗限制等挑战。我们需要在模型精度和运行效率之间找到最佳平衡点。

模型优化策略
  • 轻量级架构: 使用MobileNet、EfficientNet等专为移动设备设计的网络架构
  • 深度可分离卷积: 将标准卷积分解为深度卷积和点卷积,大幅减少参数量
  • 通道洗牌: 通过ShuffleNet的通道洗牌操作提高组卷积的信息交换
  • 神经网络搜索: 使用AutoML技术自动搜索最优的移动端网络架构
运行时优化
  • 图像预处理流水线: 优化图像缩放、归一化等预处理操作的计算效率
  • 批处理策略: 根据设备性能动态调整批处理大小以平衡延迟和吞吐量
  • 内存管理: 使用内存池和对象复用减少内存分配开销
  • 多线程优化: 合理利用多核CPU资源进行并行计算
硬件加速

现代移动设备配备了专门的AI加速器,如Apple的Neural Engine、高通的Hexagon DSP、华为的NPU等。合理利用这些硬件能够显著提升推理性能。

Core ML

iOS设备的机器学习框架,支持Neural Engine加速

TensorFlow Lite

跨平台轻量级推理框架,支持GPU和DSP加速

ONNX Runtime

高性能推理引擎,支持多种硬件后端

云端服务架构

微服务设计

采用微服务架构将手写文字识别系统分解为多个独立的服务模块,提高系统的可扩展性和可维护性。

📸

图像处理服务

负责图像预处理和格式转换

🧠

模型推理服务

执行深度学习模型推理

📝

后处理服务

结果优化和格式化输出

💾

缓存服务

缓存频繁访问的结果

负载均衡与容灾

面对大规模并发请求,系统需要具备良好的负载均衡能力和容灾机制,确保服务的高可用性和稳定性。

负载均衡策略:
  • • 轮询调度:请求均匀分配到各节点
  • • 加权轮询:根据节点性能分配权重
  • • 最少连接:优先选择连接数最少的节点
  • • 一致性哈希:保证会话亲和性
容灾机制:
  • • 健康检查:定期检测服务状态
  • • 自动故障转移:故障节点自动下线
  • • 熔断器模式:防止级联故障
  • • 限流降级:过载时的保护机制

性能监控与运维

完善的监控体系对于保证生产环境的稳定运行至关重要。我们需要从多个维度监控系统的运行状态。

业务指标
  • • 识别准确率
  • • 平均响应时间
  • • 请求成功率
  • • 用户满意度
系统指标
  • • CPU使用率
  • • 内存占用
  • • GPU利用率
  • • 网络带宽
模型指标
  • • 模型精度漂移
  • • 推理延迟分布
  • • 异常检测
  • • A/B测试结果

完整实现代码

以下是基于OpenCV和PySide6的完整手写文字识别系统实现,包含图像预处理、深度学习模型训练、GUI界面等核心功能模块。

# -*- coding: utf-8 -*-
"""
手写文字识别系统 - 完整实现
使用OpenCV进行图像处理,深度学习模型进行字符识别
基于PySide6构建现代化GUI界面   作者丁林松 
"""

import sys
import cv2
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd
import os
import json
import pickle
from pathlib import Path
import logging
from typing import List, Tuple, Dict, Optional, Union
import time
from dataclasses import dataclass
from enum import Enum

# PySide6 GUI组件
from PySide6.QtWidgets import (
    QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, 
    QGridLayout, QPushButton, QLabel, QTextEdit, QProgressBar, 
    QFileDialog, QTabWidget, QGroupBox, QSlider, QSpinBox,
    QComboBox, QCheckBox, QListWidget, QTableWidget, QTableWidgetItem,
    QSplitter, QScrollArea, QFrame, QStatusBar, QMenuBar, QToolBar,
    QDialog, QDialogButtonBox, QFormLayout, QLineEdit, QSpacerItem,
    QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsPixmapItem
)
from PySide6.QtCore import (
    Qt, QThread, QTimer, Signal, QMutex, QWaitCondition, 
    QSettings, QStandardPaths, QRect, QPoint, QSize
)
from PySide6.QtGui import (
    QPixmap, QImage, QPainter, QPen, QColor, QFont, QAction,
    QIcon, QCursor, QKeySequence, QShortcut, QPalette
)

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('handwriting_recognition.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

@dataclass
class ModelConfig:
    """模型配置类"""
    input_size: Tuple[int, int] = (64, 64)
    num_classes: int = 62  # 26小写+26大写+10数字
    learning_rate: float = 0.001
    batch_size: int = 32
    epochs: int = 100
    dropout_rate: float = 0.5
    weight_decay: float = 1e-4
    
class ModelType(Enum):
    """模型类型枚举"""
    LENET = "LeNet"
    RESNET = "ResNet"
    MOBILENET = "MobileNet"
    CUSTOM_CNN = "CustomCNN"

class ImagePreprocessor:
    """图像预处理器类"""
    
    def __init__(self):
        self.config = {
            'target_size': (64, 64),
            'gaussian_blur_ksize': 5,
            'bilateral_d': 9,
            'bilateral_sigma_color': 75,
            'bilateral_sigma_space': 75,
            'adaptive_threshold_block_size': 11,
            'adaptive_threshold_c': 2,
            'morph_kernel_size': 3,
            'contour_min_area': 100,
            'contour_max_area': 5000
        }
        
    def remove_noise(self, image: np.ndarray, method: str = 'bilateral') -> np.ndarray:
        """
        图像去噪处理
        
        Args:
            image: 输入图像
            method: 去噪方法 ('gaussian', 'bilateral', 'median', 'morphology')
            
        Returns:
            去噪后的图像
        """
        try:
            if method == 'gaussian':
                return cv2.GaussianBlur(image, 
                    (self.config['gaussian_blur_ksize'], self.config['gaussian_blur_ksize']), 0)
            elif method == 'bilateral':
                return cv2.bilateralFilter(image, 
                    self.config['bilateral_d'], 
                    self.config['bilateral_sigma_color'], 
                    self.config['bilateral_sigma_space'])
            elif method == 'median':
                return cv2.medianBlur(image, 5)
            elif method == 'morphology':
                kernel = np.ones((self.config['morph_kernel_size'], 
                                self.config['morph_kernel_size']), np.uint8)
                return cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel)
            else:
                return image
        except Exception as e:
            logger.error(f"去噪处理失败: {e}")
            return image
    
    def binarize_image(self, image: np.ndarray, method: str = 'adaptive') -> np.ndarray:
        """
        图像二值化处理
        
        Args:
            image: 输入灰度图像
            method: 二值化方法 ('otsu', 'adaptive', 'global')
            
        Returns:
            二值化后的图像
        """
        try:
            if method == 'otsu':
                _, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
                return binary
            elif method == 'adaptive':
                return cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                    cv2.THRESH_BINARY, self.config['adaptive_threshold_block_size'], 
                    self.config['adaptive_threshold_c'])
            elif method == 'global':
                _, binary = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
                return binary
            else:
                return image
        except Exception as e:
            logger.error(f"二值化处理失败: {e}")
            return image
    
    def correct_skew(self, image: np.ndarray) -> Tuple[np.ndarray, float]:
        """
        图像倾斜校正
        
        Args:
            image: 输入二值化图像
            
        Returns:
            校正后的图像和检测到的角度
        """
        try:
            # 边缘检测
            edges = cv2.Canny(image, 50, 150, apertureSize=3)
            
            # 霍夫直线变换
            lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=100)
            
            if lines is not None:
                angles = []
                for rho, theta in lines[:, 0]:
                    angle = theta * 180 / np.pi
                    if angle < 45:
                        angles.append(angle)
                    elif angle > 135:
                        angles.append(angle - 180)
                
                if angles:
                    # 计算中位数角度
                    median_angle = np.median(angles)
                    
                    # 图像旋转
                    h, w = image.shape
                    center = (w // 2, h // 2)
                    rotation_matrix = cv2.getRotationMatrix2D(center, median_angle, 1.0)
                    corrected = cv2.warpAffine(image, rotation_matrix, (w, h), 
                        flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
                    
                    return corrected, median_angle
            
            return image, 0.0
            
        except Exception as e:
            logger.error(f"倾斜校正失败: {e}")
            return image, 0.0
    
    def segment_characters(self, image: np.ndarray) -> List[Tuple[np.ndarray, Tuple[int, int, int, int]]]:
        """
        字符分割
        
        Args:
            image: 输入二值化图像
            
        Returns:
            分割后的字符图像列表和边界框
        """
        try:
            # 连通域分析
            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
                image, connectivity=8, ltype=cv2.CV_32S)
            
            characters = []
            
            for i in range(1, num_labels):  # 跳过背景
                x, y, w, h, area = stats[i]
                
                # 过滤噪声
                if (self.config['contour_min_area'] <= area <= self.config['contour_max_area'] 
                    and w > 10 and h > 10):
                    
                    # 提取字符区域
                    char_image = image[y:y+h, x:x+w]
                    characters.append((char_image, (x, y, w, h)))
            
            # 按x坐标排序(从左到右)
            characters.sort(key=lambda x: x[1][0])
            
            return characters
            
        except Exception as e:
            logger.error(f"字符分割失败: {e}")
            return []
    
    def normalize_character(self, char_image: np.ndarray) -> np.ndarray:
        """
        字符图像归一化
        
        Args:
            char_image: 输入字符图像
            
        Returns:
            归一化后的图像
        """
        try:
            # 添加边距
            padding = 10
            h, w = char_image.shape
            padded = np.ones((h + 2*padding, w + 2*padding), dtype=np.uint8) * 255
            padded[padding:padding+h, padding:padding+w] = char_image
            
            # 调整大小
            resized = cv2.resize(padded, self.config['target_size'], 
                interpolation=cv2.INTER_AREA)
            
            # 归一化到[0, 1]
            normalized = resized.astype(np.float32) / 255.0
            
            return normalized
            
        except Exception as e:
            logger.error(f"字符归一化失败: {e}")
            return np.zeros(self.config['target_size'], dtype=np.float32)
    
    def extract_features_hog(self, image: np.ndarray) -> np.ndarray:
        """
        提取HOG特征
        
        Args:
            image: 输入图像
            
        Returns:
            HOG特征向量
        """
        try:
            # HOG参数配置
            win_size = (64, 64)
            block_size = (16, 16)
            block_stride = (8, 8)
            cell_size = (8, 8)
            nbins = 9
            
            # 创建HOG描述符
            hog = cv2.HOGDescriptor(win_size, block_size, block_stride, 
                cell_size, nbins)
            
            # 调整图像大小
            resized = cv2.resize(image, win_size)
            
            # 计算HOG特征
            features = hog.compute(resized)
            
            return features.flatten()
            
        except Exception as e:
            logger.error(f"HOG特征提取失败: {e}")
            return np.array([])
    
    def extract_features_lbp(self, image: np.ndarray, radius: int = 3, 
                           n_points: int = 24) -> np.ndarray:
        """
        提取LBP特征
        
        Args:
            image: 输入图像
            radius: LBP半径
            n_points: 采样点数
            
        Returns:
            LBP特征直方图
        """
        try:
            def lbp_calculate(image, radius, n_points):
                """计算LBP"""
                h, w = image.shape
                lbp = np.zeros((h, w), dtype=np.uint8)
                
                for y in range(radius, h - radius):
                    for x in range(radius, w - radius):
                        center_pixel = image[y, x]
                        binary_string = ""
                        
                        for i in range(n_points):
                            angle = 2 * np.pi * i / n_points
                            sample_x = x + radius * np.cos(angle)
                            sample_y = y + radius * np.sin(angle)
                            
                            # 双线性插值
                            x1, y1 = int(sample_x), int(sample_y)
                            x2, y2 = x1 + 1, y1 + 1
                            
                            if x2 < w and y2 < h:
                                # 双线性插值计算像素值
                                dx, dy = sample_x - x1, sample_y - y1
                                pixel_value = (1-dx)*(1-dy)*image[y1,x1] + \
                                             dx*(1-dy)*image[y1,x2] + \
                                             (1-dx)*dy*image[y2,x1] + \
                                             dx*dy*image[y2,x2]
                                
                                binary_string += "1" if pixel_value >= center_pixel else "0"
                        
                        lbp[y, x] = int(binary_string, 2)
                
                return lbp
            
            # 计算LBP
            lbp_image = lbp_calculate(image, radius, n_points)
            
            # 计算直方图
            hist, _ = np.histogram(lbp_image.ravel(), bins=2**n_points, 
                range=(0, 2**n_points))
            
            # 归一化
            hist = hist.astype(np.float32)
            hist /= (hist.sum() + 1e-7)
            
            return hist
            
        except Exception as e:
            logger.error(f"LBP特征提取失败: {e}")
            return np.array([])
    
    def process_image(self, image: np.ndarray) -> Tuple[List[np.ndarray], Dict]:
        """
        完整的图像处理流水线
        
        Args:
            image: 输入图像
            
        Returns:
            处理后的字符图像列表和处理信息
        """
        try:
            info = {'steps': [], 'stats': {}}
            
            # 转换为灰度图
            if len(image.shape) == 3:
                gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            else:
                gray = image.copy()
            info['steps'].append('灰度转换')
            
            # 去噪
            denoised = self.remove_noise(gray, 'bilateral')
            info['steps'].append('双边滤波去噪')
            
            # 二值化
            binary = self.binarize_image(denoised, 'adaptive')
            info['steps'].append('自适应二值化')
            
            # 倾斜校正
            corrected, angle = self.correct_skew(binary)
            info['steps'].append(f'倾斜校正 (角度: {angle:.2f}°)')
            info['stats']['skew_angle'] = angle
            
            # 字符分割
            characters_data = self.segment_characters(corrected)
            info['steps'].append(f'字符分割 (检测到 {len(characters_data)} 个字符)')
            info['stats']['character_count'] = len(characters_data)
            
            # 字符归一化
            normalized_chars = []
            for char_img, bbox in characters_data:
                normalized = self.normalize_character(char_img)
                normalized_chars.append(normalized)
            
            info['steps'].append('字符归一化')
            
            return normalized_chars, info
            
        except Exception as e:
            logger.error(f"图像处理流水线失败: {e}")
            return [], {'steps': ['处理失败'], 'stats': {}}

class LeNetModel(nn.Module):
    """改进的LeNet-5模型"""
    
    def __init__(self, num_classes: int = 62, dropout_rate: float = 0.5):
        super(LeNetModel, self).__init__()
        
        # 卷积层
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.bn2 = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16, 120, kernel_size=5)
        self.bn3 = nn.BatchNorm2d(120)
        
        # 全连接层
        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(84, num_classes)
        
        # Dropout
        self.dropout = nn.Dropout(dropout_rate)
        
        # 池化层
        self.pool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        # 第一个卷积块
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        
        # 第二个卷积块
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        
        # 第三个卷积块
        x = F.relu(self.bn3(self.conv3(x)))
        
        # 展平
        x = x.view(x.size(0), -1)
        
        # 全连接层
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

class ResidualBlock(nn.Module):
    """残差块"""
    
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
            stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, 
            stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # 如果输入输出维度不匹配,使用1x1卷积调整
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                    stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = self.shortcut(x)
        
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        out += identity
        out = F.relu(out)
        
        return out

class ResNetModel(nn.Module):
    """简化的ResNet模型"""
    
    def __init__(self, num_classes: int = 62, num_blocks: List[int] = [2, 2, 2, 2]):
        super(ResNetModel, self).__init__()
        
        self.in_channels = 64
        
        # 初始卷积层
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # 残差层
        self.layer1 = self._make_layer(64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(512, num_blocks[3], stride=2)
        
        # 全局平均池化和分类器
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        
    def _make_layer(self, out_channels: int, num_blocks: int, stride: int):
        """构建残差层"""
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        
        for stride in strides:
            layers.append(ResidualBlock(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
            
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x

class LSTMModel(nn.Module):
    """LSTM序列模型"""
    
    def __init__(self, input_size: int, hidden_size: int = 256, 
                 num_layers: int = 2, num_classes: int = 62):
        super(LSTMModel, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, 
            batch_first=True, bidirectional=True)
        
        self.fc = nn.Linear(hidden_size * 2, num_classes)  # 双向LSTM
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        # LSTM前向传播
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        
        out, _ = self.lstm(x, (h0, c0))
        
        # 应用dropout
        out = self.dropout(out)
        
        # 最后一个时间步的输出
        out = self.fc(out[:, -1, :])
        
        return out

class HandwritingDataset(Dataset):
    """手写文字数据集类"""
    
    def __init__(self, images: List[np.ndarray], labels: List[int], 
                 transform=None, augment: bool = False):
        self.images = images
        self.labels = labels
        self.transform = transform
        self.augment = augment
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        # 数据增强
        if self.augment:
            image = self._augment_image(image)
        
        # 转换为张量
        if len(image.shape) == 2:
            image = np.expand_dims(image, axis=0)  # 添加通道维度
        
        image = torch.FloatTensor(image)
        label = torch.LongTensor([label])
        
        if self.transform:
            image = self.transform(image)
            
        return image, label.squeeze()
    
    def _augment_image(self, image: np.ndarray) -> np.ndarray:
        """数据增强"""
        try:
            h, w = image.shape
            
            # 随机旋转 (-15° to 15°)
            if np.random.random() < 0.5:
                angle = np.random.uniform(-15, 15)
                center = (w // 2, h // 2)
                rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
                image = cv2.warpAffine(image, rotation_matrix, (w, h))
            
            # 随机缩放 (0.8 to 1.2)
            if np.random.random() < 0.5:
                scale = np.random.uniform(0.8, 1.2)
                new_h, new_w = int(h * scale), int(w * scale)
                resized = cv2.resize(image, (new_w, new_h))
                
                # 居中放置
                if scale > 1:
                    # 裁剪
                    start_y = (new_h - h) // 2
                    start_x = (new_w - w) // 2
                    image = resized[start_y:start_y+h, start_x:start_x+w]
                else:
                    # 填充
                    pad_y = (h - new_h) // 2
                    pad_x = (w - new_w) // 2
                    image = np.ones((h, w), dtype=np.float32)
                    image[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = resized
            
            # 添加噪声
            if np.random.random() < 0.3:
                noise = np.random.normal(0, 0.05, image.shape)
                image = np.clip(image + noise, 0, 1)
            
            # 弹性变形
            if np.random.random() < 0.2:
                image = self._elastic_deformation(image)
                
            return image.astype(np.float32)
            
        except Exception as e:
            logger.error(f"数据增强失败: {e}")
            return image
    
    def _elastic_deformation(self, image: np.ndarray, alpha: float = 20, 
                           sigma: float = 4) -> np.ndarray:
        """弹性变形"""
        try:
            h, w = image.shape
            
            # 生成随机位移场
            dx = np.random.uniform(-1, 1, (h, w)) * alpha
            dy = np.random.uniform(-1, 1, (h, w)) * alpha
            
            # 高斯平滑
            dx = cv2.GaussianBlur(dx, (0, 0), sigma)
            dy = cv2.GaussianBlur(dy, (0, 0), sigma)
            
            # 生成坐标网格
            x, y = np.meshgrid(np.arange(w), np.arange(h))
            map_x = (x + dx).astype(np.float32)
            map_y = (y + dy).astype(np.float32)
            
            # 重新映射
            deformed = cv2.remap(image, map_x, map_y, cv2.INTER_LINEAR)
            
            return deformed
            
        except Exception as e:
            logger.error(f"弹性变形失败: {e}")
            return image

class ModelTrainer:
    """模型训练器"""
    
    def __init__(self, config: ModelConfig):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        
    def create_model(self, model_type: ModelType) -> nn.Module:
        """创建模型"""
        try:
            if model_type == ModelType.LENET:
                model = LeNetModel(
                    num_classes=self.config.num_classes,
                    dropout_rate=self.config.dropout_rate
                )
            elif model_type == ModelType.RESNET:
                model = ResNetModel(num_classes=self.config.num_classes)
            else:
                raise ValueError(f"不支持的模型类型: {model_type}")
            
            model = model.to(self.device)
            self.model = model
            
            # 创建优化器
            self.optimizer = optim.Adam(
                model.parameters(),
                lr=self.config.learning_rate,
                weight_decay=self.config.weight_decay
            )
            
            # 学习率调度器
            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer, T_max=self.config.epochs
            )
            
            logger.info(f"创建模型成功: {model_type.value}")
            logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
            
            return model
            
        except Exception as e:
            logger.error(f"创建模型失败: {e}")
            raise
    
    def train_epoch(self, train_loader: DataLoader, criterion) -> Tuple[float, float]:
        """训练一个epoch"""
        self.model.train()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            # 前向传播
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = criterion(output, target)
            
            # 反向传播
            loss.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            # 统计
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            total_correct += pred.eq(target).sum().item()
            total_samples += data.size(0)
            
            if batch_idx % 100 == 0:
                logger.info(f'Batch {batch_idx}/{len(train_loader)}, '
                          f'Loss: {loss.item():.6f}')
        
        avg_loss = total_loss / len(train_loader)
        accuracy = total_correct / total_samples
        
        return avg_loss, accuracy
    
    def validate_epoch(self, val_loader: DataLoader, criterion) -> Tuple[float, float]:
        """验证一个epoch"""
        self.model.eval()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(self.device), target.to(self.device)
                
                output = self.model(data)
                loss = criterion(output, target)
                
                total_loss += loss.item()
                pred = output.argmax(dim=1)
                total_correct += pred.eq(target).sum().item()
                total_samples += data.size(0)
        
        avg_loss = total_loss / len(val_loader)
        accuracy = total_correct / total_samples
        
        return avg_loss, accuracy
    
    def train(self, train_loader: DataLoader, val_loader: DataLoader,
              save_path: str = 'model_checkpoint.pth') -> Dict:
        """完整训练流程"""
        try:
            # 损失函数
            criterion = nn.CrossEntropyLoss()
            
            best_val_acc = 0.0
            patience = 10
            patience_counter = 0
            
            logger.info("开始训练...")
            
            for epoch in range(self.config.epochs):
                start_time = time.time()
                
                # 训练
                train_loss, train_acc = self.train_epoch(train_loader, criterion)
                
                # 验证
                val_loss, val_acc = self.validate_epoch(val_loader, criterion)
                
                # 学习率调度
                self.scheduler.step()
                
                # 记录历史
                self.train_losses.append(train_loss)
                self.val_losses.append(val_loss)
                self.train_accuracies.append(train_acc)
                self.val_accuracies.append(val_acc)
                
                epoch_time = time.time() - start_time
                
                logger.info(f'Epoch {epoch+1}/{self.config.epochs} '
                          f'({epoch_time:.2f}s) - '
                          f'Train Loss: {train_loss:.6f}, Train Acc: {train_acc:.4f}, '
                          f'Val Loss: {val_loss:.6f}, Val Acc: {val_acc:.4f}')
                
                # 保存最佳模型
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    patience_counter = 0
                    
                    checkpoint = {
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'scheduler_state_dict': self.scheduler.state_dict(),
                        'best_val_acc': best_val_acc,
                        'config': self.config
                    }
                    torch.save(checkpoint, save_path)
                    logger.info(f"保存最佳模型,验证准确率: {best_val_acc:.4f}")
                else:
                    patience_counter += 1
                
                # 早停
                if patience_counter >= patience:
                    logger.info(f"早停在epoch {epoch+1},最佳验证准确率: {best_val_acc:.4f}")
                    break
            
            # 训练总结
            training_summary = {
                'best_val_accuracy': best_val_acc,
                'final_train_accuracy': self.train_accuracies[-1],
                'final_val_accuracy': self.val_accuracies[-1],
                'total_epochs': len(self.train_losses),
                'train_losses': self.train_losses,
                'val_losses': self.val_losses,
                'train_accuracies': self.train_accuracies,
                'val_accuracies': self.val_accuracies
            }
            
            logger.info("训练完成!")
            return training_summary
            
        except Exception as e:
            logger.error(f"训练失败: {e}")
            raise
    
    def evaluate_model(self, test_loader: DataLoader, 
                      class_names: List[str] = None) -> Dict:
        """模型评估"""
        try:
            self.model.eval()
            all_predictions = []
            all_targets = []
            all_probabilities = []
            
            with torch.no_grad():
                for data, target in test_loader:
                    data, target = data.to(self.device), target.to(self.device)
                    
                    output = self.model(data)
                    probabilities = F.softmax(output, dim=1)
                    predictions = output.argmax(dim=1)
                    
                    all_predictions.extend(predictions.cpu().numpy())
                    all_targets.extend(target.cpu().numpy())
                    all_probabilities.extend(probabilities.cpu().numpy())
            
            # 计算指标
            accuracy = np.mean(np.array(all_predictions) == np.array(all_targets))
            
            # 分类报告
            if class_names is None:
                class_names = [str(i) for i in range(self.config.num_classes)]
            
            classification_rep = classification_report(
                all_targets, all_predictions, 
                target_names=class_names, output_dict=True
            )
            
            # 混淆矩阵
            conf_matrix = confusion_matrix(all_targets, all_predictions)
            
            evaluation_results = {
                'accuracy': accuracy,
                'predictions': all_predictions,
                'targets': all_targets,
                'probabilities': all_probabilities,
                'classification_report': classification_rep,
                'confusion_matrix': conf_matrix
            }
            
            logger.info(f"测试准确率: {accuracy:.4f}")
            
            return evaluation_results
            
        except Exception as e:
            logger.error(f"模型评估失败: {e}")
            raise

class HandwritingRecognitionGUI(QMainWindow):
    """手写文字识别GUI主窗口"""
    
    def __init__(self):
        super().__init__()
        self.preprocessor = ImagePreprocessor()
        self.trainer = None
        self.model = None
        self.class_names = self._generate_class_names()
        self.current_image = None
        self.processed_characters = []
        
        self.init_ui()
        self.setup_connections()
        
    def _generate_class_names(self) -> List[str]:
        """生成类别名称列表"""
        names = []
        # 数字 0-9
        names.extend([str(i) for i in range(10)])
        # 大写字母 A-Z
        names.extend([chr(i) for i in range(ord('A'), ord('Z') + 1)])
        # 小写字母 a-z
        names.extend([chr(i) for i in range(ord('a'), ord('z') + 1)])
        return names
    
    def init_ui(self):
        """初始化用户界面"""
        self.setWindowTitle("手写文字识别系统 v2.0")
        self.setGeometry(100, 100, 1400, 900)
        
        # 创建中央部件
        central_widget = QWidget()
        self.setCentralWidget(central_widget)
        
        # 创建主布局
        main_layout = QHBoxLayout(central_widget)
        
        # 创建分割器
        splitter = QSplitter(Qt.Horizontal)
        main_layout.addWidget(splitter)
        
        # 左侧面板
        left_panel = self.create_left_panel()
        splitter.addWidget(left_panel)
        
        # 中央面板
        central_panel = self.create_central_panel()
        splitter.addWidget(central_panel)
        
        # 右侧面板
        right_panel = self.create_right_panel()
        splitter.addWidget(right_panel)
        
        # 设置分割器比例
        splitter.setSizes([300, 600, 500])
        
        # 创建菜单栏
        self.create_menu_bar()
        
        # 创建工具栏
        self.create_toolbar()
        
        # 创建状态栏
        self.statusBar().showMessage("就绪")
        
        # 应用样式
        self.apply_styles()
    
    def create_left_panel(self) -> QWidget:
        """创建左侧控制面板"""
        panel = QWidget()
        layout = QVBoxLayout(panel)
        
        # 图像输入组
        input_group = QGroupBox("图像输入")
        input_layout = QVBoxLayout(input_group)
        
        self.load_image_btn = QPushButton("加载图像")
        self.load_image_btn.setMinimumHeight(40)
        input_layout.addWidget(self.load_image_btn)
        
        self.capture_btn = QPushButton("摄像头拍摄")
        self.capture_btn.setMinimumHeight(40)
        input_layout.addWidget(self.capture_btn)
        
        layout.addWidget(input_group)
        
        # 预处理参数组
        preprocess_group = QGroupBox("预处理参数")
        preprocess_layout = QFormLayout(preprocess_group)
        
        # 去噪方法
        self.denoise_combo = QComboBox()
        self.denoise_combo.addItems(['bilateral', 'gaussian', 'median', 'morphology'])
        preprocess_layout.addRow("去噪方法:", self.denoise_combo)
        
        # 二值化方法
        self.binarize_combo = QComboBox()
        self.binarize_combo.addItems(['adaptive', 'otsu', 'global'])
        preprocess_layout.addRow("二值化方法:", self.binarize_combo)
        
        # 阈值参数
        self.threshold_slider = QSlider(Qt.Horizontal)
        self.threshold_slider.setRange(1, 21)
        self.threshold_slider.setValue(11)
        self.threshold_label = QLabel("11")
        threshold_layout = QHBoxLayout()
        threshold_layout.addWidget(self.threshold_slider)
        threshold_layout.addWidget(self.threshold_label)
        preprocess_layout.addRow("阈值块大小:", threshold_layout)
        
        # 处理按钮
        self.process_btn = QPushButton("开始处理")
        self.process_btn.setMinimumHeight(40)
        preprocess_layout.addRow(self.process_btn)
        
        layout.addWidget(preprocess_group)
        
        # 模型设置组
        model_group = QGroupBox("模型设置")
        model_layout = QFormLayout(model_group)
        
        # 模型类型
        self.model_combo = QComboBox()
        self.model_combo.addItems(['LeNet', 'ResNet', 'MobileNet'])
        model_layout.addRow("模型类型:", self.model_combo)
        
        # 加载模型按钮
        self.load_model_btn = QPushButton("加载模型")
        model_layout.addRow(self.load_model_btn)
        
        # 识别按钮
        self.recognize_btn = QPushButton("开始识别")
        self.recognize_btn.setMinimumHeight(40)
        self.recognize_btn.setEnabled(False)
        model_layout.addRow(self.recognize_btn)
        
        layout.addWidget(model_group)
        
        # 训练设置组
        train_group = QGroupBox("模型训练")
        train_layout = QFormLayout(train_group)
        
        # 学习率
        self.lr_spinbox = QSpinBox()
        self.lr_spinbox.setDecimals(4)
        self.lr_spinbox.setRange(0.0001, 0.1)
        self.lr_spinbox.setValue(0.001)
        self.lr_spinbox.setSingleStep(0.0001)
        train_layout.addRow("学习率:", self.lr_spinbox)
        
        # 批大小
        self.batch_spinbox = QSpinBox()
        self.batch_spinbox.setRange(1, 128)
        self.batch_spinbox.setValue(32)
        train_layout.addRow("批大小:", self.batch_spinbox)
        
        # 训练轮数
        self.epochs_spinbox = QSpinBox()
        self.epochs_spinbox.setRange(1, 1000)
        self.epochs_spinbox.setValue(100)
        train_layout.addRow("训练轮数:", self.epochs_spinbox)
        
        # 开始训练按钮
        self.train_btn = QPushButton("开始训练")
        self.train_btn.setMinimumHeight(40)
        train_layout.addRow(self.train_btn)
        
        layout.addWidget(train_group)
        
        layout.addStretch()
        
        return panel
    
    def create_central_panel(self) -> QWidget:
        """创建中央图像显示面板"""
        panel = QWidget()
        layout = QVBoxLayout(panel)
        
        # 创建标签页
        self.image_tabs = QTabWidget()
        
        # 原始图像标签页
        self.original_view = QGraphicsView()
        self.original_scene = QGraphicsScene()
        self.original_view.setScene(self.original_scene)
        self.image_tabs.addTab(self.original_view, "原始图像")
        
        # 处理步骤标签页
        self.processed_view = QGraphicsView()
        self.processed_scene = QGraphicsScene()
        self.processed_view.setScene(self.processed_scene)
        self.image_tabs.addTab(self.processed_view, "处理结果")
        
        # 字符分割标签页
        self.segments_view = QGraphicsView()
        self.segments_scene = QGraphicsScene()
        self.segments_view.setScene(self.segments_scene)
        self.image_tabs.addTab(self.segments_view, "字符分割")
        
        layout.addWidget(self.image_tabs)
        
        # 处理进度条
        self.progress_bar = QProgressBar()
        self.progress_bar.setVisible(False)
        layout.addWidget(self.progress_bar)
        
        return panel
    
    def create_right_panel(self) -> QWidget:
        """创建右侧结果面板"""
        panel = QWidget()
        layout = QVBoxLayout(panel)
        
        # 识别结果组
        result_group = QGroupBox("识别结果")
        result_layout = QVBoxLayout(result_group)
        
        # 识别文本显示
        self.result_text = QTextEdit()
        self.result_text.setMinimumHeight(150)
        self.result_text.setFont(QFont("Arial", 14))
        result_layout.addWidget(self.result_text)
        
        # 置信度显示
        self.confidence_label = QLabel("置信度: N/A")
        result_layout.addWidget(self.confidence_label)
        
        # 操作按钮
        button_layout = QHBoxLayout()
        self.copy_btn = QPushButton("复制结果")
        self.save_btn = QPushButton("保存结果")
        button_layout.addWidget(self.copy_btn)
        button_layout.addWidget(self.save_btn)
        result_layout.addLayout(button_layout)
        
        layout.addWidget(result_group)
        
        # 处理信息组
        info_group = QGroupBox("处理信息")
        info_layout = QVBoxLayout(info_group)
        
        self.info_list = QListWidget()
        info_layout.addWidget(self.info_list)
        
        layout.addWidget(info_group)
        
        # 字符详细信息组
        detail_group = QGroupBox("字符详细信息")
        detail_layout = QVBoxLayout(detail_group)
        
        self.detail_table = QTableWidget()
        self.detail_table.setColumnCount(3)
        self.detail_table.setHorizontalHeaderLabels(["字符", "置信度", "位置"])
        detail_layout.addWidget(self.detail_table)
        
        layout.addWidget(detail_group)
        
        return panel
    
    def create_menu_bar(self):
        """创建菜单栏"""
        menubar = self.menuBar()
        
        # 文件菜单
        file_menu = menubar.addMenu('文件')
        
        open_action = QAction('打开图像', self)
        open_action.setShortcut('Ctrl+O')
        open_action.triggered.connect(self.load_image)
        file_menu.addAction(open_action)
        
        save_action = QAction('保存结果', self)
        save_action.setShortcut('Ctrl+S')
        save_action.triggered.connect(self.save_results)
        file_menu.addAction(save_action)
        
        file_menu.addSeparator()
        
        exit_action = QAction('退出', self)
        exit_action.setShortcut('Ctrl+Q')
        exit_action.triggered.connect(self.close)
        file_menu.addAction(exit_action)
        
        # 模型菜单
        model_menu = menubar.addMenu('模型')
        
        load_model_action = QAction('加载模型', self)
        load_model_action.triggered.connect(self.load_model)
        model_menu.addAction(load_model_action)
        
        train_model_action = QAction('训练模型', self)
        train_model_action.triggered.connect(self.start_training)
        model_menu.addAction(train_model_action)
        
        # 帮助菜单
        help_menu = menubar.addMenu('帮助')
        
        about_action = QAction('关于', self)
        about_action.triggered.connect(self.show_about)
        help_menu.addAction(about_action)
    
    def create_toolbar(self):
        """创建工具栏"""
        toolbar = self.addToolBar('工具栏')
        
        # 加载图像
        load_action = QAction('📁', self)
        load_action.setToolTip('加载图像')
        load_action.triggered.connect(self.load_image)
        toolbar.addAction(load_action)
        
        # 处理图像
        process_action = QAction('🔧', self)
        process_action.setToolTip('处理图像')
        process_action.triggered.connect(self.process_image)
        toolbar.addAction(process_action)
        
        # 识别文字
        recognize_action = QAction('🔍', self)
        recognize_action.setToolTip('识别文字')
        recognize_action.triggered.connect(self.recognize_text)
        toolbar.addAction(recognize_action)
        
        toolbar.addSeparator()
        
        # 保存结果
        save_action = QAction('💾', self)
        save_action.setToolTip('保存结果')
        save_action.triggered.connect(self.save_results)
        toolbar.addAction(save_action)
    
    def apply_styles(self):
        """应用样式"""
        style = """
        QMainWindow {
            background-color: #f5f5f5;
        }
        QGroupBox {
            font-weight: bold;
            border: 2px solid #cccccc;
            border-radius: 8px;
            margin: 5px;
            padding-top: 10px;
        }
        QGroupBox::title {
            subcontrol-origin: margin;
            left: 10px;
            padding: 0 5px 0 5px;
        }
        QPushButton {
            background-color: #5D5CDE;
            color: white;
            border: none;
            border-radius: 6px;
            padding: 8px 16px;
            font-size: 13px;
            font-weight: 500;
        }
        QPushButton:hover {
            background-color: #4c4bb8;
        }
        QPushButton:pressed {
            background-color: #3d3a96;
        }
        QPushButton:disabled {
            background-color: #cccccc;
            color: #666666;
        }
        QComboBox, QSpinBox {
            border: 1px solid #ddd;
            border-radius: 4px;
            padding: 5px;
            background-color: white;
        }
        QTextEdit, QListWidget, QTableWidget {
            border: 1px solid #ddd;
            border-radius: 4px;
            background-color: white;
        }
        QTabWidget::pane {
            border: 1px solid #ddd;
            border-radius: 4px;
        }
        QTabBar::tab {
            background-color: #e0e0e0;
            padding: 8px 16px;
            margin-right: 2px;
            border-top-left-radius: 4px;
            border-top-right-radius: 4px;
        }
        QTabBar::tab:selected {
            background-color: #5D5CDE;
            color: white;
        }
        QProgressBar {
            border: 1px solid #ddd;
            border-radius: 4px;
            text-align: center;
        }
        QProgressBar::chunk {
            background-color: #5D5CDE;
            border-radius: 3px;
        }
        """
        self.setStyleSheet(style)
    
    def setup_connections(self):
        """设置信号连接"""
        # 按钮连接
        self.load_image_btn.clicked.connect(self.load_image)
        self.process_btn.clicked.connect(self.process_image)
        self.recognize_btn.clicked.connect(self.recognize_text)
        self.load_model_btn.clicked.connect(self.load_model)
        self.train_btn.clicked.connect(self.start_training)
        self.copy_btn.clicked.connect(self.copy_results)
        self.save_btn.clicked.connect(self.save_results)
        
        # 滑块连接
        self.threshold_slider.valueChanged.connect(
            lambda v: self.threshold_label.setText(str(v))
        )
    
    def load_image(self):
        """加载图像"""
        try:
            file_path, _ = QFileDialog.getOpenFileName(
                self, "选择图像文件", "",
                "图像文件 (*.png *.jpg *.jpeg *.bmp *.tiff)"
            )
            
            if file_path:
                # 读取图像
                self.current_image = cv2.imread(file_path)
                if self.current_image is None:
                    self.show_message("错误", "无法读取图像文件")
                    return
                
                # 显示原始图像
                self.display_image(self.current_image, self.original_scene)
                
                # 更新状态
                self.statusBar().showMessage(f"已加载图像: {Path(file_path).name}")
                self.process_btn.setEnabled(True)
                
                # 清空之前的结果
                self.result_text.clear()
                self.info_list.clear()
                self.detail_table.setRowCount(0)
                
        except Exception as e:
            logger.error(f"加载图像失败: {e}")
            self.show_message("错误", f"加载图像失败: {str(e)}")
    
    def display_image(self, image: np.ndarray, scene: QGraphicsScene):
        """在场景中显示图像"""
        try:
            # 转换图像格式
            if len(image.shape) == 3:
                height, width, channel = image.shape
                bytes_per_line = 3 * width
                q_image = QImage(image.data, width, height, bytes_per_line, 
                    QImage.Format_BGR888).rgbSwapped()
            else:
                height, width = image.shape
                bytes_per_line = width
                q_image = QImage(image.data, width, height, bytes_per_line, 
                    QImage.Format_Grayscale8)
            
            # 创建像素图
            pixmap = QPixmap.fromImage(q_image)
            
            # 清除场景并添加图像
            scene.clear()
            scene.addPixmap(pixmap)
            scene.setSceneRect(pixmap.rect())
            
        except Exception as e:
            logger.error(f"显示图像失败: {e}")
    
    def process_image(self):
        """处理图像"""
        if self.current_image is None:
            self.show_message("警告", "请先加载图像")
            return
        
        try:
            # 显示进度条
            self.progress_bar.setVisible(True)
            self.progress_bar.setValue(0)
            
            # 更新预处理器配置
            self.preprocessor.config['adaptive_threshold_block_size'] = \
                self.threshold_slider.value()
            
            # 处理图像
            self.progress_bar.setValue(20)
            processed_chars, info = self.preprocessor.process_image(self.current_image)
            
            self.progress_bar.setValue(60)
            self.processed_characters = processed_chars
            
            # 显示处理信息
            self.info_list.clear()
            for step in info['steps']:
                self.info_list.addItem(step)
            
            # 显示统计信息
            if 'stats' in info:
                for key, value in info['stats'].items():
                    self.info_list.addItem(f"{key}: {value}")
            
            self.progress_bar.setValue(80)
            
            # 显示字符分割结果
            if processed_chars:
                self.display_character_segments(processed_chars)
                self.recognize_btn.setEnabled(bool(self.model))
            
            self.progress_bar.setValue(100)
            self.statusBar().showMessage(f"处理完成,检测到 {len(processed_chars)} 个字符")
            
            # 隐藏进度条
            QTimer.singleShot(1000, lambda: self.progress_bar.setVisible(False))
            
        except Exception as e:
            logger.error(f"图像处理失败: {e}")
            self.show_message("错误", f"图像处理失败: {str(e)}")
            self.progress_bar.setVisible(False)
    
    def display_character_segments(self, characters: List[np.ndarray]):
        """显示字符分割结果"""
        try:
            self.segments_scene.clear()
            
            if not characters:
                return
            
            # 计算网格布局
            cols = min(10, len(characters))
            rows = (len(characters) + cols - 1) // cols
            
            cell_size = 80
            margin = 10
            
            for i, char_img in enumerate(characters):
                row = i // cols
                col = i % cols
                
                # 转换为显示格式
                if char_img.dtype == np.float32:
                    display_img = (char_img * 255).astype(np.uint8)
                else:
                    display_img = char_img
                
                # 调整大小
                resized = cv2.resize(display_img, (cell_size-margin, cell_size-margin))
                
                # 转换为QImage
                height, width = resized.shape
                q_image = QImage(resized.data, width, height, width, 
                    QImage.Format_Grayscale8)
                pixmap = QPixmap.fromImage(q_image)
                
                # 添加到场景
                x = col * cell_size
                y = row * cell_size
                item = self.segments_scene.addPixmap(pixmap)
                item.setPos(x, y)
                
                # 添加序号
                text_item = self.segments_scene.addText(str(i), QFont("Arial", 10))
                text_item.setPos(x, y + cell_size - 20)
            
            # 设置场景矩形
            scene_rect = QRect(0, 0, cols * cell_size, rows * cell_size)
            self.segments_scene.setSceneRect(scene_rect)
            
        except Exception as e:
            logger.error(f"显示字符分割结果失败: {e}")
    
    def load_model(self):
        """加载预训练模型"""
        try:
            file_path, _ = QFileDialog.getOpenFileName(
                self, "选择模型文件", "",
                "模型文件 (*.pth *.pt)"
            )
            
            if file_path:
                # 加载检查点
                checkpoint = torch.load(file_path, map_location='cpu')
                
                # 创建模型
                config = checkpoint.get('config', ModelConfig())
                model_type = ModelType.LENET  # 默认
                
                if self.model_combo.currentText() == 'ResNet':
                    model_type = ModelType.RESNET
                
                # 创建训练器和模型
                self.trainer = ModelTrainer(config)
                self.model = self.trainer.create_model(model_type)
                
                # 加载权重
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.model.eval()
                
                self.statusBar().showMessage(f"已加载模型: {Path(file_path).name}")
                
                # 启用识别按钮
                if self.processed_characters:
                    self.recognize_btn.setEnabled(True)
                
        except Exception as e:
            logger.error(f"加载模型失败: {e}")
            self.show_message("错误", f"加载模型失败: {str(e)}")
    
    def recognize_text(self):
        """识别文字"""
        if not self.processed_characters:
            self.show_message("警告", "请先处理图像")
            return
        
        if not self.model:
            self.show_message("警告", "请先加载模型")
            return
        
        try:
            # 显示进度条
            self.progress_bar.setVisible(True)
            self.progress_bar.setValue(0)
            
            recognized_text = ""
            confidences = []
            predictions = []
            
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            self.model.to(device)
            self.model.eval()
            
            # 更新详细信息表格
            self.detail_table.setRowCount(len(self.processed_characters))
            
            with torch.no_grad():
                for i, char_img in enumerate(self.processed_characters):
                    # 预处理
                    if len(char_img.shape) == 2:
                        char_tensor = torch.FloatTensor(char_img).unsqueeze(0).unsqueeze(0)
                    else:
                        char_tensor = torch.FloatTensor(char_img).unsqueeze(0)
                    
                    char_tensor = char_tensor.to(device)
                    
                    # 推理
                    output = self.model(char_tensor)
                    probabilities = F.softmax(output, dim=1)
                    
                    # 获取预测结果
                    confidence, predicted = torch.max(probabilities, 1)
                    char_class = predicted.item()
                    char_confidence = confidence.item()
                    
                    # 转换为字符
                    predicted_char = self.class_names[char_class]
                    recognized_text += predicted_char
                    
                    confidences.append(char_confidence)
                    predictions.append(predicted_char)
                    
                    # 更新表格
                    self.detail_table.setItem(i, 0, QTableWidgetItem(predicted_char))
                    self.detail_table.setItem(i, 1, QTableWidgetItem(f"{char_confidence:.3f}"))
                    self.detail_table.setItem(i, 2, QTableWidgetItem(f"({i})"))
                    
                    # 更新进度
                    progress = int((i + 1) / len(self.processed_characters) * 100)
                    self.progress_bar.setValue(progress)
            
            # 显示结果
            self.result_text.setText(recognized_text)
            
            avg_confidence = np.mean(confidences) if confidences else 0
            self.confidence_label.setText(f"平均置信度: {avg_confidence:.3f}")
            
            self.statusBar().showMessage(f"识别完成,识别出 {len(recognized_text)} 个字符")
            
            # 隐藏进度条
            QTimer.singleShot(1000, lambda: self.progress_bar.setVisible(False))
            
        except Exception as e:
            logger.error(f"文字识别失败: {e}")
            self.show_message("错误", f"文字识别失败: {str(e)}")
            self.progress_bar.setVisible(False)
    
    def start_training(self):
        """开始训练模型"""
        # 这里应该打开一个训练对话框或切换到训练界面
        # 由于篇幅限制,这里只显示一个简单的消息
        self.show_message("信息", "训练功能需要额外的数据集和配置。\n请参考完整的训练代码实现。")
    
    def copy_results(self):
        """复制识别结果"""
        text = self.result_text.toPlainText()
        if text:
            clipboard = QApplication.clipboard()
            clipboard.setText(text)
            self.statusBar().showMessage("结果已复制到剪贴板")
        else:
            self.show_message("警告", "没有可复制的内容")
    
    def save_results(self):
        """保存识别结果"""
        text = self.result_text.toPlainText()
        if not text:
            self.show_message("警告", "没有可保存的内容")
            return
        
        try:
            file_path, _ = QFileDialog.getSaveFileName(
                self, "保存识别结果", "",
                "文本文件 (*.txt);;所有文件 (*)"
            )
            
            if file_path:
                with open(file_path, 'w', encoding='utf-8') as f:
                    f.write(text)
                self.statusBar().showMessage(f"结果已保存到: {Path(file_path).name}")
                
        except Exception as e:
            logger.error(f"保存结果失败: {e}")
            self.show_message("错误", f"保存结果失败: {str(e)}")
    
    def show_about(self):
        """显示关于对话框"""
        about_text = """
        手写文字识别系统 v2.0
        基于OpenCV和深度学习的手写文字识别系统
        主要特性:
        
            智能图像预处理
            深度学习字符识别
            实时处理和识别
            现代化GUI界面
        
        技术栈: OpenCV, PyTorch, PySide6
        作者: AI Assistant
        """
        
        from PySide6.QtWidgets import QMessageBox
        msg = QMessageBox(self)
        msg.setWindowTitle("关于")
        msg.setText(about_text)
        msg.setIcon(QMessageBox.Information)
        msg.exec()
    
    def show_message(self, title: str, message: str):
        """显示消息对话框"""
        from PySide6.QtWidgets import QMessageBox
        msg = QMessageBox(self)
        msg.setWindowTitle(title)
        msg.setText(message)
        
        if "错误" in title or "失败" in message:
            msg.setIcon(QMessageBox.Critical)
        elif "警告" in title:
            msg.setIcon(QMessageBox.Warning)
        else:
            msg.setIcon(QMessageBox.Information)
        
        msg.exec()

def main():
    """主函数"""
    try:
        # 创建应用
        app = QApplication(sys.argv)
        app.setApplicationName("手写文字识别系统")
        app.setApplicationVersion("2.0")
        
        # 设置应用图标和样式
        app.setStyle('Fusion')
        
        # 创建主窗口
        window = HandwritingRecognitionGUI()
        window.show()
        
        # 运行应用
        sys.exit(app.exec())
        
    except Exception as e:
        logger.error(f"应用启动失败: {e}")
        print(f"应用启动失败: {e}")

if __name__ == "__main__":
    main()

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐