AI:免费版本的grok,经4次调试达到目前的效果

提示词:用python pyqt5 写一个模型视图框架例如 yolov8 v11这种模型,要求可以检测图片和视频,分2个窗口,显示原图片原视频,另一个是检测窗口显示实时效果,另外设置保存选项,阈值

环境:yolov11环境 

# -*- coding: utf-8 -*-
"""
YOLOv8 / YOLO11 实时检测 GUI(双窗口 + FPS + 图片持久显示 + 自定义保存)
作者:Grok (2025)
"""

import sys
import os
import cv2
import numpy as np
from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QWidget, QHBoxLayout, QVBoxLayout,
    QLabel, QPushButton, QSlider, QFileDialog, QComboBox, QCheckBox,
    QGroupBox, QFormLayout, QLineEdit, QMessageBox
)
from PyQt5.QtCore import Qt, QTimer, QThread, pyqtSignal
from PyQt5.QtGui import QImage, QPixmap

from ultralytics import YOLO


# ==================== 推理线程 ====================
class InferenceThread(QThread):
    frame_ready = pyqtSignal(np.ndarray, np.ndarray)   # (原始帧, 结果帧)
    finished_signal = pyqtSignal()
    error_signal = pyqtSignal(str)

    def __init__(self, model, conf_threshold=0.25, save_dir=None):
        super().__init__()
        self.model = model
        self.conf = conf_threshold
        self.save_dir = save_dir
        self._stop = False
        self.source = None
        self.writer = None

    def run(self):
        if self.source is None:
            return

        # ------------------- 图片模式 -------------------
        if isinstance(self.source, str):
            img_path = self.source
            _, ext = os.path.splitext(img_path.lower())
            if ext in {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.webp'}:
                img = cv2.imread(img_path)
                if img is None:
                    self.error_signal.emit(f"无法读取图片: {img_path}")
                    return

                results = self.model(img, conf=self.conf, verbose=False)[0]
                annotated = results.plot()

                # === 保存图片 ===
                if self.save_dir:
                    os.makedirs(self.save_dir, exist_ok=True)
                    base = os.path.basename(img_path)
                    name, _ = os.path.splitext(base)
                    save_path = os.path.join(self.save_dir, f"{name}_det{ext}")
                    cv2.imwrite(save_path, annotated)
                    print(f"图片保存: {save_path}")

                # === 持续发送帧,保持显示 ===
                while not self._stop:
                    self.frame_ready.emit(img, annotated)
                    self.msleep(100)  # 100ms 刷新一次

                return

        # ------------------- 视频 / 摄像头 -------------------
        cap = cv2.VideoCapture(self.source if isinstance(self.source, int) else self.source)
        if not cap.isOpened():
            err = f"无法打开视频源: {self.source}"
            self.error_signal.emit(err)
            self.finished_signal.emit()
            return

        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps_input = cap.get(cv2.CAP_PROP_FPS)
        if fps_input <= 0:
            fps_input = 30.0
        delay_ms = max(1, int(1000 / fps_input))

        # === FPS 计数器 ===
        prev_time = cv2.getTickCount()
        fps_counter = 0
        fps_display = 0.0

        # === VideoWriter ===
        video_save_path = None
        if self.save_dir:
            os.makedirs(self.save_dir, exist_ok=True)
            video_save_path = os.path.join(self.save_dir, f"det_output_{os.getpid()}.mp4")
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            self.writer = cv2.VideoWriter(video_save_path, fourcc, fps_input, (width, height))

        while not self._stop and cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # === 推理 ===
            results = self.model(frame, conf=self.conf, verbose=False)[0]
            annotated = results.plot()

            # === 计算 FPS ===
            fps_counter += 1
            curr_time = cv2.getTickCount()
            time_diff = (curr_time - prev_time) / cv2.getTickFrequency()
            if time_diff >= 1.0:
                fps_display = fps_counter / time_diff
                fps_counter = 0
                prev_time = curr_time

            # === 绘制 FPS(醒目绿色)===
            fps_text = f"FPS: {fps_display:.1f}"
            cv2.rectangle(annotated, (5, 5), (160, 40), (0, 0, 0), -1)  # 黑色背景
            cv2.putText(annotated, fps_text, (10, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 2, cv2.LINE_AA)

            # === 发送帧 ===
            self.frame_ready.emit(frame.copy(), annotated)

            # === 保存视频 ===
            if self.writer:
                self.writer.write(annotated)

            self.msleep(delay_ms)

        # === 清理 ===
        cap.release()
        if self.writer:
            self.writer.release()
            print(f"视频保存: {video_save_path}")

        self.finished_signal.emit()

    def stop(self):
        self._stop = True
        if self.writer:
            self.writer.release()
        self.wait()


# ==================== 主窗口 ====================
class YOLODetectorGUI(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("YOLOv8 / YOLO11 实时检测 (双窗口 + FPS + 持久显示)")
        self.resize(1500, 800)

        # 模型
        self.model = None
        self.model_list = [
            "yolov8n.pt", "yolov8s.pt", "yolov8m.pt", "yolov8l.pt", "yolov8x.pt",
            "yolov11n.pt", "yolov11s.pt", "yolov11m.pt", "yolov11l.pt", "yolov11x.pt"
        ]
        self.load_model(self.model_list[0])

        # 线程
        self.infer_thread = None

        # 保存目录
        self.save_dir = os.path.join(os.getcwd(), "output")

        # UI
        self.init_ui()

        # 定时刷新
        self.timer = QTimer()
        self.timer.timeout.connect(self.update_frames)
        self.timer.start(30)

        self.raw_frame = None
        self.result_frame = None

    def init_ui(self):
        central = QWidget()
        self.setCentralWidget(central)
        main_layout = QHBoxLayout(central)

        # ----- 左右画面 -----
        view_layout = QHBoxLayout()
        self.lbl_raw = QLabel("原始画面")
        self.lbl_raw.setAlignment(Qt.AlignCenter)
        self.lbl_raw.setMinimumSize(640, 480)
        self.lbl_raw.setStyleSheet("background:#111; color:#aaa; border:1px solid #444;")
        view_layout.addWidget(self.lbl_raw)

        self.lbl_result = QLabel("检测结果")
        self.lbl_result.setAlignment(Qt.AlignCenter)
        self.lbl_result.setMinimumSize(640, 480)
        self.lbl_result.setStyleSheet("background:#111; color:#aaa; border:1px solid #444;")
        view_layout.addWidget(self.lbl_result)

        main_layout.addLayout(view_layout, stretch=9)

        # ----- 控制面板 -----
        ctrl = QVBoxLayout()
        ctrl.setAlignment(Qt.AlignTop)

        # 模型选择
        model_box = QGroupBox("模型")
        ml = QFormLayout()
        self.cmb_model = QComboBox()
        self.cmb_model.addItems(self.model_list)
        self.cmb_model.currentTextChanged.connect(self.on_model_changed)
        ml.addRow("模型:", self.cmb_model)
        model_box.setLayout(ml)
        ctrl.addWidget(model_box)

        # 置信度
        conf_box = QGroupBox("置信度阈值")
        cl = QFormLayout()
        self.slider_conf = QSlider(Qt.Horizontal)
        self.slider_conf.setRange(0, 100)
        self.slider_conf.setValue(25)
        self.lbl_conf_val = QLabel("0.25")
        self.slider_conf.valueChanged.connect(self.on_conf_changed)
        cl.addRow(self.slider_conf)
        cl.addRow("当前值:", self.lbl_conf_val)
        conf_box.setLayout(cl)
        ctrl.addWidget(conf_box)

        # 保存设置
        save_dir_box = QGroupBox("保存设置")
        sl = QFormLayout()
        self.edit_save_dir = QLineEdit(self.save_dir)
        self.edit_save_dir.setReadOnly(True)
        btn_choose_dir = QPushButton("选择目录")
        btn_choose_dir.clicked.connect(self.choose_save_dir)
        dir_layout = QHBoxLayout()
        dir_layout.addWidget(self.edit_save_dir)
        dir_layout.addWidget(btn_choose_dir)
        self.chk_save = QCheckBox("启用保存")
        self.chk_save.setChecked(True)
        sl.addRow("保存目录:", dir_layout)
        sl.addRow(self.chk_save)
        save_dir_box.setLayout(sl)
        ctrl.addWidget(save_dir_box)

        # 打开 Threads
        btn_box = QVBoxLayout()
        self.btn_open_img = QPushButton("打开图片")
        self.btn_open_vid = QPushButton("打开视频")
        self.btn_open_cam = QPushButton("打开摄像头")
        self.btn_stop = QPushButton("停止")
        self.btn_stop.setEnabled(False)
        self.btn_stop.clicked.connect(self.stop_inference)

        self.btn_open_img.clicked.connect(self.open_image)
        self.btn_open_vid.clicked.connect(self.open_video)
        self.btn_open_cam.clicked.connect(self.open_camera)

        btn_box.addWidget(self.btn_open_img)
        btn_box.addWidget(self.btn_open_vid)
        btn_box.addWidget(self.btn_open_cam)
        btn_box.addWidget(self.btn_stop)
        ctrl.addLayout(btn_box)

        ctrl.addStretch()
        main_layout.addLayout(ctrl, stretch=1)

    def load_model(self, name):
        try:
            self.model = YOLO(name)
            print(f"模型加载成功: {name}")
        except Exception as e:
            QMessageBox.critical(self, "错误", f"模型加载失败:\n{e}")

    def on_model_changed(self, name):
        self.load_model(name)

    def on_conf_changed(self, val):
        conf = val / 100.0
        self.lbl_conf_val.setText(f"{conf:.2f}")
        if self.infer_thread and self.infer_thread.isRunning():
            self.infer_thread.conf = conf

    def choose_save_dir(self):
        dir_path = QFileDialog.getExistingDirectory(self, "选择保存目录", self.save_dir)
        if dir_path:
            self.save_dir = dir_path
            self.edit_save_dir.setText(dir_path)

    def open_image(self):
        path, _ = QFileDialog.getOpenFileName(
            self, "打开图片", "",
            "图片 (*.png *.jpg *.jpeg *.bmp *.tiff *.webp)")
        if path:
            self.start_inference(path)

    def open_video(self):
        path, _ = QFileDialog.getOpenFileName(
            self, "打开视频", "",
            "视频 (*.mp4 *.avi *.mov *.mkv *.wmv *.flv)")
        if path:
            self.start_inference(path)

    def open_camera(self):
        self.start_inference(0)

    def start_inference(self, source):
        if not self.model:
            return

        self.stop_inference()

        save_dir = self.save_dir if self.chk_save.isChecked() else None

        self.infer_thread = InferenceThread(
            model=self.model,
            conf_threshold=self.slider_conf.value() / 100.0,
            save_dir=save_dir
        )
        self.infer_thread.source = source
        self.infer_thread.frame_ready.connect(self.on_frame_ready)
        self.infer_thread.finished_signal.connect(self.on_inference_finished)
        self.infer_thread.error_signal.connect(self.on_inference_error)
        self.infer_thread.start()

        # UI 状态
        self.btn_stop.setEnabled(True)
        self.btn_open_img.setEnabled(False)
        self.btn_open_vid.setEnabled(False)
        self.btn_open_cam.setEnabled(False)

    def stop_inference(self):
        if self.infer_thread and self.infer_thread.isRunning():
            self.infer_thread.stop()
            self.infer_thread.wait()

        self.btn_stop.setEnabled(False)
        self.btn_open_img.setEnabled(True)
        self.btn_open_vid.setEnabled(True)
        self.btn_open_cam.setEnabled(True)

        # === 关键:不清空帧,保留最后一帧(图片/视频都保留)===
        # self.raw_frame = None
        # self.result_frame = None
        # self.lbl_raw.clear()
        # self.lbl_result.clear()

    def on_frame_ready(self, raw, result):
        self.raw_frame = raw.copy()
        self.result_frame = result.copy()

    def on_inference_finished(self):
        # 视频结束时保留最后一帧
        pass

    def on_inference_error(self, msg):
        QMessageBox.warning(self, "错误", msg)
        self.stop_inference()

    def update_frames(self):
        if self.raw_frame is not None:
            self.show_cv_img(self.raw_frame, self.lbl_raw)
        if self.result_frame is not None:
            self.show_cv_img(self.result_frame, self.lbl_result)

    def show_cv_img(self, cv_img, label):
        h, w, ch = cv_img.shape
        bytes_per_line = ch * w
        q_img = QImage(cv_img.data, w, h, bytes_per_line, QImage.Format_RGB888).rgbSwapped()
        pix = QPixmap.fromImage(q_img)
        label.setPixmap(pix.scaled(label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))

    def closeEvent(self, event):
        self.stop_inference()
        super().closeEvent(event)


# ==================== 程序入口 ====================
if __name__ == "__main__":
    app = QApplication(sys.argv)
    win = YOLODetectorGUI()
    win.show()
    sys.exit(app.exec_())

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐