6.1 案例

6.1.1 NL2SQL

如果可以的话,还应该加一个反馈机制:当第一次写出来的SQL语句在执行时发生错误后,不应该停止,而是将错误信息与错误的SQL语句喂给大模型,让大模型修正SQL语句,直至正确。

import os
import pandas as pd
from dotenv import load_dotenv
from dataclasses import dataclass
from typing import List, Dict, Any
from langchain.agents import create_agent
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain.agents.structured_output import ToolStrategy

# 大模型
llm = ChatOllama(
    model="qwen3:4b",
    base_url="http://localhost:11434",
    temperature=0
)

# 连接数据库
load_dotenv()

MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
MYSQL_PORT = os.getenv("MYSQL_PORT", "3306")
MYSQL_USER = os.getenv("MYSQL_USER")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD")
MYSQL_DB = os.getenv("MYSQL_DB")


DB_URI = (
    f"mysql+pymysql://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/{MYSQL_DB}?charset=utf8mb4"
)

db = SQLDatabase.from_uri(
    DB_URI,
    include_tables=["df_customers", "df_orders"]  # 白名单
)
table_info = db.get_table_info()   # 获取数据库中表结构

# 定义提示词模板
system_template = """
你是MySQL查询专家,严格遵守以下规则:
1. 仅执行查询操作,禁止增、删、改这些危险操作
2. 严格根据表结构进行查询,表结构信息:{table_info}
3. 调用工具前必须验证字段是否存在,生成SQL语句后自检语法
4. 当生成完SQL语句后,需要执行SQL语句,从数据库中提取相应的数据
5. 如果涉及多表连接、聚合等复杂操作,必要时可以创建临时表、使用子查询等方法
"""
system_prompt = SystemMessagePromptTemplate.from_template(system_template)
user_prompt = HumanMessagePromptTemplate.from_template("{question}")
prompt = ChatPromptTemplate.from_messages([
    system_prompt,
    user_prompt
])

# 定义工具
def generate_sql(question: str) -> str:
    """根据自然语言问题生成一条 MySQL SQL语句。"""
    msg = prompt.format_messages(
        table_info = table_info,
        question=question
    )
    sql = llm.invoke(msg).content.strip()

    return sql

def run_sql(sql: str) -> dict:
    """在数据库中执行SQL语句,提取相应的数据"""
    with db._engine.connect() as con:
        df = pd.read_sql(sql, con)
        df = df.to_dict(orient="records")
        return df

# 格式化输出
@dataclass
class FinalAnswer:
    """输出结果"""
    sql: str
    df: List[Dict[str, Any]]
    explanation: str

# 定义智能体
agent = create_agent(
    model=llm,
    tools=[generate_sql, run_sql],
    response_format=ToolStrategy(FinalAnswer)
)

def sql_query_agent(question):
    out = agent.invoke(
        {"messages": [{"role": "user", "content": question}]}
    )
    sql = out["structured_response"].sql
    df = pd.DataFrame(out["structured_response"].df)
    explanation = out["structured_response"].explanation
    
    return sql, df, explanation

# 流式输出
# for step in agent.stream(
#     {"messages": [{"role": "user", "content": "查询ID为2的用户的个人信息及其总下单次数,总下单次数列用“order_num”表示"}]},
#     stream_mode="values",
# ):
#     step["messages"][-1].pretty_print()

if __name__ == "__main__":
    question = "查询ID为2的用户的个人信息及其总下单次数,总下单次数列用“order_num”表示"
    sql, df, _ = sql_query_agent(question)