diff --git a/utils/agent.py b/utils/agent.py index 4ebafdd..9270f4c 100644 --- a/utils/agent.py +++ b/utils/agent.py @@ -118,19 +118,21 @@ class Agent: def __init__( self, + session_id: str, instructions: str, output_type: OutputSpec = str, capabilities: Optional[List[AgentCapability]] = None, ): """ 初始化智能体 + :param session_id: 会话唯一标识 :param instructions: 指令 :param skills: 智能体技能列表,默认为不使用技能 :param output_type: 输出类型 :return: 智能体实例 """ - # 生成会话唯一标识 - self.session_id = uuid4().hex.lower() + # 会话唯一标识 + self.session_id = session_id # 实例智能体的记忆体 self.agent_memory = AgentMemory() @@ -150,6 +152,8 @@ class Agent: retries=1, ) + self.agent.to_web() + async def output_message_streamed( self, user_prompt: str | List[str] ) -> AsyncGenerator[str, None]: @@ -167,8 +171,9 @@ class Agent: async with self.agent.run_stream( user_prompt=user_prompt, message_history=message_history ) as result: - async for chunk in result.stream_text(): - yield chunk + async for chunk in result.stream_text(delta=True): # 只返回新增内容 + if chunk: + yield chunk self.agent_memory.create_new_messages( session_id=self.session_id, diff --git a/utils/agent_memory.db b/utils/agent_memory.db index 54008da..6ad6cd4 100644 Binary files a/utils/agent_memory.db and b/utils/agent_memory.db differ diff --git a/产品需求文档AI生成/application/application.py b/产品需求文档AI生成/application/application.py index 5325487..b9cd3ce 100644 --- a/产品需求文档AI生成/application/application.py +++ b/产品需求文档AI生成/application/application.py @@ -1,28 +1,10 @@ -"""The main Chat app.""" +# -*- coding: utf-8 -*- +""" +主模块 +""" -from reflex import App, Component, vstack as VStack, color as Color, theme as Theme -from .components import session, navbar +import reflex +from .pages.index import index - -def index() -> Component: - """The main app.""" - return VStack( - navbar.navbar(), - session.session(), - session.action_bar(), - background_color=Color(color="mauve", shade=1), - color=Color(color="mauve", shade=12), - height="100dvh", - align_items="stretch", - spacing="0", - ) - - -# Add state and page to the app. -app = App( - theme=Theme( - appearance="dark", - accent_color="purple", - ), -) +app = reflex.App() app.add_page(index) diff --git a/产品需求文档AI生成/application/components/navbar.py b/产品需求文档AI生成/application/components/navbar.py deleted file mode 100644 index a3f78d0..0000000 --- a/产品需求文档AI生成/application/components/navbar.py +++ /dev/null @@ -1,142 +0,0 @@ -# -*- coding: utf-8 -*- -""" -导航栏组件 -""" - -from reflex import ( - Component, - badge as Badge, - button as Button, - color as Color, - dialog as Dialog, - divider as Divider, - drawer as Drawer, - foreach as Foreach, - form as Form, - heading as Heading, - hstack as HStack, - icon as Icon, - input as Input, - tooltip as Tooltip, - vstack as VStack, -) - -from ..state import State - - -def sidebar_session(session_name: str) -> Component: - """ - 侧边栏会话组件 - :param session_name: 会话名称 - :return: Component - """ - return Drawer.close( - HStack( - Button( - session_name, - on_click=lambda: State.set_current_session(session_name), - width="80%", - variant="surface", - ), - Button( - Icon( - tag="trash", - on_click=lambda: State.delete_session(session_name), - stroke_width=1, - ), - width="20%", - variant="surface", - color_scheme="red", - ), - width="100%", - ), - key=session_name, - ) - - -def sidebar(trigger) -> Component: - """ - 侧边栏组件 - """ - return Drawer.root( - Drawer.trigger(trigger), - Drawer.overlay(), - Drawer.portal( - Drawer.content( - VStack( - Heading("Chats", color=Color("mauve", 11)), - Divider(), - Foreach( - State.get_session_names, - lambda session_name: sidebar_session(session_name), - ), - align_items="stretch", - width="100%", - ), - top="auto", - right="auto", - height="100%", - width="20em", - padding="2em", - background_color=Color("mauve", 2), - outline="none", - ) - ), - direction="left", - ) - - -def modal(trigger) -> Component: - """A modal to create a new chat.""" - return Dialog.root( - Dialog.trigger(trigger), - Dialog.content( - Form( - HStack( - Input( - placeholder="Chat name", - name="session_name", - flex="1", - min_width="20ch", - ), - Button("Create chat"), - spacing="2", - wrap="wrap", - width="100%", - ), - on_submit=State.create_session, - ), - background_color=Color("mauve", 1), - ), - open=State.create_session_modal_is_open, - on_open_change=State.toggle_create_session_modal, - ) - - -def navbar(): - return HStack( - Badge( - State.current_session_name, - Tooltip( - Icon("info", size=14), - content="The current selected chat.", - ), - size="3", - variant="soft", - margin_inline_end="auto", - ), - modal( - Icon("message-square-plus"), - ), - sidebar( - Icon( - "messages-square", - background_color=Color("mauve", 6), - ) - ), - justify_content="space-between", - align_items="center", - padding="12px", - border_bottom=f"1px solid {Color('mauve', 3)}", - background_color=Color("mauve", 2), - ) diff --git a/产品需求文档AI生成/application/components/navigation_bar.py b/产品需求文档AI生成/application/components/navigation_bar.py new file mode 100644 index 0000000..9867829 --- /dev/null +++ b/产品需求文档AI生成/application/components/navigation_bar.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +""" +导航栏组件 +""" + +import reflex + +from ..state import State + + +def create_session_modal(trigger) -> reflex.Component: + """ + 创建会话模态窗组件 + """ + return reflex.dialog.root( + reflex.dialog.trigger(trigger), + reflex.dialog.content( + reflex.form( + reflex.hstack( + reflex.input( + name="session_name", + placeholder="会话名称", + flex="auto", + min_width="20ch", + ), + reflex.button("创建会话"), + spacing="2", + wrap="wrap", + width="100%", + ), + on_submit=State.create_session, + ), + background_color=reflex.color("mauve", 1), + ), + open=State.create_session_modal_is_open, + on_open_change=State.toggle_create_session_modal, + ) + + +def session_history_item(session_name: str) -> reflex.Component: + """ + 会话历史侧边栏中一次会话组件 + :param session_name: 会话名称 + :return: Component + """ + return reflex.drawer.close( + reflex.hstack( + reflex.button( + session_name, + on_click=lambda: State.switch_session( + session_name + ), # 点击按钮将切换会话 + width="80%", + variant="surface", + ), + reflex.button( + reflex.icon( + tag="trash", + on_click=lambda: State.delete_session( + session_name + ), # 点击按钮删除会话 + stroke_width=1, + ), + width="20%", + variant="surface", + color_scheme="red", + ), + width="100%", + ), + key=session_name, # 使用会话名称作为唯一标识(会话名称不可重复) + ) + + +def session_history(trigger) -> reflex.Component: + """ + 会话历史侧边栏组件 + """ + return reflex.drawer.root( + reflex.drawer.trigger(trigger), + reflex.drawer.overlay(), + reflex.drawer.portal( + reflex.drawer.content( + reflex.vstack( + reflex.heading("会话列表", color=reflex.color("mauve", 11)), + reflex.divider(), + reflex.foreach( + State.get_session_names, # 获取所有会话名称 + lambda session_name: session_history_item( + session_name=session_name + ), # 创建会话组件 + ), + align_items="stretch", + width="100%", + ), + top="auto", + right="auto", + height="100%", + width="20em", + padding="2em", + background_color=reflex.color("mauve", 2), + outline="none", + ) + ), + direction="left", + ) + + +def navigation_bar() -> reflex.Component: + """ + 导航栏组件 + """ + return reflex.hstack( + reflex.badge( + State.current_session_name, + size="3", + variant="soft", + margin_inline_end="auto", + ), + create_session_modal( + reflex.box( + reflex.tooltip(reflex.icon("message-square-plus"), content="创建会话") + ) + ), + session_history( + reflex.box( + reflex.tooltip( + reflex.icon("messages-square"), + content="会话历史", + ) + ) + ), + justify_content="space-between", + align_items="center", + padding="12px", + border_bottom=f"1px solid {reflex.color('mauve', 3)}", + background_color=reflex.color("mauve", 2), + ) diff --git a/产品需求文档AI生成/application/components/session.py b/产品需求文档AI生成/application/components/session.py index b41ca23..de8d5d6 100644 --- a/产品需求文档AI生成/application/components/session.py +++ b/产品需求文档AI生成/application/components/session.py @@ -2,61 +2,43 @@ """ 会话组件 """ - -from reflex import ( - Component, - auto_scroll as AutoScroll, - box as Box, - button as Button, - center as Center, - color as Color, - foreach as Foreach, - form as Form, - hstack as HStack, - icon as Icon, - input as Input, - logo as Logo, - markdown as Markdown, - text as Text, - tooltip as Tooltip, - vstack as VStack, -) +import reflex from reflex.constants.colors import ColorType from ..state import State, Turn -def message(text: str, color: ColorType) -> Component: +def message_bubble(message: str, color: ColorType) -> reflex.Component: """ - 消息组件 - :param text: 文本 + 对话组件中一个消息气泡组件 + :param message: 消息 :param color: 颜色 :return: Component """ - return Markdown( - text, - background_color=Color(color=color, shade=4), - color=Color(color=color, shade=12), + return reflex.markdown( + message, + color=reflex.color(color=color, shade=12), + background_color=reflex.color(color=color, shade=4), display="inline-block", padding_inline="1em", border_radius="8px", ) -def turn(turn: Turn) -> Component: +def turn(turn: Turn) -> reflex.Component: """ - 对话组件 + 会话话组件中一次对话组件 :param turn: 对话 :return: Component """ - return Box( - Box( - message(text=turn.input_message, color="mauve"), + return reflex.box( + reflex.box( + message_bubble(message=turn.input_message, color="mauve"), text_align="right", margin_bottom="8px", ), - Box( - message(text=turn.output_message, color="accent"), + reflex.box( + message_bubble(message=turn.output_message, color="accent"), text_align="left", margin_bottom="8px", ), @@ -65,58 +47,51 @@ def turn(turn: Turn) -> Component: ) -def session() -> Component: +def session_area() -> reflex.Component: """ - 会话组件 + 会话区域组件 :return: Component """ - return AutoScroll( - Foreach(State.get_turns, turn), + return reflex.auto_scroll( + reflex.foreach(State.get_current_session_turns, turn), flex="1", padding="8px", overflow_y="auto", ) -def action_bar() -> Component: +def input_bar() -> reflex.Component: """ - 输入发送栏组件 + 输入栏组件 """ - return Center( - VStack( - Form( - HStack( - Input( - Input.slot( - Tooltip( - Icon("info", size=18), - content="Enter a question to get a response.", - ) - ), - placeholder="Type something...", - id="input_message", - flex="1", + return reflex.center( + reflex.vstack( + reflex.form( + reflex.hstack( + reflex.input( + name="input_message", + placeholder="请输入...", + flex="auto", ), - Button( - "Send", - loading=State.current_session_processing, - disabled=State.current_session_processing, + reflex.button( + "发送", type="submit", + loading=State.get_current_session_status, # 正在处理中时按钮显示为 loading + disabled=State.get_current_session_status, # 正在处理中时按钮禁用 ), max_width="50em", margin="0 auto", align_items="center", ), - reset_on_submit=True, # 提交后清空输入框 on_submit=State.adapt_input_message, + reset_on_submit=True, # 提交后清空输入框 ), - Text( - "ReflexGPT may return factually incorrect or misleading responses. Use discretion.", + reflex.text( + "抹茶兔兔工作室", text_align="center", font_size=".75em", - color=Color("mauve", 10), + color=reflex.color("mauve", 10), ), - Logo(margin_block="-1em"), width="100%", padding_x="16px", align="stretch", @@ -127,8 +102,8 @@ def action_bar() -> Component: padding_y="16px", backdrop_filter="auto", backdrop_blur="lg", - border_top=f"1px solid {Color('mauve', 3)}", - background_color=Color("mauve", 2), + border_top=f"1px solid {reflex.color('mauve', 3)}", + background_color=reflex.color("mauve", 2), align="stretch", width="100%", ) diff --git a/产品需求文档AI生成/application/pages/__init__.py b/产品需求文档AI生成/application/pages/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/产品需求文档AI生成/application/pages/index.py b/产品需求文档AI生成/application/pages/index.py new file mode 100644 index 0000000..482d95a --- /dev/null +++ b/产品需求文档AI生成/application/pages/index.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +""" +会话页面 +""" + +import reflex + +from ..components.navigation_bar import navigation_bar +from ..components.session import session_area, input_bar + + +def index() -> reflex.Component: + """根页面""" + return reflex.vstack( + navigation_bar(), # 导航栏 + session_area(), # 会话区域 + input_bar(), # 输入栏 + color=reflex.color(color="mauve", shade=12), + background_color=reflex.color(color="mauve", shade=1), + height="100dvh", + align_items="stretch", + spacing="0", + ) diff --git a/产品需求文档AI生成/application/state.py b/产品需求文档AI生成/application/state.py index cd3c420..82570fa 100644 --- a/产品需求文档AI生成/application/state.py +++ b/产品需求文档AI生成/application/state.py @@ -3,7 +3,8 @@ 应用状态管理模块 """ -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List +from uuid import uuid4 from pydantic import BaseModel, Field import reflex @@ -13,18 +14,22 @@ from sys import path path.append((Path(__file__).resolve().parent.parent.parent.as_posix())) from utils.agent import Agent +# 所有会话绑定的智能体 +agents: Dict[str, Any] = {} + def retrieve_agent(state) -> Agent: """ 获取当前会话绑定的智能体 :return: 当前会话绑定的智能体 """ - - if state.current_session_name not in state._agents: - state._agents[state.current_session_name] = Agent( - instructions="You are a friendly chatbot named Reflex. Respond in markdown." + current_session_name = state.current_session_name + if current_session_name not in agents: + agents[current_session_name] = Agent( + session_id=state.sessions[current_session_name].id, + instructions="You are a friendly chatbot named Reflex. Respond in markdown.", ) - return state._agents[state.current_session_name] + return agents[current_session_name] class Turn(BaseModel): @@ -35,24 +40,22 @@ class Turn(BaseModel): class Session(BaseModel): - """会话数据模型,包含会绑定智能体和会话对话列表""" + """会话数据模型,包含会话唯一标识和会话对话列表""" + id: str = Field(default_factory=lambda: uuid4().hex, description="会话唯一标识") + is_processing: bool = Field(default=False, description="会话是否正在处理中") turns: List[Turn] = Field(default_factory=list, description="会话对话列表") +# Reflex.State 统一管理应用数据与功能状态,作为前后端交互枢纽,借助响应式特性实现页面自动更新 class State(reflex.State): """应用状态""" # 当前会话名称 - current_session_name: str = "Intro" - # 当前会话正在处理 - current_session_processing: bool = False + current_session_name: str = "NewSession" # 所有会话 - sessions: Dict[str, Session] = {"Intro": Session()} - - # 所有会话绑定的智能体(私有变量) - _agents: Dict[str, Any] = {} + sessions: Dict[str, Session] = {current_session_name: Session()} # 新建会话模态窗是否打开 create_session_modal_is_open: bool = False @@ -66,14 +69,60 @@ class State(reflex.State): return list(self.sessions) @reflex.event - def set_current_session(self, session_name: str) -> None: + def switch_session(self, session_name: str) -> None: """ - 将所选会话设置为当前会话 + 切换会话 :param session_name: 会话名称 :return: None """ self.current_session_name = session_name + @reflex.var + def get_current_session_status(self) -> bool: + """ + 获取当前会话状态 + :return: 当前会话状态,其中 True 表示正在处理中,False 表示处理完成 + """ + if self.current_session_name not in self.sessions: + return False + return self.sessions[self.current_session_name].is_processing + + @reflex.var + def get_current_session_turns(self) -> List[Turn]: + """ + 获取当前会话对话列表 + :return: 当前会话对话列表 + """ + if self.current_session_name not in self.sessions: + return [] + return self.sessions[self.current_session_name].turns + + @reflex.event + def create_session(self, form_data: Dict[str, Any]) -> None: + """ + 创建会话 + :param form_data: 创建会话表单数据 + :return: None + """ + session_name = form_data["session_name"].strip() + + # 若创建会话名称为空则默认使用"NewSession"作为会话名称 + if not session_name: + session_name = "NewSession" + + original_session_name = session_name + counter = 1 + # 若会话名称重复则在会话名称后面添加序号至不重复 + while session_name in self.sessions: + session_name = f"{original_session_name}({counter})" + counter += 1 + + self.current_session_name = session_name + self.sessions[session_name] = Session() + + # 关闭新建会话模态窗 + self.create_session_modal_is_open = False + @reflex.event def delete_session(self, session_name: str) -> None: """ @@ -88,59 +137,23 @@ class State(reflex.State): # 若删除会话后所有会话为空则默认创建空白会话 if not self.sessions: - self.sessions["Intro"] = Session() + self.sessions["NewSession"] = Session() - # 若删除会话后当前会话名称不存在则默认使用第一个会话名称 + # 删除会话后,若当前会话名称不存在则默认使用第一个会话名称 if self.current_session_name not in self.sessions: self.current_session_name = next(iter(self.sessions)) - @reflex.var - def get_turns(self) -> List[Turn]: - """ - 获取当前会话所有对话 - :return: 当前会话所有对话 - """ - if self.current_session_name not in self.sessions: - return ( - [] - ) # 因 reflex 实时响应状态变更(前端立刻自动刷新),故需要考虑当前会话不存在的情况 - return self.sessions[self.current_session_name].turns - @reflex.event def toggle_create_session_modal(self, is_open: bool) -> None: """ - 切换新建会话模态窗开关状态 + 打开 / 关闭新建会话模态窗 :param is_open: 打开或关闭新建会话模态窗 :return: None """ self.create_session_modal_is_open = is_open @reflex.event - def create_session(self, form_data: Dict[str, Any]) -> None: - """ - 创建会话 - :param form_data: 创建会话表单数据 - :return: None - """ - session_name = form_data["session_name"].strip() - - # 若创建会话名称为空字则默认使用"Intro"作为会话名称 - if not session_name: - session_name = "Intro" - - original_session_name = session_name - counter = 1 - # 若会话名称重复则在会话名称后面添加序号至不重复 - while session_name in self.sessions: - session_name = f"{original_session_name}_{counter}" - counter += 1 - - self.current_session_name = session_name - self.sessions[session_name] = Session() - self.create_session_modal_is_open = False # 关闭新建会话模态窗 - - @reflex.event - async def adapt_input_message(self, form_data: dict[str, str]) -> AsyncGenerator: + async def adapt_input_message(self, form_data: dict[str, Any]) -> AsyncGenerator: """ 适配用户输入消息 :param form_data: 对话表单数据 @@ -150,10 +163,10 @@ class State(reflex.State): if not input_message: return - async for value in self._process_input_message(input_message): + async for value in self.process_input_message(input_message=input_message): yield value - async def _process_input_message(self, input_message: str) -> AsyncGenerator: + async def process_input_message(self, input_message: str) -> AsyncGenerator: """ 处理用户输入消息 :param input_message: 用户输入的消息 @@ -169,7 +182,7 @@ class State(reflex.State): ) # 当前会话正在处理 - self.current_session_processing = True + current_session.is_processing = True yield agent = retrieve_agent(self) @@ -180,4 +193,4 @@ class State(reflex.State): yield # 当前会话处理完成 - self.current_session_processing = False + current_session.is_processing = False diff --git a/产品需求文档AI生成/rxconfig.py b/产品需求文档AI生成/rxconfig.py index 64e4ac0..7595216 100644 --- a/产品需求文档AI生成/rxconfig.py +++ b/产品需求文档AI生成/rxconfig.py @@ -1,5 +1,15 @@ -import reflex as rx +import reflex +from reflex.plugins import RadixThemesPlugin +from reflex_base.plugins.sitemap import SitemapPlugin -config = rx.Config( +config = reflex.Config( app_name="application", + disable_plugins=[SitemapPlugin], + plugins=[ + RadixThemesPlugin( + theme=reflex.theme( + appearance="dark", accent_color="purple" # 暗黑模式 # 主题色 + ) + ) + ], )