112 lines
3.6 KiB
Python
112 lines
3.6 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
记忆模块
|
||
"""
|
||
|
||
# 列举导入模块
|
||
from pathlib import Path
|
||
import sys
|
||
import time
|
||
from typing import List
|
||
from uuid import uuid4
|
||
|
||
from pydantic_ai import ModelMessage
|
||
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
||
|
||
sys.path.append(Path(__file__).parent.parent.parent.as_posix())
|
||
from utils.sqlite import SQLite
|
||
|
||
|
||
class Memory(SQLite):
|
||
"""
|
||
记忆体,支持:
|
||
create:新增对话消息
|
||
read:查询会话历史消息
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""
|
||
初始化记忆体
|
||
"""
|
||
# 构建数据库路径
|
||
super().__init__(
|
||
database=Path(__file__).parent.resolve() / "memory_database.db"
|
||
)
|
||
|
||
try:
|
||
with self:
|
||
self.execute(
|
||
sql="""
|
||
CREATE TABLE IF NOT EXISTS messages
|
||
(
|
||
--唯一标识
|
||
id TEXT PRIMARY KEY,
|
||
--会话唯一标识
|
||
session_id TEXT NOT NULL,
|
||
--对话轮次
|
||
dialogue_round INTEGER NOT NULL,
|
||
--对话消息
|
||
dialogue_message TEXT NOT NULL,
|
||
--时间戳(毫秒)
|
||
timestamp INTEGER NOT NULL
|
||
)
|
||
"""
|
||
)
|
||
except Exception as exception:
|
||
raise RuntimeError(f"初始化记忆体发生异常:{str(exception)}") from exception
|
||
|
||
def create(
|
||
self, session_id: str, dialogue_round: int, dialogue_message: List[ModelMessage]
|
||
) -> bool:
|
||
"""
|
||
新增对话消息
|
||
:param session_id: 会话唯一标识
|
||
:param dialogue_round: 对话轮次
|
||
:param dialogue_message: 对话消息
|
||
:return: 新增是否成功
|
||
"""
|
||
try:
|
||
with self:
|
||
return self.execute(
|
||
sql="""
|
||
INSERT INTO messages (id, session_id, dialogue_round, dialogue_message, timestamp) VALUES (?, ?, ?, ?, ?)
|
||
""",
|
||
parameters=(
|
||
uuid4().hex.lower(),
|
||
session_id,
|
||
dialogue_round,
|
||
ModelMessagesTypeAdapter.dump_json(dialogue_message),
|
||
int(time.time() * 1000),
|
||
),
|
||
)
|
||
except Exception as exception:
|
||
raise RuntimeError(f"新增对话消息发生异常:{str(exception)}") from exception
|
||
|
||
def read(self, session_id: str) -> List[ModelMessage]:
|
||
"""
|
||
查询会话历史消息
|
||
:param session_id: 会话唯一标识
|
||
:return: 会话历史消息
|
||
"""
|
||
try:
|
||
with self:
|
||
result = self.query_all(
|
||
sql="""
|
||
SELECT dialogue_message
|
||
FROM messages
|
||
WHERE session_id = ?
|
||
ORDER BY timestamp ASC
|
||
""",
|
||
parameters=(session_id,),
|
||
)
|
||
message_history = []
|
||
for row in result:
|
||
message_history.extend(
|
||
ModelMessagesTypeAdapter.validate_json(row["dialogue_message"])
|
||
)
|
||
return message_history
|
||
except Exception as exception:
|
||
raise RuntimeError(
|
||
f"查询会话历史消息发生异常:{str(exception)}"
|
||
) from exception
|