diff --git a/utils/agent.py b/utils/agent.py index 9270f4c..4ede8cf 100644 --- a/utils/agent.py +++ b/utils/agent.py @@ -169,11 +169,26 @@ class Agent: ) async with self.agent.run_stream( - user_prompt=user_prompt, message_history=message_history + user_prompt=user_prompt, + message_history=message_history, ) as result: - async for chunk in result.stream_text(delta=True): # 只返回新增内容 - if chunk: - yield chunk + async for part in result.stream_output(): # 全量结构化输出分片 + match part.part_kind: + case "text": # 就文本分片拆分 content 和 thinking + if not (part_content := part.content.strip()): + continue + if not part.is_reasoning: + yield f"0:{part_content}" # content + else: + yield f"1:{part_content}" # thinking + case "tool-call": # 技能调用分片 + yield f"2:技能名称:{part.tool_name},调用参数:{part.args}" + case "tool-return": # 技能返回分片 + yield f"3:技能结果:{part.content}" + case "error": # 错误分片 + yield f"4:{str(part.error)}" + case _: + continue self.agent_memory.create_new_messages( session_id=self.session_id, diff --git a/产品需求文档AI生成/application/components/session.py b/产品需求文档AI生成/application/components/session.py index c5a593f..dfd4606 100644 --- a/产品需求文档AI生成/application/components/session.py +++ b/产品需求文档AI生成/application/components/session.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -会话组件 +会话相关组件 """ import reflex from reflex.constants.colors import ColorType @@ -33,12 +33,12 @@ def turn(turn: Turn) -> reflex.Component: """ return reflex.box( reflex.box( - message_bubble(message=turn.input_message, color="mauve"), + message_bubble(message=turn.input, color="mauve"), text_align="right", margin_bottom="8px", ), reflex.box( - message_bubble(message=turn.output_message, color="accent"), + message_bubble(message=turn.output, color="accent"), text_align="left", margin_bottom="8px", ), @@ -69,7 +69,7 @@ def input_bar() -> reflex.Component: reflex.form( reflex.hstack( reflex.input( - name="input_message", + name="input", placeholder="请输入...", flex="auto", ), @@ -83,7 +83,7 @@ def input_bar() -> reflex.Component: margin="0 auto", align_items="center", ), - on_submit=State.adapt_input_message, + on_submit=State.adapt_input, reset_on_submit=True, # 提交后清空输入框 ), reflex.text( @@ -106,4 +106,4 @@ def input_bar() -> reflex.Component: background_color=reflex.color("mauve", 2), align="stretch", width="100%", - ) + ) # reflex.center 等价 reflex.box(display="flex", align_items="center", justify_content="center") diff --git a/产品需求文档AI生成/application/state.py b/产品需求文档AI生成/application/state.py index 82570fa..9ae30fd 100644 --- a/产品需求文档AI生成/application/state.py +++ b/产品需求文档AI生成/application/state.py @@ -3,7 +3,7 @@ 应用状态管理模块 """ -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator, Dict, List, Literal from uuid import uuid4 from pydantic import BaseModel, Field @@ -32,11 +32,22 @@ def retrieve_agent(state) -> Agent: return agents[current_session_name] +class MessageBlock(BaseModel): + """消息块数据模型,包含类型和内容""" + + type: Literal[ + "thinking", "content", "skill_call", "skill_result", "skill_error" + ] = Field(..., description="类型") + content: str = Field(default="", description="内容") + + class Turn(BaseModel): """对话数据模型,包含用户输入消息和智能体输出消息""" - input_message: str = Field(..., description="用户输入的消息") - output_message: str = Field(default="", description="智能体输出的消息") + input: str = Field(..., description="用户输入的消息") + output: List[MessageBlock] = Field( + default_factory=list, description="智能体输出的消息" + ) class Session(BaseModel): @@ -153,43 +164,45 @@ class State(reflex.State): self.create_session_modal_is_open = is_open @reflex.event - async def adapt_input_message(self, form_data: dict[str, Any]) -> AsyncGenerator: + async def adapt_input(self, form_data: dict[str, Any]) -> AsyncGenerator: """ - 适配用户输入消息 + 适配用户输入 :param form_data: 对话表单数据 :return: AsyncGenerator """ - input_message = form_data["input_message"].strip() - if not input_message: + input = form_data["input_message"].strip() + if not input: return - async for value in self.process_input_message(input_message=input_message): + async for value in self.process_input(input=input): yield value - async def process_input_message(self, input_message: str) -> AsyncGenerator: + async def process_input(self, input: str) -> AsyncGenerator: """ - 处理用户输入消息 - :param input_message: 用户输入的消息 + 处理用户输入 + :param input: 用户输入 :return: AsyncGenerator """ # 当前会话 current_session = self.sessions[self.current_session_name] - # 将用户输入消息添加到当前会话对话列表 - current_session.turns.append( - Turn( - input_message=input_message, - ) - ) - # 当前会话正在处理 current_session.is_processing = True + # 将用户输入添加到当前会话对话列表 + current_session.turns.append( + Turn( + input=input, + ) + ) + # 当前对话 + current_turn = current_session.turns[-1] yield agent = retrieve_agent(self) - async for chunk in agent.output_message_streamed(user_prompt=input_message): + async for chunk in agent.output_message_streamed(user_prompt=input): if not chunk: continue current_session.turns[-1].output_message += chunk + yield # 当前会话处理完成