20260513更新
This commit is contained in:
parent
5e7b10c939
commit
f55bec1c80
|
|
@ -4,3 +4,46 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 列举导入模块
|
# 列举导入模块
|
||||||
|
import asyncio
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from utils.agent import BaseAgent
|
||||||
|
from utils.memory import Memory
|
||||||
|
|
||||||
|
|
||||||
|
# 主函数
|
||||||
|
async def main():
|
||||||
|
print("进入会话模式,输入 exit 结束会话")
|
||||||
|
|
||||||
|
# 实例智能体
|
||||||
|
agent = BaseAgent(instructions="请用一句话简洁回复。")
|
||||||
|
# 实例记忆体
|
||||||
|
memory = Memory()
|
||||||
|
|
||||||
|
# 生成会话唯一标识
|
||||||
|
session_id = uuid4().hex.lower()
|
||||||
|
|
||||||
|
dialogue_round = 1
|
||||||
|
while True:
|
||||||
|
user_prompt = input("用户:").strip().lower()
|
||||||
|
if user_prompt == "exit":
|
||||||
|
print("会话结束")
|
||||||
|
break
|
||||||
|
# 查询会话历史消息
|
||||||
|
message_history = memory.read(session_id=session_id)
|
||||||
|
result = await agent.run(
|
||||||
|
user_prompt=user_prompt, message_history=message_history
|
||||||
|
)
|
||||||
|
# 记录会话历史消息
|
||||||
|
memory.create(
|
||||||
|
session_id=session_id,
|
||||||
|
dialogue_round=dialogue_round,
|
||||||
|
dialogue_message=result.new_messages(),
|
||||||
|
)
|
||||||
|
|
||||||
|
print("智能体:", result.output)
|
||||||
|
dialogue_round += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main=main())
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,9 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 列举导入模块
|
# 列举导入模块
|
||||||
import asyncio
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pydantic_ai import Agent, AgentRunResult
|
from pydantic_ai import Agent, AgentRunResult, ModelMessage
|
||||||
from pydantic_ai.models.openai import OpenAIChatModel
|
from pydantic_ai.models.openai import OpenAIChatModel
|
||||||
from pydantic_ai.output import OutputSpec
|
from pydantic_ai.output import OutputSpec
|
||||||
from pydantic_ai.providers.openai import OpenAIProvider
|
from pydantic_ai.providers.openai import OpenAIProvider
|
||||||
|
|
@ -15,7 +14,9 @@ from pydantic_ai.providers.openai import OpenAIProvider
|
||||||
|
|
||||||
class BaseAgent:
|
class BaseAgent:
|
||||||
"""
|
"""
|
||||||
通用智能体基类
|
通用智能体基类,支持:
|
||||||
|
1)实例智能体
|
||||||
|
2)异步运行
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, instructions: str, output_type: OutputSpec = str):
|
def __init__(self, instructions: str, output_type: OutputSpec = str):
|
||||||
|
|
@ -51,26 +52,15 @@ class BaseAgent:
|
||||||
)
|
)
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
async def run(self, user_prompt: str | List[str]) -> AgentRunResult:
|
async def run(
|
||||||
|
self, user_prompt: str | List[str], message_history: List[ModelMessage] = []
|
||||||
|
) -> AgentRunResult:
|
||||||
"""
|
"""
|
||||||
异步运行智能体
|
异步运行
|
||||||
:param user_prompt: 用户提示词
|
:param user_prompt: 用户提示词
|
||||||
|
:param message_history: 历史消息
|
||||||
:return: 智能体回复
|
:return: 智能体回复
|
||||||
"""
|
"""
|
||||||
return await self.agent.run(user_prompt=user_prompt)
|
return await self.agent.run(
|
||||||
|
user_prompt=user_prompt, message_history=message_history
|
||||||
|
)
|
||||||
async def test():
|
|
||||||
# 1. 创建智能体(给系统提示词)
|
|
||||||
agent = BaseAgent(instructions="请用一句话简洁回复。")
|
|
||||||
|
|
||||||
# 2. 运行智能体(给用户问题)
|
|
||||||
result = await agent.run(user_prompt="Hello World 最早出现在哪里?")
|
|
||||||
|
|
||||||
# 3. 输出结果
|
|
||||||
print("答案:", result.output)
|
|
||||||
|
|
||||||
|
|
||||||
# 运行
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(test())
|
|
||||||
|
|
@ -0,0 +1,111 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
记忆模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 列举导入模块
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic_ai import ModelMessage
|
||||||
|
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
||||||
|
|
||||||
|
sys.path.append(Path(__file__).parent.parent.parent.as_posix())
|
||||||
|
from utils.sqlite import SQLite
|
||||||
|
|
||||||
|
|
||||||
|
class Memory(SQLite):
|
||||||
|
"""
|
||||||
|
记忆体,支持:
|
||||||
|
create:新增对话消息
|
||||||
|
read:查询会话历史消息
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
初始化记忆体
|
||||||
|
"""
|
||||||
|
# 构建数据库路径
|
||||||
|
super().__init__(
|
||||||
|
database=Path(__file__).parent.resolve() / "memory_database.db"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self:
|
||||||
|
self.execute(
|
||||||
|
sql="""
|
||||||
|
CREATE TABLE IF NOT EXISTS messages
|
||||||
|
(
|
||||||
|
--唯一标识
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
--会话唯一标识
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
--对话轮次
|
||||||
|
dialogue_round INTEGER NOT NULL,
|
||||||
|
--对话消息
|
||||||
|
dialogue_message TEXT NOT NULL,
|
||||||
|
--时间戳(毫秒)
|
||||||
|
timestamp INTEGER NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except Exception as exception:
|
||||||
|
raise RuntimeError(f"初始化记忆体发生异常:{str(exception)}") from exception
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self, session_id: str, dialogue_round: int, dialogue_message: List[ModelMessage]
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
新增对话消息
|
||||||
|
:param session_id: 会话唯一标识
|
||||||
|
:param dialogue_round: 对话轮次
|
||||||
|
:param dialogue_message: 对话消息
|
||||||
|
:return: 新增是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self:
|
||||||
|
return self.execute(
|
||||||
|
sql="""
|
||||||
|
INSERT INTO messages (id, session_id, dialogue_round, dialogue_message, timestamp) VALUES (?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
parameters=(
|
||||||
|
uuid4().hex.lower(),
|
||||||
|
session_id,
|
||||||
|
dialogue_round,
|
||||||
|
ModelMessagesTypeAdapter.dump_json(dialogue_message),
|
||||||
|
int(time.time() * 1000),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as exception:
|
||||||
|
raise RuntimeError(f"新增对话消息发生异常:{str(exception)}") from exception
|
||||||
|
|
||||||
|
def read(self, session_id: str) -> List[ModelMessage]:
|
||||||
|
"""
|
||||||
|
查询会话历史消息
|
||||||
|
:param session_id: 会话唯一标识
|
||||||
|
:return: 会话历史消息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self:
|
||||||
|
result = self.query_all(
|
||||||
|
sql="""
|
||||||
|
SELECT dialogue_message
|
||||||
|
FROM messages
|
||||||
|
WHERE session_id = ?
|
||||||
|
ORDER BY timestamp ASC
|
||||||
|
""",
|
||||||
|
parameters=(session_id,),
|
||||||
|
)
|
||||||
|
message_history = []
|
||||||
|
for row in result:
|
||||||
|
message_history.extend(
|
||||||
|
ModelMessagesTypeAdapter.validate_json(row["dialogue_message"])
|
||||||
|
)
|
||||||
|
return message_history
|
||||||
|
except Exception as exception:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"查询会话历史消息发生异常:{str(exception)}"
|
||||||
|
) from exception
|
||||||
Binary file not shown.
Loading…
Reference in New Issue