190 lines
5.7 KiB
Python
190 lines
5.7 KiB
Python
import os
|
|
from typing import Any, Optional, TypedDict
|
|
import reflex
|
|
from openai import OpenAI
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
from sys import path
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
|
|
path.append(str(Path(__file__).resolve().parent.parent.parent))
|
|
from utils.agent import Agent
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class Turn(BaseModel):
|
|
"""一轮对话中用户输入消息和智能体输出消息"""
|
|
|
|
input_message: str = Field(description="用户输入的消息")
|
|
output_message: str
|
|
|
|
|
|
class Session(BaseModel):
|
|
"""会话数据模型"""
|
|
|
|
uuid: str = Field(..., description="会话唯一标识")
|
|
agent: Optional[Agent] = Field(default=None, description="智能体实例") # 鼓励
|
|
message_history: List[Turn] = Field(default=[], description="会话消息历史")
|
|
|
|
|
|
class State(reflex.State):
|
|
"""实例状态管理"""
|
|
|
|
agent: Agent | None = None
|
|
|
|
def _init_agent(self) -> Agent:
|
|
"""初始化智能体"""
|
|
if not self.agent:
|
|
self.agent = Agent(
|
|
instructions="You are a friendly chatbot named Reflex. Respond in markdown."
|
|
)
|
|
return self.agent
|
|
|
|
current_chat_name = "Intros" # 当前聊天名称
|
|
|
|
sessions: dict[str, list[Turn]] = {
|
|
current_chat_name: [],
|
|
} # 会话列表
|
|
|
|
# 智能体是否正在输出
|
|
outputting: bool = False
|
|
|
|
# 新会话弹窗是否打开
|
|
is_modal_open: bool = False
|
|
|
|
@reflex.event
|
|
def create_chat(self, form_data: dict[str, Any]):
|
|
"""Create a new chat."""
|
|
# Add the new chat to the list of chats.
|
|
new_chat_name = form_data["new_chat_name"]
|
|
self.current_chat = new_chat_name
|
|
self._chats[new_chat_name] = []
|
|
self.is_modal_open = False
|
|
|
|
@reflex.event
|
|
def set_is_modal_open(self, is_open: bool):
|
|
"""Set the new chat modal open state.
|
|
|
|
Args:
|
|
is_open: Whether the modal is open.
|
|
"""
|
|
self.is_modal_open = is_open
|
|
|
|
@reflex.var
|
|
def selected_chat(self) -> list[TurnMessages]:
|
|
"""Get the list of turns for the current chat.
|
|
|
|
Returns:
|
|
The list of turns.
|
|
"""
|
|
return (
|
|
self._chats[self.current_chat] if self.current_chat in self._chats else []
|
|
)
|
|
|
|
@reflex.event
|
|
def delete_chat(self, chat_name: str):
|
|
"""Delete the current chat."""
|
|
if chat_name not in self._chats:
|
|
return
|
|
del self._chats[chat_name]
|
|
if len(self._chats) == 0:
|
|
self._chats = {
|
|
"Intros": [],
|
|
}
|
|
if self.current_chat not in self._chats:
|
|
self.current_chat = list(self._chats.keys())[0]
|
|
|
|
@reflex.event
|
|
def set_chat(self, chat_name: str):
|
|
"""Set the name of the current chat.
|
|
|
|
Args:
|
|
chat_name: The name of the chat.
|
|
"""
|
|
self.current_chat = chat_name
|
|
|
|
@reflex.event
|
|
def set_new_chat_name(self, new_chat_name: str):
|
|
"""Set the name of the new chat.
|
|
|
|
Args:
|
|
new_chat_name: The name of the new chat.
|
|
"""
|
|
self.new_chat_name = new_chat_name
|
|
|
|
@reflex.var
|
|
def chat_titles(self) -> list[str]:
|
|
"""Get the list of chat titles.
|
|
|
|
Returns:
|
|
The list of chat names.
|
|
"""
|
|
return list(self._chats.keys())
|
|
|
|
@reflex.event
|
|
async def process_question(self, form_data: dict[str, Any]):
|
|
# Get the question from the form
|
|
question = form_data["question"]
|
|
|
|
# Check if the question is empty
|
|
if not question:
|
|
return
|
|
|
|
async for value in self.openai_process_question(question):
|
|
yield value
|
|
|
|
@reflex.event
|
|
async def openai_process_question(self, question: str):
|
|
"""Get the response from the API.
|
|
|
|
Args:
|
|
form_data: A dict with the current question.
|
|
"""
|
|
|
|
# Add the question to the list of questions.
|
|
qa = TurnMessages(input_message=question, output_message="")
|
|
self._chats[self.current_chat].append(qa)
|
|
|
|
# Clear the input and start the processing.
|
|
self.processing = True
|
|
yield
|
|
|
|
# Build the messages.
|
|
messages: list[ChatCompletionMessageParam] = [
|
|
{
|
|
"role": "system",
|
|
"content": "You are a friendly chatbot named Reflex. Respond in markdown.",
|
|
}
|
|
]
|
|
for qa in self._chats[self.current_chat]:
|
|
messages.append({"role": "user", "content": qa["input_message"]})
|
|
messages.append({"role": "assistant", "content": qa["output_message"]})
|
|
|
|
# Remove the last mock answer.
|
|
messages = messages[:-1]
|
|
|
|
# Start a new session to answer the question.
|
|
session = OpenAI().chat.completions.create(
|
|
model=os.getenv("OPENAI_MODEL", "gpt-3.5-turbo"),
|
|
messages=messages,
|
|
stream=True,
|
|
)
|
|
|
|
# Stream the results, yielding after every word.
|
|
for item in session:
|
|
if hasattr(item.choices[0].delta, "content"):
|
|
answer_text = item.choices[0].delta.content
|
|
# Ensure answer_text is not None before concatenation
|
|
if answer_text is not None:
|
|
self._chats[self.current_chat][-1]["answer"] += answer_text
|
|
else:
|
|
# Handle the case where answer_text is None, perhaps log it or assign a default value
|
|
# For example, assigning an empty string if answer_text is None
|
|
answer_text = ""
|
|
self._chats[self.current_chat][-1]["answer"] += answer_text
|
|
self._chats = self._chats
|
|
yield
|
|
|
|
# Toggle the processing flag.
|
|
self.processing = False
|