Python/utils/agent.py

182 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
智能体模块
"""
# 列举导入模块
from pathlib import Path
import time
from typing import AsyncGenerator, List, Optional
from uuid import uuid4
from pydantic_ai import Agent as PydanticAIAgent, ModelMessage
from pydantic_ai.capabilities import AgentCapability
from pydantic_ai.messages import ModelMessagesTypeAdapter
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.output import OutputSpec
from pydantic_ai.providers.openai import OpenAIProvider
from .sqlite import SQLite
class AgentMemory(SQLite):
"""
智能体的记忆体,支持:
create新增对话消息
read查询会话历史消息
"""
def __init__(self):
"""
初始化智能体的记忆体
"""
# 构建智能体的记忆体的数据库路径
super().__init__(database=Path(__file__).parent.resolve() / "agent_memory.db")
try:
with self:
self.execute(
sql="""
CREATE TABLE IF NOT EXISTS new_messages
(
--唯一标识
id TEXT PRIMARY KEY,
--会话唯一标识
session_id TEXT NOT NULL,
--新对话消息
new_messages TEXT NOT NULL,
--时间戳(毫秒)
timestamp INTEGER NOT NULL
)
"""
)
except Exception as exception:
raise RuntimeError(
f"初始化智能体的记忆体发生异常:{str(exception)}"
) from exception
def create_new_messages(
self, session_id: str, new_messages: List[ModelMessage]
) -> bool:
"""
新增新对话消息
:param session_id: 会话唯一标识
:param new_messages: 新对话消息
:return: 新增是否成功
"""
try:
with self:
return self.execute(
sql="""
INSERT INTO new_messages (id, session_id, new_messages, timestamp) VALUES (?, ?, ?, ?)
""",
parameters=(
uuid4().hex.lower(),
session_id,
ModelMessagesTypeAdapter.dump_json(new_messages),
int(time.time() * 1000),
),
)
except Exception as exception:
raise RuntimeError(f"新增对话消息发生异常:{str(exception)}") from exception
def read_message_history(self, session_id: str) -> List[ModelMessage]:
"""
查询对话消息历史
:param session_id: 会话唯一标识
:return: 对话消息历史
"""
try:
with self:
result = self.query_all(
sql="""
SELECT new_messages
FROM new_messages
WHERE session_id = ?
ORDER BY timestamp ASC
""",
parameters=(session_id,),
)
message_history = []
for row in result:
message_history.extend(
ModelMessagesTypeAdapter.validate_json(row["new_messages"])
)
return message_history
except Exception as exception:
raise RuntimeError(
f"查询会话历史消息发生异常:{str(exception)}"
) from exception
class Agent:
"""
智能体,支持:
1 实例智能体
2 异步运行
"""
def __init__(
self,
session_id: str,
instructions: str,
output_type: OutputSpec = str,
capabilities: Optional[List[AgentCapability]] = None,
):
"""
初始化智能体
:param session_id: 会话唯一标识
:param instructions: 指令
:param skills: 智能体技能列表,默认为不使用技能
:param output_type: 输出类型
:return: 智能体实例
"""
# 会话唯一标识
self.session_id = session_id
# 实例智能体的记忆体
self.agent_memory = AgentMemory()
# 实例智能体
self.agent = PydanticAIAgent(
model=OpenAIChatModel(
model_name="deepseek-v4-flash",
provider=OpenAIProvider(
base_url="https://tokenhub.tencentmaas.com/v1",
api_key="sk-D9Y1mCe8VlvNqLuSC4mAjqEwxJ2nW4C0h8a7EPn8kg9RLsHq",
),
),
instructions=instructions,
capabilities=capabilities,
output_type=output_type,
retries=1,
)
self.agent.to_web()
async def output_message_streamed(
self, user_prompt: str | List[str]
) -> AsyncGenerator[str, None]:
"""
智能体流式输出消息
:param user_prompt: 用户提示词(用户输入消息)
:return: 流式消息
"""
"""定义一次会话session包含若干论对话turn每一轮对话由用户输入消息message和智能体输出消息组成"""
# 查询该会话历史对话消息列表
message_history = self.agent_memory.read_message_history(
session_id=self.session_id
)
async with self.agent.run_stream(
user_prompt=user_prompt, message_history=message_history
) as result:
async for chunk in result.stream_text(delta=True): # 只返回新增内容
if chunk:
yield chunk
self.agent_memory.create_new_messages(
session_id=self.session_id,
new_messages=result.new_messages(),
)