This commit is contained in:
parent
70c39ea7e2
commit
f18b3a9c82
|
|
@ -169,11 +169,26 @@ class Agent:
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self.agent.run_stream(
|
async with self.agent.run_stream(
|
||||||
user_prompt=user_prompt, message_history=message_history
|
user_prompt=user_prompt,
|
||||||
|
message_history=message_history,
|
||||||
) as result:
|
) as result:
|
||||||
async for chunk in result.stream_text(delta=True): # 只返回新增内容
|
async for part in result.stream_output(): # 全量结构化输出分片
|
||||||
if chunk:
|
match part.part_kind:
|
||||||
yield chunk
|
case "text": # 就文本分片拆分 content 和 thinking
|
||||||
|
if not (part_content := part.content.strip()):
|
||||||
|
continue
|
||||||
|
if not part.is_reasoning:
|
||||||
|
yield f"0:{part_content}" # content
|
||||||
|
else:
|
||||||
|
yield f"1:{part_content}" # thinking
|
||||||
|
case "tool-call": # 技能调用分片
|
||||||
|
yield f"2:技能名称:{part.tool_name},调用参数:{part.args}"
|
||||||
|
case "tool-return": # 技能返回分片
|
||||||
|
yield f"3:技能结果:{part.content}"
|
||||||
|
case "error": # 错误分片
|
||||||
|
yield f"4:{str(part.error)}"
|
||||||
|
case _:
|
||||||
|
continue
|
||||||
|
|
||||||
self.agent_memory.create_new_messages(
|
self.agent_memory.create_new_messages(
|
||||||
session_id=self.session_id,
|
session_id=self.session_id,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""
|
"""
|
||||||
会话组件
|
会话相关组件
|
||||||
"""
|
"""
|
||||||
import reflex
|
import reflex
|
||||||
from reflex.constants.colors import ColorType
|
from reflex.constants.colors import ColorType
|
||||||
|
|
@ -33,12 +33,12 @@ def turn(turn: Turn) -> reflex.Component:
|
||||||
"""
|
"""
|
||||||
return reflex.box(
|
return reflex.box(
|
||||||
reflex.box(
|
reflex.box(
|
||||||
message_bubble(message=turn.input_message, color="mauve"),
|
message_bubble(message=turn.input, color="mauve"),
|
||||||
text_align="right",
|
text_align="right",
|
||||||
margin_bottom="8px",
|
margin_bottom="8px",
|
||||||
),
|
),
|
||||||
reflex.box(
|
reflex.box(
|
||||||
message_bubble(message=turn.output_message, color="accent"),
|
message_bubble(message=turn.output, color="accent"),
|
||||||
text_align="left",
|
text_align="left",
|
||||||
margin_bottom="8px",
|
margin_bottom="8px",
|
||||||
),
|
),
|
||||||
|
|
@ -69,7 +69,7 @@ def input_bar() -> reflex.Component:
|
||||||
reflex.form(
|
reflex.form(
|
||||||
reflex.hstack(
|
reflex.hstack(
|
||||||
reflex.input(
|
reflex.input(
|
||||||
name="input_message",
|
name="input",
|
||||||
placeholder="请输入...",
|
placeholder="请输入...",
|
||||||
flex="auto",
|
flex="auto",
|
||||||
),
|
),
|
||||||
|
|
@ -83,7 +83,7 @@ def input_bar() -> reflex.Component:
|
||||||
margin="0 auto",
|
margin="0 auto",
|
||||||
align_items="center",
|
align_items="center",
|
||||||
),
|
),
|
||||||
on_submit=State.adapt_input_message,
|
on_submit=State.adapt_input,
|
||||||
reset_on_submit=True, # 提交后清空输入框
|
reset_on_submit=True, # 提交后清空输入框
|
||||||
),
|
),
|
||||||
reflex.text(
|
reflex.text(
|
||||||
|
|
@ -106,4 +106,4 @@ def input_bar() -> reflex.Component:
|
||||||
background_color=reflex.color("mauve", 2),
|
background_color=reflex.color("mauve", 2),
|
||||||
align="stretch",
|
align="stretch",
|
||||||
width="100%",
|
width="100%",
|
||||||
)
|
) # reflex.center 等价 reflex.box(display="flex", align_items="center", justify_content="center")
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
应用状态管理模块
|
应用状态管理模块
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List
|
from typing import Any, AsyncGenerator, Dict, List, Literal
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
@ -32,11 +32,22 @@ def retrieve_agent(state) -> Agent:
|
||||||
return agents[current_session_name]
|
return agents[current_session_name]
|
||||||
|
|
||||||
|
|
||||||
|
class MessageBlock(BaseModel):
|
||||||
|
"""消息块数据模型,包含类型和内容"""
|
||||||
|
|
||||||
|
type: Literal[
|
||||||
|
"thinking", "content", "skill_call", "skill_result", "skill_error"
|
||||||
|
] = Field(..., description="类型")
|
||||||
|
content: str = Field(default="", description="内容")
|
||||||
|
|
||||||
|
|
||||||
class Turn(BaseModel):
|
class Turn(BaseModel):
|
||||||
"""对话数据模型,包含用户输入消息和智能体输出消息"""
|
"""对话数据模型,包含用户输入消息和智能体输出消息"""
|
||||||
|
|
||||||
input_message: str = Field(..., description="用户输入的消息")
|
input: str = Field(..., description="用户输入的消息")
|
||||||
output_message: str = Field(default="", description="智能体输出的消息")
|
output: List[MessageBlock] = Field(
|
||||||
|
default_factory=list, description="智能体输出的消息"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Session(BaseModel):
|
class Session(BaseModel):
|
||||||
|
|
@ -153,43 +164,45 @@ class State(reflex.State):
|
||||||
self.create_session_modal_is_open = is_open
|
self.create_session_modal_is_open = is_open
|
||||||
|
|
||||||
@reflex.event
|
@reflex.event
|
||||||
async def adapt_input_message(self, form_data: dict[str, Any]) -> AsyncGenerator:
|
async def adapt_input(self, form_data: dict[str, Any]) -> AsyncGenerator:
|
||||||
"""
|
"""
|
||||||
适配用户输入消息
|
适配用户输入
|
||||||
:param form_data: 对话表单数据
|
:param form_data: 对话表单数据
|
||||||
:return: AsyncGenerator
|
:return: AsyncGenerator
|
||||||
"""
|
"""
|
||||||
input_message = form_data["input_message"].strip()
|
input = form_data["input_message"].strip()
|
||||||
if not input_message:
|
if not input:
|
||||||
return
|
return
|
||||||
|
|
||||||
async for value in self.process_input_message(input_message=input_message):
|
async for value in self.process_input(input=input):
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
async def process_input_message(self, input_message: str) -> AsyncGenerator:
|
async def process_input(self, input: str) -> AsyncGenerator:
|
||||||
"""
|
"""
|
||||||
处理用户输入消息
|
处理用户输入
|
||||||
:param input_message: 用户输入的消息
|
:param input: 用户输入
|
||||||
:return: AsyncGenerator
|
:return: AsyncGenerator
|
||||||
"""
|
"""
|
||||||
# 当前会话
|
# 当前会话
|
||||||
current_session = self.sessions[self.current_session_name]
|
current_session = self.sessions[self.current_session_name]
|
||||||
# 将用户输入消息添加到当前会话对话列表
|
|
||||||
current_session.turns.append(
|
|
||||||
Turn(
|
|
||||||
input_message=input_message,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 当前会话正在处理
|
# 当前会话正在处理
|
||||||
current_session.is_processing = True
|
current_session.is_processing = True
|
||||||
|
# 将用户输入添加到当前会话对话列表
|
||||||
|
current_session.turns.append(
|
||||||
|
Turn(
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 当前对话
|
||||||
|
current_turn = current_session.turns[-1]
|
||||||
yield
|
yield
|
||||||
|
|
||||||
agent = retrieve_agent(self)
|
agent = retrieve_agent(self)
|
||||||
async for chunk in agent.output_message_streamed(user_prompt=input_message):
|
async for chunk in agent.output_message_streamed(user_prompt=input):
|
||||||
if not chunk:
|
if not chunk:
|
||||||
continue
|
continue
|
||||||
current_session.turns[-1].output_message += chunk
|
current_session.turns[-1].output_message += chunk
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# 当前会话处理完成
|
# 当前会话处理完成
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue