# -*- coding: utf-8 -*- """ 应用状态管理模块 """ from typing import Any, AsyncGenerator, Dict, List, Optional from pydantic import BaseModel, Field import reflex from pathlib import Path from sys import path path.append((Path(__file__).resolve().parent.parent.parent.as_posix())) from utils.agent import Agent def retrieve_agent(state) -> Agent: """ 获取当前会话绑定的智能体 :return: 当前会话绑定的智能体 """ if state.current_session_name not in state._agents: state._agents[state.current_session_name] = Agent( instructions="You are a friendly chatbot named Reflex. Respond in markdown." ) return state._agents[state.current_session_name] class Turn(BaseModel): """对话数据模型,包含用户输入消息和智能体输出消息""" input_message: str = Field(..., description="用户输入的消息") output_message: str = Field(default="", description="智能体输出的消息") class Session(BaseModel): """会话数据模型,包含会绑定智能体和会话对话列表""" turns: List[Turn] = Field(default_factory=list, description="会话对话列表") class State(reflex.State): """应用状态""" # 当前会话名称 current_session_name: str = "Intro" # 当前会话正在处理 current_session_processing: bool = False # 所有会话 sessions: Dict[str, Session] = {"Intro": Session()} # 所有会话绑定的智能体(私有变量) _agents: Dict[str, Any] = {} # 新建会话模态窗是否打开 create_session_modal_is_open: bool = False @reflex.var def get_session_names(self) -> List[str]: """ 获取所有会话名称 :return: 所有会话名称 """ return list(self.sessions) @reflex.event def set_current_session(self, session_name: str) -> None: """ 将所选会话设置为当前会话 :param session_name: 会话名称 :return: None """ self.current_session_name = session_name @reflex.event def delete_session(self, session_name: str) -> None: """ 删除会话 :param session_name: 会话名称 :return: None """ if session_name not in self.sessions: return del self.sessions[session_name] # 若删除会话后所有会话为空则默认创建空白会话 if not self.sessions: self.sessions["Intro"] = Session() # 若删除会话后当前会话名称不存在则默认使用第一个会话名称 if self.current_session_name not in self.sessions: self.current_session_name = next(iter(self.sessions)) @reflex.var def get_turns(self) -> List[Turn]: """ 获取当前会话所有对话 :return: 当前会话所有对话 """ if self.current_session_name not in self.sessions: return ( [] ) # 因 reflex 实时响应状态变更(前端立刻自动刷新),故需要考虑当前会话不存在的情况 return self.sessions[self.current_session_name].turns @reflex.event def toggle_create_session_modal(self, is_open: bool) -> None: """ 切换新建会话模态窗开关状态 :param is_open: 打开或关闭新建会话模态窗 :return: None """ self.create_session_modal_is_open = is_open @reflex.event def create_session(self, form_data: Dict[str, Any]) -> None: """ 创建会话 :param form_data: 创建会话表单数据 :return: None """ session_name = form_data["session_name"].strip() # 若创建会话名称为空字则默认使用"Intro"作为会话名称 if not session_name: session_name = "Intro" original_session_name = session_name counter = 1 # 若会话名称重复则在会话名称后面添加序号至不重复 while session_name in self.sessions: session_name = f"{original_session_name}_{counter}" counter += 1 self.current_session_name = session_name self.sessions[session_name] = Session() self.create_session_modal_is_open = False # 关闭新建会话模态窗 @reflex.event async def adapt_input_message(self, form_data: dict[str, str]) -> AsyncGenerator: """ 适配用户输入消息 :param form_data: 对话表单数据 :return: AsyncGenerator """ input_message = form_data["input_message"].strip() if not input_message: return async for value in self._process_input_message(input_message): yield value async def _process_input_message(self, input_message: str) -> AsyncGenerator: """ 处理用户输入消息 :param input_message: 用户输入的消息 :return: AsyncGenerator """ # 当前会话 current_session = self.sessions[self.current_session_name] # 将用户输入消息添加到当前会话对话列表 current_session.turns.append( Turn( input_message=input_message, ) ) # 当前会话正在处理 self.current_session_processing = True yield agent = retrieve_agent(self) async for chunk in agent.output_message_streamed(user_prompt=input_message): if not chunk: continue current_session.turns[-1].output_message += chunk yield # 当前会话处理完成 self.current_session_processing = False