197 lines
6.2 KiB
Python
197 lines
6.2 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
应用状态管理模块
|
||
"""
|
||
|
||
from typing import Any, AsyncGenerator, Dict, List
|
||
from uuid import uuid4
|
||
|
||
from pydantic import BaseModel, Field
|
||
import reflex
|
||
from pathlib import Path
|
||
from sys import path
|
||
|
||
path.append((Path(__file__).resolve().parent.parent.parent.as_posix()))
|
||
from utils.agent import Agent
|
||
|
||
# 所有会话绑定的智能体
|
||
agents: Dict[str, Any] = {}
|
||
|
||
|
||
def retrieve_agent(state) -> Agent:
|
||
"""
|
||
获取当前会话绑定的智能体
|
||
:return: 当前会话绑定的智能体
|
||
"""
|
||
current_session_name = state.current_session_name
|
||
if current_session_name not in agents:
|
||
agents[current_session_name] = Agent(
|
||
session_id=state.sessions[current_session_name].id,
|
||
instructions="You are a friendly chatbot named Reflex. Respond in markdown.",
|
||
)
|
||
return agents[current_session_name]
|
||
|
||
|
||
class Turn(BaseModel):
|
||
"""对话数据模型,包含用户输入消息和智能体输出消息"""
|
||
|
||
input_message: str = Field(..., description="用户输入的消息")
|
||
output_message: str = Field(default="", description="智能体输出的消息")
|
||
|
||
|
||
class Session(BaseModel):
|
||
"""会话数据模型,包含会话唯一标识和会话对话列表"""
|
||
|
||
id: str = Field(default_factory=lambda: uuid4().hex, description="会话唯一标识")
|
||
is_processing: bool = Field(default=False, description="会话是否正在处理中")
|
||
turns: List[Turn] = Field(default_factory=list, description="会话对话列表")
|
||
|
||
|
||
# Reflex.State 统一管理应用数据与功能状态,作为前后端交互枢纽,借助响应式特性实现页面自动更新
|
||
class State(reflex.State):
|
||
"""应用状态"""
|
||
|
||
# 当前会话名称
|
||
current_session_name: str = "NewSession"
|
||
|
||
# 所有会话
|
||
sessions: Dict[str, Session] = {current_session_name: Session()}
|
||
|
||
# 新建会话模态窗是否打开
|
||
create_session_modal_is_open: bool = False
|
||
|
||
@reflex.var
|
||
def get_session_names(self) -> List[str]:
|
||
"""
|
||
获取所有会话名称
|
||
:return: 所有会话名称
|
||
"""
|
||
return list(self.sessions)
|
||
|
||
@reflex.event
|
||
def switch_session(self, session_name: str) -> None:
|
||
"""
|
||
切换会话
|
||
:param session_name: 会话名称
|
||
:return: None
|
||
"""
|
||
self.current_session_name = session_name
|
||
|
||
@reflex.var
|
||
def get_current_session_status(self) -> bool:
|
||
"""
|
||
获取当前会话状态
|
||
:return: 当前会话状态,其中 True 表示正在处理中,False 表示处理完成
|
||
"""
|
||
if self.current_session_name not in self.sessions:
|
||
return False
|
||
return self.sessions[self.current_session_name].is_processing
|
||
|
||
@reflex.var
|
||
def get_current_session_turns(self) -> List[Turn]:
|
||
"""
|
||
获取当前会话对话列表
|
||
:return: 当前会话对话列表
|
||
"""
|
||
if self.current_session_name not in self.sessions:
|
||
return []
|
||
return self.sessions[self.current_session_name].turns
|
||
|
||
@reflex.event
|
||
def create_session(self, form_data: Dict[str, Any]) -> None:
|
||
"""
|
||
创建会话
|
||
:param form_data: 创建会话表单数据
|
||
:return: None
|
||
"""
|
||
session_name = form_data["session_name"].strip()
|
||
|
||
# 若创建会话名称为空则默认使用"NewSession"作为会话名称
|
||
if not session_name:
|
||
session_name = "NewSession"
|
||
|
||
original_session_name = session_name
|
||
counter = 1
|
||
# 若会话名称重复则在会话名称后面添加序号至不重复
|
||
while session_name in self.sessions:
|
||
session_name = f"{original_session_name}({counter})"
|
||
counter += 1
|
||
|
||
self.current_session_name = session_name
|
||
self.sessions[session_name] = Session()
|
||
|
||
# 关闭新建会话模态窗
|
||
self.create_session_modal_is_open = False
|
||
|
||
@reflex.event
|
||
def delete_session(self, session_name: str) -> None:
|
||
"""
|
||
删除会话
|
||
:param session_name: 会话名称
|
||
:return: None
|
||
"""
|
||
if session_name not in self.sessions:
|
||
return
|
||
|
||
del self.sessions[session_name]
|
||
|
||
# 若删除会话后所有会话为空则默认创建空白会话
|
||
if not self.sessions:
|
||
self.sessions["NewSession"] = Session()
|
||
|
||
# 删除会话后,若当前会话名称不存在则默认使用第一个会话名称
|
||
if self.current_session_name not in self.sessions:
|
||
self.current_session_name = next(iter(self.sessions))
|
||
|
||
@reflex.event
|
||
def toggle_create_session_modal(self, is_open: bool) -> None:
|
||
"""
|
||
打开 / 关闭新建会话模态窗
|
||
:param is_open: 打开或关闭新建会话模态窗
|
||
:return: None
|
||
"""
|
||
self.create_session_modal_is_open = is_open
|
||
|
||
@reflex.event
|
||
async def adapt_input_message(self, form_data: dict[str, Any]) -> AsyncGenerator:
|
||
"""
|
||
适配用户输入消息
|
||
:param form_data: 对话表单数据
|
||
:return: AsyncGenerator
|
||
"""
|
||
input_message = form_data["input_message"].strip()
|
||
if not input_message:
|
||
return
|
||
|
||
async for value in self.process_input_message(input_message=input_message):
|
||
yield value
|
||
|
||
async def process_input_message(self, input_message: str) -> AsyncGenerator:
|
||
"""
|
||
处理用户输入消息
|
||
:param input_message: 用户输入的消息
|
||
:return: AsyncGenerator
|
||
"""
|
||
# 当前会话
|
||
current_session = self.sessions[self.current_session_name]
|
||
# 将用户输入消息添加到当前会话对话列表
|
||
current_session.turns.append(
|
||
Turn(
|
||
input_message=input_message,
|
||
)
|
||
)
|
||
|
||
# 当前会话正在处理
|
||
current_session.is_processing = True
|
||
yield
|
||
|
||
agent = retrieve_agent(self)
|
||
async for chunk in agent.output_message_streamed(user_prompt=input_message):
|
||
if not chunk:
|
||
continue
|
||
current_session.turns[-1].output_message += chunk
|
||
yield
|
||
|
||
# 当前会话处理完成
|
||
current_session.is_processing = False
|