184 lines
5.6 KiB
Python
184 lines
5.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
应用状态管理模块
|
|
"""
|
|
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
|
|
|
from pydantic import BaseModel, Field
|
|
import reflex
|
|
from pathlib import Path
|
|
from sys import path
|
|
|
|
path.append((Path(__file__).resolve().parent.parent.parent.as_posix()))
|
|
from utils.agent import Agent
|
|
|
|
|
|
def retrieve_agent(state) -> Agent:
|
|
"""
|
|
获取当前会话绑定的智能体
|
|
:return: 当前会话绑定的智能体
|
|
"""
|
|
|
|
if state.current_session_name not in state._agents:
|
|
state._agents[state.current_session_name] = Agent(
|
|
instructions="You are a friendly chatbot named Reflex. Respond in markdown."
|
|
)
|
|
return state._agents[state.current_session_name]
|
|
|
|
|
|
class Turn(BaseModel):
|
|
"""对话数据模型,包含用户输入消息和智能体输出消息"""
|
|
|
|
input_message: str = Field(..., description="用户输入的消息")
|
|
output_message: str = Field(default="", description="智能体输出的消息")
|
|
|
|
|
|
class Session(BaseModel):
|
|
"""会话数据模型,包含会绑定智能体和会话对话列表"""
|
|
|
|
turns: List[Turn] = Field(default_factory=list, description="会话对话列表")
|
|
|
|
|
|
class State(reflex.State):
|
|
"""应用状态"""
|
|
|
|
# 当前会话名称
|
|
current_session_name: str = "Intro"
|
|
# 当前会话正在处理
|
|
current_session_processing: bool = False
|
|
|
|
# 所有会话
|
|
sessions: Dict[str, Session] = {"Intro": Session()}
|
|
|
|
# 所有会话绑定的智能体(私有变量)
|
|
_agents: Dict[str, Any] = {}
|
|
|
|
# 新建会话模态窗是否打开
|
|
create_session_modal_is_open: bool = False
|
|
|
|
@reflex.var
|
|
def get_session_names(self) -> List[str]:
|
|
"""
|
|
获取所有会话名称
|
|
:return: 所有会话名称
|
|
"""
|
|
return list(self.sessions)
|
|
|
|
@reflex.event
|
|
def set_current_session(self, session_name: str) -> None:
|
|
"""
|
|
将所选会话设置为当前会话
|
|
:param session_name: 会话名称
|
|
:return: None
|
|
"""
|
|
self.current_session_name = session_name
|
|
|
|
@reflex.event
|
|
def delete_session(self, session_name: str) -> None:
|
|
"""
|
|
删除会话
|
|
:param session_name: 会话名称
|
|
:return: None
|
|
"""
|
|
if session_name not in self.sessions:
|
|
return
|
|
|
|
del self.sessions[session_name]
|
|
|
|
# 若删除会话后所有会话为空则默认创建空白会话
|
|
if not self.sessions:
|
|
self.sessions["Intro"] = Session()
|
|
|
|
# 若删除会话后当前会话名称不存在则默认使用第一个会话名称
|
|
if self.current_session_name not in self.sessions:
|
|
self.current_session_name = next(iter(self.sessions))
|
|
|
|
@reflex.var
|
|
def get_turns(self) -> List[Turn]:
|
|
"""
|
|
获取当前会话所有对话
|
|
:return: 当前会话所有对话
|
|
"""
|
|
if self.current_session_name not in self.sessions:
|
|
return (
|
|
[]
|
|
) # 因 reflex 实时响应状态变更(前端立刻自动刷新),故需要考虑当前会话不存在的情况
|
|
return self.sessions[self.current_session_name].turns
|
|
|
|
@reflex.event
|
|
def toggle_create_session_modal(self, is_open: bool) -> None:
|
|
"""
|
|
切换新建会话模态窗开关状态
|
|
:param is_open: 打开或关闭新建会话模态窗
|
|
:return: None
|
|
"""
|
|
self.create_session_modal_is_open = is_open
|
|
|
|
@reflex.event
|
|
def create_session(self, form_data: Dict[str, Any]) -> None:
|
|
"""
|
|
创建会话
|
|
:param form_data: 创建会话表单数据
|
|
:return: None
|
|
"""
|
|
session_name = form_data["session_name"].strip()
|
|
|
|
# 若创建会话名称为空字则默认使用"Intro"作为会话名称
|
|
if not session_name:
|
|
session_name = "Intro"
|
|
|
|
original_session_name = session_name
|
|
counter = 1
|
|
# 若会话名称重复则在会话名称后面添加序号至不重复
|
|
while session_name in self.sessions:
|
|
session_name = f"{original_session_name}_{counter}"
|
|
counter += 1
|
|
|
|
self.current_session_name = session_name
|
|
self.sessions[session_name] = Session()
|
|
self.create_session_modal_is_open = False # 关闭新建会话模态窗
|
|
|
|
@reflex.event
|
|
async def adapt_input_message(self, form_data: dict[str, str]) -> AsyncGenerator:
|
|
"""
|
|
适配用户输入消息
|
|
:param form_data: 对话表单数据
|
|
:return: AsyncGenerator
|
|
"""
|
|
input_message = form_data["input_message"].strip()
|
|
if not input_message:
|
|
return
|
|
|
|
async for value in self._process_input_message(input_message):
|
|
yield value
|
|
|
|
async def _process_input_message(self, input_message: str) -> AsyncGenerator:
|
|
"""
|
|
处理用户输入消息
|
|
:param input_message: 用户输入的消息
|
|
:return: AsyncGenerator
|
|
"""
|
|
# 当前会话
|
|
current_session = self.sessions[self.current_session_name]
|
|
# 将用户输入消息添加到当前会话对话列表
|
|
current_session.turns.append(
|
|
Turn(
|
|
input_message=input_message,
|
|
)
|
|
)
|
|
|
|
# 当前会话正在处理
|
|
self.current_session_processing = True
|
|
yield
|
|
|
|
agent = retrieve_agent(self)
|
|
async for chunk in agent.output_message_streamed(user_prompt=input_message):
|
|
if not chunk:
|
|
continue
|
|
current_session.turns[-1].output_message += chunk
|
|
yield
|
|
|
|
# 当前会话处理完成
|
|
self.current_session_processing = False
|