diff --git a/产品需求文档AI生成/ai_agent/__init__.py b/utils/__init__.py similarity index 100% rename from 产品需求文档AI生成/ai_agent/__init__.py rename to utils/__init__.py diff --git a/utils/agent.py b/utils/agent.py index 41c3390..05a0ed6 100644 --- a/utils/agent.py +++ b/utils/agent.py @@ -5,25 +5,108 @@ # 列举导入模块 from pathlib import Path -from sys import path -from typing import List, Optional, Union, cast -from typing import TypeVar +import time +from typing import AsyncGenerator, List, Optional from uuid import uuid4 -from pydantic_ai import Agent as BaseAgent, AgentRunResult +from pydantic_ai import Agent as PydanticAIAgent, ModelMessage from pydantic_ai.capabilities import AgentCapability +from pydantic_ai.messages import ModelMessagesTypeAdapter from pydantic_ai.models.openai import OpenAIChatModel -from pydantic_ai.native_tools import AbstractNativeTool from pydantic_ai.output import OutputSpec from pydantic_ai.providers.openai import OpenAIProvider -from starlette.applications import Starlette -from starlette.routing import Mount, Route -path.append(Path(__file__).resolve().parent.as_posix()) -from memory import Memory +from sqlite import SQLite -AgentDepsT = TypeVar("AgentDepsT") -OutputDataT = TypeVar("OutputDataT") + +class AgentMemory(SQLite): + """ + 智能体的记忆体,支持: + create:新增对话消息 + read:查询会话历史消息 + """ + + def __init__(self): + """ + 初始化智能体的记忆体 + """ + # 构建智能体的记忆体的数据库路径 + super().__init__(database=Path(__file__).parent.resolve() / "agent_memory.db") + + try: + with self: + self.execute( + sql=""" + CREATE TABLE IF NOT EXISTS new_messages + ( + --唯一标识 + id TEXT PRIMARY KEY, + --会话唯一标识 + session_id TEXT NOT NULL, + --新对话消息 + new_messages TEXT NOT NULL, + --时间戳(毫秒) + timestamp INTEGER NOT NULL + ) + """ + ) + except Exception as exception: + raise RuntimeError( + f"初始化智能体的记忆体发生异常:{str(exception)}" + ) from exception + + def create_new_messages( + self, session_id: str, new_messages: List[ModelMessage] + ) -> bool: + """ + 新增新对话消息 + :param session_id: 会话唯一标识 + :param new_messages: 新对话消息 + :return: 新增是否成功 + """ + try: + with self: + return self.execute( + sql=""" + INSERT INTO new_messages (id, session_id, new_messages, timestamp) VALUES (?, ?, ?, ?) + """, + parameters=( + uuid4().hex.lower(), + session_id, + ModelMessagesTypeAdapter.dump_json(new_messages), + int(time.time() * 1000), + ), + ) + except Exception as exception: + raise RuntimeError(f"新增对话消息发生异常:{str(exception)}") from exception + + def read_message_history(self, session_id: str) -> List[ModelMessage]: + """ + 查询对话消息历史 + :param session_id: 会话唯一标识 + :return: 对话消息历史 + """ + try: + with self: + result = self.query_all( + sql=""" + SELECT new_messages + FROM new_messages + WHERE session_id = ? + ORDER BY timestamp ASC + """, + parameters=(session_id,), + ) + message_history = [] + for row in result: + message_history.extend( + ModelMessagesTypeAdapter.validate_json(row["new_message"]) + ) + return message_history + except Exception as exception: + raise RuntimeError( + f"查询会话历史消息发生异常:{str(exception)}" + ) from exception class Agent: @@ -49,31 +132,11 @@ class Agent: # 生成会话唯一标识 self.session_id = uuid4().hex.lower() - # 创建智能体 - self.agent = self._create_agent( - instructions=instructions, - capabilities=capabilities, - output_type=output_type, - ) - a = self.agent.to_web() + # 实例智能体的记忆体 + self.agent_memory = AgentMemory() - # 实例记忆体 - self.memory = Memory() - - def _create_agent( - self, - instructions: str, - capabilities: Optional[List[AgentCapability]], - output_type: OutputSpec, - ) -> BaseAgent: - """ - 创建智能体 - :param instructions: 指令 - :param capabilities: 智能体能力列表 - :param output_type: 输出类型 - :return: Agent 实例 - """ - agent = BaseAgent( + # 实例智能体 + self.agent = PydanticAIAgent( model=OpenAIChatModel( model_name="deepseek-v4-flash", provider=OpenAIProvider( @@ -86,165 +149,28 @@ class Agent: output_type=output_type, retries=1, ) - return agent - async def run(self, user_prompt: str | List[str]) -> AgentRunResult: + async def output_message_streamed( + self, user_prompt: str | List[str] + ) -> AsyncGenerator[str, None]: """ - 异步运行 - :param user_prompt: 用户提示词 - :return: 智能体回复 + 智能体流式输出消息 + :param user_prompt: 用户提示词(用户输入消息) + :return: 流式消息 """ - # 查询会话历史消息 - message_history = self.memory.read(session_id=self.session_id) - result = await self.agent.run( + """定义:一次会话(session)包含若干论对话(turn),每一轮对话由用户输入消息(message)和智能体输出消息组成""" + # 查询该会话历史对话消息列表 + message_history = self.agent_memory.read_message_history( + session_id=self.session_id + ) + + async with self.agent.run_stream( user_prompt=user_prompt, message_history=message_history - ) - # 记录会话历史消息 - self.memory.create( + ) as result: + async for chunk in result.stream_text(): + yield chunk + + self.agent_memory.create_new_messages( session_id=self.session_id, - dialogue_message=result.new_messages(), + new_messages=result.new_messages(), ) - return result - - def run_agent_application( - self, - native_tools: Optional[List[AbstractNativeTool]] = None, - html_source: Optional[Union[str, Path]] = None, - ) -> Starlette: - """ - 启动智能体应用 - :param native_tools: 前端可选工具列表 - :return: Starlette 实例 - """ - from starlette.requests import Request - from starlette.responses import JSONResponse, HTMLResponse, Response - - def api( - agent: BaseAgent[AgentDepsT, OutputDataT], - native_tools: Optional[List[AbstractNativeTool]] = None, - ) -> Starlette: - """ - 创建处理请求应用 - :param agent: 智能体 - :param native_tools: 前端可选原生工具列表 - :return: Starlette 实例 - """ - from pydantic_ai.models import Model - from pydantic_ai.ui._web.api import ModelInfo - from pydantic_ai.ui._web.api import ( - ConfigureFrontend, - BuiltinToolInfo, - ChatRequestExtra, - validate_request_options, - ) - from pydantic_ai.ui.vercel_ai import VercelAIAdapter - from pydantic_ai.capabilities import NativeTool - - async def chat_options(request: Request) -> Response: - """处理跨域预检请求""" - return Response() - - agent_model = cast(Model, agent.model) - # 前端可选原生工具列表 - frontend_native_tools = [ - t - for t in (native_tools or []) - if t.unique_id - not in { - t.unique_id - for t in agent._cap_native_tools - if isinstance(t, AbstractNativeTool) - } - ] - - async def configurations(request: Request) -> Response: - """处理前端模型与工具配置请求""" - configurations = ConfigureFrontend( - models=[ - ModelInfo( - id=agent_model.model_id, - name=agent_model.label, - builtin_tools=[ - t.unique_id - for t in frontend_native_tools - if type(t) in agent_model.profile.supported_native_tools - ], - ) - ], # 前端仅可选择智能体已配置的模型 - builtin_tools=[ - BuiltinToolInfo(id=t.unique_id, name=t.label) - for t in frontend_native_tools - ], - ) - return JSONResponse(content=configurations.model_dump(by_alias=True)) - - async def chat_post(request: Request) -> Response: - """处理对话请求""" - # 实例 Vercel AI 适配器 - adapter = await VercelAIAdapter[ - TypeVar("AgentDepsT"), TypeVar("OutputDataT") - ].from_request(request=request, agent=self.agent) - - # 解析请求中额外数据,包括前端选择的模型标识、原生工具标识和其它配置等 - extra_data = ChatRequestExtra.model_validate( - adapter.run_input.__pydantic_extra__ - ) - if error := validate_request_options( - extra_data=extra_data, - model_ids={agent_model.model_id}, # 前端仅可选择智能体已配置的模型 - builtin_tool_ids={t.unique_id for t in frontend_native_tools}, - ): - return JSONResponse(content={"error": error}, status_code=400) - - streaming_response = await VercelAIAdapter[ - TypeVar("AgentDepsT"), TypeVar("OutputDataT") - ].dispatch_request( - request=request, - agent=self.agent, - capabilities=[ - NativeTool(t) - for t in frontend_native_tools - if t.unique_id in extra_data.builtin_tools - ], - ) - return streaming_response - - async def health(request: Request) -> Response: - """处理健康检查请求""" - return JSONResponse(content={"ok": True}) - - return Starlette( - routes=[ - Route(path="/chat", endpoint=chat_options, methods=["OPTIONS"]), - Route(path="/configure", endpoint=configurations, methods=["GET"]), - Route(path="/chat", endpoint=chat_post, methods=["POST"]), - Route(path="/health", endpoint=health, methods=["GET"]), - ] - ) - - application = Starlette( - routes=[ - Mount( - "/api", - app=api(agent=self.agent, native_tools=native_tools), - ) - ] - ) - - async def index(request: Request) -> Response: - """S处理聊天界面请求""" - from pydantic_ai.ui._web.app import _get_ui_html - - content = await _get_ui_html(html_source) - - return HTMLResponse( - content=content, - headers={ - "Cache-Control": "public, max-age=3600", - }, - ) - - application.router.add_route(path="/", endpoint=index, methods=["GET"]) - application.router.add_route(path="/{id}", endpoint=index, methods=["GET"]) - - return application diff --git a/utils/memory.py b/utils/memory.py deleted file mode 100644 index 2a70185..0000000 --- a/utils/memory.py +++ /dev/null @@ -1,105 +0,0 @@ -# -*- coding: utf-8 -*- -""" -记忆模块 -""" - -# 列举导入模块 -from pathlib import Path -import sys -import time -from typing import List -from uuid import uuid4 - -from pydantic_ai import ModelMessage -from pydantic_ai.messages import ModelMessagesTypeAdapter - -sys.path.append(Path(__file__).resolve().parent.as_posix()) -from sqlite import SQLite - - -class Memory(SQLite): - """ - 记忆体,支持: - create:新增对话消息 - read:查询会话历史消息 - """ - - def __init__(self): - """ - 初始化记忆体 - """ - # 构建数据库路径 - super().__init__( - database=Path(__file__).parent.resolve() / "memory_database.db" - ) - - try: - with self: - self.execute( - sql=""" - CREATE TABLE IF NOT EXISTS messages - ( - --唯一标识 - id TEXT PRIMARY KEY, - --会话唯一标识 - session_id TEXT NOT NULL, - --对话消息 - dialogue_message TEXT NOT NULL, - --时间戳(毫秒) - timestamp INTEGER NOT NULL - ) - """ - ) - except Exception as exception: - raise RuntimeError(f"初始化记忆体发生异常:{str(exception)}") from exception - - def create(self, session_id: str, dialogue_message: List[ModelMessage]) -> bool: - """ - 新增对话消息 - :param session_id: 会话唯一标识 - :param dialogue_message: 对话消息 - :return: 新增是否成功 - """ - try: - with self: - return self.execute( - sql=""" - INSERT INTO messages (id, session_id, dialogue_message, timestamp) VALUES (?, ?, ?, ?) - """, - parameters=( - uuid4().hex.lower(), - session_id, - ModelMessagesTypeAdapter.dump_json(dialogue_message), - int(time.time() * 1000), - ), - ) - except Exception as exception: - raise RuntimeError(f"新增对话消息发生异常:{str(exception)}") from exception - - def read(self, session_id: str) -> List[ModelMessage]: - """ - 查询会话历史消息 - :param session_id: 会话唯一标识 - :return: 会话历史消息 - """ - try: - with self: - result = self.query_all( - sql=""" - SELECT dialogue_message - FROM messages - WHERE session_id = ? - ORDER BY timestamp ASC - """, - parameters=(session_id,), - ) - message_history = [] - for row in result: - message_history.extend( - ModelMessagesTypeAdapter.validate_json(row["dialogue_message"]) - ) - return message_history - except Exception as exception: - raise RuntimeError( - f"查询会话历史消息发生异常:{str(exception)}" - ) from exception diff --git a/utils/memory_database.db b/utils/memory_database.db deleted file mode 100644 index 94e2cec..0000000 Binary files a/utils/memory_database.db and /dev/null differ diff --git a/产品需求文档AI生成/ai_agent/ai_agent.py b/产品需求文档AI生成/ai_agent/ai_agent.py deleted file mode 100644 index 995473d..0000000 --- a/产品需求文档AI生成/ai_agent/ai_agent.py +++ /dev/null @@ -1,282 +0,0 @@ -import reflex as rx -import reflex as rx - - -# ============================================== -# 全局状态(已修复所有类型错误) -# ============================================== -class State(rx.State): - # 登录状态 - is_logged_in: bool = False - username: str = "" - password: str = "" - - # 当前页面 - current_page: str = "chat" - - # 会话列表 - sessions: list[dict] = [ - {"id": 1, "title": "会话 1"}, - {"id": 2, "title": "会话 2"}, - {"id": 3, "title": "会话 3"}, - ] - current_session_id: int | None = 1 - - # 登录弹窗控制 - show_login_modal: bool = True - - # ------------------------------ - # 显式定义 setter(修复报错) - # ------------------------------ - @rx.event - def set_username(self, value: str): - self.username = value - - @rx.event - def set_password(self, value: str): - self.password = value - - # ------------------------------ - # 登录 / 退出 - # ------------------------------ - @rx.event - def login(self): - if self.username.strip() and self.password.strip(): - self.is_logged_in = True - self.show_login_modal = False - - @rx.event - def logout(self): - self.is_logged_in = False - self.show_login_modal = True - self.username = "" - self.password = "" - - # ------------------------------ - # 页面切换 - # ------------------------------ - @rx.event - def set_page(self, page: str): - self.current_page = page - - # ------------------------------ - # 会话操作 - # ------------------------------ - @rx.event - def new_session(self): - new_id = max([s["id"] for s in self.sessions], default=0) + 1 - self.sessions.append({"id": new_id, "title": f"会话 {new_id}"}) - self.current_session_id = new_id - - @rx.event - def delete_session(self, session_id: int): - self.sessions = [s for s in self.sessions if s["id"] != session_id] - if self.current_session_id == session_id: - if self.sessions: - self.current_session_id = self.sessions[0]["id"] - else: - self.current_session_id = None - - @rx.event - def set_current_session_id(self, sid: int): - self.current_session_id = sid - - -# ============================================== -# 登录弹窗(已修复) -# ============================================== -def login_modal(): - return rx.cond( - State.show_login_modal, - rx.box( - rx.box( - rx.vstack( - rx.heading("登录"), - rx.input( - placeholder="用户名", - value=State.username, - on_change=State.set_username, - width="100%", - ), - rx.input( - placeholder="密码", - type="password", - value=State.password, - on_change=State.set_password, - width="100%", - ), - rx.button( - "登录", - on_click=State.login, - width="100%", - bg="blue", - color="white", - ), - spacing="4", - padding="6", - width="350px", - ), - bg="white", - padding="8", - border_radius="lg", - shadow="lg", - ), - position="fixed", - inset="0", - bg="rgba(0,0,0,0.5)", - display="flex", - align_items="center", - justify_content="center", - z_index="1000", - ), - ) - - -# ============================================== -# 左侧边栏 -# ============================================== -def sidebar(): - return rx.box( - rx.vstack( - rx.heading("AI 智能体", font_size="lg"), - rx.button( - "💬 聊天", - on_click=State.set_page("chat"), - width="100%", - bg=rx.cond(State.current_page == "chat", "blue", "gray"), - color="white", - ), - rx.button( - "📚 知识库", - on_click=State.set_page("knowledge"), - width="100%", - bg=rx.cond(State.current_page == "knowledge", "blue", "gray"), - color="white", - ), - rx.button( - "⚙️ 设置", - on_click=State.set_page("settings"), - width="100%", - bg=rx.cond(State.current_page == "settings", "blue", "gray"), - color="white", - ), - rx.spacer(), - rx.button( - "🚪 退出登录", - on_click=State.logout, - width="100%", - bg="red", - color="white", - ), - padding="4", - height="100vh", - spacing="3", - ), - width="220px", - bg="white", - shadow="md", - ) - - -# ============================================== -# 聊天页面 -# ============================================== -def chat_page(): - return rx.hstack( - rx.box( - rx.vstack( - rx.button( - "➕ 新建会话", - on_click=State.new_session, - width="100%", - bg="green", - color="white", - ), - rx.foreach( - State.sessions, - lambda s: rx.hstack( - rx.text( - s["title"], - flex=1, - cursor="pointer", - on_click=lambda: State.set_current_session_id(s["id"]), - ), - rx.icon( - "trash", - size=16, - color="red", - cursor="pointer", - on_click=lambda: State.delete_session(s["id"]), - ), - width="100%", - padding="2", - bg=rx.cond( - State.current_session_id == s["id"], - "gray.100", - "transparent", - ), - border_radius="sm", - ), - ), - spacing="2", - align_items="start", - width="100%", - ), - width="220px", - padding="4", - border_right="1px solid #e2e8f0", - ), - rx.box( - rx.heading("聊天内容"), - flex=1, - padding="4", - ), - flex=1, - ) - - -# ============================================== -# 内容区域 -# ============================================== -def content_area(): - return rx.box( - rx.cond( - State.current_page == "chat", - chat_page(), - rx.cond( - State.current_page == "knowledge", - rx.heading("知识库管理"), - rx.heading("系统设置"), - ), - ), - flex=1, - ) - - -# ============================================== -# 主页面 -# ============================================== -def index(): - return rx.box( - login_modal(), - rx.cond( - State.is_logged_in, - rx.hstack( - sidebar(), - content_area(), - ), - rx.center( - rx.heading("请登录使用系统"), - height="100vh", - ), - ), - min_height="100vh", - bg="gray.50", - ) - - -# ============================================== -# 启动 -# ============================================== -app = rx.App() -app.add_page(index) diff --git a/产品需求文档AI生成/application/__init__.py b/产品需求文档AI生成/application/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/产品需求文档AI生成/application/application.py b/产品需求文档AI生成/application/application.py new file mode 100644 index 0000000..45ce875 --- /dev/null +++ b/产品需求文档AI生成/application/application.py @@ -0,0 +1,28 @@ +"""The main Chat app.""" + +import reflex as rx +from application.components import chat, navbar + + +def index() -> rx.Component: + """The main app.""" + return rx.vstack( + navbar.navbar(), + chat.chat(), + chat.action_bar(), + background_color=rx.color("mauve", 1), + color=rx.color("mauve", 12), + height="100dvh", + align_items="stretch", + spacing="0", + ) + + +# Add state and page to the app. +app = rx.App( + theme=rx.theme( + appearance="dark", + accent_color="purple", + ), +) +app.add_page(index) diff --git a/产品需求文档AI生成/application/components/__init__.py b/产品需求文档AI生成/application/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/产品需求文档AI生成/application/components/chat.py b/产品需求文档AI生成/application/components/chat.py new file mode 100644 index 0000000..d56587f --- /dev/null +++ b/产品需求文档AI生成/application/components/chat.py @@ -0,0 +1,112 @@ +import reflex as rx + +from application.state import Turn, State +from reflex.constants.colors import ColorType + + +def message_content(text: str, color: ColorType) -> rx.Component: + """Create a message content component. + + Args: + text: The text to display. + color: The color of the message. + + Returns: + A component displaying the message. + """ + return rx.markdown( + text, + background_color=rx.color(color, 4), + color=rx.color(color, 12), + display="inline-block", + padding_inline="1em", + border_radius="8px", + ) + + +def message(turn: Turn) -> rx.Component: + """A single question/answer message. + + Args: + turn: The question/answer pair. + + Returns: + A component displaying the question/answer pair. + """ + return rx.box( + rx.box( + message_content(turn["question"], "mauve"), + text_align="right", + margin_bottom="8px", + ), + rx.box( + message_content(turn["answer"], "accent"), + text_align="left", + margin_bottom="8px", + ), + max_width="50em", + margin_inline="auto", + ) + + +def chat() -> rx.Component: + """List all the messages in a single conversation.""" + return rx.auto_scroll( + rx.foreach(State.selected_chat, message), + flex="1", + padding="8px", + ) + + +def action_bar() -> rx.Component: + """The action bar to send a new message.""" + return rx.center( + rx.vstack( + rx.form( + rx.hstack( + rx.input( + rx.input.slot( + rx.tooltip( + rx.icon("info", size=18), + content="Enter a question to get a response.", + ) + ), + placeholder="Type something...", + id="question", + flex="1", + ), + rx.button( + "Send", + loading=State.processing, + disabled=State.processing, + type="submit", + ), + max_width="50em", + margin="0 auto", + align_items="center", + ), + reset_on_submit=True, + on_submit=State.process_question, + ), + rx.text( + "ReflexGPT may return factually incorrect or misleading responses. Use discretion.", + text_align="center", + font_size=".75em", + color=rx.color("mauve", 10), + ), + rx.logo(margin_block="-1em"), + width="100%", + padding_x="16px", + align="stretch", + ), + position="sticky", + bottom="0", + left="0", + padding_y="16px", + backdrop_filter="auto", + backdrop_blur="lg", + border_top=f"1px solid {rx.color('mauve', 3)}", + background_color=rx.color("mauve", 2), + align="stretch", + width="100%", + ) diff --git a/产品需求文档AI生成/application/components/navbar.py b/产品需求文档AI生成/application/components/navbar.py new file mode 100644 index 0000000..ae6fa62 --- /dev/null +++ b/产品需求文档AI生成/application/components/navbar.py @@ -0,0 +1,115 @@ +import reflex as rx +from application.state import State + + +def sidebar_chat(chat: str) -> rx.Component: + """A sidebar chat item. + + Args: + chat: The chat item. + """ + return rx.drawer.close( + rx.hstack( + rx.button( + chat, + on_click=lambda: State.set_chat(chat), + width="80%", + variant="surface", + ), + rx.button( + rx.icon( + tag="trash", + on_click=State.delete_chat(chat), + stroke_width=1, + ), + width="20%", + variant="surface", + color_scheme="red", + ), + width="100%", + ), + key=chat, + ) + + +def sidebar(trigger) -> rx.Component: + """The sidebar component.""" + return rx.drawer.root( + rx.drawer.trigger(trigger), + rx.drawer.overlay(), + rx.drawer.portal( + rx.drawer.content( + rx.vstack( + rx.heading("Chats", color=rx.color("mauve", 11)), + rx.divider(), + rx.foreach(State.chat_titles, lambda chat: sidebar_chat(chat)), + align_items="stretch", + width="100%", + ), + top="auto", + right="auto", + height="100%", + width="20em", + padding="2em", + background_color=rx.color("mauve", 2), + outline="none", + ) + ), + direction="left", + ) + + +def modal(trigger) -> rx.Component: + """A modal to create a new chat.""" + return rx.dialog.root( + rx.dialog.trigger(trigger), + rx.dialog.content( + rx.form( + rx.hstack( + rx.input( + placeholder="Chat name", + name="new_chat_name", + flex="1", + min_width="20ch", + ), + rx.button("Create chat"), + spacing="2", + wrap="wrap", + width="100%", + ), + on_submit=State.create_chat, + ), + background_color=rx.color("mauve", 1), + ), + open=State.is_modal_open, + on_open_change=State.set_is_modal_open, + ) + + +def navbar(): + return rx.hstack( + rx.badge( + State.current_chat, + rx.tooltip( + rx.icon("info", size=14), + content="The current selected chat.", + ), + size="3", + variant="soft", + margin_inline_end="auto", + ), + modal( + rx.icon_button("message-square-plus"), + ), + sidebar( + rx.icon_button( + "messages-square", + background_color=rx.color("mauve", 6), + ) + ), + justify_content="space-between", + align_items="center", + padding="12px", + border_bottom=f"1px solid {rx.color('mauve', 3)}", + background_color=rx.color("mauve", 2), + ) diff --git a/产品需求文档AI生成/application/state.py b/产品需求文档AI生成/application/state.py new file mode 100644 index 0000000..af16aa8 --- /dev/null +++ b/产品需求文档AI生成/application/state.py @@ -0,0 +1,189 @@ +import os +from typing import Any, Optional, TypedDict +import reflex +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam +from sys import path +from pathlib import Path +from dataclasses import dataclass + +path.append(str(Path(__file__).resolve().parent.parent.parent)) +from utils.agent import Agent +from pydantic import BaseModel, Field + + +class Turn(BaseModel): + """一轮对话中用户输入消息和智能体输出消息""" + + input_message: str = Field(description="用户输入的消息") + output_message: str + + +class Session(BaseModel): + """会话数据模型""" + + uuid: str = Field(..., description="会话唯一标识") + agent: Optional[Agent] = Field(default=None, description="智能体实例") # 鼓励 + message_history: List[Turn] = Field(default=[], description="会话消息历史") + + +class State(reflex.State): + """实例状态管理""" + + agent: Agent | None = None + + def _init_agent(self) -> Agent: + """初始化智能体""" + if not self.agent: + self.agent = Agent( + instructions="You are a friendly chatbot named Reflex. Respond in markdown." + ) + return self.agent + + current_chat_name = "Intros" # 当前聊天名称 + + sessions: dict[str, list[Turn]] = { + current_chat_name: [], + } # 会话列表 + + # 智能体是否正在输出 + outputting: bool = False + + # 新会话弹窗是否打开 + is_modal_open: bool = False + + @reflex.event + def create_chat(self, form_data: dict[str, Any]): + """Create a new chat.""" + # Add the new chat to the list of chats. + new_chat_name = form_data["new_chat_name"] + self.current_chat = new_chat_name + self._chats[new_chat_name] = [] + self.is_modal_open = False + + @reflex.event + def set_is_modal_open(self, is_open: bool): + """Set the new chat modal open state. + + Args: + is_open: Whether the modal is open. + """ + self.is_modal_open = is_open + + @reflex.var + def selected_chat(self) -> list[TurnMessages]: + """Get the list of turns for the current chat. + + Returns: + The list of turns. + """ + return ( + self._chats[self.current_chat] if self.current_chat in self._chats else [] + ) + + @reflex.event + def delete_chat(self, chat_name: str): + """Delete the current chat.""" + if chat_name not in self._chats: + return + del self._chats[chat_name] + if len(self._chats) == 0: + self._chats = { + "Intros": [], + } + if self.current_chat not in self._chats: + self.current_chat = list(self._chats.keys())[0] + + @reflex.event + def set_chat(self, chat_name: str): + """Set the name of the current chat. + + Args: + chat_name: The name of the chat. + """ + self.current_chat = chat_name + + @reflex.event + def set_new_chat_name(self, new_chat_name: str): + """Set the name of the new chat. + + Args: + new_chat_name: The name of the new chat. + """ + self.new_chat_name = new_chat_name + + @reflex.var + def chat_titles(self) -> list[str]: + """Get the list of chat titles. + + Returns: + The list of chat names. + """ + return list(self._chats.keys()) + + @reflex.event + async def process_question(self, form_data: dict[str, Any]): + # Get the question from the form + question = form_data["question"] + + # Check if the question is empty + if not question: + return + + async for value in self.openai_process_question(question): + yield value + + @reflex.event + async def openai_process_question(self, question: str): + """Get the response from the API. + + Args: + form_data: A dict with the current question. + """ + + # Add the question to the list of questions. + qa = TurnMessages(input_message=question, output_message="") + self._chats[self.current_chat].append(qa) + + # Clear the input and start the processing. + self.processing = True + yield + + # Build the messages. + messages: list[ChatCompletionMessageParam] = [ + { + "role": "system", + "content": "You are a friendly chatbot named Reflex. Respond in markdown.", + } + ] + for qa in self._chats[self.current_chat]: + messages.append({"role": "user", "content": qa["input_message"]}) + messages.append({"role": "assistant", "content": qa["output_message"]}) + + # Remove the last mock answer. + messages = messages[:-1] + + # Start a new session to answer the question. + session = OpenAI().chat.completions.create( + model=os.getenv("OPENAI_MODEL", "gpt-3.5-turbo"), + messages=messages, + stream=True, + ) + + # Stream the results, yielding after every word. + for item in session: + if hasattr(item.choices[0].delta, "content"): + answer_text = item.choices[0].delta.content + # Ensure answer_text is not None before concatenation + if answer_text is not None: + self._chats[self.current_chat][-1]["answer"] += answer_text + else: + # Handle the case where answer_text is None, perhaps log it or assign a default value + # For example, assigning an empty string if answer_text is None + answer_text = "" + self._chats[self.current_chat][-1]["answer"] += answer_text + self._chats = self._chats + yield + + # Toggle the processing flag. + self.processing = False diff --git a/产品需求文档AI生成/rxconfig.py b/产品需求文档AI生成/rxconfig.py index 8aa1902..64e4ac0 100644 --- a/产品需求文档AI生成/rxconfig.py +++ b/产品需求文档AI生成/rxconfig.py @@ -1,5 +1,5 @@ import reflex as rx config = rx.Config( - app_name="ai_agent", + app_name="application", )