Python/utils/request.py

670 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
# 导入模块
import hashlib
import hmac
import json
import threading
import time
from functools import wraps
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Literal, Optional, Tuple, Union
from xml.etree import ElementTree
from pydantic import BaseModel, Field, HttpUrl, model_validator
from requests import Response, Session
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from sqlite import SQLite
def restrict(refill_rate: float = 5.0, max_tokens: int = 5):
"""
请求限速装饰器
:param refill_rate: 令牌填充速率,单位为个/秒
:param max_tokens: 最大令牌数,单位为个
"""
class TokenBucket:
# noinspection PyShadowingNames
def __init__(self, max_tokens: int, refill_rate: float):
"""
初始化令牌桶限流
:param refill_rate: 令牌填充速率,单位为个/秒
:param max_tokens: 最大令牌数,单位为个
"""
# 初始化最大令牌数
self.max_tokens = max_tokens
# 初始化当前令牌数
self.tokens = self.max_tokens * 0.5
# 初始化令牌填充速率
self.refill_rate = refill_rate
# 初始化上一次填充令牌的时间戳(使用单调时间戳)
self.refill_timestamp = time.monotonic()
# 初始化线程锁(所有线程共用)
self.thread_lock = threading.Lock()
# 填充令牌
def _refill(self) -> None:
with self.thread_lock:
# 本次填充令牌的时间戳
refill_timestamp = time.monotonic()
# 重新计算令牌桶中令牌数
# noinspection PyTypeChecker
self.tokens = min(
self.max_tokens,
max(
0,
self.tokens
+ self.refill_rate * (refill_timestamp - self.refill_timestamp),
),
)
self.refill_timestamp = refill_timestamp
# 尝试消耗令牌
def consume(self) -> Tuple[bool, float]:
# 填充令牌
self._refill()
with self.thread_lock:
if self.tokens >= 1:
self.tokens -= 1
return True, 0
# 等待时长
# noinspection PyTypeChecker
wait_time = min(
1 / self.refill_rate,
max(
0,
1 / self.refill_rate
- (time.monotonic() - self.refill_timestamp),
),
)
return False, wait_time
# 初始化所有被装饰的函数创建令牌桶限流存储
buckets = {}
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
# 若当前被装饰的函数不在所有被装饰的函数创建令牌桶限流存储则为当前被装饰的函数实例化令牌桶限流
if func not in buckets:
# 初始化令牌桶限流
buckets[func] = TokenBucket(
refill_rate=refill_rate, max_tokens=max_tokens
)
bucket = buckets[func]
# 重试次数
retries = 0
while retries <= 10:
# 尝试消耗令牌
success, wait_time = bucket.consume()
# 若消耗令牌成功则返回被装饰的函数,否则等待
if success:
return func(*args, **kwargs)
time.sleep(wait_time * 2)
retries += 1
raise Exception("request too frequently")
return wrapper
return decorator
class Request:
"""请求客户端"""
class RequestException(Exception):
"""请求异常"""
def __init__(
self,
status: Optional[int] = 400,
code: int = 0,
message: str = "请求发生异常",
):
"""
:param status: 状态码
:param code: 错误码
:param message: 错误信息
"""
self.status = status
self.code = code
self.message = message
super().__init__(self.message)
class Parameters(BaseModel):
"""
请求参数模型,支持自动校验
"""
url: HttpUrl = Field(
default=..., description="统一资源定位符基于HttpUrl自动校验"
)
params: Optional[Dict[str, Any]] = Field(
default=None, description="统一资源定位符的查询参数"
)
headers: Optional[Dict[str, str]] = Field(default=None, description="请求头")
data: Optional[Dict[str, Any]] = Field(default=None, description="表单数据")
json_: Optional[Dict[str, Any]] = Field(
default=None, alias="json", description="JSON数据"
)
files: Optional[
Dict[
str,
Union[
Tuple[str, bytes],
Tuple[str, bytes, str],
Tuple[str, bytes, str, Dict[str, str]],
],
]
] = Field(
default=None,
description="上传文件,{字段名: (文件名, 字节数据, 内容类型, 请求头)}",
)
stream_enabled: Optional[bool] = Field(default=None, description="使用流式传输")
guid: Optional[str] = Field(default=None, description="缓存全局唯一标识")
@model_validator(mode="after")
def validate_data(self):
"""校验表单数据和JSON数据互斥"""
if self.data is not None and self.json_ is not None:
raise ValueError("表单数据和JSON数据不能同时使用")
return self
@model_validator(mode="after")
def validate_files(self):
if self.files is not None and self.stream_enabled:
raise ValueError("上传文件和使用流式传输不能同时使用")
return self
class Caches(SQLite):
"""请求缓存"""
def __init__(self, cache_ttl: int):
"""
初始化
:param cache_ttl: 缓存生存时间,单位为秒
"""
# 初始化
super().__init__(database=Path(__file__).parent.resolve() / "caches.db")
# 初始化缓存生存时间,单位为秒
self.cache_ttl = cache_ttl
# 初始化缓存表和时间戳索引(不清理过期缓存)
try:
with self:
self._execute(
sql="""
CREATE TABLE IF NOT EXISTS caches
(
--缓存唯一标识
guid TEXT PRIMARY KEY,
--缓存JSON序列化
cache TEXT NOT NULL,
--缓存时间
timestamp REAL NOT NULL
)
"""
)
self._execute(
sql="""
CREATE INDEX IF NOT EXISTS idx_timestamp ON caches(timestamp)
"""
)
except Exception as exception:
raise RuntimeError(
f"初始化缓存数据库发生异常:{str(exception)}"
) from exception
# noinspection PyShadowingNames
def query(self, guid: str) -> Optional[Dict[str, Any]]:
"""
查询并获取单条缓存
:param guid: 缓存唯一标识
:return: 缓存
"""
# noinspection PyBroadException
try:
with self:
result = self._query_one(
sql="""
SELECT cache
FROM caches
WHERE guid = ? AND timestamp >= ?
""",
parameters=(guid, time.time() - self.cache_ttl),
)
return (
None if result is None else json.loads(result["cache"])
) # 就缓存JSON反序列化
except Exception as exception:
raise RuntimeError("查询并获取单条缓存发生异常") from exception
# noinspection PyShadowingNames
def update(self, guid: str, cache: Dict) -> Optional[bool]:
"""
新增或更新缓存(若无则新增缓存,若有则更新缓存)
:param guid: 缓存唯一标识
:param cache: 缓存
:return: 成功返回True失败返回False
"""
# noinspection PyBroadException
try:
with self:
return self._execute(
sql="""
INSERT OR REPLACE INTO caches (guid, cache, timestamp) VALUES (?, ?, ?)
""",
parameters=(
guid,
json.dumps(cache, ensure_ascii=False),
time.time(),
),
)
except Exception as exception:
raise RuntimeError("新增或更新缓存发生异常") from exception
def __init__(
self,
default_headers: Optional[Dict[str, str]] = None,
total: int = 3,
backoff_factor: float = 0.5,
timeout: int = 60,
cache_enabled: bool = False,
cache_ttl: int = 360,
):
"""
:param default_headers: 默认请求头
:param total: 最大重试次数
:param backoff_factor: 重试间隔退避因子
:param timeout: 超时时间,单位为秒
:param cache_enabled: 使用缓存
:param cache_ttl: 缓存生存时间,单位为天
"""
# 创建请求会话并挂载适配器
self.session = self._create_session(
default_headers=default_headers, total=total, backoff_factor=backoff_factor
)
# 初始化超时时间
self.timeout = timeout
# 初始化使用缓存
self.cache_enabled = cache_enabled
# 初始化缓存生存时间,单位由天转为秒
self.cache_ttl = cache_ttl * 86400
self.caches: Optional[Request.Caches] = None
# 若使用缓存则实例化缓存
if self.cache_enabled:
self.caches = Request.Caches(cache_ttl=self.cache_ttl)
def __del__(self):
"""析构时关闭请求会话"""
if hasattr(self, "session") and self.session:
self.session.close()
@staticmethod
def _create_session(
total: int,
backoff_factor: float,
default_headers: Optional[Dict[str, str]] = None,
) -> Session:
"""
创建请求会话并挂载适配器
:param default_headers: 默认请求头
:param total: 最大重试次数
:param backoff_factor: 重试间隔退避因子
:return Session: 请求会话实例
"""
# 实例化请求会话
session = Session()
# 设置默认请求头
if default_headers:
session.headers.update(default_headers)
# 设置重试策略并挂载适配器
adapter = HTTPAdapter(
max_retries=Retry(
allowed_methods=["HEAD", "GET", "POST", "PUT", "DELETE", "PATCH"],
status_forcelist=[
408,
502,
503,
504,
], # 408为请求超时502为网关错误503为服务不可用504为网关超时
total=total,
respect_retry_after_header=True,
backoff_factor=backoff_factor,
)
)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
def get(
self, **kwargs
) -> Any:
"""发送GET请求"""
return self._request(method="GET", parameters=self.Parameters(**kwargs))
def post(
self, **kwargs
) -> Any:
"""发送POST请求"""
return self._request(method="POST", parameters=self.Parameters(**kwargs))
def download(
self, stream_enabled: bool = False, chunk_size: int = 1024, **kwargs
) -> Any:
"""
下载文件
:param stream_enabled: 使用流式传输
:param chunk_size: 流式传输的分块大小
"""
response = self._request(
method="GET",
parameters=self.Parameters(**{"stream_enabled": stream_enabled, **kwargs}),
)
# 若使用流式传输则处理流式传输响应
if stream_enabled:
return self._process_stream_response(
response=response, chunk_size=chunk_size
)
return response
def _request(
self, method: Literal["GET", "POST"], parameters: Parameters
) -> Union[
str, Tuple[str, bytes], Dict[str, Any], ElementTree.Element, Response, None
]:
"""请求"""
# 将请求参数模型转为请求参数字典
parameters = parameters.model_dump(exclude_none=True, by_alias=True)
# 将URL由HttpUrl对象转为字符串
parameters["url"] = str(parameters["url"])
# 过滤表单数据中None值
if parameters.get("data") is not None:
parameters["data"] = {
k: v for k, v in parameters["data"].items() if v is not None
}
# 过滤JSON数据中None值
if parameters.get("json") is not None:
parameters["json"] = {
k: v for k, v in parameters["json"].items() if v is not None
}
# 使用流式传输
stream_enabled = parameters.pop("stream_enabled", False)
# 缓存全局唯一标识
guid = parameters.pop("guid", None)
# 若使用缓存且缓存全局唯一标识非空则查询并获取单条缓存
if self.cache_enabled and guid is not None:
cache = self.cache_client.query(guid)
if cache is not None:
return cache
# 发送请求并处理响应
# noinspection PyBroadException
try:
response = self.session.request(
method=method, timeout=self.timeout, **parameters
)
response.raise_for_status() # 若返回非2??状态码则抛出异常
# 若使用流式传输则直接返回(不缓存)
if stream_enabled:
return response
# 处理响应
response = self._process_response(response=response)
# 若使用缓存且缓存全局唯一标识非空则新增或更新缓存
if self.cache_enabled and guid is not None:
self.cache_client.update(guid, response)
return response
except Exception as exception:
# noinspection PyBroadException
try:
response = getattr(exception, "response", None)
status = (
response.json().get("status", response.status_code)
if response is not None
else None
)
message = (
response.json().get("message", response.text)
if response is not None
else str(exception).splitlines()[0]
)
except Exception:
status = None
message = f"{method} {parameters["url"]} 请求发生异常:{str(exception).splitlines()[0]}"
return self.RequestException(status=status, message=message).__dict__
# 处理响应
@staticmethod
def _process_response(
response: Response,
) -> Union[str, Tuple[str, bytes], Dict[str, Any], ElementTree.Element, None]:
# 若响应内容为空则返回None
content = response.content
if not content:
return None
# 响应类型
_type = response.headers.get("Content-Type", "").split(";")[0].strip().lower()
# 根据响应类型匹配响应内容解析方法并返回
# noinspection PyUnreachableCode
match _type:
# JSONJSON反序列化
case "application/json" | "text/json":
return response.json()
# XML解析为XML对象Element实例
case "application/xml" | "text/xml":
return ElementTree.fromstring(content)
# 以image/开头:返回影像件格式和响应内容
case _ if _type.startswith("image/"):
# 影像件格式
image_format = _type.split(sep="/", maxsplit=1)[1]
return image_format, content
# 其它的响应类型先UTF8解码再返回若解码发生异常则直接返回
case _:
try:
return content.decode("utf-8")
except UnicodeDecodeError:
return content
# 处理流式传输响应
@staticmethod
def _process_stream_response(
response: Response, chunk_size: int
) -> Generator[bytes, None, None]:
"""
处理流式响应
:param response: requests.Response对象
:param chunk_size: 分块大小
:return: 字节数据块生成器
"""
if not isinstance(chunk_size, int):
raise ValueError("分块大小数据类型必须为整数")
if chunk_size <= 0:
raise ValueError("分块大小必须大于0")
try:
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk:
yield chunk
finally:
response.close()
class Authenticator:
def __init__(
self,
):
"""认证器(用于获取访问令牌)"""
# 初始化
self._initialize()
def _initialize(self):
"""初始化"""
# 初始化访问凭证地址对象
self.certifications_path = (
Path(__file__).parent.resolve() / "certifications.json"
)
# 若访问凭证地址对象不存在则创建
if not self.certifications_path.exists():
with open(self.certifications_path, "w", encoding="utf-8") as file:
json.dump(
{},
file,
ensure_ascii=False,
)
# 初始化请求客户端
self.http_client = HTTPClient()
def _szkt_get_certification(self) -> tuple[str, float]:
"""获取深圳快瞳访问凭证"""
response = self.http_client.get(
url="https://ai.inspirvision.cn/s/api/getAccessToken?accessKey=APPID_6Gf78H59D3O2Q81u&accessSecret=947b8829d4d5d55890b304d322ac2d0d"
)
# 若非响应成功则抛出异常
if not (response["status"] == 200 and response["code"] == 0):
raise RuntimeError("获取深圳快瞳访问凭证发生异常")
# 返回访问令牌、失效时间戳
# noinspection PyTypeChecker
return (
response["data"]["access_token"],
time.time() + response["data"]["expires_in"],
)
def _hlyj_get_certification(self) -> Tuple[str, float]:
"""获取合力亿捷访问凭证"""
# 企业访问标识
access_key_id = "25938f1c190448829dbdb5d344231e42"
# 签名秘钥
secret_access_key = "44dc0299aff84d68ae27712f8784f173"
# 时间戳(秒级)
timestamp = int(time.time())
# 签名企业访问标识、签名秘钥和时间戳拼接后计算的十六进制的HMAC-SHA256
signature = hmac.new(
secret_access_key.encode("utf-8"),
f"{access_key_id}{secret_access_key}{timestamp}".encode("utf-8"),
hashlib.sha256,
).hexdigest()
response = self.http_client.get(
url=f"https://kms.7x24cc.com/api/v1/corp/auth/token?access_key_id={access_key_id}&timestamp={timestamp}&signature={signature}"
)
# 若非响应成功则抛出异常
if not response["success"]:
raise RuntimeError("获取合力亿捷访问凭证发生异常")
# 返回访问令牌、失效时间戳
# noinspection PyTypeChecker
return (
response["data"],
time.time() + 1 * 60 * 60, # 访问令牌有效期为1小时
)
def _feishu_get_certification(self) -> tuple[str, float]:
"""获取飞书访问凭证"""
response = self.http_client.post(
url="https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal",
data={
"app_id": "cli_a1587980be78500c",
"app_secret": "vZXGZomwfmyaHXoG8s810d1YYGLsIqCA",
},
)
# 若非响应成功则抛出异常
if not response["code"] == 0:
raise RuntimeError("获取飞书访问凭证发生异常")
# 返回访问令牌、失效时间戳
# noinspection PyTypeChecker
return (
response["tenant_access_token"],
time.time() + response["expire"],
)
def get_token(self, servicer: str) -> str | None:
"""
获取访问令牌
:param servicer: 服务商,暂仅支持深圳快瞳、合力亿捷和飞书
:return token: 访问令牌
"""
with threading.Lock():
# 初始化访问令牌和失效时间戳
token, expired_timestamp = None, 0
try:
with open(self.certifications_path, "r", encoding="utf-8") as file:
# 本地打开并读取所有服务商的访问凭证
certifications = json.load(file)
# 获取指定服务商的访问凭证
certification = certifications.get(servicer, None)
# 若指定服务商的访问凭证非空则解析访问令牌和失效时间戳
if certification is not None:
# 访问令牌
token = certification["token"]
# 失效时间戳
expired_timestamp = certification["expired_timestamp"]
# 若反序列化发生异常则重置访问凭证储存文件
except json.decoder.JSONDecodeError:
with open(self.certifications_path, "w", encoding="utf-8") as file:
json.dump(
{},
file,
ensure_ascii=False,
)
except Exception:
raise RuntimeError("获取访问令牌发生异常")
if time.time() > expired_timestamp:
# noinspection PyUnreachableCode
match servicer:
case "szkt":
token, expired_timestamp = self._szkt_get_certification()
case "hlyj":
token, expired_timestamp = self._hlyj_get_certification()
case "feishu":
token, expired_timestamp = self._feishu_get_certification()
case _:
raise RuntimeError(f"未设置服务商:{servicer}获取访问凭证方法")
# 更新服务商访问凭证
certifications[servicer] = {
"token": token,
"expired_timestamp": expired_timestamp,
}
# 将所有服务商访问凭证保存至本地文件
with open(self.certifications_path, "w", encoding="utf-8") as file:
json.dump(
certifications,
file,
ensure_ascii=False,
)
return token