This commit is contained in:
liubiren 2026-06-11 09:29:38 +08:00
parent 70c39ea7e2
commit f18b3a9c82
3 changed files with 57 additions and 29 deletions

View File

@ -169,11 +169,26 @@ class Agent:
)
async with self.agent.run_stream(
user_prompt=user_prompt, message_history=message_history
user_prompt=user_prompt,
message_history=message_history,
) as result:
async for chunk in result.stream_text(delta=True): # 只返回新增内容
if chunk:
yield chunk
async for part in result.stream_output(): # 全量结构化输出分片
match part.part_kind:
case "text": # 就文本分片拆分 content 和 thinking
if not (part_content := part.content.strip()):
continue
if not part.is_reasoning:
yield f"0:{part_content}" # content
else:
yield f"1:{part_content}" # thinking
case "tool-call": # 技能调用分片
yield f"2:技能名称:{part.tool_name},调用参数:{part.args}"
case "tool-return": # 技能返回分片
yield f"3:技能结果:{part.content}"
case "error": # 错误分片
yield f"4:{str(part.error)}"
case _:
continue
self.agent_memory.create_new_messages(
session_id=self.session_id,

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
会话组件
会话相关组件
"""
import reflex
from reflex.constants.colors import ColorType
@ -33,12 +33,12 @@ def turn(turn: Turn) -> reflex.Component:
"""
return reflex.box(
reflex.box(
message_bubble(message=turn.input_message, color="mauve"),
message_bubble(message=turn.input, color="mauve"),
text_align="right",
margin_bottom="8px",
),
reflex.box(
message_bubble(message=turn.output_message, color="accent"),
message_bubble(message=turn.output, color="accent"),
text_align="left",
margin_bottom="8px",
),
@ -69,7 +69,7 @@ def input_bar() -> reflex.Component:
reflex.form(
reflex.hstack(
reflex.input(
name="input_message",
name="input",
placeholder="请输入...",
flex="auto",
),
@ -83,7 +83,7 @@ def input_bar() -> reflex.Component:
margin="0 auto",
align_items="center",
),
on_submit=State.adapt_input_message,
on_submit=State.adapt_input,
reset_on_submit=True, # 提交后清空输入框
),
reflex.text(
@ -106,4 +106,4 @@ def input_bar() -> reflex.Component:
background_color=reflex.color("mauve", 2),
align="stretch",
width="100%",
)
) # reflex.center 等价 reflex.box(display="flex", align_items="center", justify_content="center")

View File

@ -3,7 +3,7 @@
应用状态管理模块
"""
from typing import Any, AsyncGenerator, Dict, List
from typing import Any, AsyncGenerator, Dict, List, Literal
from uuid import uuid4
from pydantic import BaseModel, Field
@ -32,11 +32,22 @@ def retrieve_agent(state) -> Agent:
return agents[current_session_name]
class MessageBlock(BaseModel):
"""消息块数据模型,包含类型和内容"""
type: Literal[
"thinking", "content", "skill_call", "skill_result", "skill_error"
] = Field(..., description="类型")
content: str = Field(default="", description="内容")
class Turn(BaseModel):
"""对话数据模型,包含用户输入消息和智能体输出消息"""
input_message: str = Field(..., description="用户输入的消息")
output_message: str = Field(default="", description="智能体输出的消息")
input: str = Field(..., description="用户输入的消息")
output: List[MessageBlock] = Field(
default_factory=list, description="智能体输出的消息"
)
class Session(BaseModel):
@ -153,43 +164,45 @@ class State(reflex.State):
self.create_session_modal_is_open = is_open
@reflex.event
async def adapt_input_message(self, form_data: dict[str, Any]) -> AsyncGenerator:
async def adapt_input(self, form_data: dict[str, Any]) -> AsyncGenerator:
"""
适配用户输入消息
适配用户输入
:param form_data: 对话表单数据
:return: AsyncGenerator
"""
input_message = form_data["input_message"].strip()
if not input_message:
input = form_data["input_message"].strip()
if not input:
return
async for value in self.process_input_message(input_message=input_message):
async for value in self.process_input(input=input):
yield value
async def process_input_message(self, input_message: str) -> AsyncGenerator:
async def process_input(self, input: str) -> AsyncGenerator:
"""
处理用户输入消息
:param input_message: 用户输入的消息
处理用户输入
:param input: 用户输入
:return: AsyncGenerator
"""
# 当前会话
current_session = self.sessions[self.current_session_name]
# 将用户输入消息添加到当前会话对话列表
current_session.turns.append(
Turn(
input_message=input_message,
)
)
# 当前会话正在处理
current_session.is_processing = True
# 将用户输入添加到当前会话对话列表
current_session.turns.append(
Turn(
input=input,
)
)
# 当前对话
current_turn = current_session.turns[-1]
yield
agent = retrieve_agent(self)
async for chunk in agent.output_message_streamed(user_prompt=input_message):
async for chunk in agent.output_message_streamed(user_prompt=input):
if not chunk:
continue
current_session.turns[-1].output_message += chunk
yield
# 当前会话处理完成