This commit is contained in:
parent
f18b3a9c82
commit
7ee474044e
|
|
@ -8,10 +8,10 @@ from pathlib import Path
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
from pydantic_ai.messages import ModelMessage
|
||||||
from pydantic_ai import Agent as PydanticAIAgent, ModelMessage
|
from pydantic_ai import Agent as PydanticAIAgent, ModelMessage
|
||||||
from pydantic_ai.capabilities import AgentCapability
|
from pydantic_ai.capabilities import AgentCapability
|
||||||
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
from pydantic_ai.messages import ModelMessagesTypeAdapter, ModelResponse
|
||||||
from pydantic_ai.models.openai import OpenAIChatModel
|
from pydantic_ai.models.openai import OpenAIChatModel
|
||||||
from pydantic_ai.output import OutputSpec
|
from pydantic_ai.output import OutputSpec
|
||||||
from pydantic_ai.providers.openai import OpenAIProvider
|
from pydantic_ai.providers.openai import OpenAIProvider
|
||||||
|
|
@ -172,24 +172,36 @@ class Agent:
|
||||||
user_prompt=user_prompt,
|
user_prompt=user_prompt,
|
||||||
message_history=message_history,
|
message_history=message_history,
|
||||||
) as result:
|
) as result:
|
||||||
async for part in result.stream_output(): # 全量结构化输出分片
|
async for response in result.stream_response(): # 完整结构化相应对象
|
||||||
match part.part_kind:
|
if not isinstance(response, ModelResponse):
|
||||||
case "text": # 就文本分片拆分 content 和 thinking
|
continue
|
||||||
if not (part_content := part.content.strip()):
|
|
||||||
continue
|
if not hasattr(response, "parts") or not response.parts:
|
||||||
if not part.is_reasoning:
|
continue
|
||||||
yield f"0:{part_content}" # content
|
|
||||||
else:
|
for part in response.parts:
|
||||||
yield f"1:{part_content}" # thinking
|
if isinstance(part, str):
|
||||||
case "tool-call": # 技能调用分片
|
content = part.strip()
|
||||||
yield f"2:技能名称:{part.tool_name},调用参数:{part.args}"
|
if content:
|
||||||
case "tool-return": # 技能返回分片
|
yield f"00:{content}"
|
||||||
yield f"3:技能结果:{part.content}"
|
|
||||||
case "error": # 错误分片
|
|
||||||
yield f"4:{str(part.error)}"
|
|
||||||
case _:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
match part.part_kind:
|
||||||
|
case "text":
|
||||||
|
content = part.content.strip()
|
||||||
|
if content:
|
||||||
|
yield f"00:{content}"
|
||||||
|
case "thinking":
|
||||||
|
content = part.content.strip()
|
||||||
|
if content:
|
||||||
|
yield f"01:{content}"
|
||||||
|
case "tool-call" | "builtin-tool-call":
|
||||||
|
yield f"02:技能名称:{part.tool_name},调用参数:{part.args}"
|
||||||
|
case "tool-return" | "builtin-tool-return":
|
||||||
|
yield f"03:技能结果:{part.content}"
|
||||||
|
case _:
|
||||||
|
continue
|
||||||
|
|
||||||
self.agent_memory.create_new_messages(
|
self.agent_memory.create_new_messages(
|
||||||
session_id=self.session_id,
|
session_id=self.session_id,
|
||||||
new_messages=result.new_messages(),
|
new_messages=result.new_messages(),
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -5,12 +5,12 @@
|
||||||
import reflex
|
import reflex
|
||||||
from reflex.constants.colors import ColorType
|
from reflex.constants.colors import ColorType
|
||||||
|
|
||||||
from ..state import State, Turn
|
from ..state import State, MessageBlockType, MessageBlock, Turn
|
||||||
|
|
||||||
|
|
||||||
def message_bubble(message: str, color: ColorType) -> reflex.Component:
|
def input_bubble(message: str, color: ColorType) -> reflex.Component:
|
||||||
"""
|
"""
|
||||||
对话组件中一个消息气泡组件
|
输入气泡组件
|
||||||
:param message: 消息
|
:param message: 消息
|
||||||
:param color: 颜色
|
:param color: 颜色
|
||||||
:return: Component
|
:return: Component
|
||||||
|
|
@ -25,6 +25,45 @@ def message_bubble(message: str, color: ColorType) -> reflex.Component:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def output_bubble(message_block: MessageBlock) -> reflex.Component:
|
||||||
|
"""
|
||||||
|
输出气泡组件
|
||||||
|
:param message_block: 消息块
|
||||||
|
:return: 气泡组件
|
||||||
|
"""
|
||||||
|
color = reflex.cond(
|
||||||
|
message_block.type == MessageBlockType.content,
|
||||||
|
"accent",
|
||||||
|
reflex.cond(
|
||||||
|
message_block.type == MessageBlockType.thinking,
|
||||||
|
"iris",
|
||||||
|
reflex.cond(
|
||||||
|
message_block.type == MessageBlockType.tool_call,
|
||||||
|
"orange",
|
||||||
|
reflex.cond(
|
||||||
|
message_block.type == MessageBlockType.tool_result,
|
||||||
|
"teal",
|
||||||
|
reflex.cond(
|
||||||
|
message_block.type == MessageBlockType.error,
|
||||||
|
"red",
|
||||||
|
"mauve", # 兜底
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return reflex.markdown(
|
||||||
|
message_block.content,
|
||||||
|
color=reflex.color(color=color, shade=12),
|
||||||
|
background_color=reflex.color(color=color, shade=4),
|
||||||
|
display="inline-block",
|
||||||
|
padding_inline="1em",
|
||||||
|
padding_block="0.5em",
|
||||||
|
border_radius="8px",
|
||||||
|
margin_bottom="4px",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def turn(turn: Turn) -> reflex.Component:
|
def turn(turn: Turn) -> reflex.Component:
|
||||||
"""
|
"""
|
||||||
对话组件
|
对话组件
|
||||||
|
|
@ -33,12 +72,12 @@ def turn(turn: Turn) -> reflex.Component:
|
||||||
"""
|
"""
|
||||||
return reflex.box(
|
return reflex.box(
|
||||||
reflex.box(
|
reflex.box(
|
||||||
message_bubble(message=turn.input, color="mauve"),
|
input_bubble(message=turn.input, color="mauve"),
|
||||||
text_align="right",
|
text_align="right",
|
||||||
margin_bottom="8px",
|
margin_bottom="8px",
|
||||||
),
|
),
|
||||||
reflex.box(
|
reflex.box(
|
||||||
message_bubble(message=turn.output, color="accent"),
|
reflex.foreach(turn.output, output_bubble),
|
||||||
text_align="left",
|
text_align="left",
|
||||||
margin_bottom="8px",
|
margin_bottom="8px",
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Literal
|
from typing import Any, AsyncGenerator, Dict, List, Literal
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
from enum import StrEnum
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import reflex
|
import reflex
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -27,17 +27,29 @@ def retrieve_agent(state) -> Agent:
|
||||||
if current_session_name not in agents:
|
if current_session_name not in agents:
|
||||||
agents[current_session_name] = Agent(
|
agents[current_session_name] = Agent(
|
||||||
session_id=state.sessions[current_session_name].id,
|
session_id=state.sessions[current_session_name].id,
|
||||||
instructions="You are a friendly chatbot named Reflex. Respond in markdown.",
|
instructions="You are a friendly chatbot",
|
||||||
)
|
)
|
||||||
return agents[current_session_name]
|
return agents[current_session_name]
|
||||||
|
|
||||||
|
|
||||||
|
# 消息块类型
|
||||||
|
class MessageBlockType(StrEnum):
|
||||||
|
|
||||||
|
content = "content"
|
||||||
|
thinking = "thinking"
|
||||||
|
tool_call = "tool_call"
|
||||||
|
tool_result = "tool_result"
|
||||||
|
error = "error"
|
||||||
|
|
||||||
|
|
||||||
|
# 消息块类型前缀映射
|
||||||
|
MESSAGE_BLOCK_TYPE_PREFIX_MAP = {f"{i:02d}:": m for i, m in enumerate(MessageBlockType)}
|
||||||
|
|
||||||
|
|
||||||
class MessageBlock(BaseModel):
|
class MessageBlock(BaseModel):
|
||||||
"""消息块数据模型,包含类型和内容"""
|
"""消息块数据模型,包含类型和内容"""
|
||||||
|
|
||||||
type: Literal[
|
type: MessageBlockType = Field(..., description="类型")
|
||||||
"thinking", "content", "skill_call", "skill_result", "skill_error"
|
|
||||||
] = Field(..., description="类型")
|
|
||||||
content: str = Field(default="", description="内容")
|
content: str = Field(default="", description="内容")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -170,7 +182,7 @@ class State(reflex.State):
|
||||||
:param form_data: 对话表单数据
|
:param form_data: 对话表单数据
|
||||||
:return: AsyncGenerator
|
:return: AsyncGenerator
|
||||||
"""
|
"""
|
||||||
input = form_data["input_message"].strip()
|
input = form_data["input"].strip()
|
||||||
if not input:
|
if not input:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -193,17 +205,36 @@ class State(reflex.State):
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
yield # 通知前端更新状态(显示用户输入)
|
||||||
|
|
||||||
# 当前对话
|
# 当前对话
|
||||||
current_turn = current_session.turns[-1]
|
current_turn = current_session.turns[-1]
|
||||||
yield
|
|
||||||
|
|
||||||
|
# 获取当前会话绑定的智能体
|
||||||
agent = retrieve_agent(self)
|
agent = retrieve_agent(self)
|
||||||
async for chunk in agent.output_message_streamed(user_prompt=input):
|
async for chunk in agent.output_message_streamed(user_prompt=input):
|
||||||
|
# 跳过空分块
|
||||||
if not chunk:
|
if not chunk:
|
||||||
|
yield
|
||||||
continue
|
continue
|
||||||
current_session.turns[-1].output_message += chunk
|
|
||||||
|
|
||||||
yield
|
# 匹配消息块类型
|
||||||
|
prefix_matched = next(
|
||||||
|
(t for t in MESSAGE_BLOCK_TYPE_PREFIX_MAP if chunk.startswith(t)), None
|
||||||
|
)
|
||||||
|
# 跳过未匹配分块
|
||||||
|
if not prefix_matched:
|
||||||
|
yield
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 消息块类型
|
||||||
|
type = MESSAGE_BLOCK_TYPE_PREFIX_MAP[prefix_matched]
|
||||||
|
# 若当前对话输出为空或当前消息块类型和上一个消息块类型不一致则创建消息块
|
||||||
|
if not current_turn.output or current_turn.output[-1].type != type:
|
||||||
|
current_turn.output.append(MessageBlock(type=type))
|
||||||
|
current_turn.output[-1].content += chunk.removeprefix(prefix_matched)
|
||||||
|
|
||||||
|
yield # 通知前端更新状态(打字机效果显示输出)
|
||||||
|
|
||||||
# 当前会话处理完成
|
# 当前会话处理完成
|
||||||
current_session.is_processing = False
|
current_session.is_processing = False
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue