This commit is contained in:
liubiren 2026-06-11 21:04:13 +08:00
parent f18b3a9c82
commit 7ee474044e
4 changed files with 114 additions and 32 deletions

View File

@ -8,10 +8,10 @@ from pathlib import Path
import time import time
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from uuid import uuid4 from uuid import uuid4
from pydantic_ai.messages import ModelMessage
from pydantic_ai import Agent as PydanticAIAgent, ModelMessage from pydantic_ai import Agent as PydanticAIAgent, ModelMessage
from pydantic_ai.capabilities import AgentCapability from pydantic_ai.capabilities import AgentCapability
from pydantic_ai.messages import ModelMessagesTypeAdapter from pydantic_ai.messages import ModelMessagesTypeAdapter, ModelResponse
from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.output import OutputSpec from pydantic_ai.output import OutputSpec
from pydantic_ai.providers.openai import OpenAIProvider from pydantic_ai.providers.openai import OpenAIProvider
@ -172,24 +172,36 @@ class Agent:
user_prompt=user_prompt, user_prompt=user_prompt,
message_history=message_history, message_history=message_history,
) as result: ) as result:
async for part in result.stream_output(): # 全量结构化输出分片 async for response in result.stream_response(): # 完整结构化相应对象
match part.part_kind: if not isinstance(response, ModelResponse):
case "text": # 就文本分片拆分 content 和 thinking continue
if not (part_content := part.content.strip()):
continue if not hasattr(response, "parts") or not response.parts:
if not part.is_reasoning: continue
yield f"0:{part_content}" # content
else: for part in response.parts:
yield f"1:{part_content}" # thinking if isinstance(part, str):
case "tool-call": # 技能调用分片 content = part.strip()
yield f"2:技能名称:{part.tool_name},调用参数:{part.args}" if content:
case "tool-return": # 技能返回分片 yield f"00:{content}"
yield f"3:技能结果:{part.content}"
case "error": # 错误分片
yield f"4:{str(part.error)}"
case _:
continue continue
match part.part_kind:
case "text":
content = part.content.strip()
if content:
yield f"00:{content}"
case "thinking":
content = part.content.strip()
if content:
yield f"01:{content}"
case "tool-call" | "builtin-tool-call":
yield f"02:技能名称:{part.tool_name},调用参数:{part.args}"
case "tool-return" | "builtin-tool-return":
yield f"03:技能结果:{part.content}"
case _:
continue
self.agent_memory.create_new_messages( self.agent_memory.create_new_messages(
session_id=self.session_id, session_id=self.session_id,
new_messages=result.new_messages(), new_messages=result.new_messages(),

Binary file not shown.

View File

@ -5,12 +5,12 @@
import reflex import reflex
from reflex.constants.colors import ColorType from reflex.constants.colors import ColorType
from ..state import State, Turn from ..state import State, MessageBlockType, MessageBlock, Turn
def message_bubble(message: str, color: ColorType) -> reflex.Component: def input_bubble(message: str, color: ColorType) -> reflex.Component:
""" """
对话组件中一个消息气泡组件 输入气泡组件
:param message: 消息 :param message: 消息
:param color: 颜色 :param color: 颜色
:return: Component :return: Component
@ -25,6 +25,45 @@ def message_bubble(message: str, color: ColorType) -> reflex.Component:
) )
def output_bubble(message_block: MessageBlock) -> reflex.Component:
"""
输出气泡组件
:param message_block: 消息块
:return: 气泡组件
"""
color = reflex.cond(
message_block.type == MessageBlockType.content,
"accent",
reflex.cond(
message_block.type == MessageBlockType.thinking,
"iris",
reflex.cond(
message_block.type == MessageBlockType.tool_call,
"orange",
reflex.cond(
message_block.type == MessageBlockType.tool_result,
"teal",
reflex.cond(
message_block.type == MessageBlockType.error,
"red",
"mauve", # 兜底
),
),
),
),
)
return reflex.markdown(
message_block.content,
color=reflex.color(color=color, shade=12),
background_color=reflex.color(color=color, shade=4),
display="inline-block",
padding_inline="1em",
padding_block="0.5em",
border_radius="8px",
margin_bottom="4px",
)
def turn(turn: Turn) -> reflex.Component: def turn(turn: Turn) -> reflex.Component:
""" """
对话组件 对话组件
@ -33,12 +72,12 @@ def turn(turn: Turn) -> reflex.Component:
""" """
return reflex.box( return reflex.box(
reflex.box( reflex.box(
message_bubble(message=turn.input, color="mauve"), input_bubble(message=turn.input, color="mauve"),
text_align="right", text_align="right",
margin_bottom="8px", margin_bottom="8px",
), ),
reflex.box( reflex.box(
message_bubble(message=turn.output, color="accent"), reflex.foreach(turn.output, output_bubble),
text_align="left", text_align="left",
margin_bottom="8px", margin_bottom="8px",
), ),

View File

@ -5,7 +5,7 @@
from typing import Any, AsyncGenerator, Dict, List, Literal from typing import Any, AsyncGenerator, Dict, List, Literal
from uuid import uuid4 from uuid import uuid4
from enum import StrEnum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import reflex import reflex
from pathlib import Path from pathlib import Path
@ -27,17 +27,29 @@ def retrieve_agent(state) -> Agent:
if current_session_name not in agents: if current_session_name not in agents:
agents[current_session_name] = Agent( agents[current_session_name] = Agent(
session_id=state.sessions[current_session_name].id, session_id=state.sessions[current_session_name].id,
instructions="You are a friendly chatbot named Reflex. Respond in markdown.", instructions="You are a friendly chatbot",
) )
return agents[current_session_name] return agents[current_session_name]
# 消息块类型
class MessageBlockType(StrEnum):
content = "content"
thinking = "thinking"
tool_call = "tool_call"
tool_result = "tool_result"
error = "error"
# 消息块类型前缀映射
MESSAGE_BLOCK_TYPE_PREFIX_MAP = {f"{i:02d}:": m for i, m in enumerate(MessageBlockType)}
class MessageBlock(BaseModel): class MessageBlock(BaseModel):
"""消息块数据模型,包含类型和内容""" """消息块数据模型,包含类型和内容"""
type: Literal[ type: MessageBlockType = Field(..., description="类型")
"thinking", "content", "skill_call", "skill_result", "skill_error"
] = Field(..., description="类型")
content: str = Field(default="", description="内容") content: str = Field(default="", description="内容")
@ -170,7 +182,7 @@ class State(reflex.State):
:param form_data: 对话表单数据 :param form_data: 对话表单数据
:return: AsyncGenerator :return: AsyncGenerator
""" """
input = form_data["input_message"].strip() input = form_data["input"].strip()
if not input: if not input:
return return
@ -193,17 +205,36 @@ class State(reflex.State):
input=input, input=input,
) )
) )
yield # 通知前端更新状态(显示用户输入)
# 当前对话 # 当前对话
current_turn = current_session.turns[-1] current_turn = current_session.turns[-1]
yield
# 获取当前会话绑定的智能体
agent = retrieve_agent(self) agent = retrieve_agent(self)
async for chunk in agent.output_message_streamed(user_prompt=input): async for chunk in agent.output_message_streamed(user_prompt=input):
# 跳过空分块
if not chunk: if not chunk:
yield
continue continue
current_session.turns[-1].output_message += chunk
yield # 匹配消息块类型
prefix_matched = next(
(t for t in MESSAGE_BLOCK_TYPE_PREFIX_MAP if chunk.startswith(t)), None
)
# 跳过未匹配分块
if not prefix_matched:
yield
continue
# 消息块类型
type = MESSAGE_BLOCK_TYPE_PREFIX_MAP[prefix_matched]
# 若当前对话输出为空或当前消息块类型和上一个消息块类型不一致则创建消息块
if not current_turn.output or current_turn.output[-1].type != type:
current_turn.output.append(MessageBlock(type=type))
current_turn.output[-1].content += chunk.removeprefix(prefix_matched)
yield # 通知前端更新状态(打字机效果显示输出)
# 当前会话处理完成 # 当前会话处理完成
current_session.is_processing = False current_session.is_processing = False