260610更新
This commit is contained in:
parent
d8edc97ca8
commit
75a1cd998e
|
|
@ -118,19 +118,21 @@ class Agent:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
session_id: str,
|
||||||
instructions: str,
|
instructions: str,
|
||||||
output_type: OutputSpec = str,
|
output_type: OutputSpec = str,
|
||||||
capabilities: Optional[List[AgentCapability]] = None,
|
capabilities: Optional[List[AgentCapability]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化智能体
|
初始化智能体
|
||||||
|
:param session_id: 会话唯一标识
|
||||||
:param instructions: 指令
|
:param instructions: 指令
|
||||||
:param skills: 智能体技能列表,默认为不使用技能
|
:param skills: 智能体技能列表,默认为不使用技能
|
||||||
:param output_type: 输出类型
|
:param output_type: 输出类型
|
||||||
:return: 智能体实例
|
:return: 智能体实例
|
||||||
"""
|
"""
|
||||||
# 生成会话唯一标识
|
# 会话唯一标识
|
||||||
self.session_id = uuid4().hex.lower()
|
self.session_id = session_id
|
||||||
|
|
||||||
# 实例智能体的记忆体
|
# 实例智能体的记忆体
|
||||||
self.agent_memory = AgentMemory()
|
self.agent_memory = AgentMemory()
|
||||||
|
|
@ -150,6 +152,8 @@ class Agent:
|
||||||
retries=1,
|
retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.agent.to_web()
|
||||||
|
|
||||||
async def output_message_streamed(
|
async def output_message_streamed(
|
||||||
self, user_prompt: str | List[str]
|
self, user_prompt: str | List[str]
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
|
|
@ -167,7 +171,8 @@ class Agent:
|
||||||
async with self.agent.run_stream(
|
async with self.agent.run_stream(
|
||||||
user_prompt=user_prompt, message_history=message_history
|
user_prompt=user_prompt, message_history=message_history
|
||||||
) as result:
|
) as result:
|
||||||
async for chunk in result.stream_text():
|
async for chunk in result.stream_text(delta=True): # 只返回新增内容
|
||||||
|
if chunk:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
self.agent_memory.create_new_messages(
|
self.agent_memory.create_new_messages(
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -1,28 +1,10 @@
|
||||||
"""The main Chat app."""
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
主模块
|
||||||
|
"""
|
||||||
|
|
||||||
from reflex import App, Component, vstack as VStack, color as Color, theme as Theme
|
import reflex
|
||||||
from .components import session, navbar
|
from .pages.index import index
|
||||||
|
|
||||||
|
app = reflex.App()
|
||||||
def index() -> Component:
|
|
||||||
"""The main app."""
|
|
||||||
return VStack(
|
|
||||||
navbar.navbar(),
|
|
||||||
session.session(),
|
|
||||||
session.action_bar(),
|
|
||||||
background_color=Color(color="mauve", shade=1),
|
|
||||||
color=Color(color="mauve", shade=12),
|
|
||||||
height="100dvh",
|
|
||||||
align_items="stretch",
|
|
||||||
spacing="0",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Add state and page to the app.
|
|
||||||
app = App(
|
|
||||||
theme=Theme(
|
|
||||||
appearance="dark",
|
|
||||||
accent_color="purple",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
app.add_page(index)
|
app.add_page(index)
|
||||||
|
|
|
||||||
|
|
@ -1,142 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
导航栏组件
|
|
||||||
"""
|
|
||||||
|
|
||||||
from reflex import (
|
|
||||||
Component,
|
|
||||||
badge as Badge,
|
|
||||||
button as Button,
|
|
||||||
color as Color,
|
|
||||||
dialog as Dialog,
|
|
||||||
divider as Divider,
|
|
||||||
drawer as Drawer,
|
|
||||||
foreach as Foreach,
|
|
||||||
form as Form,
|
|
||||||
heading as Heading,
|
|
||||||
hstack as HStack,
|
|
||||||
icon as Icon,
|
|
||||||
input as Input,
|
|
||||||
tooltip as Tooltip,
|
|
||||||
vstack as VStack,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..state import State
|
|
||||||
|
|
||||||
|
|
||||||
def sidebar_session(session_name: str) -> Component:
|
|
||||||
"""
|
|
||||||
侧边栏会话组件
|
|
||||||
:param session_name: 会话名称
|
|
||||||
:return: Component
|
|
||||||
"""
|
|
||||||
return Drawer.close(
|
|
||||||
HStack(
|
|
||||||
Button(
|
|
||||||
session_name,
|
|
||||||
on_click=lambda: State.set_current_session(session_name),
|
|
||||||
width="80%",
|
|
||||||
variant="surface",
|
|
||||||
),
|
|
||||||
Button(
|
|
||||||
Icon(
|
|
||||||
tag="trash",
|
|
||||||
on_click=lambda: State.delete_session(session_name),
|
|
||||||
stroke_width=1,
|
|
||||||
),
|
|
||||||
width="20%",
|
|
||||||
variant="surface",
|
|
||||||
color_scheme="red",
|
|
||||||
),
|
|
||||||
width="100%",
|
|
||||||
),
|
|
||||||
key=session_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sidebar(trigger) -> Component:
|
|
||||||
"""
|
|
||||||
侧边栏组件
|
|
||||||
"""
|
|
||||||
return Drawer.root(
|
|
||||||
Drawer.trigger(trigger),
|
|
||||||
Drawer.overlay(),
|
|
||||||
Drawer.portal(
|
|
||||||
Drawer.content(
|
|
||||||
VStack(
|
|
||||||
Heading("Chats", color=Color("mauve", 11)),
|
|
||||||
Divider(),
|
|
||||||
Foreach(
|
|
||||||
State.get_session_names,
|
|
||||||
lambda session_name: sidebar_session(session_name),
|
|
||||||
),
|
|
||||||
align_items="stretch",
|
|
||||||
width="100%",
|
|
||||||
),
|
|
||||||
top="auto",
|
|
||||||
right="auto",
|
|
||||||
height="100%",
|
|
||||||
width="20em",
|
|
||||||
padding="2em",
|
|
||||||
background_color=Color("mauve", 2),
|
|
||||||
outline="none",
|
|
||||||
)
|
|
||||||
),
|
|
||||||
direction="left",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def modal(trigger) -> Component:
|
|
||||||
"""A modal to create a new chat."""
|
|
||||||
return Dialog.root(
|
|
||||||
Dialog.trigger(trigger),
|
|
||||||
Dialog.content(
|
|
||||||
Form(
|
|
||||||
HStack(
|
|
||||||
Input(
|
|
||||||
placeholder="Chat name",
|
|
||||||
name="session_name",
|
|
||||||
flex="1",
|
|
||||||
min_width="20ch",
|
|
||||||
),
|
|
||||||
Button("Create chat"),
|
|
||||||
spacing="2",
|
|
||||||
wrap="wrap",
|
|
||||||
width="100%",
|
|
||||||
),
|
|
||||||
on_submit=State.create_session,
|
|
||||||
),
|
|
||||||
background_color=Color("mauve", 1),
|
|
||||||
),
|
|
||||||
open=State.create_session_modal_is_open,
|
|
||||||
on_open_change=State.toggle_create_session_modal,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def navbar():
|
|
||||||
return HStack(
|
|
||||||
Badge(
|
|
||||||
State.current_session_name,
|
|
||||||
Tooltip(
|
|
||||||
Icon("info", size=14),
|
|
||||||
content="The current selected chat.",
|
|
||||||
),
|
|
||||||
size="3",
|
|
||||||
variant="soft",
|
|
||||||
margin_inline_end="auto",
|
|
||||||
),
|
|
||||||
modal(
|
|
||||||
Icon("message-square-plus"),
|
|
||||||
),
|
|
||||||
sidebar(
|
|
||||||
Icon(
|
|
||||||
"messages-square",
|
|
||||||
background_color=Color("mauve", 6),
|
|
||||||
)
|
|
||||||
),
|
|
||||||
justify_content="space-between",
|
|
||||||
align_items="center",
|
|
||||||
padding="12px",
|
|
||||||
border_bottom=f"1px solid {Color('mauve', 3)}",
|
|
||||||
background_color=Color("mauve", 2),
|
|
||||||
)
|
|
||||||
|
|
@ -0,0 +1,137 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
导航栏组件
|
||||||
|
"""
|
||||||
|
|
||||||
|
import reflex
|
||||||
|
|
||||||
|
from ..state import State
|
||||||
|
|
||||||
|
|
||||||
|
def create_session_modal(trigger) -> reflex.Component:
|
||||||
|
"""
|
||||||
|
创建会话模态窗组件
|
||||||
|
"""
|
||||||
|
return reflex.dialog.root(
|
||||||
|
reflex.dialog.trigger(trigger),
|
||||||
|
reflex.dialog.content(
|
||||||
|
reflex.form(
|
||||||
|
reflex.hstack(
|
||||||
|
reflex.input(
|
||||||
|
name="session_name",
|
||||||
|
placeholder="会话名称",
|
||||||
|
flex="auto",
|
||||||
|
min_width="20ch",
|
||||||
|
),
|
||||||
|
reflex.button("创建会话"),
|
||||||
|
spacing="2",
|
||||||
|
wrap="wrap",
|
||||||
|
width="100%",
|
||||||
|
),
|
||||||
|
on_submit=State.create_session,
|
||||||
|
),
|
||||||
|
background_color=reflex.color("mauve", 1),
|
||||||
|
),
|
||||||
|
open=State.create_session_modal_is_open,
|
||||||
|
on_open_change=State.toggle_create_session_modal,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def session_history_item(session_name: str) -> reflex.Component:
|
||||||
|
"""
|
||||||
|
会话历史侧边栏中一次会话组件
|
||||||
|
:param session_name: 会话名称
|
||||||
|
:return: Component
|
||||||
|
"""
|
||||||
|
return reflex.drawer.close(
|
||||||
|
reflex.hstack(
|
||||||
|
reflex.button(
|
||||||
|
session_name,
|
||||||
|
on_click=lambda: State.switch_session(
|
||||||
|
session_name
|
||||||
|
), # 点击按钮将切换会话
|
||||||
|
width="80%",
|
||||||
|
variant="surface",
|
||||||
|
),
|
||||||
|
reflex.button(
|
||||||
|
reflex.icon(
|
||||||
|
tag="trash",
|
||||||
|
on_click=lambda: State.delete_session(
|
||||||
|
session_name
|
||||||
|
), # 点击按钮删除会话
|
||||||
|
stroke_width=1,
|
||||||
|
),
|
||||||
|
width="20%",
|
||||||
|
variant="surface",
|
||||||
|
color_scheme="red",
|
||||||
|
),
|
||||||
|
width="100%",
|
||||||
|
),
|
||||||
|
key=session_name, # 使用会话名称作为唯一标识(会话名称不可重复)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def session_history(trigger) -> reflex.Component:
|
||||||
|
"""
|
||||||
|
会话历史侧边栏组件
|
||||||
|
"""
|
||||||
|
return reflex.drawer.root(
|
||||||
|
reflex.drawer.trigger(trigger),
|
||||||
|
reflex.drawer.overlay(),
|
||||||
|
reflex.drawer.portal(
|
||||||
|
reflex.drawer.content(
|
||||||
|
reflex.vstack(
|
||||||
|
reflex.heading("会话列表", color=reflex.color("mauve", 11)),
|
||||||
|
reflex.divider(),
|
||||||
|
reflex.foreach(
|
||||||
|
State.get_session_names, # 获取所有会话名称
|
||||||
|
lambda session_name: session_history_item(
|
||||||
|
session_name=session_name
|
||||||
|
), # 创建会话组件
|
||||||
|
),
|
||||||
|
align_items="stretch",
|
||||||
|
width="100%",
|
||||||
|
),
|
||||||
|
top="auto",
|
||||||
|
right="auto",
|
||||||
|
height="100%",
|
||||||
|
width="20em",
|
||||||
|
padding="2em",
|
||||||
|
background_color=reflex.color("mauve", 2),
|
||||||
|
outline="none",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
direction="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def navigation_bar() -> reflex.Component:
|
||||||
|
"""
|
||||||
|
导航栏组件
|
||||||
|
"""
|
||||||
|
return reflex.hstack(
|
||||||
|
reflex.badge(
|
||||||
|
State.current_session_name,
|
||||||
|
size="3",
|
||||||
|
variant="soft",
|
||||||
|
margin_inline_end="auto",
|
||||||
|
),
|
||||||
|
create_session_modal(
|
||||||
|
reflex.box(
|
||||||
|
reflex.tooltip(reflex.icon("message-square-plus"), content="创建会话")
|
||||||
|
)
|
||||||
|
),
|
||||||
|
session_history(
|
||||||
|
reflex.box(
|
||||||
|
reflex.tooltip(
|
||||||
|
reflex.icon("messages-square"),
|
||||||
|
content="会话历史",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
justify_content="space-between",
|
||||||
|
align_items="center",
|
||||||
|
padding="12px",
|
||||||
|
border_bottom=f"1px solid {reflex.color('mauve', 3)}",
|
||||||
|
background_color=reflex.color("mauve", 2),
|
||||||
|
)
|
||||||
|
|
@ -2,61 +2,43 @@
|
||||||
"""
|
"""
|
||||||
会话组件
|
会话组件
|
||||||
"""
|
"""
|
||||||
|
import reflex
|
||||||
from reflex import (
|
|
||||||
Component,
|
|
||||||
auto_scroll as AutoScroll,
|
|
||||||
box as Box,
|
|
||||||
button as Button,
|
|
||||||
center as Center,
|
|
||||||
color as Color,
|
|
||||||
foreach as Foreach,
|
|
||||||
form as Form,
|
|
||||||
hstack as HStack,
|
|
||||||
icon as Icon,
|
|
||||||
input as Input,
|
|
||||||
logo as Logo,
|
|
||||||
markdown as Markdown,
|
|
||||||
text as Text,
|
|
||||||
tooltip as Tooltip,
|
|
||||||
vstack as VStack,
|
|
||||||
)
|
|
||||||
from reflex.constants.colors import ColorType
|
from reflex.constants.colors import ColorType
|
||||||
|
|
||||||
from ..state import State, Turn
|
from ..state import State, Turn
|
||||||
|
|
||||||
|
|
||||||
def message(text: str, color: ColorType) -> Component:
|
def message_bubble(message: str, color: ColorType) -> reflex.Component:
|
||||||
"""
|
"""
|
||||||
消息组件
|
对话组件中一个消息气泡组件
|
||||||
:param text: 文本
|
:param message: 消息
|
||||||
:param color: 颜色
|
:param color: 颜色
|
||||||
:return: Component
|
:return: Component
|
||||||
"""
|
"""
|
||||||
return Markdown(
|
return reflex.markdown(
|
||||||
text,
|
message,
|
||||||
background_color=Color(color=color, shade=4),
|
color=reflex.color(color=color, shade=12),
|
||||||
color=Color(color=color, shade=12),
|
background_color=reflex.color(color=color, shade=4),
|
||||||
display="inline-block",
|
display="inline-block",
|
||||||
padding_inline="1em",
|
padding_inline="1em",
|
||||||
border_radius="8px",
|
border_radius="8px",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def turn(turn: Turn) -> Component:
|
def turn(turn: Turn) -> reflex.Component:
|
||||||
"""
|
"""
|
||||||
对话组件
|
会话话组件中一次对话组件
|
||||||
:param turn: 对话
|
:param turn: 对话
|
||||||
:return: Component
|
:return: Component
|
||||||
"""
|
"""
|
||||||
return Box(
|
return reflex.box(
|
||||||
Box(
|
reflex.box(
|
||||||
message(text=turn.input_message, color="mauve"),
|
message_bubble(message=turn.input_message, color="mauve"),
|
||||||
text_align="right",
|
text_align="right",
|
||||||
margin_bottom="8px",
|
margin_bottom="8px",
|
||||||
),
|
),
|
||||||
Box(
|
reflex.box(
|
||||||
message(text=turn.output_message, color="accent"),
|
message_bubble(message=turn.output_message, color="accent"),
|
||||||
text_align="left",
|
text_align="left",
|
||||||
margin_bottom="8px",
|
margin_bottom="8px",
|
||||||
),
|
),
|
||||||
|
|
@ -65,58 +47,51 @@ def turn(turn: Turn) -> Component:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def session() -> Component:
|
def session_area() -> reflex.Component:
|
||||||
"""
|
"""
|
||||||
会话组件
|
会话区域组件
|
||||||
:return: Component
|
:return: Component
|
||||||
"""
|
"""
|
||||||
return AutoScroll(
|
return reflex.auto_scroll(
|
||||||
Foreach(State.get_turns, turn),
|
reflex.foreach(State.get_current_session_turns, turn),
|
||||||
flex="1",
|
flex="1",
|
||||||
padding="8px",
|
padding="8px",
|
||||||
overflow_y="auto",
|
overflow_y="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def action_bar() -> Component:
|
def input_bar() -> reflex.Component:
|
||||||
"""
|
"""
|
||||||
输入发送栏组件
|
输入栏组件
|
||||||
"""
|
"""
|
||||||
return Center(
|
return reflex.center(
|
||||||
VStack(
|
reflex.vstack(
|
||||||
Form(
|
reflex.form(
|
||||||
HStack(
|
reflex.hstack(
|
||||||
Input(
|
reflex.input(
|
||||||
Input.slot(
|
name="input_message",
|
||||||
Tooltip(
|
placeholder="请输入...",
|
||||||
Icon("info", size=18),
|
flex="auto",
|
||||||
content="Enter a question to get a response.",
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
placeholder="Type something...",
|
reflex.button(
|
||||||
id="input_message",
|
"发送",
|
||||||
flex="1",
|
|
||||||
),
|
|
||||||
Button(
|
|
||||||
"Send",
|
|
||||||
loading=State.current_session_processing,
|
|
||||||
disabled=State.current_session_processing,
|
|
||||||
type="submit",
|
type="submit",
|
||||||
|
loading=State.get_current_session_status, # 正在处理中时按钮显示为 loading
|
||||||
|
disabled=State.get_current_session_status, # 正在处理中时按钮禁用
|
||||||
),
|
),
|
||||||
max_width="50em",
|
max_width="50em",
|
||||||
margin="0 auto",
|
margin="0 auto",
|
||||||
align_items="center",
|
align_items="center",
|
||||||
),
|
),
|
||||||
reset_on_submit=True, # 提交后清空输入框
|
|
||||||
on_submit=State.adapt_input_message,
|
on_submit=State.adapt_input_message,
|
||||||
|
reset_on_submit=True, # 提交后清空输入框
|
||||||
),
|
),
|
||||||
Text(
|
reflex.text(
|
||||||
"ReflexGPT may return factually incorrect or misleading responses. Use discretion.",
|
"抹茶兔兔工作室",
|
||||||
text_align="center",
|
text_align="center",
|
||||||
font_size=".75em",
|
font_size=".75em",
|
||||||
color=Color("mauve", 10),
|
color=reflex.color("mauve", 10),
|
||||||
),
|
),
|
||||||
Logo(margin_block="-1em"),
|
|
||||||
width="100%",
|
width="100%",
|
||||||
padding_x="16px",
|
padding_x="16px",
|
||||||
align="stretch",
|
align="stretch",
|
||||||
|
|
@ -127,8 +102,8 @@ def action_bar() -> Component:
|
||||||
padding_y="16px",
|
padding_y="16px",
|
||||||
backdrop_filter="auto",
|
backdrop_filter="auto",
|
||||||
backdrop_blur="lg",
|
backdrop_blur="lg",
|
||||||
border_top=f"1px solid {Color('mauve', 3)}",
|
border_top=f"1px solid {reflex.color('mauve', 3)}",
|
||||||
background_color=Color("mauve", 2),
|
background_color=reflex.color("mauve", 2),
|
||||||
align="stretch",
|
align="stretch",
|
||||||
width="100%",
|
width="100%",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
会话页面
|
||||||
|
"""
|
||||||
|
|
||||||
|
import reflex
|
||||||
|
|
||||||
|
from ..components.navigation_bar import navigation_bar
|
||||||
|
from ..components.session import session_area, input_bar
|
||||||
|
|
||||||
|
|
||||||
|
def index() -> reflex.Component:
|
||||||
|
"""根页面"""
|
||||||
|
return reflex.vstack(
|
||||||
|
navigation_bar(), # 导航栏
|
||||||
|
session_area(), # 会话区域
|
||||||
|
input_bar(), # 输入栏
|
||||||
|
color=reflex.color(color="mauve", shade=12),
|
||||||
|
background_color=reflex.color(color="mauve", shade=1),
|
||||||
|
height="100dvh",
|
||||||
|
align_items="stretch",
|
||||||
|
spacing="0",
|
||||||
|
)
|
||||||
|
|
@ -3,7 +3,8 @@
|
||||||
应用状态管理模块
|
应用状态管理模块
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import reflex
|
import reflex
|
||||||
|
|
@ -13,18 +14,22 @@ from sys import path
|
||||||
path.append((Path(__file__).resolve().parent.parent.parent.as_posix()))
|
path.append((Path(__file__).resolve().parent.parent.parent.as_posix()))
|
||||||
from utils.agent import Agent
|
from utils.agent import Agent
|
||||||
|
|
||||||
|
# 所有会话绑定的智能体
|
||||||
|
agents: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
def retrieve_agent(state) -> Agent:
|
def retrieve_agent(state) -> Agent:
|
||||||
"""
|
"""
|
||||||
获取当前会话绑定的智能体
|
获取当前会话绑定的智能体
|
||||||
:return: 当前会话绑定的智能体
|
:return: 当前会话绑定的智能体
|
||||||
"""
|
"""
|
||||||
|
current_session_name = state.current_session_name
|
||||||
if state.current_session_name not in state._agents:
|
if current_session_name not in agents:
|
||||||
state._agents[state.current_session_name] = Agent(
|
agents[current_session_name] = Agent(
|
||||||
instructions="You are a friendly chatbot named Reflex. Respond in markdown."
|
session_id=state.sessions[current_session_name].id,
|
||||||
|
instructions="You are a friendly chatbot named Reflex. Respond in markdown.",
|
||||||
)
|
)
|
||||||
return state._agents[state.current_session_name]
|
return agents[current_session_name]
|
||||||
|
|
||||||
|
|
||||||
class Turn(BaseModel):
|
class Turn(BaseModel):
|
||||||
|
|
@ -35,24 +40,22 @@ class Turn(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Session(BaseModel):
|
class Session(BaseModel):
|
||||||
"""会话数据模型,包含会绑定智能体和会话对话列表"""
|
"""会话数据模型,包含会话唯一标识和会话对话列表"""
|
||||||
|
|
||||||
|
id: str = Field(default_factory=lambda: uuid4().hex, description="会话唯一标识")
|
||||||
|
is_processing: bool = Field(default=False, description="会话是否正在处理中")
|
||||||
turns: List[Turn] = Field(default_factory=list, description="会话对话列表")
|
turns: List[Turn] = Field(default_factory=list, description="会话对话列表")
|
||||||
|
|
||||||
|
|
||||||
|
# Reflex.State 统一管理应用数据与功能状态,作为前后端交互枢纽,借助响应式特性实现页面自动更新
|
||||||
class State(reflex.State):
|
class State(reflex.State):
|
||||||
"""应用状态"""
|
"""应用状态"""
|
||||||
|
|
||||||
# 当前会话名称
|
# 当前会话名称
|
||||||
current_session_name: str = "Intro"
|
current_session_name: str = "NewSession"
|
||||||
# 当前会话正在处理
|
|
||||||
current_session_processing: bool = False
|
|
||||||
|
|
||||||
# 所有会话
|
# 所有会话
|
||||||
sessions: Dict[str, Session] = {"Intro": Session()}
|
sessions: Dict[str, Session] = {current_session_name: Session()}
|
||||||
|
|
||||||
# 所有会话绑定的智能体(私有变量)
|
|
||||||
_agents: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
# 新建会话模态窗是否打开
|
# 新建会话模态窗是否打开
|
||||||
create_session_modal_is_open: bool = False
|
create_session_modal_is_open: bool = False
|
||||||
|
|
@ -66,14 +69,60 @@ class State(reflex.State):
|
||||||
return list(self.sessions)
|
return list(self.sessions)
|
||||||
|
|
||||||
@reflex.event
|
@reflex.event
|
||||||
def set_current_session(self, session_name: str) -> None:
|
def switch_session(self, session_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
将所选会话设置为当前会话
|
切换会话
|
||||||
:param session_name: 会话名称
|
:param session_name: 会话名称
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
self.current_session_name = session_name
|
self.current_session_name = session_name
|
||||||
|
|
||||||
|
@reflex.var
|
||||||
|
def get_current_session_status(self) -> bool:
|
||||||
|
"""
|
||||||
|
获取当前会话状态
|
||||||
|
:return: 当前会话状态,其中 True 表示正在处理中,False 表示处理完成
|
||||||
|
"""
|
||||||
|
if self.current_session_name not in self.sessions:
|
||||||
|
return False
|
||||||
|
return self.sessions[self.current_session_name].is_processing
|
||||||
|
|
||||||
|
@reflex.var
|
||||||
|
def get_current_session_turns(self) -> List[Turn]:
|
||||||
|
"""
|
||||||
|
获取当前会话对话列表
|
||||||
|
:return: 当前会话对话列表
|
||||||
|
"""
|
||||||
|
if self.current_session_name not in self.sessions:
|
||||||
|
return []
|
||||||
|
return self.sessions[self.current_session_name].turns
|
||||||
|
|
||||||
|
@reflex.event
|
||||||
|
def create_session(self, form_data: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
创建会话
|
||||||
|
:param form_data: 创建会话表单数据
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
session_name = form_data["session_name"].strip()
|
||||||
|
|
||||||
|
# 若创建会话名称为空则默认使用"NewSession"作为会话名称
|
||||||
|
if not session_name:
|
||||||
|
session_name = "NewSession"
|
||||||
|
|
||||||
|
original_session_name = session_name
|
||||||
|
counter = 1
|
||||||
|
# 若会话名称重复则在会话名称后面添加序号至不重复
|
||||||
|
while session_name in self.sessions:
|
||||||
|
session_name = f"{original_session_name}({counter})"
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
self.current_session_name = session_name
|
||||||
|
self.sessions[session_name] = Session()
|
||||||
|
|
||||||
|
# 关闭新建会话模态窗
|
||||||
|
self.create_session_modal_is_open = False
|
||||||
|
|
||||||
@reflex.event
|
@reflex.event
|
||||||
def delete_session(self, session_name: str) -> None:
|
def delete_session(self, session_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -88,59 +137,23 @@ class State(reflex.State):
|
||||||
|
|
||||||
# 若删除会话后所有会话为空则默认创建空白会话
|
# 若删除会话后所有会话为空则默认创建空白会话
|
||||||
if not self.sessions:
|
if not self.sessions:
|
||||||
self.sessions["Intro"] = Session()
|
self.sessions["NewSession"] = Session()
|
||||||
|
|
||||||
# 若删除会话后当前会话名称不存在则默认使用第一个会话名称
|
# 删除会话后,若当前会话名称不存在则默认使用第一个会话名称
|
||||||
if self.current_session_name not in self.sessions:
|
if self.current_session_name not in self.sessions:
|
||||||
self.current_session_name = next(iter(self.sessions))
|
self.current_session_name = next(iter(self.sessions))
|
||||||
|
|
||||||
@reflex.var
|
|
||||||
def get_turns(self) -> List[Turn]:
|
|
||||||
"""
|
|
||||||
获取当前会话所有对话
|
|
||||||
:return: 当前会话所有对话
|
|
||||||
"""
|
|
||||||
if self.current_session_name not in self.sessions:
|
|
||||||
return (
|
|
||||||
[]
|
|
||||||
) # 因 reflex 实时响应状态变更(前端立刻自动刷新),故需要考虑当前会话不存在的情况
|
|
||||||
return self.sessions[self.current_session_name].turns
|
|
||||||
|
|
||||||
@reflex.event
|
@reflex.event
|
||||||
def toggle_create_session_modal(self, is_open: bool) -> None:
|
def toggle_create_session_modal(self, is_open: bool) -> None:
|
||||||
"""
|
"""
|
||||||
切换新建会话模态窗开关状态
|
打开 / 关闭新建会话模态窗
|
||||||
:param is_open: 打开或关闭新建会话模态窗
|
:param is_open: 打开或关闭新建会话模态窗
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
self.create_session_modal_is_open = is_open
|
self.create_session_modal_is_open = is_open
|
||||||
|
|
||||||
@reflex.event
|
@reflex.event
|
||||||
def create_session(self, form_data: Dict[str, Any]) -> None:
|
async def adapt_input_message(self, form_data: dict[str, Any]) -> AsyncGenerator:
|
||||||
"""
|
|
||||||
创建会话
|
|
||||||
:param form_data: 创建会话表单数据
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
session_name = form_data["session_name"].strip()
|
|
||||||
|
|
||||||
# 若创建会话名称为空字则默认使用"Intro"作为会话名称
|
|
||||||
if not session_name:
|
|
||||||
session_name = "Intro"
|
|
||||||
|
|
||||||
original_session_name = session_name
|
|
||||||
counter = 1
|
|
||||||
# 若会话名称重复则在会话名称后面添加序号至不重复
|
|
||||||
while session_name in self.sessions:
|
|
||||||
session_name = f"{original_session_name}_{counter}"
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
self.current_session_name = session_name
|
|
||||||
self.sessions[session_name] = Session()
|
|
||||||
self.create_session_modal_is_open = False # 关闭新建会话模态窗
|
|
||||||
|
|
||||||
@reflex.event
|
|
||||||
async def adapt_input_message(self, form_data: dict[str, str]) -> AsyncGenerator:
|
|
||||||
"""
|
"""
|
||||||
适配用户输入消息
|
适配用户输入消息
|
||||||
:param form_data: 对话表单数据
|
:param form_data: 对话表单数据
|
||||||
|
|
@ -150,10 +163,10 @@ class State(reflex.State):
|
||||||
if not input_message:
|
if not input_message:
|
||||||
return
|
return
|
||||||
|
|
||||||
async for value in self._process_input_message(input_message):
|
async for value in self.process_input_message(input_message=input_message):
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
async def _process_input_message(self, input_message: str) -> AsyncGenerator:
|
async def process_input_message(self, input_message: str) -> AsyncGenerator:
|
||||||
"""
|
"""
|
||||||
处理用户输入消息
|
处理用户输入消息
|
||||||
:param input_message: 用户输入的消息
|
:param input_message: 用户输入的消息
|
||||||
|
|
@ -169,7 +182,7 @@ class State(reflex.State):
|
||||||
)
|
)
|
||||||
|
|
||||||
# 当前会话正在处理
|
# 当前会话正在处理
|
||||||
self.current_session_processing = True
|
current_session.is_processing = True
|
||||||
yield
|
yield
|
||||||
|
|
||||||
agent = retrieve_agent(self)
|
agent = retrieve_agent(self)
|
||||||
|
|
@ -180,4 +193,4 @@ class State(reflex.State):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# 当前会话处理完成
|
# 当前会话处理完成
|
||||||
self.current_session_processing = False
|
current_session.is_processing = False
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,15 @@
|
||||||
import reflex as rx
|
import reflex
|
||||||
|
from reflex.plugins import RadixThemesPlugin
|
||||||
|
from reflex_base.plugins.sitemap import SitemapPlugin
|
||||||
|
|
||||||
config = rx.Config(
|
config = reflex.Config(
|
||||||
app_name="application",
|
app_name="application",
|
||||||
|
disable_plugins=[SitemapPlugin],
|
||||||
|
plugins=[
|
||||||
|
RadixThemesPlugin(
|
||||||
|
theme=reflex.theme(
|
||||||
|
appearance="dark", accent_color="purple" # 暗黑模式 # 主题色
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue