From f55bec1c80679664464732619723b57f31c041f4 Mon Sep 17 00:00:00 2001 From: liubiren Date: Wed, 13 May 2026 21:04:57 +0800 Subject: [PATCH] =?UTF-8?q?20260513=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 产品需求文档AI生成/main.py | 43 +++++++ .../utils/{agnet.py => agent.py} | 34 ++---- 产品需求文档AI生成/utils/memory.py | 111 ++++++++++++++++++ 产品需求文档AI生成/utils/memory_database.db | Bin 0 -> 20480 bytes 4 files changed, 166 insertions(+), 22 deletions(-) rename 产品需求文档AI生成/utils/{agnet.py => agent.py} (69%) create mode 100644 产品需求文档AI生成/utils/memory.py create mode 100644 产品需求文档AI生成/utils/memory_database.db diff --git a/产品需求文档AI生成/main.py b/产品需求文档AI生成/main.py index 358d017..509d83e 100644 --- a/产品需求文档AI生成/main.py +++ b/产品需求文档AI生成/main.py @@ -4,3 +4,46 @@ """ # 列举导入模块 +import asyncio +from uuid import uuid4 + +from utils.agent import BaseAgent +from utils.memory import Memory + + +# 主函数 +async def main(): + print("进入会话模式,输入 exit 结束会话") + + # 实例智能体 + agent = BaseAgent(instructions="请用一句话简洁回复。") + # 实例记忆体 + memory = Memory() + + # 生成会话唯一标识 + session_id = uuid4().hex.lower() + + dialogue_round = 1 + while True: + user_prompt = input("用户:").strip().lower() + if user_prompt == "exit": + print("会话结束") + break + # 查询会话历史消息 + message_history = memory.read(session_id=session_id) + result = await agent.run( + user_prompt=user_prompt, message_history=message_history + ) + # 记录会话历史消息 + memory.create( + session_id=session_id, + dialogue_round=dialogue_round, + dialogue_message=result.new_messages(), + ) + + print("智能体:", result.output) + dialogue_round += 1 + + +if __name__ == "__main__": + asyncio.run(main=main()) diff --git a/产品需求文档AI生成/utils/agnet.py b/产品需求文档AI生成/utils/agent.py similarity index 69% rename from 产品需求文档AI生成/utils/agnet.py rename to 产品需求文档AI生成/utils/agent.py index b666ee1..5853b15 100644 --- a/产品需求文档AI生成/utils/agnet.py +++ b/产品需求文档AI生成/utils/agent.py @@ -4,10 +4,9 @@ """ # 列举导入模块 -import asyncio from typing import List -from pydantic_ai import Agent, AgentRunResult +from pydantic_ai import Agent, AgentRunResult, ModelMessage from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.output import OutputSpec from pydantic_ai.providers.openai import OpenAIProvider @@ -15,7 +14,9 @@ from pydantic_ai.providers.openai import OpenAIProvider class BaseAgent: """ - 通用智能体基类 + 通用智能体基类,支持: + 1)实例智能体 + 2)异步运行 """ def __init__(self, instructions: str, output_type: OutputSpec = str): @@ -51,26 +52,15 @@ class BaseAgent: ) return agent - async def run(self, user_prompt: str | List[str]) -> AgentRunResult: + async def run( + self, user_prompt: str | List[str], message_history: List[ModelMessage] = [] + ) -> AgentRunResult: """ - 异步运行智能体 + 异步运行 :param user_prompt: 用户提示词 + :param message_history: 历史消息 :return: 智能体回复 """ - return await self.agent.run(user_prompt=user_prompt) - - -async def test(): - # 1. 创建智能体(给系统提示词) - agent = BaseAgent(instructions="请用一句话简洁回复。") - - # 2. 运行智能体(给用户问题) - result = await agent.run(user_prompt="Hello World 最早出现在哪里?") - - # 3. 输出结果 - print("答案:", result.output) - - -# 运行 -if __name__ == "__main__": - asyncio.run(test()) + return await self.agent.run( + user_prompt=user_prompt, message_history=message_history + ) diff --git a/产品需求文档AI生成/utils/memory.py b/产品需求文档AI生成/utils/memory.py new file mode 100644 index 0000000..f067185 --- /dev/null +++ b/产品需求文档AI生成/utils/memory.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +""" +记忆模块 +""" + +# 列举导入模块 +from pathlib import Path +import sys +import time +from typing import List +from uuid import uuid4 + +from pydantic_ai import ModelMessage +from pydantic_ai.messages import ModelMessagesTypeAdapter + +sys.path.append(Path(__file__).parent.parent.parent.as_posix()) +from utils.sqlite import SQLite + + +class Memory(SQLite): + """ + 记忆体,支持: + create:新增对话消息 + read:查询会话历史消息 + """ + + def __init__(self): + """ + 初始化记忆体 + """ + # 构建数据库路径 + super().__init__( + database=Path(__file__).parent.resolve() / "memory_database.db" + ) + + try: + with self: + self.execute( + sql=""" + CREATE TABLE IF NOT EXISTS messages + ( + --唯一标识 + id TEXT PRIMARY KEY, + --会话唯一标识 + session_id TEXT NOT NULL, + --对话轮次 + dialogue_round INTEGER NOT NULL, + --对话消息 + dialogue_message TEXT NOT NULL, + --时间戳(毫秒) + timestamp INTEGER NOT NULL + ) + """ + ) + except Exception as exception: + raise RuntimeError(f"初始化记忆体发生异常:{str(exception)}") from exception + + def create( + self, session_id: str, dialogue_round: int, dialogue_message: List[ModelMessage] + ) -> bool: + """ + 新增对话消息 + :param session_id: 会话唯一标识 + :param dialogue_round: 对话轮次 + :param dialogue_message: 对话消息 + :return: 新增是否成功 + """ + try: + with self: + return self.execute( + sql=""" + INSERT INTO messages (id, session_id, dialogue_round, dialogue_message, timestamp) VALUES (?, ?, ?, ?, ?) + """, + parameters=( + uuid4().hex.lower(), + session_id, + dialogue_round, + ModelMessagesTypeAdapter.dump_json(dialogue_message), + int(time.time() * 1000), + ), + ) + except Exception as exception: + raise RuntimeError(f"新增对话消息发生异常:{str(exception)}") from exception + + def read(self, session_id: str) -> List[ModelMessage]: + """ + 查询会话历史消息 + :param session_id: 会话唯一标识 + :return: 会话历史消息 + """ + try: + with self: + result = self.query_all( + sql=""" + SELECT dialogue_message + FROM messages + WHERE session_id = ? + ORDER BY timestamp ASC + """, + parameters=(session_id,), + ) + message_history = [] + for row in result: + message_history.extend( + ModelMessagesTypeAdapter.validate_json(row["dialogue_message"]) + ) + return message_history + except Exception as exception: + raise RuntimeError( + f"查询会话历史消息发生异常:{str(exception)}" + ) from exception diff --git a/产品需求文档AI生成/utils/memory_database.db b/产品需求文档AI生成/utils/memory_database.db new file mode 100644 index 0000000000000000000000000000000000000000..d41aeb50b6213f6b46c3f5691eddcd3c05d3a725 GIT binary patch literal 20480 zcmeHO%Wo4$7`GEr3RRSu3lx zf4#f0>7PJdhx$nduw_ko6lV<{B&Y^;!stCb_^`pxFwe1Ssx6< z5(e30?T*7YeH3Txr1y8UH1MKa>8>8;)2r^BazFBeO^>`=|Ki@#gCal?pa{H}2poR1 zZgqS6np1BzAcdHmpmV%th)9*ON_ZL0Sd=tXL6V?K0xPgc;Y3qMih?l0f@~5(1P;jx zhrZk!PT{m|g(GmMClfYF*x^Vx|LfWOFBb|!`wN5L!Wq`W{F&VJ#D(z2@GMC~HJ;@q zh7}o3=;XKvhav*kEGbCkl`o*Ct2nv~2pDQ+ERtqY>10pJ4)5E!7%d_%a;}zG!m`sD z-G;>#ybKp}V}+Af;U&O1b_G_^>rYR~9f<45>2DWiH zlE}p4`!-g9HhX;X-u0=|BTn|_lQ2UYo-73H^XG3kW5e*Y_}kgZJ3oW?)a5hIgDkY) z88|gPainl=G=KDp4}>#t%Q=?w;qY3XCfG_QVu`M3sf#y;tyotA+nF@+-G%N-CwpTC zNkitJ>I`|cPs(OcZJ<%myEzTD!C16Tu){MAGb8G zu$mRJ#~(f!_a!qN*$bm1m9e9Cau;z&2PZ=JVZEC~ze>9=0Ht!aR182E<)z7tT`tcH z?s#)caK?xw%WL{-{o)E*o{*Vs$=>i>50rf3)P4J4%p?sGFClIalClWd#q=V^jAN@C zP^&)Ogl18ckm^m8*@17RV5*Svj;q+CnRFa_y4$u>R-~oHbFDk0H3JLiKm-7q)=Z=RyT zE}F6zp3Tca13SI#{JOPGe+KHBP6V39nohtiJtzVc0g3=cfFeKZz4n2CqZiILJ;F`ez{Lq6UKoOt_Py{Ff6ak6= zMc}z2a73?%VEW)0siD?xP^~Pdx=PXXkc)HYyV-mih}(u3?!{#~A-2!j=>$>k`RrIRcNuWa z&miqa%lqN+Zt(oQzRB?{Bri=3^%Z|~v$&=%J(wIn>>L_!W9hw=Xo#Zsy67=MUo+@= z*f+Fbx)h|UJ?uLVn!zrpg*g1lyZvru@8cF8xxx8IZ#qL)A5&F%ILvADg87|;{e?>- zzGhzoI6Ti!r~iO6GX5fDVNKq@IL|i#zdCetZqLik@>7ZC70xCE&6JlV{@FEBws@tK zoAq=lTjC0S|0DUC;{W{_FEq