Python/产品需求文档AI生成/application/state.py

197 lines
6.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
应用状态管理模块
"""
from typing import Any, AsyncGenerator, Dict, List
from uuid import uuid4
from pydantic import BaseModel, Field
import reflex
from pathlib import Path
from sys import path
path.append((Path(__file__).resolve().parent.parent.parent.as_posix()))
from utils.agent import Agent
# 所有会话绑定的智能体
agents: Dict[str, Any] = {}
def retrieve_agent(state) -> Agent:
"""
获取当前会话绑定的智能体
:return: 当前会话绑定的智能体
"""
current_session_name = state.current_session_name
if current_session_name not in agents:
agents[current_session_name] = Agent(
session_id=state.sessions[current_session_name].id,
instructions="You are a friendly chatbot named Reflex. Respond in markdown.",
)
return agents[current_session_name]
class Turn(BaseModel):
"""对话数据模型,包含用户输入消息和智能体输出消息"""
input_message: str = Field(..., description="用户输入的消息")
output_message: str = Field(default="", description="智能体输出的消息")
class Session(BaseModel):
"""会话数据模型,包含会话唯一标识和会话对话列表"""
id: str = Field(default_factory=lambda: uuid4().hex, description="会话唯一标识")
is_processing: bool = Field(default=False, description="会话是否正在处理中")
turns: List[Turn] = Field(default_factory=list, description="会话对话列表")
# Reflex.State 统一管理应用数据与功能状态,作为前后端交互枢纽,借助响应式特性实现页面自动更新
class State(reflex.State):
"""应用状态"""
# 当前会话名称
current_session_name: str = "NewSession"
# 所有会话
sessions: Dict[str, Session] = {current_session_name: Session()}
# 新建会话模态窗是否打开
create_session_modal_is_open: bool = False
@reflex.var
def get_session_names(self) -> List[str]:
"""
获取所有会话名称
:return: 所有会话名称
"""
return list(self.sessions)
@reflex.event
def switch_session(self, session_name: str) -> None:
"""
切换会话
:param session_name: 会话名称
:return: None
"""
self.current_session_name = session_name
@reflex.var
def get_current_session_status(self) -> bool:
"""
获取当前会话状态
:return: 当前会话状态,其中 True 表示正在处理中False 表示处理完成
"""
if self.current_session_name not in self.sessions:
return False
return self.sessions[self.current_session_name].is_processing
@reflex.var
def get_current_session_turns(self) -> List[Turn]:
"""
获取当前会话对话列表
:return: 当前会话对话列表
"""
if self.current_session_name not in self.sessions:
return []
return self.sessions[self.current_session_name].turns
@reflex.event
def create_session(self, form_data: Dict[str, Any]) -> None:
"""
创建会话
:param form_data: 创建会话表单数据
:return: None
"""
session_name = form_data["session_name"].strip()
# 若创建会话名称为空则默认使用"NewSession"作为会话名称
if not session_name:
session_name = "NewSession"
original_session_name = session_name
counter = 1
# 若会话名称重复则在会话名称后面添加序号至不重复
while session_name in self.sessions:
session_name = f"{original_session_name}({counter})"
counter += 1
self.current_session_name = session_name
self.sessions[session_name] = Session()
# 关闭新建会话模态窗
self.create_session_modal_is_open = False
@reflex.event
def delete_session(self, session_name: str) -> None:
"""
删除会话
:param session_name: 会话名称
:return: None
"""
if session_name not in self.sessions:
return
del self.sessions[session_name]
# 若删除会话后所有会话为空则默认创建空白会话
if not self.sessions:
self.sessions["NewSession"] = Session()
# 删除会话后,若当前会话名称不存在则默认使用第一个会话名称
if self.current_session_name not in self.sessions:
self.current_session_name = next(iter(self.sessions))
@reflex.event
def toggle_create_session_modal(self, is_open: bool) -> None:
"""
打开 / 关闭新建会话模态窗
:param is_open: 打开或关闭新建会话模态窗
:return: None
"""
self.create_session_modal_is_open = is_open
@reflex.event
async def adapt_input_message(self, form_data: dict[str, Any]) -> AsyncGenerator:
"""
适配用户输入消息
:param form_data: 对话表单数据
:return: AsyncGenerator
"""
input_message = form_data["input_message"].strip()
if not input_message:
return
async for value in self.process_input_message(input_message=input_message):
yield value
async def process_input_message(self, input_message: str) -> AsyncGenerator:
"""
处理用户输入消息
:param input_message: 用户输入的消息
:return: AsyncGenerator
"""
# 当前会话
current_session = self.sessions[self.current_session_name]
# 将用户输入消息添加到当前会话对话列表
current_session.turns.append(
Turn(
input_message=input_message,
)
)
# 当前会话正在处理
current_session.is_processing = True
yield
agent = retrieve_agent(self)
async for chunk in agent.output_message_streamed(user_prompt=input_message):
if not chunk:
continue
current_session.turns[-1].output_message += chunk
yield
# 当前会话处理完成
current_session.is_processing = False