Python/产品需求文档AI生成/utils/memory.py

112 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
记忆模块
"""
# 列举导入模块
from pathlib import Path
import sys
import time
from typing import List
from uuid import uuid4
from pydantic_ai import ModelMessage
from pydantic_ai.messages import ModelMessagesTypeAdapter
sys.path.append(Path(__file__).parent.parent.parent.as_posix())
from utils.sqlite import SQLite
class Memory(SQLite):
"""
记忆体,支持:
create新增对话消息
read查询会话历史消息
"""
def __init__(self):
"""
初始化记忆体
"""
# 构建数据库路径
super().__init__(
database=Path(__file__).parent.resolve() / "memory_database.db"
)
try:
with self:
self.execute(
sql="""
CREATE TABLE IF NOT EXISTS messages
(
--唯一标识
id TEXT PRIMARY KEY,
--会话唯一标识
session_id TEXT NOT NULL,
--对话轮次
dialogue_round INTEGER NOT NULL,
--对话消息
dialogue_message TEXT NOT NULL,
--时间戳(毫秒)
timestamp INTEGER NOT NULL
)
"""
)
except Exception as exception:
raise RuntimeError(f"初始化记忆体发生异常:{str(exception)}") from exception
def create(
self, session_id: str, dialogue_round: int, dialogue_message: List[ModelMessage]
) -> bool:
"""
新增对话消息
:param session_id: 会话唯一标识
:param dialogue_round: 对话轮次
:param dialogue_message: 对话消息
:return: 新增是否成功
"""
try:
with self:
return self.execute(
sql="""
INSERT INTO messages (id, session_id, dialogue_round, dialogue_message, timestamp) VALUES (?, ?, ?, ?, ?)
""",
parameters=(
uuid4().hex.lower(),
session_id,
dialogue_round,
ModelMessagesTypeAdapter.dump_json(dialogue_message),
int(time.time() * 1000),
),
)
except Exception as exception:
raise RuntimeError(f"新增对话消息发生异常:{str(exception)}") from exception
def read(self, session_id: str) -> List[ModelMessage]:
"""
查询会话历史消息
:param session_id: 会话唯一标识
:return: 会话历史消息
"""
try:
with self:
result = self.query_all(
sql="""
SELECT dialogue_message
FROM messages
WHERE session_id = ?
ORDER BY timestamp ASC
""",
parameters=(session_id,),
)
message_history = []
for row in result:
message_history.extend(
ModelMessagesTypeAdapter.validate_json(row["dialogue_message"])
)
return message_history
except Exception as exception:
raise RuntimeError(
f"查询会话历史消息发生异常:{str(exception)}"
) from exception