260610更新
This commit is contained in:
parent
d8edc97ca8
commit
75a1cd998e
|
|
@ -118,19 +118,21 @@ class Agent:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
instructions: str,
|
||||
output_type: OutputSpec = str,
|
||||
capabilities: Optional[List[AgentCapability]] = None,
|
||||
):
|
||||
"""
|
||||
初始化智能体
|
||||
:param session_id: 会话唯一标识
|
||||
:param instructions: 指令
|
||||
:param skills: 智能体技能列表,默认为不使用技能
|
||||
:param output_type: 输出类型
|
||||
:return: 智能体实例
|
||||
"""
|
||||
# 生成会话唯一标识
|
||||
self.session_id = uuid4().hex.lower()
|
||||
# 会话唯一标识
|
||||
self.session_id = session_id
|
||||
|
||||
# 实例智能体的记忆体
|
||||
self.agent_memory = AgentMemory()
|
||||
|
|
@ -150,6 +152,8 @@ class Agent:
|
|||
retries=1,
|
||||
)
|
||||
|
||||
self.agent.to_web()
|
||||
|
||||
async def output_message_streamed(
|
||||
self, user_prompt: str | List[str]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
|
|
@ -167,7 +171,8 @@ class Agent:
|
|||
async with self.agent.run_stream(
|
||||
user_prompt=user_prompt, message_history=message_history
|
||||
) as result:
|
||||
async for chunk in result.stream_text():
|
||||
async for chunk in result.stream_text(delta=True): # 只返回新增内容
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
self.agent_memory.create_new_messages(
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -1,28 +1,10 @@
|
|||
"""The main Chat app."""
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
主模块
|
||||
"""
|
||||
|
||||
from reflex import App, Component, vstack as VStack, color as Color, theme as Theme
|
||||
from .components import session, navbar
|
||||
import reflex
|
||||
from .pages.index import index
|
||||
|
||||
|
||||
def index() -> Component:
|
||||
"""The main app."""
|
||||
return VStack(
|
||||
navbar.navbar(),
|
||||
session.session(),
|
||||
session.action_bar(),
|
||||
background_color=Color(color="mauve", shade=1),
|
||||
color=Color(color="mauve", shade=12),
|
||||
height="100dvh",
|
||||
align_items="stretch",
|
||||
spacing="0",
|
||||
)
|
||||
|
||||
|
||||
# Add state and page to the app.
|
||||
app = App(
|
||||
theme=Theme(
|
||||
appearance="dark",
|
||||
accent_color="purple",
|
||||
),
|
||||
)
|
||||
app = reflex.App()
|
||||
app.add_page(index)
|
||||
|
|
|
|||
|
|
@ -1,142 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
导航栏组件
|
||||
"""
|
||||
|
||||
from reflex import (
|
||||
Component,
|
||||
badge as Badge,
|
||||
button as Button,
|
||||
color as Color,
|
||||
dialog as Dialog,
|
||||
divider as Divider,
|
||||
drawer as Drawer,
|
||||
foreach as Foreach,
|
||||
form as Form,
|
||||
heading as Heading,
|
||||
hstack as HStack,
|
||||
icon as Icon,
|
||||
input as Input,
|
||||
tooltip as Tooltip,
|
||||
vstack as VStack,
|
||||
)
|
||||
|
||||
from ..state import State
|
||||
|
||||
|
||||
def sidebar_session(session_name: str) -> Component:
|
||||
"""
|
||||
侧边栏会话组件
|
||||
:param session_name: 会话名称
|
||||
:return: Component
|
||||
"""
|
||||
return Drawer.close(
|
||||
HStack(
|
||||
Button(
|
||||
session_name,
|
||||
on_click=lambda: State.set_current_session(session_name),
|
||||
width="80%",
|
||||
variant="surface",
|
||||
),
|
||||
Button(
|
||||
Icon(
|
||||
tag="trash",
|
||||
on_click=lambda: State.delete_session(session_name),
|
||||
stroke_width=1,
|
||||
),
|
||||
width="20%",
|
||||
variant="surface",
|
||||
color_scheme="red",
|
||||
),
|
||||
width="100%",
|
||||
),
|
||||
key=session_name,
|
||||
)
|
||||
|
||||
|
||||
def sidebar(trigger) -> Component:
|
||||
"""
|
||||
侧边栏组件
|
||||
"""
|
||||
return Drawer.root(
|
||||
Drawer.trigger(trigger),
|
||||
Drawer.overlay(),
|
||||
Drawer.portal(
|
||||
Drawer.content(
|
||||
VStack(
|
||||
Heading("Chats", color=Color("mauve", 11)),
|
||||
Divider(),
|
||||
Foreach(
|
||||
State.get_session_names,
|
||||
lambda session_name: sidebar_session(session_name),
|
||||
),
|
||||
align_items="stretch",
|
||||
width="100%",
|
||||
),
|
||||
top="auto",
|
||||
right="auto",
|
||||
height="100%",
|
||||
width="20em",
|
||||
padding="2em",
|
||||
background_color=Color("mauve", 2),
|
||||
outline="none",
|
||||
)
|
||||
),
|
||||
direction="left",
|
||||
)
|
||||
|
||||
|
||||
def modal(trigger) -> Component:
|
||||
"""A modal to create a new chat."""
|
||||
return Dialog.root(
|
||||
Dialog.trigger(trigger),
|
||||
Dialog.content(
|
||||
Form(
|
||||
HStack(
|
||||
Input(
|
||||
placeholder="Chat name",
|
||||
name="session_name",
|
||||
flex="1",
|
||||
min_width="20ch",
|
||||
),
|
||||
Button("Create chat"),
|
||||
spacing="2",
|
||||
wrap="wrap",
|
||||
width="100%",
|
||||
),
|
||||
on_submit=State.create_session,
|
||||
),
|
||||
background_color=Color("mauve", 1),
|
||||
),
|
||||
open=State.create_session_modal_is_open,
|
||||
on_open_change=State.toggle_create_session_modal,
|
||||
)
|
||||
|
||||
|
||||
def navbar():
|
||||
return HStack(
|
||||
Badge(
|
||||
State.current_session_name,
|
||||
Tooltip(
|
||||
Icon("info", size=14),
|
||||
content="The current selected chat.",
|
||||
),
|
||||
size="3",
|
||||
variant="soft",
|
||||
margin_inline_end="auto",
|
||||
),
|
||||
modal(
|
||||
Icon("message-square-plus"),
|
||||
),
|
||||
sidebar(
|
||||
Icon(
|
||||
"messages-square",
|
||||
background_color=Color("mauve", 6),
|
||||
)
|
||||
),
|
||||
justify_content="space-between",
|
||||
align_items="center",
|
||||
padding="12px",
|
||||
border_bottom=f"1px solid {Color('mauve', 3)}",
|
||||
background_color=Color("mauve", 2),
|
||||
)
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
导航栏组件
|
||||
"""
|
||||
|
||||
import reflex
|
||||
|
||||
from ..state import State
|
||||
|
||||
|
||||
def create_session_modal(trigger) -> reflex.Component:
|
||||
"""
|
||||
创建会话模态窗组件
|
||||
"""
|
||||
return reflex.dialog.root(
|
||||
reflex.dialog.trigger(trigger),
|
||||
reflex.dialog.content(
|
||||
reflex.form(
|
||||
reflex.hstack(
|
||||
reflex.input(
|
||||
name="session_name",
|
||||
placeholder="会话名称",
|
||||
flex="auto",
|
||||
min_width="20ch",
|
||||
),
|
||||
reflex.button("创建会话"),
|
||||
spacing="2",
|
||||
wrap="wrap",
|
||||
width="100%",
|
||||
),
|
||||
on_submit=State.create_session,
|
||||
),
|
||||
background_color=reflex.color("mauve", 1),
|
||||
),
|
||||
open=State.create_session_modal_is_open,
|
||||
on_open_change=State.toggle_create_session_modal,
|
||||
)
|
||||
|
||||
|
||||
def session_history_item(session_name: str) -> reflex.Component:
|
||||
"""
|
||||
会话历史侧边栏中一次会话组件
|
||||
:param session_name: 会话名称
|
||||
:return: Component
|
||||
"""
|
||||
return reflex.drawer.close(
|
||||
reflex.hstack(
|
||||
reflex.button(
|
||||
session_name,
|
||||
on_click=lambda: State.switch_session(
|
||||
session_name
|
||||
), # 点击按钮将切换会话
|
||||
width="80%",
|
||||
variant="surface",
|
||||
),
|
||||
reflex.button(
|
||||
reflex.icon(
|
||||
tag="trash",
|
||||
on_click=lambda: State.delete_session(
|
||||
session_name
|
||||
), # 点击按钮删除会话
|
||||
stroke_width=1,
|
||||
),
|
||||
width="20%",
|
||||
variant="surface",
|
||||
color_scheme="red",
|
||||
),
|
||||
width="100%",
|
||||
),
|
||||
key=session_name, # 使用会话名称作为唯一标识(会话名称不可重复)
|
||||
)
|
||||
|
||||
|
||||
def session_history(trigger) -> reflex.Component:
|
||||
"""
|
||||
会话历史侧边栏组件
|
||||
"""
|
||||
return reflex.drawer.root(
|
||||
reflex.drawer.trigger(trigger),
|
||||
reflex.drawer.overlay(),
|
||||
reflex.drawer.portal(
|
||||
reflex.drawer.content(
|
||||
reflex.vstack(
|
||||
reflex.heading("会话列表", color=reflex.color("mauve", 11)),
|
||||
reflex.divider(),
|
||||
reflex.foreach(
|
||||
State.get_session_names, # 获取所有会话名称
|
||||
lambda session_name: session_history_item(
|
||||
session_name=session_name
|
||||
), # 创建会话组件
|
||||
),
|
||||
align_items="stretch",
|
||||
width="100%",
|
||||
),
|
||||
top="auto",
|
||||
right="auto",
|
||||
height="100%",
|
||||
width="20em",
|
||||
padding="2em",
|
||||
background_color=reflex.color("mauve", 2),
|
||||
outline="none",
|
||||
)
|
||||
),
|
||||
direction="left",
|
||||
)
|
||||
|
||||
|
||||
def navigation_bar() -> reflex.Component:
|
||||
"""
|
||||
导航栏组件
|
||||
"""
|
||||
return reflex.hstack(
|
||||
reflex.badge(
|
||||
State.current_session_name,
|
||||
size="3",
|
||||
variant="soft",
|
||||
margin_inline_end="auto",
|
||||
),
|
||||
create_session_modal(
|
||||
reflex.box(
|
||||
reflex.tooltip(reflex.icon("message-square-plus"), content="创建会话")
|
||||
)
|
||||
),
|
||||
session_history(
|
||||
reflex.box(
|
||||
reflex.tooltip(
|
||||
reflex.icon("messages-square"),
|
||||
content="会话历史",
|
||||
)
|
||||
)
|
||||
),
|
||||
justify_content="space-between",
|
||||
align_items="center",
|
||||
padding="12px",
|
||||
border_bottom=f"1px solid {reflex.color('mauve', 3)}",
|
||||
background_color=reflex.color("mauve", 2),
|
||||
)
|
||||
|
|
@ -2,61 +2,43 @@
|
|||
"""
|
||||
会话组件
|
||||
"""
|
||||
|
||||
from reflex import (
|
||||
Component,
|
||||
auto_scroll as AutoScroll,
|
||||
box as Box,
|
||||
button as Button,
|
||||
center as Center,
|
||||
color as Color,
|
||||
foreach as Foreach,
|
||||
form as Form,
|
||||
hstack as HStack,
|
||||
icon as Icon,
|
||||
input as Input,
|
||||
logo as Logo,
|
||||
markdown as Markdown,
|
||||
text as Text,
|
||||
tooltip as Tooltip,
|
||||
vstack as VStack,
|
||||
)
|
||||
import reflex
|
||||
from reflex.constants.colors import ColorType
|
||||
|
||||
from ..state import State, Turn
|
||||
|
||||
|
||||
def message(text: str, color: ColorType) -> Component:
|
||||
def message_bubble(message: str, color: ColorType) -> reflex.Component:
|
||||
"""
|
||||
消息组件
|
||||
:param text: 文本
|
||||
对话组件中一个消息气泡组件
|
||||
:param message: 消息
|
||||
:param color: 颜色
|
||||
:return: Component
|
||||
"""
|
||||
return Markdown(
|
||||
text,
|
||||
background_color=Color(color=color, shade=4),
|
||||
color=Color(color=color, shade=12),
|
||||
return reflex.markdown(
|
||||
message,
|
||||
color=reflex.color(color=color, shade=12),
|
||||
background_color=reflex.color(color=color, shade=4),
|
||||
display="inline-block",
|
||||
padding_inline="1em",
|
||||
border_radius="8px",
|
||||
)
|
||||
|
||||
|
||||
def turn(turn: Turn) -> Component:
|
||||
def turn(turn: Turn) -> reflex.Component:
|
||||
"""
|
||||
对话组件
|
||||
会话话组件中一次对话组件
|
||||
:param turn: 对话
|
||||
:return: Component
|
||||
"""
|
||||
return Box(
|
||||
Box(
|
||||
message(text=turn.input_message, color="mauve"),
|
||||
return reflex.box(
|
||||
reflex.box(
|
||||
message_bubble(message=turn.input_message, color="mauve"),
|
||||
text_align="right",
|
||||
margin_bottom="8px",
|
||||
),
|
||||
Box(
|
||||
message(text=turn.output_message, color="accent"),
|
||||
reflex.box(
|
||||
message_bubble(message=turn.output_message, color="accent"),
|
||||
text_align="left",
|
||||
margin_bottom="8px",
|
||||
),
|
||||
|
|
@ -65,58 +47,51 @@ def turn(turn: Turn) -> Component:
|
|||
)
|
||||
|
||||
|
||||
def session() -> Component:
|
||||
def session_area() -> reflex.Component:
|
||||
"""
|
||||
会话组件
|
||||
会话区域组件
|
||||
:return: Component
|
||||
"""
|
||||
return AutoScroll(
|
||||
Foreach(State.get_turns, turn),
|
||||
return reflex.auto_scroll(
|
||||
reflex.foreach(State.get_current_session_turns, turn),
|
||||
flex="1",
|
||||
padding="8px",
|
||||
overflow_y="auto",
|
||||
)
|
||||
|
||||
|
||||
def action_bar() -> Component:
|
||||
def input_bar() -> reflex.Component:
|
||||
"""
|
||||
输入发送栏组件
|
||||
输入栏组件
|
||||
"""
|
||||
return Center(
|
||||
VStack(
|
||||
Form(
|
||||
HStack(
|
||||
Input(
|
||||
Input.slot(
|
||||
Tooltip(
|
||||
Icon("info", size=18),
|
||||
content="Enter a question to get a response.",
|
||||
)
|
||||
return reflex.center(
|
||||
reflex.vstack(
|
||||
reflex.form(
|
||||
reflex.hstack(
|
||||
reflex.input(
|
||||
name="input_message",
|
||||
placeholder="请输入...",
|
||||
flex="auto",
|
||||
),
|
||||
placeholder="Type something...",
|
||||
id="input_message",
|
||||
flex="1",
|
||||
),
|
||||
Button(
|
||||
"Send",
|
||||
loading=State.current_session_processing,
|
||||
disabled=State.current_session_processing,
|
||||
reflex.button(
|
||||
"发送",
|
||||
type="submit",
|
||||
loading=State.get_current_session_status, # 正在处理中时按钮显示为 loading
|
||||
disabled=State.get_current_session_status, # 正在处理中时按钮禁用
|
||||
),
|
||||
max_width="50em",
|
||||
margin="0 auto",
|
||||
align_items="center",
|
||||
),
|
||||
reset_on_submit=True, # 提交后清空输入框
|
||||
on_submit=State.adapt_input_message,
|
||||
reset_on_submit=True, # 提交后清空输入框
|
||||
),
|
||||
Text(
|
||||
"ReflexGPT may return factually incorrect or misleading responses. Use discretion.",
|
||||
reflex.text(
|
||||
"抹茶兔兔工作室",
|
||||
text_align="center",
|
||||
font_size=".75em",
|
||||
color=Color("mauve", 10),
|
||||
color=reflex.color("mauve", 10),
|
||||
),
|
||||
Logo(margin_block="-1em"),
|
||||
width="100%",
|
||||
padding_x="16px",
|
||||
align="stretch",
|
||||
|
|
@ -127,8 +102,8 @@ def action_bar() -> Component:
|
|||
padding_y="16px",
|
||||
backdrop_filter="auto",
|
||||
backdrop_blur="lg",
|
||||
border_top=f"1px solid {Color('mauve', 3)}",
|
||||
background_color=Color("mauve", 2),
|
||||
border_top=f"1px solid {reflex.color('mauve', 3)}",
|
||||
background_color=reflex.color("mauve", 2),
|
||||
align="stretch",
|
||||
width="100%",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
会话页面
|
||||
"""
|
||||
|
||||
import reflex
|
||||
|
||||
from ..components.navigation_bar import navigation_bar
|
||||
from ..components.session import session_area, input_bar
|
||||
|
||||
|
||||
def index() -> reflex.Component:
|
||||
"""根页面"""
|
||||
return reflex.vstack(
|
||||
navigation_bar(), # 导航栏
|
||||
session_area(), # 会话区域
|
||||
input_bar(), # 输入栏
|
||||
color=reflex.color(color="mauve", shade=12),
|
||||
background_color=reflex.color(color="mauve", shade=1),
|
||||
height="100dvh",
|
||||
align_items="stretch",
|
||||
spacing="0",
|
||||
)
|
||||
|
|
@ -3,7 +3,8 @@
|
|||
应用状态管理模块
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import reflex
|
||||
|
|
@ -13,18 +14,22 @@ from sys import path
|
|||
path.append((Path(__file__).resolve().parent.parent.parent.as_posix()))
|
||||
from utils.agent import Agent
|
||||
|
||||
# 所有会话绑定的智能体
|
||||
agents: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def retrieve_agent(state) -> Agent:
|
||||
"""
|
||||
获取当前会话绑定的智能体
|
||||
:return: 当前会话绑定的智能体
|
||||
"""
|
||||
|
||||
if state.current_session_name not in state._agents:
|
||||
state._agents[state.current_session_name] = Agent(
|
||||
instructions="You are a friendly chatbot named Reflex. Respond in markdown."
|
||||
current_session_name = state.current_session_name
|
||||
if current_session_name not in agents:
|
||||
agents[current_session_name] = Agent(
|
||||
session_id=state.sessions[current_session_name].id,
|
||||
instructions="You are a friendly chatbot named Reflex. Respond in markdown.",
|
||||
)
|
||||
return state._agents[state.current_session_name]
|
||||
return agents[current_session_name]
|
||||
|
||||
|
||||
class Turn(BaseModel):
|
||||
|
|
@ -35,24 +40,22 @@ class Turn(BaseModel):
|
|||
|
||||
|
||||
class Session(BaseModel):
|
||||
"""会话数据模型,包含会绑定智能体和会话对话列表"""
|
||||
"""会话数据模型,包含会话唯一标识和会话对话列表"""
|
||||
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="会话唯一标识")
|
||||
is_processing: bool = Field(default=False, description="会话是否正在处理中")
|
||||
turns: List[Turn] = Field(default_factory=list, description="会话对话列表")
|
||||
|
||||
|
||||
# Reflex.State 统一管理应用数据与功能状态,作为前后端交互枢纽,借助响应式特性实现页面自动更新
|
||||
class State(reflex.State):
|
||||
"""应用状态"""
|
||||
|
||||
# 当前会话名称
|
||||
current_session_name: str = "Intro"
|
||||
# 当前会话正在处理
|
||||
current_session_processing: bool = False
|
||||
current_session_name: str = "NewSession"
|
||||
|
||||
# 所有会话
|
||||
sessions: Dict[str, Session] = {"Intro": Session()}
|
||||
|
||||
# 所有会话绑定的智能体(私有变量)
|
||||
_agents: Dict[str, Any] = {}
|
||||
sessions: Dict[str, Session] = {current_session_name: Session()}
|
||||
|
||||
# 新建会话模态窗是否打开
|
||||
create_session_modal_is_open: bool = False
|
||||
|
|
@ -66,14 +69,60 @@ class State(reflex.State):
|
|||
return list(self.sessions)
|
||||
|
||||
@reflex.event
|
||||
def set_current_session(self, session_name: str) -> None:
|
||||
def switch_session(self, session_name: str) -> None:
|
||||
"""
|
||||
将所选会话设置为当前会话
|
||||
切换会话
|
||||
:param session_name: 会话名称
|
||||
:return: None
|
||||
"""
|
||||
self.current_session_name = session_name
|
||||
|
||||
@reflex.var
|
||||
def get_current_session_status(self) -> bool:
|
||||
"""
|
||||
获取当前会话状态
|
||||
:return: 当前会话状态,其中 True 表示正在处理中,False 表示处理完成
|
||||
"""
|
||||
if self.current_session_name not in self.sessions:
|
||||
return False
|
||||
return self.sessions[self.current_session_name].is_processing
|
||||
|
||||
@reflex.var
|
||||
def get_current_session_turns(self) -> List[Turn]:
|
||||
"""
|
||||
获取当前会话对话列表
|
||||
:return: 当前会话对话列表
|
||||
"""
|
||||
if self.current_session_name not in self.sessions:
|
||||
return []
|
||||
return self.sessions[self.current_session_name].turns
|
||||
|
||||
@reflex.event
|
||||
def create_session(self, form_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
创建会话
|
||||
:param form_data: 创建会话表单数据
|
||||
:return: None
|
||||
"""
|
||||
session_name = form_data["session_name"].strip()
|
||||
|
||||
# 若创建会话名称为空则默认使用"NewSession"作为会话名称
|
||||
if not session_name:
|
||||
session_name = "NewSession"
|
||||
|
||||
original_session_name = session_name
|
||||
counter = 1
|
||||
# 若会话名称重复则在会话名称后面添加序号至不重复
|
||||
while session_name in self.sessions:
|
||||
session_name = f"{original_session_name}({counter})"
|
||||
counter += 1
|
||||
|
||||
self.current_session_name = session_name
|
||||
self.sessions[session_name] = Session()
|
||||
|
||||
# 关闭新建会话模态窗
|
||||
self.create_session_modal_is_open = False
|
||||
|
||||
@reflex.event
|
||||
def delete_session(self, session_name: str) -> None:
|
||||
"""
|
||||
|
|
@ -88,59 +137,23 @@ class State(reflex.State):
|
|||
|
||||
# 若删除会话后所有会话为空则默认创建空白会话
|
||||
if not self.sessions:
|
||||
self.sessions["Intro"] = Session()
|
||||
self.sessions["NewSession"] = Session()
|
||||
|
||||
# 若删除会话后当前会话名称不存在则默认使用第一个会话名称
|
||||
# 删除会话后,若当前会话名称不存在则默认使用第一个会话名称
|
||||
if self.current_session_name not in self.sessions:
|
||||
self.current_session_name = next(iter(self.sessions))
|
||||
|
||||
@reflex.var
|
||||
def get_turns(self) -> List[Turn]:
|
||||
"""
|
||||
获取当前会话所有对话
|
||||
:return: 当前会话所有对话
|
||||
"""
|
||||
if self.current_session_name not in self.sessions:
|
||||
return (
|
||||
[]
|
||||
) # 因 reflex 实时响应状态变更(前端立刻自动刷新),故需要考虑当前会话不存在的情况
|
||||
return self.sessions[self.current_session_name].turns
|
||||
|
||||
@reflex.event
|
||||
def toggle_create_session_modal(self, is_open: bool) -> None:
|
||||
"""
|
||||
切换新建会话模态窗开关状态
|
||||
打开 / 关闭新建会话模态窗
|
||||
:param is_open: 打开或关闭新建会话模态窗
|
||||
:return: None
|
||||
"""
|
||||
self.create_session_modal_is_open = is_open
|
||||
|
||||
@reflex.event
|
||||
def create_session(self, form_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
创建会话
|
||||
:param form_data: 创建会话表单数据
|
||||
:return: None
|
||||
"""
|
||||
session_name = form_data["session_name"].strip()
|
||||
|
||||
# 若创建会话名称为空字则默认使用"Intro"作为会话名称
|
||||
if not session_name:
|
||||
session_name = "Intro"
|
||||
|
||||
original_session_name = session_name
|
||||
counter = 1
|
||||
# 若会话名称重复则在会话名称后面添加序号至不重复
|
||||
while session_name in self.sessions:
|
||||
session_name = f"{original_session_name}_{counter}"
|
||||
counter += 1
|
||||
|
||||
self.current_session_name = session_name
|
||||
self.sessions[session_name] = Session()
|
||||
self.create_session_modal_is_open = False # 关闭新建会话模态窗
|
||||
|
||||
@reflex.event
|
||||
async def adapt_input_message(self, form_data: dict[str, str]) -> AsyncGenerator:
|
||||
async def adapt_input_message(self, form_data: dict[str, Any]) -> AsyncGenerator:
|
||||
"""
|
||||
适配用户输入消息
|
||||
:param form_data: 对话表单数据
|
||||
|
|
@ -150,10 +163,10 @@ class State(reflex.State):
|
|||
if not input_message:
|
||||
return
|
||||
|
||||
async for value in self._process_input_message(input_message):
|
||||
async for value in self.process_input_message(input_message=input_message):
|
||||
yield value
|
||||
|
||||
async def _process_input_message(self, input_message: str) -> AsyncGenerator:
|
||||
async def process_input_message(self, input_message: str) -> AsyncGenerator:
|
||||
"""
|
||||
处理用户输入消息
|
||||
:param input_message: 用户输入的消息
|
||||
|
|
@ -169,7 +182,7 @@ class State(reflex.State):
|
|||
)
|
||||
|
||||
# 当前会话正在处理
|
||||
self.current_session_processing = True
|
||||
current_session.is_processing = True
|
||||
yield
|
||||
|
||||
agent = retrieve_agent(self)
|
||||
|
|
@ -180,4 +193,4 @@ class State(reflex.State):
|
|||
yield
|
||||
|
||||
# 当前会话处理完成
|
||||
self.current_session_processing = False
|
||||
current_session.is_processing = False
|
||||
|
|
|
|||
|
|
@ -1,5 +1,15 @@
|
|||
import reflex as rx
|
||||
import reflex
|
||||
from reflex.plugins import RadixThemesPlugin
|
||||
from reflex_base.plugins.sitemap import SitemapPlugin
|
||||
|
||||
config = rx.Config(
|
||||
config = reflex.Config(
|
||||
app_name="application",
|
||||
disable_plugins=[SitemapPlugin],
|
||||
plugins=[
|
||||
RadixThemesPlugin(
|
||||
theme=reflex.theme(
|
||||
appearance="dark", accent_color="purple" # 暗黑模式 # 主题色
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue