This commit is contained in:
liubiren 2026-06-11 09:29:38 +08:00
parent 70c39ea7e2
commit f18b3a9c82
3 changed files with 57 additions and 29 deletions

View File

@ -169,11 +169,26 @@ class Agent:
) )
async with self.agent.run_stream( async with self.agent.run_stream(
user_prompt=user_prompt, message_history=message_history user_prompt=user_prompt,
message_history=message_history,
) as result: ) as result:
async for chunk in result.stream_text(delta=True): # 只返回新增内容 async for part in result.stream_output(): # 全量结构化输出分片
if chunk: match part.part_kind:
yield chunk case "text": # 就文本分片拆分 content 和 thinking
if not (part_content := part.content.strip()):
continue
if not part.is_reasoning:
yield f"0:{part_content}" # content
else:
yield f"1:{part_content}" # thinking
case "tool-call": # 技能调用分片
yield f"2:技能名称:{part.tool_name},调用参数:{part.args}"
case "tool-return": # 技能返回分片
yield f"3:技能结果:{part.content}"
case "error": # 错误分片
yield f"4:{str(part.error)}"
case _:
continue
self.agent_memory.create_new_messages( self.agent_memory.create_new_messages(
session_id=self.session_id, session_id=self.session_id,

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
会话组件 会话相关组件
""" """
import reflex import reflex
from reflex.constants.colors import ColorType from reflex.constants.colors import ColorType
@ -33,12 +33,12 @@ def turn(turn: Turn) -> reflex.Component:
""" """
return reflex.box( return reflex.box(
reflex.box( reflex.box(
message_bubble(message=turn.input_message, color="mauve"), message_bubble(message=turn.input, color="mauve"),
text_align="right", text_align="right",
margin_bottom="8px", margin_bottom="8px",
), ),
reflex.box( reflex.box(
message_bubble(message=turn.output_message, color="accent"), message_bubble(message=turn.output, color="accent"),
text_align="left", text_align="left",
margin_bottom="8px", margin_bottom="8px",
), ),
@ -69,7 +69,7 @@ def input_bar() -> reflex.Component:
reflex.form( reflex.form(
reflex.hstack( reflex.hstack(
reflex.input( reflex.input(
name="input_message", name="input",
placeholder="请输入...", placeholder="请输入...",
flex="auto", flex="auto",
), ),
@ -83,7 +83,7 @@ def input_bar() -> reflex.Component:
margin="0 auto", margin="0 auto",
align_items="center", align_items="center",
), ),
on_submit=State.adapt_input_message, on_submit=State.adapt_input,
reset_on_submit=True, # 提交后清空输入框 reset_on_submit=True, # 提交后清空输入框
), ),
reflex.text( reflex.text(
@ -106,4 +106,4 @@ def input_bar() -> reflex.Component:
background_color=reflex.color("mauve", 2), background_color=reflex.color("mauve", 2),
align="stretch", align="stretch",
width="100%", width="100%",
) ) # reflex.center 等价 reflex.box(display="flex", align_items="center", justify_content="center")

View File

@ -3,7 +3,7 @@
应用状态管理模块 应用状态管理模块
""" """
from typing import Any, AsyncGenerator, Dict, List from typing import Any, AsyncGenerator, Dict, List, Literal
from uuid import uuid4 from uuid import uuid4
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -32,11 +32,22 @@ def retrieve_agent(state) -> Agent:
return agents[current_session_name] return agents[current_session_name]
class MessageBlock(BaseModel):
"""消息块数据模型,包含类型和内容"""
type: Literal[
"thinking", "content", "skill_call", "skill_result", "skill_error"
] = Field(..., description="类型")
content: str = Field(default="", description="内容")
class Turn(BaseModel): class Turn(BaseModel):
"""对话数据模型,包含用户输入消息和智能体输出消息""" """对话数据模型,包含用户输入消息和智能体输出消息"""
input_message: str = Field(..., description="用户输入的消息") input: str = Field(..., description="用户输入的消息")
output_message: str = Field(default="", description="智能体输出的消息") output: List[MessageBlock] = Field(
default_factory=list, description="智能体输出的消息"
)
class Session(BaseModel): class Session(BaseModel):
@ -153,43 +164,45 @@ class State(reflex.State):
self.create_session_modal_is_open = is_open self.create_session_modal_is_open = is_open
@reflex.event @reflex.event
async def adapt_input_message(self, form_data: dict[str, Any]) -> AsyncGenerator: async def adapt_input(self, form_data: dict[str, Any]) -> AsyncGenerator:
""" """
适配用户输入消息 适配用户输入
:param form_data: 对话表单数据 :param form_data: 对话表单数据
:return: AsyncGenerator :return: AsyncGenerator
""" """
input_message = form_data["input_message"].strip() input = form_data["input_message"].strip()
if not input_message: if not input:
return return
async for value in self.process_input_message(input_message=input_message): async for value in self.process_input(input=input):
yield value yield value
async def process_input_message(self, input_message: str) -> AsyncGenerator: async def process_input(self, input: str) -> AsyncGenerator:
""" """
处理用户输入消息 处理用户输入
:param input_message: 用户输入的消息 :param input: 用户输入
:return: AsyncGenerator :return: AsyncGenerator
""" """
# 当前会话 # 当前会话
current_session = self.sessions[self.current_session_name] current_session = self.sessions[self.current_session_name]
# 将用户输入消息添加到当前会话对话列表
current_session.turns.append(
Turn(
input_message=input_message,
)
)
# 当前会话正在处理 # 当前会话正在处理
current_session.is_processing = True current_session.is_processing = True
# 将用户输入添加到当前会话对话列表
current_session.turns.append(
Turn(
input=input,
)
)
# 当前对话
current_turn = current_session.turns[-1]
yield yield
agent = retrieve_agent(self) agent = retrieve_agent(self)
async for chunk in agent.output_message_streamed(user_prompt=input_message): async for chunk in agent.output_message_streamed(user_prompt=input):
if not chunk: if not chunk:
continue continue
current_session.turns[-1].output_message += chunk current_session.turns[-1].output_message += chunk
yield yield
# 当前会话处理完成 # 当前会话处理完成