177 lines
5.7 KiB
Python
177 lines
5.7 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
智能体模块
|
||
"""
|
||
|
||
# 列举导入模块
|
||
from pathlib import Path
|
||
import time
|
||
from typing import AsyncGenerator, List, Optional
|
||
from uuid import uuid4
|
||
|
||
from pydantic_ai import Agent as PydanticAIAgent, ModelMessage
|
||
from pydantic_ai.capabilities import AgentCapability
|
||
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
||
from pydantic_ai.models.openai import OpenAIChatModel
|
||
from pydantic_ai.output import OutputSpec
|
||
from pydantic_ai.providers.openai import OpenAIProvider
|
||
|
||
from .sqlite import SQLite
|
||
|
||
|
||
class AgentMemory(SQLite):
|
||
"""
|
||
智能体的记忆体,支持:
|
||
create:新增对话消息
|
||
read:查询会话历史消息
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""
|
||
初始化智能体的记忆体
|
||
"""
|
||
# 构建智能体的记忆体的数据库路径
|
||
super().__init__(database=Path(__file__).parent.resolve() / "agent_memory.db")
|
||
|
||
try:
|
||
with self:
|
||
self.execute(
|
||
sql="""
|
||
CREATE TABLE IF NOT EXISTS new_messages
|
||
(
|
||
--唯一标识
|
||
id TEXT PRIMARY KEY,
|
||
--会话唯一标识
|
||
session_id TEXT NOT NULL,
|
||
--新对话消息
|
||
new_messages TEXT NOT NULL,
|
||
--时间戳(毫秒)
|
||
timestamp INTEGER NOT NULL
|
||
)
|
||
"""
|
||
)
|
||
except Exception as exception:
|
||
raise RuntimeError(
|
||
f"初始化智能体的记忆体发生异常:{str(exception)}"
|
||
) from exception
|
||
|
||
def create_new_messages(
|
||
self, session_id: str, new_messages: List[ModelMessage]
|
||
) -> bool:
|
||
"""
|
||
新增新对话消息
|
||
:param session_id: 会话唯一标识
|
||
:param new_messages: 新对话消息
|
||
:return: 新增是否成功
|
||
"""
|
||
try:
|
||
with self:
|
||
return self.execute(
|
||
sql="""
|
||
INSERT INTO new_messages (id, session_id, new_messages, timestamp) VALUES (?, ?, ?, ?)
|
||
""",
|
||
parameters=(
|
||
uuid4().hex.lower(),
|
||
session_id,
|
||
ModelMessagesTypeAdapter.dump_json(new_messages),
|
||
int(time.time() * 1000),
|
||
),
|
||
)
|
||
except Exception as exception:
|
||
raise RuntimeError(f"新增对话消息发生异常:{str(exception)}") from exception
|
||
|
||
def read_message_history(self, session_id: str) -> List[ModelMessage]:
|
||
"""
|
||
查询对话消息历史
|
||
:param session_id: 会话唯一标识
|
||
:return: 对话消息历史
|
||
"""
|
||
try:
|
||
with self:
|
||
result = self.query_all(
|
||
sql="""
|
||
SELECT new_messages
|
||
FROM new_messages
|
||
WHERE session_id = ?
|
||
ORDER BY timestamp ASC
|
||
""",
|
||
parameters=(session_id,),
|
||
)
|
||
message_history = []
|
||
for row in result:
|
||
message_history.extend(
|
||
ModelMessagesTypeAdapter.validate_json(row["new_messages"])
|
||
)
|
||
return message_history
|
||
except Exception as exception:
|
||
raise RuntimeError(
|
||
f"查询会话历史消息发生异常:{str(exception)}"
|
||
) from exception
|
||
|
||
|
||
class Agent:
|
||
"""
|
||
智能体,支持:
|
||
1 实例智能体
|
||
2 异步运行
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
instructions: str,
|
||
output_type: OutputSpec = str,
|
||
capabilities: Optional[List[AgentCapability]] = None,
|
||
):
|
||
"""
|
||
初始化智能体
|
||
:param instructions: 指令
|
||
:param skills: 智能体技能列表,默认为不使用技能
|
||
:param output_type: 输出类型
|
||
:return: 智能体实例
|
||
"""
|
||
# 生成会话唯一标识
|
||
self.session_id = uuid4().hex.lower()
|
||
|
||
# 实例智能体的记忆体
|
||
self.agent_memory = AgentMemory()
|
||
|
||
# 实例智能体
|
||
self.agent = PydanticAIAgent(
|
||
model=OpenAIChatModel(
|
||
model_name="deepseek-v4-flash",
|
||
provider=OpenAIProvider(
|
||
base_url="https://tokenhub.tencentmaas.com/v1",
|
||
api_key="sk-D9Y1mCe8VlvNqLuSC4mAjqEwxJ2nW4C0h8a7EPn8kg9RLsHq",
|
||
),
|
||
),
|
||
instructions=instructions,
|
||
capabilities=capabilities,
|
||
output_type=output_type,
|
||
retries=1,
|
||
)
|
||
|
||
async def output_message_streamed(
|
||
self, user_prompt: str | List[str]
|
||
) -> AsyncGenerator[str, None]:
|
||
"""
|
||
智能体流式输出消息
|
||
:param user_prompt: 用户提示词(用户输入消息)
|
||
:return: 流式消息
|
||
"""
|
||
"""定义:一次会话(session)包含若干论对话(turn),每一轮对话由用户输入消息(message)和智能体输出消息组成"""
|
||
# 查询该会话历史对话消息列表
|
||
message_history = self.agent_memory.read_message_history(
|
||
session_id=self.session_id
|
||
)
|
||
|
||
async with self.agent.run_stream(
|
||
user_prompt=user_prompt, message_history=message_history
|
||
) as result:
|
||
async for chunk in result.stream_text():
|
||
yield chunk
|
||
|
||
self.agent_memory.create_new_messages(
|
||
session_id=self.session_id,
|
||
new_messages=result.new_messages(),
|
||
)
|