20260512更新
This commit is contained in:
parent
3a8b7d7948
commit
5e7b10c939
|
|
@ -4,11 +4,13 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 列举导入模块
|
# 列举导入模块
|
||||||
from pydantic_ai import Agent
|
import asyncio
|
||||||
from pydantic_ai.providers.openai import OpenAIProvider
|
from typing import List
|
||||||
from pydantic_ai.models.openai import OpenAIChatModel
|
|
||||||
|
|
||||||
from typing import Any, Optional, List
|
from pydantic_ai import Agent, AgentRunResult
|
||||||
|
from pydantic_ai.models.openai import OpenAIChatModel
|
||||||
|
from pydantic_ai.output import OutputSpec
|
||||||
|
from pydantic_ai.providers.openai import OpenAIProvider
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent:
|
class BaseAgent:
|
||||||
|
|
@ -16,23 +18,25 @@ class BaseAgent:
|
||||||
通用智能体基类
|
通用智能体基类
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, system_prompt: str, output_type: Any, tools: List = []):
|
def __init__(self, instructions: str, output_type: OutputSpec = str):
|
||||||
"""
|
"""
|
||||||
初始化智能体
|
初始化智能体
|
||||||
:param system_prompt: 系统提示词
|
:param instructions: 指令
|
||||||
:param output_type: 输出类型
|
:param output_type: 输出类型
|
||||||
:param tools: 工具列表
|
|
||||||
:return: 智能体实例
|
:return: 智能体实例
|
||||||
"""
|
"""
|
||||||
self.agent = self._instantiate_agent(
|
# 实例智能体
|
||||||
system_prompt=system_prompt, output_type=output_type, tools=tools
|
self.agent = self._instantiate(
|
||||||
|
instructions=instructions,
|
||||||
|
output_type=output_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _instantiate_agent(
|
def _instantiate(self, instructions: str, output_type: OutputSpec) -> Agent:
|
||||||
self, system_prompt: str, output_type: Any, tools: list = []
|
|
||||||
) -> Agent:
|
|
||||||
"""
|
"""
|
||||||
实例化智能体
|
实例智能体
|
||||||
|
:param instructions: 指令
|
||||||
|
:param output_type: 输出类型
|
||||||
|
:return: 智能体实例
|
||||||
"""
|
"""
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
model=OpenAIChatModel(
|
model=OpenAIChatModel(
|
||||||
|
|
@ -42,18 +46,31 @@ class BaseAgent:
|
||||||
api_key="sk-D9Y1mCe8VlvNqLuSC4mAjqEwxJ2nW4C0h8a7EPn8kg9RLsHq",
|
api_key="sk-D9Y1mCe8VlvNqLuSC4mAjqEwxJ2nW4C0h8a7EPn8kg9RLsHq",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
system_prompt=system_prompt,
|
instructions=instructions,
|
||||||
output_type=output_type,
|
output_type=output_type,
|
||||||
)
|
)
|
||||||
for tool in tools:
|
|
||||||
agent.tool(tool) # 注册工具
|
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
async def run(self, user_prompt: str, **kwargs):
|
async def run(self, user_prompt: str | List[str]) -> AgentRunResult:
|
||||||
"""
|
"""
|
||||||
运行智能体
|
异步运行智能体
|
||||||
:param user_prompt: 用户提示词
|
:param user_prompt: 用户提示词
|
||||||
:param kwargs: 其它参数
|
|
||||||
:return: 智能体回复
|
:return: 智能体回复
|
||||||
"""
|
"""
|
||||||
return await self.agent.run(user_prompt=user_prompt, **kwargs)
|
return await self.agent.run(user_prompt=user_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
async def test():
|
||||||
|
# 1. 创建智能体(给系统提示词)
|
||||||
|
agent = BaseAgent(instructions="请用一句话简洁回复。")
|
||||||
|
|
||||||
|
# 2. 运行智能体(给用户问题)
|
||||||
|
result = await agent.run(user_prompt="Hello World 最早出现在哪里?")
|
||||||
|
|
||||||
|
# 3. 输出结果
|
||||||
|
print("答案:", result.output)
|
||||||
|
|
||||||
|
|
||||||
|
# 运行
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue