diff --git a/utils/agent.py b/utils/agent.py index 4ede8cf..bcabeae 100644 --- a/utils/agent.py +++ b/utils/agent.py @@ -8,10 +8,10 @@ from pathlib import Path import time from typing import AsyncGenerator, List, Optional from uuid import uuid4 - +from pydantic_ai.messages import ModelMessage from pydantic_ai import Agent as PydanticAIAgent, ModelMessage from pydantic_ai.capabilities import AgentCapability -from pydantic_ai.messages import ModelMessagesTypeAdapter +from pydantic_ai.messages import ModelMessagesTypeAdapter, ModelResponse from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.output import OutputSpec from pydantic_ai.providers.openai import OpenAIProvider @@ -172,24 +172,36 @@ class Agent: user_prompt=user_prompt, message_history=message_history, ) as result: - async for part in result.stream_output(): # 全量结构化输出分片 - match part.part_kind: - case "text": # 就文本分片拆分 content 和 thinking - if not (part_content := part.content.strip()): - continue - if not part.is_reasoning: - yield f"0:{part_content}" # content - else: - yield f"1:{part_content}" # thinking - case "tool-call": # 技能调用分片 - yield f"2:技能名称:{part.tool_name},调用参数:{part.args}" - case "tool-return": # 技能返回分片 - yield f"3:技能结果:{part.content}" - case "error": # 错误分片 - yield f"4:{str(part.error)}" - case _: + async for response in result.stream_response(): # 完整结构化相应对象 + if not isinstance(response, ModelResponse): + continue + + if not hasattr(response, "parts") or not response.parts: + continue + + for part in response.parts: + if isinstance(part, str): + content = part.strip() + if content: + yield f"00:{content}" continue + match part.part_kind: + case "text": + content = part.content.strip() + if content: + yield f"00:{content}" + case "thinking": + content = part.content.strip() + if content: + yield f"01:{content}" + case "tool-call" | "builtin-tool-call": + yield f"02:技能名称:{part.tool_name},调用参数:{part.args}" + case "tool-return" | "builtin-tool-return": + yield f"03:技能结果:{part.content}" + case _: + continue + self.agent_memory.create_new_messages( session_id=self.session_id, new_messages=result.new_messages(), diff --git a/utils/agent_memory.db b/utils/agent_memory.db index 6ad6cd4..bcc6c61 100644 Binary files a/utils/agent_memory.db and b/utils/agent_memory.db differ diff --git a/产品需求文档AI生成/application/components/session.py b/产品需求文档AI生成/application/components/session.py index dfd4606..40b1100 100644 --- a/产品需求文档AI生成/application/components/session.py +++ b/产品需求文档AI生成/application/components/session.py @@ -5,12 +5,12 @@ import reflex from reflex.constants.colors import ColorType -from ..state import State, Turn +from ..state import State, MessageBlockType, MessageBlock, Turn -def message_bubble(message: str, color: ColorType) -> reflex.Component: +def input_bubble(message: str, color: ColorType) -> reflex.Component: """ - 对话组件中一个消息气泡组件 + 输入气泡组件 :param message: 消息 :param color: 颜色 :return: Component @@ -25,6 +25,45 @@ def message_bubble(message: str, color: ColorType) -> reflex.Component: ) +def output_bubble(message_block: MessageBlock) -> reflex.Component: + """ + 输出气泡组件 + :param message_block: 消息块 + :return: 气泡组件 + """ + color = reflex.cond( + message_block.type == MessageBlockType.content, + "accent", + reflex.cond( + message_block.type == MessageBlockType.thinking, + "iris", + reflex.cond( + message_block.type == MessageBlockType.tool_call, + "orange", + reflex.cond( + message_block.type == MessageBlockType.tool_result, + "teal", + reflex.cond( + message_block.type == MessageBlockType.error, + "red", + "mauve", # 兜底 + ), + ), + ), + ), + ) + return reflex.markdown( + message_block.content, + color=reflex.color(color=color, shade=12), + background_color=reflex.color(color=color, shade=4), + display="inline-block", + padding_inline="1em", + padding_block="0.5em", + border_radius="8px", + margin_bottom="4px", + ) + + def turn(turn: Turn) -> reflex.Component: """ 对话组件 @@ -33,12 +72,12 @@ def turn(turn: Turn) -> reflex.Component: """ return reflex.box( reflex.box( - message_bubble(message=turn.input, color="mauve"), + input_bubble(message=turn.input, color="mauve"), text_align="right", margin_bottom="8px", ), reflex.box( - message_bubble(message=turn.output, color="accent"), + reflex.foreach(turn.output, output_bubble), text_align="left", margin_bottom="8px", ), diff --git a/产品需求文档AI生成/application/state.py b/产品需求文档AI生成/application/state.py index 9ae30fd..05f0fd1 100644 --- a/产品需求文档AI生成/application/state.py +++ b/产品需求文档AI生成/application/state.py @@ -5,7 +5,7 @@ from typing import Any, AsyncGenerator, Dict, List, Literal from uuid import uuid4 - +from enum import StrEnum from pydantic import BaseModel, Field import reflex from pathlib import Path @@ -27,17 +27,29 @@ def retrieve_agent(state) -> Agent: if current_session_name not in agents: agents[current_session_name] = Agent( session_id=state.sessions[current_session_name].id, - instructions="You are a friendly chatbot named Reflex. Respond in markdown.", + instructions="You are a friendly chatbot", ) return agents[current_session_name] +# 消息块类型 +class MessageBlockType(StrEnum): + + content = "content" + thinking = "thinking" + tool_call = "tool_call" + tool_result = "tool_result" + error = "error" + + +# 消息块类型前缀映射 +MESSAGE_BLOCK_TYPE_PREFIX_MAP = {f"{i:02d}:": m for i, m in enumerate(MessageBlockType)} + + class MessageBlock(BaseModel): """消息块数据模型,包含类型和内容""" - type: Literal[ - "thinking", "content", "skill_call", "skill_result", "skill_error" - ] = Field(..., description="类型") + type: MessageBlockType = Field(..., description="类型") content: str = Field(default="", description="内容") @@ -170,7 +182,7 @@ class State(reflex.State): :param form_data: 对话表单数据 :return: AsyncGenerator """ - input = form_data["input_message"].strip() + input = form_data["input"].strip() if not input: return @@ -193,17 +205,36 @@ class State(reflex.State): input=input, ) ) + yield # 通知前端更新状态(显示用户输入) + # 当前对话 current_turn = current_session.turns[-1] - yield + # 获取当前会话绑定的智能体 agent = retrieve_agent(self) async for chunk in agent.output_message_streamed(user_prompt=input): + # 跳过空分块 if not chunk: + yield continue - current_session.turns[-1].output_message += chunk - yield + # 匹配消息块类型 + prefix_matched = next( + (t for t in MESSAGE_BLOCK_TYPE_PREFIX_MAP if chunk.startswith(t)), None + ) + # 跳过未匹配分块 + if not prefix_matched: + yield + continue + + # 消息块类型 + type = MESSAGE_BLOCK_TYPE_PREFIX_MAP[prefix_matched] + # 若当前对话输出为空或当前消息块类型和上一个消息块类型不一致则创建消息块 + if not current_turn.output or current_turn.output[-1].type != type: + current_turn.output.append(MessageBlock(type=type)) + current_turn.output[-1].content += chunk.removeprefix(prefix_matched) + + yield # 通知前端更新状态(打字机效果显示输出) # 当前会话处理完成 current_session.is_processing = False