Python/产品需求文档AI生成/application/state.py

184 lines
5.6 KiB
Python

# -*- coding: utf-8 -*-
"""
应用状态管理模块
"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from pydantic import BaseModel, Field
import reflex
from pathlib import Path
from sys import path
path.append((Path(__file__).resolve().parent.parent.parent.as_posix()))
from utils.agent import Agent
def retrieve_agent(state) -> Agent:
"""
获取当前会话绑定的智能体
:return: 当前会话绑定的智能体
"""
if state.current_session_name not in state._agents:
state._agents[state.current_session_name] = Agent(
instructions="You are a friendly chatbot named Reflex. Respond in markdown."
)
return state._agents[state.current_session_name]
class Turn(BaseModel):
"""对话数据模型,包含用户输入消息和智能体输出消息"""
input_message: str = Field(..., description="用户输入的消息")
output_message: str = Field(default="", description="智能体输出的消息")
class Session(BaseModel):
"""会话数据模型,包含会绑定智能体和会话对话列表"""
turns: List[Turn] = Field(default_factory=list, description="会话对话列表")
class State(reflex.State):
"""应用状态"""
# 当前会话名称
current_session_name: str = "Intro"
# 当前会话正在处理
current_session_processing: bool = False
# 所有会话
sessions: Dict[str, Session] = {"Intro": Session()}
# 所有会话绑定的智能体(私有变量)
_agents: Dict[str, Any] = {}
# 新建会话模态窗是否打开
create_session_modal_is_open: bool = False
@reflex.var
def get_session_names(self) -> List[str]:
"""
获取所有会话名称
:return: 所有会话名称
"""
return list(self.sessions)
@reflex.event
def set_current_session(self, session_name: str) -> None:
"""
将所选会话设置为当前会话
:param session_name: 会话名称
:return: None
"""
self.current_session_name = session_name
@reflex.event
def delete_session(self, session_name: str) -> None:
"""
删除会话
:param session_name: 会话名称
:return: None
"""
if session_name not in self.sessions:
return
del self.sessions[session_name]
# 若删除会话后所有会话为空则默认创建空白会话
if not self.sessions:
self.sessions["Intro"] = Session()
# 若删除会话后当前会话名称不存在则默认使用第一个会话名称
if self.current_session_name not in self.sessions:
self.current_session_name = next(iter(self.sessions))
@reflex.var
def get_turns(self) -> List[Turn]:
"""
获取当前会话所有对话
:return: 当前会话所有对话
"""
if self.current_session_name not in self.sessions:
return (
[]
) # 因 reflex 实时响应状态变更(前端立刻自动刷新),故需要考虑当前会话不存在的情况
return self.sessions[self.current_session_name].turns
@reflex.event
def toggle_create_session_modal(self, is_open: bool) -> None:
"""
切换新建会话模态窗开关状态
:param is_open: 打开或关闭新建会话模态窗
:return: None
"""
self.create_session_modal_is_open = is_open
@reflex.event
def create_session(self, form_data: Dict[str, Any]) -> None:
"""
创建会话
:param form_data: 创建会话表单数据
:return: None
"""
session_name = form_data["session_name"].strip()
# 若创建会话名称为空字则默认使用"Intro"作为会话名称
if not session_name:
session_name = "Intro"
original_session_name = session_name
counter = 1
# 若会话名称重复则在会话名称后面添加序号至不重复
while session_name in self.sessions:
session_name = f"{original_session_name}_{counter}"
counter += 1
self.current_session_name = session_name
self.sessions[session_name] = Session()
self.create_session_modal_is_open = False # 关闭新建会话模态窗
@reflex.event
async def adapt_input_message(self, form_data: dict[str, str]) -> AsyncGenerator:
"""
适配用户输入消息
:param form_data: 对话表单数据
:return: AsyncGenerator
"""
input_message = form_data["input_message"].strip()
if not input_message:
return
async for value in self._process_input_message(input_message):
yield value
async def _process_input_message(self, input_message: str) -> AsyncGenerator:
"""
处理用户输入消息
:param input_message: 用户输入的消息
:return: AsyncGenerator
"""
# 当前会话
current_session = self.sessions[self.current_session_name]
# 将用户输入消息添加到当前会话对话列表
current_session.turns.append(
Turn(
input_message=input_message,
)
)
# 当前会话正在处理
self.current_session_processing = True
yield
agent = retrieve_agent(self)
async for chunk in agent.output_message_streamed(user_prompt=input_message):
if not chunk:
continue
current_session.turns[-1].output_message += chunk
yield
# 当前会话处理完成
self.current_session_processing = False