20260513更新
This commit is contained in:
parent
5e7b10c939
commit
f55bec1c80
|
|
@ -4,3 +4,46 @@
|
|||
"""
|
||||
|
||||
# 列举导入模块
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
|
||||
from utils.agent import BaseAgent
|
||||
from utils.memory import Memory
|
||||
|
||||
|
||||
# 主函数
|
||||
async def main():
|
||||
print("进入会话模式,输入 exit 结束会话")
|
||||
|
||||
# 实例智能体
|
||||
agent = BaseAgent(instructions="请用一句话简洁回复。")
|
||||
# 实例记忆体
|
||||
memory = Memory()
|
||||
|
||||
# 生成会话唯一标识
|
||||
session_id = uuid4().hex.lower()
|
||||
|
||||
dialogue_round = 1
|
||||
while True:
|
||||
user_prompt = input("用户:").strip().lower()
|
||||
if user_prompt == "exit":
|
||||
print("会话结束")
|
||||
break
|
||||
# 查询会话历史消息
|
||||
message_history = memory.read(session_id=session_id)
|
||||
result = await agent.run(
|
||||
user_prompt=user_prompt, message_history=message_history
|
||||
)
|
||||
# 记录会话历史消息
|
||||
memory.create(
|
||||
session_id=session_id,
|
||||
dialogue_round=dialogue_round,
|
||||
dialogue_message=result.new_messages(),
|
||||
)
|
||||
|
||||
print("智能体:", result.output)
|
||||
dialogue_round += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main=main())
|
||||
|
|
|
|||
|
|
@ -4,10 +4,9 @@
|
|||
"""
|
||||
|
||||
# 列举导入模块
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from pydantic_ai import Agent, AgentRunResult
|
||||
from pydantic_ai import Agent, AgentRunResult, ModelMessage
|
||||
from pydantic_ai.models.openai import OpenAIChatModel
|
||||
from pydantic_ai.output import OutputSpec
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
|
|
@ -15,7 +14,9 @@ from pydantic_ai.providers.openai import OpenAIProvider
|
|||
|
||||
class BaseAgent:
|
||||
"""
|
||||
通用智能体基类
|
||||
通用智能体基类,支持:
|
||||
1)实例智能体
|
||||
2)异步运行
|
||||
"""
|
||||
|
||||
def __init__(self, instructions: str, output_type: OutputSpec = str):
|
||||
|
|
@ -51,26 +52,15 @@ class BaseAgent:
|
|||
)
|
||||
return agent
|
||||
|
||||
async def run(self, user_prompt: str | List[str]) -> AgentRunResult:
|
||||
async def run(
|
||||
self, user_prompt: str | List[str], message_history: List[ModelMessage] = []
|
||||
) -> AgentRunResult:
|
||||
"""
|
||||
异步运行智能体
|
||||
异步运行
|
||||
:param user_prompt: 用户提示词
|
||||
:param message_history: 历史消息
|
||||
:return: 智能体回复
|
||||
"""
|
||||
return await self.agent.run(user_prompt=user_prompt)
|
||||
|
||||
|
||||
async def test():
|
||||
# 1. 创建智能体(给系统提示词)
|
||||
agent = BaseAgent(instructions="请用一句话简洁回复。")
|
||||
|
||||
# 2. 运行智能体(给用户问题)
|
||||
result = await agent.run(user_prompt="Hello World 最早出现在哪里?")
|
||||
|
||||
# 3. 输出结果
|
||||
print("答案:", result.output)
|
||||
|
||||
|
||||
# 运行
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test())
|
||||
return await self.agent.run(
|
||||
user_prompt=user_prompt, message_history=message_history
|
||||
)
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆模块
|
||||
"""
|
||||
|
||||
# 列举导入模块
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic_ai import ModelMessage
|
||||
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
||||
|
||||
sys.path.append(Path(__file__).parent.parent.parent.as_posix())
|
||||
from utils.sqlite import SQLite
|
||||
|
||||
|
||||
class Memory(SQLite):
|
||||
"""
|
||||
记忆体,支持:
|
||||
create:新增对话消息
|
||||
read:查询会话历史消息
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化记忆体
|
||||
"""
|
||||
# 构建数据库路径
|
||||
super().__init__(
|
||||
database=Path(__file__).parent.resolve() / "memory_database.db"
|
||||
)
|
||||
|
||||
try:
|
||||
with self:
|
||||
self.execute(
|
||||
sql="""
|
||||
CREATE TABLE IF NOT EXISTS messages
|
||||
(
|
||||
--唯一标识
|
||||
id TEXT PRIMARY KEY,
|
||||
--会话唯一标识
|
||||
session_id TEXT NOT NULL,
|
||||
--对话轮次
|
||||
dialogue_round INTEGER NOT NULL,
|
||||
--对话消息
|
||||
dialogue_message TEXT NOT NULL,
|
||||
--时间戳(毫秒)
|
||||
timestamp INTEGER NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
except Exception as exception:
|
||||
raise RuntimeError(f"初始化记忆体发生异常:{str(exception)}") from exception
|
||||
|
||||
def create(
|
||||
self, session_id: str, dialogue_round: int, dialogue_message: List[ModelMessage]
|
||||
) -> bool:
|
||||
"""
|
||||
新增对话消息
|
||||
:param session_id: 会话唯一标识
|
||||
:param dialogue_round: 对话轮次
|
||||
:param dialogue_message: 对话消息
|
||||
:return: 新增是否成功
|
||||
"""
|
||||
try:
|
||||
with self:
|
||||
return self.execute(
|
||||
sql="""
|
||||
INSERT INTO messages (id, session_id, dialogue_round, dialogue_message, timestamp) VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
parameters=(
|
||||
uuid4().hex.lower(),
|
||||
session_id,
|
||||
dialogue_round,
|
||||
ModelMessagesTypeAdapter.dump_json(dialogue_message),
|
||||
int(time.time() * 1000),
|
||||
),
|
||||
)
|
||||
except Exception as exception:
|
||||
raise RuntimeError(f"新增对话消息发生异常:{str(exception)}") from exception
|
||||
|
||||
def read(self, session_id: str) -> List[ModelMessage]:
|
||||
"""
|
||||
查询会话历史消息
|
||||
:param session_id: 会话唯一标识
|
||||
:return: 会话历史消息
|
||||
"""
|
||||
try:
|
||||
with self:
|
||||
result = self.query_all(
|
||||
sql="""
|
||||
SELECT dialogue_message
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY timestamp ASC
|
||||
""",
|
||||
parameters=(session_id,),
|
||||
)
|
||||
message_history = []
|
||||
for row in result:
|
||||
message_history.extend(
|
||||
ModelMessagesTypeAdapter.validate_json(row["dialogue_message"])
|
||||
)
|
||||
return message_history
|
||||
except Exception as exception:
|
||||
raise RuntimeError(
|
||||
f"查询会话历史消息发生异常:{str(exception)}"
|
||||
) from exception
|
||||
Binary file not shown.
Loading…
Reference in New Issue