【前置背景】

        Websocket是一种在客户端和服务器之间建立(全双工通信)长连接的协议,解决了HTTP的单项通信和短链接问题,实现了双向通信。服务器可以主动向客户端推送信息,客户端也可以向服务器发送信息,实现平等对话。

        主要适用于实时性要求较高的应用场景,如在线聊天、实时数据更新等。在目前的AI技术场景中,为实现流式对话,一般会采用websocket来实现,比如在构建ChatAI等大模型集成平台时,有使用WebSocket实现流式对话的规划。

        当然,也存在其他与大语言模型集成的方式。例如通过LangChain与text - generation - webui API集成来运行大型语言模型,这表明大语言模型也可以通过其他API接口进行服务化和集成,并非局限于WebSocket接口,本章主要是针对websocket接口进行实现,不多赘述。

【steps】

1.构建WebSocket 服务 / 客户端的基础导入部分

        这部分通常需要导入一些库和函数做好基础准备,包括:

import asyncio   #处理事件循环、协程等异步操作
import websockets  #用于实现 WebSocket 协议的第三方库,提供了异步的 WebSocket 服务器和客户端功能,可与asyncio无缝集成
import ssl   #用于处理加密通信
import json   #解析和生成 JSON 格式的数据。WebSocket 通信中常使用 JSON 作为数据交换格式,因此需要该模块处理数据序列化和反序列化
from utils import safe_json_loads, deep_parse, decode_utf8
#safe_json_loads:安全的 JSON 解析函数(可能包含异常处理,避免解析失败导致程序崩溃)
#deep_parse:深度解析数据的函数(可能用于处理嵌套结构的数据)
#decode_utf8:UTF-8 编码解码函数(可能用于处理 WebSocket 传输的字节数据解码)

后续主要实现

1)异步的websocket连接处理逻辑;

2)消息的接收、解析、响应(使用json和自定义工具函数)

3)可能的加密通信配置(通过ssl模块)

2.建立 WebSocket 连接并进行通信:

#异步的WebSocket 客户端函数 ws_client 的定义,用于建立 WebSocket 连接并进行通信
async def ws_client(
    base_url: str,   #WebSocket 服务器的基础 URL,用于建立连接
    chat_id: str,   #聊天会话的唯一标识符,用于区分不同的对话上下文,确保消息正确路由到对应的会话。
    # 快捷消息
    send_auth: str | None = None,   
    send_type: str | None = "互联网搜索",   
    send_content: str | None = "",   
    # 全格式消息
    send_message: str | None = None,    
    max_take_over_count: int = 30,      
) -> str:
    """
    :param base_url: 需要请求的地址
    :param chat_id: 聊天id
    :param send_auth: 权限Token
    :param send_type: 发送消息的类型, 目前只写了: '互联网搜索',也有包括【深度思考】等其他选项
    :param send_content: 发送的消息内容
    :param send_message: 完整的消息, 如果有则优先使用
    :param max_take_over_count: 最大接收次数
    :return: 请求单次的结果: 如果返回是空字符串则为成功返回结果; 若非空字符串则为失败原因
    """

3.WebSocket 客户端与服务器建立连接、发送消息并处理响应:

    #建立websocket连接
    async with websockets.connect(      #创建 WebSocket 连接,async with语法确保连接会被自动关闭
        uri=f"{base_url}/api/ws/runs/{chat_id}",   #定位到具体的聊天会话
        ssl=ssl._create_unverified_context(),     #禁用 SSL 证书验证(不推荐生产环境使用)
    ) as ws:   #将连接对象赋值给ws变量,后续通过该对象发送 / 接收消息。
        #构造并发送信息    
        # 优先使用全格式消息send_message
        if send_message != None:
            send_message = json.dumps(send_message)
        # 若未提供全格式消息,则用快捷参数构造消息
        elif send_type != None and send_auth != None:
            if send_type == "互联网搜索":
                task_data = {
                    "content": send_content,
                    "Authorization": send_auth,
                    "local_db_name": [],    # 本地数据库(空表示不使用)
                    "collection_db_name": [],   # 集合数据库(空表示不使用)
                    "use_Internet": True,    # 启用互联网搜索
                }
                # 完整消息结构(符合服务器预期的格式)
                send_message = json.dumps(
                    {
                        "type": "start",
                        "task": json.dumps(task_data, ensure_ascii=False),
                        "files": [],    # 附加文件(空表示无文件)
                        "team_config": {     # 团队配置(默认值)
                            "name": "Default Team",
                            "participants": [],
                            "team_type": "RoundRobinGroupChat",
                            "component_type": "team",
                        },
                    }
                )

            #  发送消息到服务器
        await ws.send(send_message)
#逻辑优先级:如果提供了send_message,则直接使用(并序列化为 JSON);否则用send_type、send_auth等参数构造消息
#消息结构:针对 "互联网搜索" 类型,构造了包含任务数据、团队配置等的复杂结构,且task字段再次进行了 JSON 序列化(可能是服务器要求的格式)

        # 接收消息并处理服务器响应
        # 循环接收消息,最多接收max_take_over_count次
        while max_take_over_count > 0:
            sou_data = await ws.recv()    # 异步接收消息(字节类型)
            # 解析消息:JSON反序列化 -> 深度解析 -> UTF-8解码
            data = decode_utf8(deep_parse(json.loads(sou_data)))
            # 判断消息类型
            print(data)
            # type为system的且status为connected的: 表示连接成功,忽略“连接成功” 的系统信息
            if data["type"] == "system" and data["status"] == "connected":
                continue    #跳过,继续等待下一条信息
            #若收到错误信息,直接返回错误信息
            if data["type"] == "error":  
                return data     #终止函数,返回错误
            print(data)
            #减少接收次数计数,防止无限循环
            max_take_over_count -= 1

若要对接口进行压力测试,则执行:

#压力测试
async def PressureTest():
    base_url = ""
    chat_id = ""
    send_auth = """"""
    send_content = "西方此次削减对东南亚援助的具体领域和幅度是怎样的?澳智库担忧该举措会对东南亚地区的经济发展、民生改善以及区域稳定产生哪些具体负面影响?"
    tasks = [
        ws_client(base_url=base_url, chat_id=chat_id,
                  send_auth=send_auth, send_content=send_content),
        ws_client(base_url=base_url, chat_id=chat_id,
                  send_auth=send_auth, send_content=send_content),
        ws_client(base_url=base_url, chat_id=chat_id,
                  send_auth=send_auth, send_content=send_content),
        ws_client(base_url=base_url, chat_id=chat_id,
                  send_auth=send_auth, send_content=send_content),
        ws_client(base_url=base_url, chat_id=chat_id,
                  send_auth=send_auth, send_content=send_content)
    ]
    # asyncio.gather 会同时执行所有任务,并等待完成
    results = await asyncio.gather(*tasks)
    print("全部执行完毕: \n", results)

4.main()

def main():
    # 发送单条信息
    asyncio.run(
        ws_client(
            base_url="{url}",
            chat_id="{chat_id}",
            send_auth="Bearer {token}",
            send_content="西方此次削减对东南亚援助的具体领域和幅度是怎样的?澳智库担忧该举措会对东南亚地区的经济发展、民生改善以及区域稳定产生哪些具体负面影响?",
            # send_message="""{"type":"start","task":"{\"content\":\"西方此次削减对东南亚援助的具体领域和幅度是怎样的?澳智库担忧该举措会对东南亚地区的经济发展、民生改善以及区域稳定产生哪些具体负面影响?\",\"Authorization\":\"Bearer {token}\",\"local_db_name\":[],\"collection_db_name\":[],\"use_Internet\":true}","files":[],"team_config":{"name":"Default Team","participants":[],"team_type":"RoundRobinGroupChat","component_type":"team"}}"""
        )
    )
    # 要压测, 使用下面这种方式
    # asyncio.run(PressureTest())
    pass


# main.py 执行函数
if __name__ == "__main__":
    main()

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐