业务流程

实现步骤

1. 加载数据库配置

在项目的根目录下创建.env 文件,设置文件内容:

DB_HOST=xxx
DB_PORT=3306
DB_USER=xxx
DB_PASSWORD=xxx
DB_NAME=xxx
DB_CHARSET=utf8mb4

加载环境变量,从 .env 文件中读取数据库配置信息

使用 os.getenv() 从环境变量中获取数据库的主机地址、端口、用户名、密码、数据库名和字符集

配置数据库连接参数

使用 quote 对密码进行 URL 编码,确保密码中的特殊字符不会导致连接失败

import os
from urllib.parse import quote
from dotenv import load_dotenv

load_dotenv()
B_CONFIG = {
    "host": os.getenv('DB_HOST'),           
    "port": int(os.getenv('DB_PORT')),     
    "user": os.getenv('DB_USER'),
    "password": os.getenv('DB_PASSWORD'),
    "database": os.getenv('DB_NAME'),
    "charset": os.getenv('DB_CHARSET')     
}


# 处理特殊字符密码
encoded_password = quote(DB_CONFIG['password'])

构建 MySQL 数据库连接 URI,并连接数据库

构建 SQLAlchemy 的连接 URI,使用 pymysql 作为驱动程序。
设置连接超时时间为 10 秒

创建一个 SQLDatabase 实例,用于与 MySQL 数据库交互

MYSQL_URI = (
    f"mysql+pymysql://{DB_CONFIG['user']}:{encoded_password}@"
    f"{DB_CONFIG['host']}:{DB_CONFIG['port']}/"
    f"{DB_CONFIG['database']}?"
    f"charset={DB_CONFIG['charset']}&connect_timeout=10"
)

db = SQLDatabase.from_uri(MYSQL_URI)

2.初始化大语言模型

初始化一个基于 ChatOpenAI 的模型,使用智谱 AI 的 GLM-4 模型。
配置 API 密钥和基础 URL

llm = ChatOpenAI(
    temperature=1,
    model='glm-4-0520',
    api_key='*****',
    base_url='https://open.bigmodel.cn/api/paas/v4/'
)

3.定义提示模板

提示模板指导 LLM 根据给定的表结构和用户问题生成 SQL 查询语句

custom_prompt = PromptTemplate.from_template("""
你是一个专业的SQL工程师,请根据以下表结构生成标准SQL查询语句:

{table_info}

请最多返回 {top_k} 条记录。

问题:{input}
SQL查询:
""")

4.SQL 查询链的创建和调用

定义表结构 table_info 和最大返回记录数 top_k。
调用 invoke 方法生成 SQL 查询语句


chian = create_sql_query_chain(
    llm=llm,
    db=db,
    prompt=custom_prompt
)

# chian.get_prompts()[0].pretty_print()
# 表结构信息和 top_k 的值
table_info = "这里是表结构信息,例如:member(id, name, tenant_code, deleted)"
top_k = 3
resp = chian.invoke({
    "input": "member表中lf租户下体系id为15286788且deleted=0的会员,一共有多少人?",
    "question": "member表中lf租户下体系id为15286788且deleted=0的会员,一共有多少人?",
    'table_info': table_info,
    'top_k': top_k
})

5.输出打印

执行生成的 SQL 查询。
使用 ast.literal_eval 安全地解析结果。
输出最终的查询结果。


print('大语言模型生成的SQL:' + resp)
sql = resp.replace('```sql', '').replace('```', '')
print('提取之后的SQL:' + sql)

try:
    result = db.run(sql)
    # 清洗结果
    result_list = ast.literal_eval(result)
    total_count = result_list[0][0]
    print(f"最终的查询结果为:{total_count}")
except Exception as e:
    print(f"❌ SQL 执行失败: {str(e)}")

输出结果:

完整代码:

import ast
from langchain.chains.sql_database.query import create_sql_query_chain
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
import os
from urllib.parse import quote
from dotenv import load_dotenv

load_dotenv()
# 基础配置(建议通过环境变量获取)
DB_CONFIG = {
    "host": os.getenv('DB_HOST'),          # 移除默认值
    "port": int(os.getenv('DB_PORT')),     # 必须转换为整数
    "user": os.getenv('DB_USER'),
    "password": os.getenv('DB_PASSWORD'),
    "database": os.getenv('DB_NAME'),
    "charset": os.getenv('DB_CHARSET')     # 动态获取字符集
}


# 处理特殊字符密码
encoded_password = quote(DB_CONFIG['password'])

# SQLAlchemy连接URI
MYSQL_URI = (
    f"mysql+pymysql://{DB_CONFIG['user']}:{encoded_password}@"
    f"{DB_CONFIG['host']}:{DB_CONFIG['port']}/"
    f"{DB_CONFIG['database']}?"
    f"charset={DB_CONFIG['charset']}&connect_timeout=10"
)
# 创建模型
llm = ChatOpenAI(
    temperature=1,
    model='glm-4-0520',
    api_key='****',
    base_url='https://open.bigmodel.cn/api/paas/v4/'
)

db = SQLDatabase.from_uri(MYSQL_URI)
# print(db.dialect)
# print(db.get_usable_table_names())
# print(db.run("SELECT COUNT(1) FROM member where saas_tenant_code ='linefriends' and deleted=0;"))

# 自定义提示模板
custom_prompt = PromptTemplate.from_template("""
你是一个专业的SQL工程师,请根据以下表结构生成标准SQL查询语句:

{table_info}

请最多返回 {top_k} 条记录。

问题:{input}
SQL查询:
""")


chian = create_sql_query_chain(
    llm=llm,
    db=db,
    prompt=custom_prompt
)

# chian.get_prompts()[0].pretty_print()
# 表结构信息和 top_k 的值
table_info = "这里是表结构信息,例如:member(id, name, tenant_code, deleted)"
top_k = 3
resp = chian.invoke({
    "input": "member表中lf租户下体系id为15286788且deleted=0的会员,一共有多少人?",
    "question": "member表中lf租户下体系id为15286788且deleted=0的会员,一共有多少人?",
    'table_info': table_info,
    'top_k': top_k
})

print('大语言模型生成的SQL:' + resp)
sql = resp.replace('```sql', '').replace('```', '')
print('提取之后的SQL:' + sql)

try:
    result = db.run(sql)
    # 清洗结果
    result_list = ast.literal_eval(result)
    total_count = result_list[0][0]
    print(f"最终的查询结果为:{total_count}")
except Exception as e:
    print(f"❌ SQL 执行失败: {str(e)}")

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐