This commit is contained in:
parent
70c39ea7e2
commit
f18b3a9c82
|
|
@ -169,11 +169,26 @@ class Agent:
|
|||
)
|
||||
|
||||
async with self.agent.run_stream(
|
||||
user_prompt=user_prompt, message_history=message_history
|
||||
user_prompt=user_prompt,
|
||||
message_history=message_history,
|
||||
) as result:
|
||||
async for chunk in result.stream_text(delta=True): # 只返回新增内容
|
||||
if chunk:
|
||||
yield chunk
|
||||
async for part in result.stream_output(): # 全量结构化输出分片
|
||||
match part.part_kind:
|
||||
case "text": # 就文本分片拆分 content 和 thinking
|
||||
if not (part_content := part.content.strip()):
|
||||
continue
|
||||
if not part.is_reasoning:
|
||||
yield f"0:{part_content}" # content
|
||||
else:
|
||||
yield f"1:{part_content}" # thinking
|
||||
case "tool-call": # 技能调用分片
|
||||
yield f"2:技能名称:{part.tool_name},调用参数:{part.args}"
|
||||
case "tool-return": # 技能返回分片
|
||||
yield f"3:技能结果:{part.content}"
|
||||
case "error": # 错误分片
|
||||
yield f"4:{str(part.error)}"
|
||||
case _:
|
||||
continue
|
||||
|
||||
self.agent_memory.create_new_messages(
|
||||
session_id=self.session_id,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
会话组件
|
||||
会话相关组件
|
||||
"""
|
||||
import reflex
|
||||
from reflex.constants.colors import ColorType
|
||||
|
|
@ -33,12 +33,12 @@ def turn(turn: Turn) -> reflex.Component:
|
|||
"""
|
||||
return reflex.box(
|
||||
reflex.box(
|
||||
message_bubble(message=turn.input_message, color="mauve"),
|
||||
message_bubble(message=turn.input, color="mauve"),
|
||||
text_align="right",
|
||||
margin_bottom="8px",
|
||||
),
|
||||
reflex.box(
|
||||
message_bubble(message=turn.output_message, color="accent"),
|
||||
message_bubble(message=turn.output, color="accent"),
|
||||
text_align="left",
|
||||
margin_bottom="8px",
|
||||
),
|
||||
|
|
@ -69,7 +69,7 @@ def input_bar() -> reflex.Component:
|
|||
reflex.form(
|
||||
reflex.hstack(
|
||||
reflex.input(
|
||||
name="input_message",
|
||||
name="input",
|
||||
placeholder="请输入...",
|
||||
flex="auto",
|
||||
),
|
||||
|
|
@ -83,7 +83,7 @@ def input_bar() -> reflex.Component:
|
|||
margin="0 auto",
|
||||
align_items="center",
|
||||
),
|
||||
on_submit=State.adapt_input_message,
|
||||
on_submit=State.adapt_input,
|
||||
reset_on_submit=True, # 提交后清空输入框
|
||||
),
|
||||
reflex.text(
|
||||
|
|
@ -106,4 +106,4 @@ def input_bar() -> reflex.Component:
|
|||
background_color=reflex.color("mauve", 2),
|
||||
align="stretch",
|
||||
width="100%",
|
||||
)
|
||||
) # reflex.center 等价 reflex.box(display="flex", align_items="center", justify_content="center")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
应用状态管理模块
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -32,11 +32,22 @@ def retrieve_agent(state) -> Agent:
|
|||
return agents[current_session_name]
|
||||
|
||||
|
||||
class MessageBlock(BaseModel):
|
||||
"""消息块数据模型,包含类型和内容"""
|
||||
|
||||
type: Literal[
|
||||
"thinking", "content", "skill_call", "skill_result", "skill_error"
|
||||
] = Field(..., description="类型")
|
||||
content: str = Field(default="", description="内容")
|
||||
|
||||
|
||||
class Turn(BaseModel):
|
||||
"""对话数据模型,包含用户输入消息和智能体输出消息"""
|
||||
|
||||
input_message: str = Field(..., description="用户输入的消息")
|
||||
output_message: str = Field(default="", description="智能体输出的消息")
|
||||
input: str = Field(..., description="用户输入的消息")
|
||||
output: List[MessageBlock] = Field(
|
||||
default_factory=list, description="智能体输出的消息"
|
||||
)
|
||||
|
||||
|
||||
class Session(BaseModel):
|
||||
|
|
@ -153,43 +164,45 @@ class State(reflex.State):
|
|||
self.create_session_modal_is_open = is_open
|
||||
|
||||
@reflex.event
|
||||
async def adapt_input_message(self, form_data: dict[str, Any]) -> AsyncGenerator:
|
||||
async def adapt_input(self, form_data: dict[str, Any]) -> AsyncGenerator:
|
||||
"""
|
||||
适配用户输入消息
|
||||
适配用户输入
|
||||
:param form_data: 对话表单数据
|
||||
:return: AsyncGenerator
|
||||
"""
|
||||
input_message = form_data["input_message"].strip()
|
||||
if not input_message:
|
||||
input = form_data["input_message"].strip()
|
||||
if not input:
|
||||
return
|
||||
|
||||
async for value in self.process_input_message(input_message=input_message):
|
||||
async for value in self.process_input(input=input):
|
||||
yield value
|
||||
|
||||
async def process_input_message(self, input_message: str) -> AsyncGenerator:
|
||||
async def process_input(self, input: str) -> AsyncGenerator:
|
||||
"""
|
||||
处理用户输入消息
|
||||
:param input_message: 用户输入的消息
|
||||
处理用户输入
|
||||
:param input: 用户输入
|
||||
:return: AsyncGenerator
|
||||
"""
|
||||
# 当前会话
|
||||
current_session = self.sessions[self.current_session_name]
|
||||
# 将用户输入消息添加到当前会话对话列表
|
||||
current_session.turns.append(
|
||||
Turn(
|
||||
input_message=input_message,
|
||||
)
|
||||
)
|
||||
|
||||
# 当前会话正在处理
|
||||
current_session.is_processing = True
|
||||
# 将用户输入添加到当前会话对话列表
|
||||
current_session.turns.append(
|
||||
Turn(
|
||||
input=input,
|
||||
)
|
||||
)
|
||||
# 当前对话
|
||||
current_turn = current_session.turns[-1]
|
||||
yield
|
||||
|
||||
agent = retrieve_agent(self)
|
||||
async for chunk in agent.output_message_streamed(user_prompt=input_message):
|
||||
async for chunk in agent.output_message_streamed(user_prompt=input):
|
||||
if not chunk:
|
||||
continue
|
||||
current_session.turns[-1].output_message += chunk
|
||||
|
||||
yield
|
||||
|
||||
# 当前会话处理完成
|
||||
|
|
|
|||
Loading…
Reference in New Issue