Python/utils/agent.py

250 lines
8.7 KiB
Python

# -*- coding: utf-8 -*-
"""
智能体模块
"""
# 列举导入模块
from pathlib import Path
from sys import path
from typing import List, Optional, Union, cast
from typing import TypeVar
from uuid import uuid4
from pydantic_ai import Agent as BaseAgent, AgentRunResult
from pydantic_ai.capabilities import AgentCapability
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.native_tools import AbstractNativeTool
from pydantic_ai.output import OutputSpec
from pydantic_ai.providers.openai import OpenAIProvider
from starlette.applications import Starlette
from starlette.routing import Mount, Route
path.append(Path(__file__).resolve().parent.as_posix())
from memory import Memory
AgentDepsT = TypeVar("AgentDepsT")
OutputDataT = TypeVar("OutputDataT")
class Agent:
"""
智能体,支持:
1 实例智能体
2 异步运行
"""
def __init__(
self,
instructions: str,
output_type: OutputSpec = str,
capabilities: Optional[List[AgentCapability]] = None,
):
"""
初始化智能体
:param instructions: 指令
:param skills: 智能体技能列表,默认为不使用技能
:param output_type: 输出类型
:return: 智能体实例
"""
# 生成会话唯一标识
self.session_id = uuid4().hex.lower()
# 创建智能体
self.agent = self._create_agent(
instructions=instructions,
capabilities=capabilities,
output_type=output_type,
)
a = self.agent.to_web()
# 实例记忆体
self.memory = Memory()
def _create_agent(
self,
instructions: str,
capabilities: Optional[List[AgentCapability]],
output_type: OutputSpec,
) -> BaseAgent:
"""
创建智能体
:param instructions: 指令
:param capabilities: 智能体能力列表
:param output_type: 输出类型
:return: Agent 实例
"""
agent = BaseAgent(
model=OpenAIChatModel(
model_name="deepseek-v4-flash",
provider=OpenAIProvider(
base_url="https://tokenhub.tencentmaas.com/v1",
api_key="sk-D9Y1mCe8VlvNqLuSC4mAjqEwxJ2nW4C0h8a7EPn8kg9RLsHq",
),
),
instructions=instructions,
capabilities=capabilities,
output_type=output_type,
retries=1,
)
return agent
async def run(self, user_prompt: str | List[str]) -> AgentRunResult:
"""
异步运行
:param user_prompt: 用户提示词
:return: 智能体回复
"""
# 查询会话历史消息
message_history = self.memory.read(session_id=self.session_id)
result = await self.agent.run(
user_prompt=user_prompt, message_history=message_history
)
# 记录会话历史消息
self.memory.create(
session_id=self.session_id,
dialogue_message=result.new_messages(),
)
return result
def create_web_application(
self,
native_tools: Optional[List[AbstractNativeTool]] = None,
html_source: Optional[Union[str, Path]] = None,
) -> Starlette:
"""
创建 Starlette 应用,为智能体提供交互式对话界面
:param native_tools: 前端可选工具列表
:return: Starlette 实例
"""
from starlette.requests import Request
from starlette.responses import JSONResponse, HTMLResponse, Response
def api(
agent: BaseAgent[AgentDepsT, OutputDataT],
native_tools: Optional[List[AbstractNativeTool]] = None,
) -> Starlette:
"""
创建处理请求应用
:param agnet: 智能体
:param native_tools: 前端可选原生工具列表
:return: Starlette 实例
"""
from pydantic_ai.models import Model
from pydantic_ai.ui._web.api import ModelInfo
from pydantic_ai.ui._web.api import (
ConfigureFrontend,
BuiltinToolInfo,
ChatRequestExtra,
validate_request_options,
)
from pydantic_ai.ui.vercel_ai import VercelAIAdapter
from pydantic_ai.capabilities import NativeTool
async def chat_options(request: Request) -> Response:
"""处理跨域预检请求"""
return Response()
agent_model = cast(Model, agent.model)
# 前端可选原生工具列表
frontend_native_tools = [
t
for t in (native_tools or [])
if t.unique_id
not in {
t.unique_id
for t in agent._cap_native_tools
if isinstance(t, AbstractNativeTool)
}
]
async def configurations(request: Request) -> Response:
"""处理前端模型与工具配置请求"""
configurations = ConfigureFrontend(
models=[
ModelInfo(
id=agent_model.model_id,
name=agent_model.label,
builtin_tools=[
t.unique_id
for t in frontend_native_tools
if type(t) in agent_model.profile.supported_native_tools
],
)
], # 前端仅可选择智能体已配置的模型
builtin_tools=[
BuiltinToolInfo(id=t.unique_id, name=t.label)
for t in frontend_native_tools
],
)
return JSONResponse(content=configurations.model_dump(by_alias=True))
async def chat_post(request: Request) -> Response:
"""处理对话请求"""
# 实例 Vercel AI 适配器
adapter = await VercelAIAdapter[
TypeVar("AgentDepsT"), TypeVar("OutputDataT")
].from_request(request=request, agent=self.agent)
# 解析请求中额外数据,包括前端选择的模型标识、原生工具标识和其它配置等
extra_data = ChatRequestExtra.model_validate(
adapter.run_input.__pydantic_extra__
)
if error := validate_request_options(
extra_data=extra_data,
model_ids={agent_model.model_id}, # 前端仅可选择智能体已配置的模型
builtin_tool_ids={t.unique_id for t in frontend_native_tools},
):
return JSONResponse(content={"error": error}, status_code=400)
streaming_response = await VercelAIAdapter[
TypeVar("AgentDepsT"), TypeVar("OutputDataT")
].dispatch_request(
request=request,
agent=self.agent,
capabilities=[
NativeTool(t)
for t in frontend_native_tools
if t.unique_id in extra_data.builtin_tools
],
)
return streaming_response
async def health(request: Request) -> Response:
"""处理健康检查请求"""
return JSONResponse(content={"ok": True})
return Starlette(
routes=[
Route(path="/chat", endpoint=chat_options, methods=["OPTIONS"]),
Route(path="/configure", endpoint=configurations, methods=["GET"]),
Route(path="/chat", endpoint=chat_post, methods=["POST"]),
Route(path="/health", endpoint=health, methods=["GET"]),
]
)
async def index(request: Request) -> Response:
"""S处理聊天界面请求"""
from pydantic_ai.ui._web.app import _get_ui_html
content = await _get_ui_html(html_source)
return HTMLResponse(
content=content,
headers={
"Cache-Control": "public, max-age=3600",
},
)
application = Starlette(
routes=[
Mount(
"/api",
app=api(agent=self.agent, native_tools=native_tools),
)
]
)
application.router.add_route(path="/", endpoint=index, methods=["GET"])
application.router.add_route(path="/{id}", endpoint=index, methods=["GET"])
return application