# -*- coding: utf-8 -*- """ 应用状态管理模块 """ from typing import Any, AsyncGenerator, Dict, List from uuid import uuid4 from pydantic import BaseModel, Field import reflex from pathlib import Path from sys import path path.append((Path(__file__).resolve().parent.parent.parent.as_posix())) from utils.agent import Agent # 所有会话绑定的智能体 agents: Dict[str, Any] = {} def retrieve_agent(state) -> Agent: """ 获取当前会话绑定的智能体 :return: 当前会话绑定的智能体 """ current_session_name = state.current_session_name if current_session_name not in agents: agents[current_session_name] = Agent( session_id=state.sessions[current_session_name].id, instructions="You are a friendly chatbot named Reflex. Respond in markdown.", ) return agents[current_session_name] class Turn(BaseModel): """对话数据模型,包含用户输入消息和智能体输出消息""" input_message: str = Field(..., description="用户输入的消息") output_message: str = Field(default="", description="智能体输出的消息") class Session(BaseModel): """会话数据模型,包含会话唯一标识和会话对话列表""" id: str = Field(default_factory=lambda: uuid4().hex, description="会话唯一标识") is_processing: bool = Field(default=False, description="会话是否正在处理中") turns: List[Turn] = Field(default_factory=list, description="会话对话列表") # Reflex.State 统一管理应用数据与功能状态,作为前后端交互枢纽,借助响应式特性实现页面自动更新 class State(reflex.State): """应用状态""" # 当前会话名称 current_session_name: str = "NewSession" # 所有会话 sessions: Dict[str, Session] = {current_session_name: Session()} # 新建会话模态窗是否打开 create_session_modal_is_open: bool = False @reflex.var def get_session_names(self) -> List[str]: """ 获取所有会话名称 :return: 所有会话名称 """ return list(self.sessions) @reflex.event def switch_session(self, session_name: str) -> None: """ 切换会话 :param session_name: 会话名称 :return: None """ self.current_session_name = session_name @reflex.var def get_current_session_status(self) -> bool: """ 获取当前会话状态 :return: 当前会话状态,其中 True 表示正在处理中,False 表示处理完成 """ if self.current_session_name not in self.sessions: return False return self.sessions[self.current_session_name].is_processing @reflex.var def get_current_session_turns(self) -> List[Turn]: """ 获取当前会话对话列表 :return: 当前会话对话列表 """ if self.current_session_name not in self.sessions: return [] return self.sessions[self.current_session_name].turns @reflex.event def create_session(self, form_data: Dict[str, Any]) -> None: """ 创建会话 :param form_data: 创建会话表单数据 :return: None """ session_name = form_data["session_name"].strip() # 若创建会话名称为空则默认使用"NewSession"作为会话名称 if not session_name: session_name = "NewSession" original_session_name = session_name counter = 1 # 若会话名称重复则在会话名称后面添加序号至不重复 while session_name in self.sessions: session_name = f"{original_session_name}({counter})" counter += 1 self.current_session_name = session_name self.sessions[session_name] = Session() # 关闭新建会话模态窗 self.create_session_modal_is_open = False @reflex.event def delete_session(self, session_name: str) -> None: """ 删除会话 :param session_name: 会话名称 :return: None """ if session_name not in self.sessions: return del self.sessions[session_name] # 若删除会话后所有会话为空则默认创建空白会话 if not self.sessions: self.sessions["NewSession"] = Session() # 删除会话后,若当前会话名称不存在则默认使用第一个会话名称 if self.current_session_name not in self.sessions: self.current_session_name = next(iter(self.sessions)) @reflex.event def toggle_create_session_modal(self, is_open: bool) -> None: """ 打开 / 关闭新建会话模态窗 :param is_open: 打开或关闭新建会话模态窗 :return: None """ self.create_session_modal_is_open = is_open @reflex.event async def adapt_input_message(self, form_data: dict[str, Any]) -> AsyncGenerator: """ 适配用户输入消息 :param form_data: 对话表单数据 :return: AsyncGenerator """ input_message = form_data["input_message"].strip() if not input_message: return async for value in self.process_input_message(input_message=input_message): yield value async def process_input_message(self, input_message: str) -> AsyncGenerator: """ 处理用户输入消息 :param input_message: 用户输入的消息 :return: AsyncGenerator """ # 当前会话 current_session = self.sessions[self.current_session_name] # 将用户输入消息添加到当前会话对话列表 current_session.turns.append( Turn( input_message=input_message, ) ) # 当前会话正在处理 current_session.is_processing = True yield agent = retrieve_agent(self) async for chunk in agent.output_message_streamed(user_prompt=input_message): if not chunk: continue current_session.turns[-1].output_message += chunk yield # 当前会话处理完成 current_session.is_processing = False