diff --git a/产品需求文档AI生成/main.py b/产品需求文档AI生成/main.py index 358d017..509d83e 100644 --- a/产品需求文档AI生成/main.py +++ b/产品需求文档AI生成/main.py @@ -4,3 +4,46 @@ """ # 列举导入模块 +import asyncio +from uuid import uuid4 + +from utils.agent import BaseAgent +from utils.memory import Memory + + +# 主函数 +async def main(): + print("进入会话模式,输入 exit 结束会话") + + # 实例智能体 + agent = BaseAgent(instructions="请用一句话简洁回复。") + # 实例记忆体 + memory = Memory() + + # 生成会话唯一标识 + session_id = uuid4().hex.lower() + + dialogue_round = 1 + while True: + user_prompt = input("用户:").strip().lower() + if user_prompt == "exit": + print("会话结束") + break + # 查询会话历史消息 + message_history = memory.read(session_id=session_id) + result = await agent.run( + user_prompt=user_prompt, message_history=message_history + ) + # 记录会话历史消息 + memory.create( + session_id=session_id, + dialogue_round=dialogue_round, + dialogue_message=result.new_messages(), + ) + + print("智能体:", result.output) + dialogue_round += 1 + + +if __name__ == "__main__": + asyncio.run(main=main()) diff --git a/产品需求文档AI生成/utils/agnet.py b/产品需求文档AI生成/utils/agent.py similarity index 69% rename from 产品需求文档AI生成/utils/agnet.py rename to 产品需求文档AI生成/utils/agent.py index b666ee1..5853b15 100644 --- a/产品需求文档AI生成/utils/agnet.py +++ b/产品需求文档AI生成/utils/agent.py @@ -4,10 +4,9 @@ """ # 列举导入模块 -import asyncio from typing import List -from pydantic_ai import Agent, AgentRunResult +from pydantic_ai import Agent, AgentRunResult, ModelMessage from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.output import OutputSpec from pydantic_ai.providers.openai import OpenAIProvider @@ -15,7 +14,9 @@ from pydantic_ai.providers.openai import OpenAIProvider class BaseAgent: """ - 通用智能体基类 + 通用智能体基类,支持: + 1)实例智能体 + 2)异步运行 """ def __init__(self, instructions: str, output_type: OutputSpec = str): @@ -51,26 +52,15 @@ class BaseAgent: ) return agent - async def run(self, user_prompt: str | List[str]) -> AgentRunResult: + async def run( + self, user_prompt: str | List[str], message_history: List[ModelMessage] = [] + ) -> AgentRunResult: """ - 异步运行智能体 + 异步运行 :param user_prompt: 用户提示词 + :param message_history: 历史消息 :return: 智能体回复 """ - return await self.agent.run(user_prompt=user_prompt) - - -async def test(): - # 1. 创建智能体(给系统提示词) - agent = BaseAgent(instructions="请用一句话简洁回复。") - - # 2. 运行智能体(给用户问题) - result = await agent.run(user_prompt="Hello World 最早出现在哪里?") - - # 3. 输出结果 - print("答案:", result.output) - - -# 运行 -if __name__ == "__main__": - asyncio.run(test()) + return await self.agent.run( + user_prompt=user_prompt, message_history=message_history + ) diff --git a/产品需求文档AI生成/utils/memory.py b/产品需求文档AI生成/utils/memory.py new file mode 100644 index 0000000..f067185 --- /dev/null +++ b/产品需求文档AI生成/utils/memory.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +""" +记忆模块 +""" + +# 列举导入模块 +from pathlib import Path +import sys +import time +from typing import List +from uuid import uuid4 + +from pydantic_ai import ModelMessage +from pydantic_ai.messages import ModelMessagesTypeAdapter + +sys.path.append(Path(__file__).parent.parent.parent.as_posix()) +from utils.sqlite import SQLite + + +class Memory(SQLite): + """ + 记忆体,支持: + create:新增对话消息 + read:查询会话历史消息 + """ + + def __init__(self): + """ + 初始化记忆体 + """ + # 构建数据库路径 + super().__init__( + database=Path(__file__).parent.resolve() / "memory_database.db" + ) + + try: + with self: + self.execute( + sql=""" + CREATE TABLE IF NOT EXISTS messages + ( + --唯一标识 + id TEXT PRIMARY KEY, + --会话唯一标识 + session_id TEXT NOT NULL, + --对话轮次 + dialogue_round INTEGER NOT NULL, + --对话消息 + dialogue_message TEXT NOT NULL, + --时间戳(毫秒) + timestamp INTEGER NOT NULL + ) + """ + ) + except Exception as exception: + raise RuntimeError(f"初始化记忆体发生异常:{str(exception)}") from exception + + def create( + self, session_id: str, dialogue_round: int, dialogue_message: List[ModelMessage] + ) -> bool: + """ + 新增对话消息 + :param session_id: 会话唯一标识 + :param dialogue_round: 对话轮次 + :param dialogue_message: 对话消息 + :return: 新增是否成功 + """ + try: + with self: + return self.execute( + sql=""" + INSERT INTO messages (id, session_id, dialogue_round, dialogue_message, timestamp) VALUES (?, ?, ?, ?, ?) + """, + parameters=( + uuid4().hex.lower(), + session_id, + dialogue_round, + ModelMessagesTypeAdapter.dump_json(dialogue_message), + int(time.time() * 1000), + ), + ) + except Exception as exception: + raise RuntimeError(f"新增对话消息发生异常:{str(exception)}") from exception + + def read(self, session_id: str) -> List[ModelMessage]: + """ + 查询会话历史消息 + :param session_id: 会话唯一标识 + :return: 会话历史消息 + """ + try: + with self: + result = self.query_all( + sql=""" + SELECT dialogue_message + FROM messages + WHERE session_id = ? + ORDER BY timestamp ASC + """, + parameters=(session_id,), + ) + message_history = [] + for row in result: + message_history.extend( + ModelMessagesTypeAdapter.validate_json(row["dialogue_message"]) + ) + return message_history + except Exception as exception: + raise RuntimeError( + f"查询会话历史消息发生异常:{str(exception)}" + ) from exception diff --git a/产品需求文档AI生成/utils/memory_database.db b/产品需求文档AI生成/utils/memory_database.db new file mode 100644 index 0000000..d41aeb5 Binary files /dev/null and b/产品需求文档AI生成/utils/memory_database.db differ