251 lines
8.7 KiB
Python
251 lines
8.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
智能体模块
|
|
"""
|
|
|
|
# 列举导入模块
|
|
from pathlib import Path
|
|
from sys import path
|
|
from typing import List, Optional, Union, cast
|
|
from typing import TypeVar
|
|
from uuid import uuid4
|
|
|
|
from pydantic_ai import Agent as BaseAgent, AgentRunResult
|
|
from pydantic_ai.capabilities import AgentCapability
|
|
from pydantic_ai.models.openai import OpenAIChatModel
|
|
from pydantic_ai.native_tools import AbstractNativeTool
|
|
from pydantic_ai.output import OutputSpec
|
|
from pydantic_ai.providers.openai import OpenAIProvider
|
|
from starlette.applications import Starlette
|
|
from starlette.routing import Mount, Route
|
|
|
|
path.append(Path(__file__).resolve().parent.as_posix())
|
|
from memory import Memory
|
|
|
|
AgentDepsT = TypeVar("AgentDepsT")
|
|
OutputDataT = TypeVar("OutputDataT")
|
|
|
|
|
|
class Agent:
|
|
"""
|
|
智能体,支持:
|
|
1 实例智能体
|
|
2 异步运行
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
instructions: str,
|
|
output_type: OutputSpec = str,
|
|
capabilities: Optional[List[AgentCapability]] = None,
|
|
):
|
|
"""
|
|
初始化智能体
|
|
:param instructions: 指令
|
|
:param skills: 智能体技能列表,默认为不使用技能
|
|
:param output_type: 输出类型
|
|
:return: 智能体实例
|
|
"""
|
|
# 生成会话唯一标识
|
|
self.session_id = uuid4().hex.lower()
|
|
|
|
# 创建智能体
|
|
self.agent = self._create_agent(
|
|
instructions=instructions,
|
|
capabilities=capabilities,
|
|
output_type=output_type,
|
|
)
|
|
a = self.agent.to_web()
|
|
|
|
# 实例记忆体
|
|
self.memory = Memory()
|
|
|
|
def _create_agent(
|
|
self,
|
|
instructions: str,
|
|
capabilities: Optional[List[AgentCapability]],
|
|
output_type: OutputSpec,
|
|
) -> BaseAgent:
|
|
"""
|
|
创建智能体
|
|
:param instructions: 指令
|
|
:param capabilities: 智能体能力列表
|
|
:param output_type: 输出类型
|
|
:return: Agent 实例
|
|
"""
|
|
agent = BaseAgent(
|
|
model=OpenAIChatModel(
|
|
model_name="deepseek-v4-flash",
|
|
provider=OpenAIProvider(
|
|
base_url="https://tokenhub.tencentmaas.com/v1",
|
|
api_key="sk-D9Y1mCe8VlvNqLuSC4mAjqEwxJ2nW4C0h8a7EPn8kg9RLsHq",
|
|
),
|
|
),
|
|
instructions=instructions,
|
|
capabilities=capabilities,
|
|
output_type=output_type,
|
|
retries=1,
|
|
)
|
|
return agent
|
|
|
|
async def run(self, user_prompt: str | List[str]) -> AgentRunResult:
|
|
"""
|
|
异步运行
|
|
:param user_prompt: 用户提示词
|
|
:return: 智能体回复
|
|
"""
|
|
# 查询会话历史消息
|
|
message_history = self.memory.read(session_id=self.session_id)
|
|
result = await self.agent.run(
|
|
user_prompt=user_prompt, message_history=message_history
|
|
)
|
|
# 记录会话历史消息
|
|
self.memory.create(
|
|
session_id=self.session_id,
|
|
dialogue_message=result.new_messages(),
|
|
)
|
|
return result
|
|
|
|
def run_agent_application(
|
|
self,
|
|
native_tools: Optional[List[AbstractNativeTool]] = None,
|
|
html_source: Optional[Union[str, Path]] = None,
|
|
) -> Starlette:
|
|
"""
|
|
启动智能体应用
|
|
:param native_tools: 前端可选工具列表
|
|
:return: Starlette 实例
|
|
"""
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, HTMLResponse, Response
|
|
|
|
def api(
|
|
agent: BaseAgent[AgentDepsT, OutputDataT],
|
|
native_tools: Optional[List[AbstractNativeTool]] = None,
|
|
) -> Starlette:
|
|
"""
|
|
创建处理请求应用
|
|
:param agent: 智能体
|
|
:param native_tools: 前端可选原生工具列表
|
|
:return: Starlette 实例
|
|
"""
|
|
from pydantic_ai.models import Model
|
|
from pydantic_ai.ui._web.api import ModelInfo
|
|
from pydantic_ai.ui._web.api import (
|
|
ConfigureFrontend,
|
|
BuiltinToolInfo,
|
|
ChatRequestExtra,
|
|
validate_request_options,
|
|
)
|
|
from pydantic_ai.ui.vercel_ai import VercelAIAdapter
|
|
from pydantic_ai.capabilities import NativeTool
|
|
|
|
async def chat_options(request: Request) -> Response:
|
|
"""处理跨域预检请求"""
|
|
return Response()
|
|
|
|
agent_model = cast(Model, agent.model)
|
|
# 前端可选原生工具列表
|
|
frontend_native_tools = [
|
|
t
|
|
for t in (native_tools or [])
|
|
if t.unique_id
|
|
not in {
|
|
t.unique_id
|
|
for t in agent._cap_native_tools
|
|
if isinstance(t, AbstractNativeTool)
|
|
}
|
|
]
|
|
|
|
async def configurations(request: Request) -> Response:
|
|
"""处理前端模型与工具配置请求"""
|
|
configurations = ConfigureFrontend(
|
|
models=[
|
|
ModelInfo(
|
|
id=agent_model.model_id,
|
|
name=agent_model.label,
|
|
builtin_tools=[
|
|
t.unique_id
|
|
for t in frontend_native_tools
|
|
if type(t) in agent_model.profile.supported_native_tools
|
|
],
|
|
)
|
|
], # 前端仅可选择智能体已配置的模型
|
|
builtin_tools=[
|
|
BuiltinToolInfo(id=t.unique_id, name=t.label)
|
|
for t in frontend_native_tools
|
|
],
|
|
)
|
|
return JSONResponse(content=configurations.model_dump(by_alias=True))
|
|
|
|
async def chat_post(request: Request) -> Response:
|
|
"""处理对话请求"""
|
|
# 实例 Vercel AI 适配器
|
|
adapter = await VercelAIAdapter[
|
|
TypeVar("AgentDepsT"), TypeVar("OutputDataT")
|
|
].from_request(request=request, agent=self.agent)
|
|
|
|
# 解析请求中额外数据,包括前端选择的模型标识、原生工具标识和其它配置等
|
|
extra_data = ChatRequestExtra.model_validate(
|
|
adapter.run_input.__pydantic_extra__
|
|
)
|
|
if error := validate_request_options(
|
|
extra_data=extra_data,
|
|
model_ids={agent_model.model_id}, # 前端仅可选择智能体已配置的模型
|
|
builtin_tool_ids={t.unique_id for t in frontend_native_tools},
|
|
):
|
|
return JSONResponse(content={"error": error}, status_code=400)
|
|
|
|
streaming_response = await VercelAIAdapter[
|
|
TypeVar("AgentDepsT"), TypeVar("OutputDataT")
|
|
].dispatch_request(
|
|
request=request,
|
|
agent=self.agent,
|
|
capabilities=[
|
|
NativeTool(t)
|
|
for t in frontend_native_tools
|
|
if t.unique_id in extra_data.builtin_tools
|
|
],
|
|
)
|
|
return streaming_response
|
|
|
|
async def health(request: Request) -> Response:
|
|
"""处理健康检查请求"""
|
|
return JSONResponse(content={"ok": True})
|
|
|
|
return Starlette(
|
|
routes=[
|
|
Route(path="/chat", endpoint=chat_options, methods=["OPTIONS"]),
|
|
Route(path="/configure", endpoint=configurations, methods=["GET"]),
|
|
Route(path="/chat", endpoint=chat_post, methods=["POST"]),
|
|
Route(path="/health", endpoint=health, methods=["GET"]),
|
|
]
|
|
)
|
|
|
|
application = Starlette(
|
|
routes=[
|
|
Mount(
|
|
"/api",
|
|
app=api(agent=self.agent, native_tools=native_tools),
|
|
)
|
|
]
|
|
)
|
|
|
|
async def index(request: Request) -> Response:
|
|
"""S处理聊天界面请求"""
|
|
from pydantic_ai.ui._web.app import _get_ui_html
|
|
|
|
content = await _get_ui_html(html_source)
|
|
|
|
return HTMLResponse(
|
|
content=content,
|
|
headers={
|
|
"Cache-Control": "public, max-age=3600",
|
|
},
|
|
)
|
|
|
|
application.router.add_route(path="/", endpoint=index, methods=["GET"])
|
|
application.router.add_route(path="/{id}", endpoint=index, methods=["GET"])
|
|
|
|
return application
|