diff --git a/utils/client.py b/utils/client.py index 96a5ccc..e193b20 100644 --- a/utils/client.py +++ b/utils/client.py @@ -92,12 +92,15 @@ class CacheClient: def __init__(self, cache_ttl: int = 360, database: str = "Caches.db"): """ - :param cache_ttl: 缓存生存时间(单位:天),默认为360天 + :param cache_ttl: 缓存生存时间,单位为天 :param database: 缓存数据库名称 """ - try: - self.cache_ttl = cache_ttl + # 初始化缓存数据库连接 + self.connection: Optional[sqlite3.Connection] = None + # 初始化缓存生存时间,单位为天 + self.cache_ttl = cache_ttl + try: # 创建缓存数据库连接 self.connection = sqlite3.connect( database=database, @@ -105,149 +108,93 @@ class CacheClient: timeout=30, # 缓存数据库锁超时时间(单位:秒),默认为30秒,避免并发锁死 ) - # 创建缓存数据库连接(使用SQLite) - self.cache_connection = sqlite3.connect( - database="SQLite.db", check_same_thread=False - ) - - # 创建缓存表 - self.connection.execute( - """CREATE TABLE IF NOT EXISTS caches ( - guid TEXT PRIMARY KEY, - scene TEXT, - cache TEXT NOT NULL, - timestamp REAL NOT NULL - )""" - ) - - # 创建时间戳索引(优化过期缓存查询效率) - self.connection.execute( - """CREATE INDEX IF NOT EXISTS index_timestamp ON caches(timestamp)""" - ) - - # 删除过期缓存 - self.connection.execute( - "DELETE FROM caches WHERE timestamp < ?", - (time.time() - self.cache_ttl * 86400,), - ) - - # 提交事务 - self.connection.commit() + # 创建缓存表和索引、清理过期缓存 + with self.connection: + self.connection.execute( + """CREATE TABLE IF NOT EXISTS caches ( + guid TEXT PRIMARY KEY, + cache TEXT NOT NULL, + timestamp REAL NOT NULL + )""" + ) + self.connection.execute( + """CREATE INDEX IF NOT EXISTS idx_timestamp ON caches(timestamp)""" + ) + self.connection.execute( + "DELETE FROM caches WHERE timestamp < ?", + (time.time() - self.cache_ttl * 86400,), + ) except Exception as exception: - if self.connection: + self._disconnect() + raise f"初始缓存数据库失败:{str(exception)}" from exception + + def _disconnect(self) -> None: + """关闭缓存数据库连接""" + if self.connection: + # noinspection PyBroadException + try: self.connection.close() - raise f"{str(exception)}" from exception + except Exception: + pass - def _query_response(self, guid: str) -> Optional[Dict]: - """ - 私有方法:根据guid查询有效缓存记录(未过期) - :param guid: 记录唯一标识 - :return: 未过期的响应数据(Dict),不存在/过期/异常时返回None - """ - if not self.cache_connection: - logger.error("查询失败:缓存数据库未连接") - return None + def __enter__(self) -> "CacheClient": + """实现上下文管理""" + return self + def __exit__(self, exc_type, exc_val, exc_tb): + """退出时关闭连接""" + self._disconnect() + return False + + def query(self, guid: str) -> Optional[Dict]: + """ + 查询缓存 + :param guid: 缓存唯一标识 + :return: 缓存 + """ with threading.Lock(): # 线程锁,保证并发安全 - cursor = None - try: - cursor = self.cache_connection.cursor() - # 查询条件:guid匹配 + 未过期 - expire_time = time.time() - self.cache_ttl * 86400 - cursor.execute( - "SELECT response FROM caches WHERE guid = ? AND timestamp >= ?", - (guid, expire_time), - ) - result = cursor.fetchone() # 获取单条记录(guid唯一) - if result: - logger.info(f"查询缓存成功:guid={guid}") - return json.loads(result[0]) # JSON字符串转Dict - logger.info(f"未查询到有效缓存:guid={guid}(不存在或已过期)") - return None - except json.JSONDecodeError as e: - logger.error( - f"缓存数据解析失败(JSON格式错误):guid={guid}", exc_info=True - ) - return None - except Exception as e: - logger.error(f"查询缓存异常:guid={guid}", exc_info=True) - self.cache_connection.rollback() # 异常回滚事务 - return None - finally: - if cursor: - cursor.close() # 确保游标关闭,释放资源 + with self.connection.cursor() as cursor: + # noinspection PyBroadException + try: + # 根据缓存唯一标识查询有效缓存 + cursor.execute( + "SELECT cache FROM caches WHERE guid = ? AND timestamp >= ?", + (guid, time.time() - self.cache_ttl * 86400), + ) + if result := cursor.fetchone(): + return json.loads(result[0]) + return None + except Exception: + self.connection.rollback() + return None - def _save_response(self, guid: str, response: Dict) -> bool: + def update(self, guid: str, cache: Dict) -> bool: """ - 私有方法:添加/更新缓存记录(存在则覆盖,不存在则新增) - :param guid: 记录唯一标识 - :param response: 待保存的响应数据(Dict) - :return: 保存成功返回True,失败返回False + 更新缓存(存在则覆盖,不存在则新增) + :param guid: 缓存唯一标识 + :param cache: 缓存 + :return: 成功返回True,失败返回False """ - if not self.cache_connection: - logger.error("保存失败:缓存数据库未连接") - return False - with threading.Lock(): # 线程锁,保证并发安全 - cursor = None - try: - cursor = self.cache_connection.cursor() - # 转换Dict为JSON字符串(ensure_ascii=False支持中文) - response_str = json.dumps(response, ensure_ascii=False, indent=2) - # INSERT OR REPLACE:存在则更新,不存在则插入 - cursor.execute( - "INSERT OR REPLACE INTO caches (guid, response, timestamp) VALUES (?, ?, ?)", - (guid, response_str, time.time()), # timestamp存储当前时间戳 - ) - self.cache_connection.commit() # 提交事务 - logger.info(f"保存缓存成功:guid={guid}") - return True - except json.JSONEncoderError as e: - logger.error( - f"缓存数据序列化失败(Dict转JSON错误):guid={guid}", exc_info=True - ) - self.cache_connection.rollback() - return False - except Exception as e: - logger.error(f"保存缓存异常:guid={guid}", exc_info=True) - self.cache_connection.rollback() # 异常回滚事务 - return False - finally: - if cursor: - cursor.close() # 确保游标关闭 - - def query_or_save_response( - self, guid: str, response: Optional[Dict] = None - ) -> Optional[Dict]: - """ - 二合一公开方法:支持查询记录 / 添加/更新记录(灵活复用) - :param guid: 记录唯一标识(必填) - :param response: 待保存的响应数据(可选): - - 不传:仅查询有效记录,返回Dict/None - - 传入:添加/更新记录,返回保存后的有效记录/Dict - :return: 查询到的记录 / 保存后的记录 / 失败时返回None - """ - # 参数校验:guid不能为空 - if not guid or not isinstance(guid, str): - logger.error("guid无效:必须是非空字符串") - return None - - # 仅查询模式(未传入response) - if response is None: - return self._query_response(guid) - - # 添加/更新模式(传入response):先保存,再查询返回最新记录 - if self._save_response(guid, response): - return self._query_response(guid) - logger.error(f"保存缓存失败,无法返回记录:guid={guid}") - return None - - def close(self): - """关闭数据库连接(程序退出时调用)""" - if self.cache_connection: - self.cache_connection.close() - self.cache_connection = None + with self.connection.cursor() as cursor: + # noinspection PyBroadException + try: + # 新增或覆盖缓存 + cursor.execute( + "INSERT OR REPLACE INTO caches (guid, cache, timestamp) VALUES (?, ?, ?)", + ( + guid, + json.dumps(cache, ensure_ascii=False), + time.time(), + ), + ) + # 提交事务 + self.connection.commit() + return True + except Exception: + self.connection.rollback() # 异常回滚事务 + return False """ @@ -859,15 +806,16 @@ class Authenticator: if time.time() > expired_timestamp: # noinspection PyUnreachableCode match servicer: + # 获取深圳快瞳访问凭证 case "szkt": - # 获取深圳快瞳访问凭证 token, expired_timestamp = self._szkt_get_certification() case "feishu": token, expired_timestamp = self._feishu_get_certification() + # 获取合力亿捷访问凭证 case "hlyj": token, expired_timestamp = self._hlyj_get_certification() case _: - raise RuntimeError(f"服务商({servicer})未设置获取访问令牌方法") + raise RuntimeError(f"未设置服务商:({servicer})") # 更新服务商访问凭证 certifications[servicer] = {