403 lines
13 KiB
Python
403 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
请求客户端模块
|
||
"""
|
||
|
||
import json
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Generator, Literal, Optional, Tuple, Union
|
||
from xml.etree import ElementTree
|
||
|
||
from pydantic import BaseModel, Field, HttpUrl, model_validator
|
||
from requests import Response, Session
|
||
from requests.adapters import HTTPAdapter
|
||
from urllib3.util.retry import Retry
|
||
|
||
sys.path.append(Path(__file__).parent.as_posix())
|
||
from sqlite import SQLite
|
||
|
||
|
||
class Parameters(BaseModel):
|
||
"""
|
||
请求参数模型
|
||
"""
|
||
|
||
url: HttpUrl = Field(default=..., description="统一资源定位符,基于HttpUrl自动校验")
|
||
params: Optional[Dict[str, Any]] = Field(
|
||
default=None, description="统一资源定位符的查询参数"
|
||
)
|
||
headers: Optional[Dict[str, str]] = Field(default=None, description="请求头")
|
||
data: Optional[Dict[str, Any]] = Field(default=None, description="表单数据")
|
||
json_: Optional[Dict[str, Any]] = Field(
|
||
default=None, alias="json", description="JSON数据"
|
||
)
|
||
files: Optional[
|
||
Dict[
|
||
str,
|
||
Union[
|
||
Tuple[str, bytes],
|
||
Tuple[str, bytes, str],
|
||
Tuple[str, bytes, str, Dict[str, str]],
|
||
],
|
||
]
|
||
] = Field(
|
||
default=None,
|
||
description="上传文件,{字段名: (文件名, 字节数据, 内容类型, 请求头)}",
|
||
)
|
||
stream_enabled: Optional[bool] = Field(default=None, description="使用流式传输")
|
||
guid: Optional[str] = Field(default=None, description="缓存全局唯一标识")
|
||
|
||
@model_validator(mode="after")
|
||
def validate_data(self):
|
||
"""校验:表单数据和JSON数据互斥"""
|
||
if self.data is not None and self.json_ is not None:
|
||
raise ValueError("表单数据和JSON数据不能同时使用")
|
||
return self
|
||
|
||
@model_validator(mode="after")
|
||
def validate_files(self):
|
||
if self.files is not None and self.stream_enabled:
|
||
raise ValueError("上传文件和使用流式传输不能同时使用")
|
||
return self
|
||
|
||
|
||
class RequestException(Exception):
|
||
"""请求异常"""
|
||
|
||
def __init__(
|
||
self,
|
||
status: Optional[int] = 400,
|
||
code: int = 0,
|
||
message: str = "请求发生异常",
|
||
):
|
||
"""
|
||
:param status: 状态码
|
||
:param code: 错误码
|
||
:param message: 错误信息
|
||
"""
|
||
self.status = status
|
||
self.code = code
|
||
self.message = message
|
||
super().__init__(self.message)
|
||
|
||
|
||
class Caches(SQLite):
|
||
"""
|
||
缓存,支持:
|
||
query:查询并返回单条缓存
|
||
update:新增或更新单条缓存
|
||
"""
|
||
|
||
def __init__(self, cache_ttl: int):
|
||
"""
|
||
初始化
|
||
:param cache_ttl: 缓存生存时间,单位为秒
|
||
"""
|
||
# 初始化
|
||
super().__init__(database=Path(__file__).parent.resolve() / "caches.db")
|
||
self.cache_ttl = cache_ttl
|
||
|
||
# 初始化缓存表(不清理过期缓存)
|
||
try:
|
||
with self:
|
||
self.execute(
|
||
sql="""
|
||
CREATE TABLE IF NOT EXISTS caches
|
||
(
|
||
--缓存唯一标识
|
||
guid TEXT PRIMARY KEY,
|
||
--缓存(JSON序列化)
|
||
cache TEXT NOT NULL,
|
||
--缓存时间
|
||
timestamp REAL NOT NULL
|
||
)
|
||
"""
|
||
)
|
||
self.execute(
|
||
sql="""
|
||
CREATE INDEX IF NOT EXISTS idx_timestamp ON caches(timestamp)
|
||
"""
|
||
)
|
||
except Exception as exception:
|
||
raise RuntimeError(f"初始化缓存表发生异常:{str(exception)}") from exception
|
||
|
||
def query(self, guid: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
查询并返回单条缓存
|
||
:param guid: 缓存唯一标识
|
||
:return: 缓存
|
||
"""
|
||
try:
|
||
with self:
|
||
result = self.query_one(
|
||
sql="""
|
||
SELECT cache
|
||
FROM caches
|
||
WHERE guid = ? AND timestamp >= ?
|
||
""",
|
||
parameters=(guid, time.time() - self.cache_ttl),
|
||
)
|
||
return None if result is None else json.loads(result["cache"])
|
||
except Exception as exception:
|
||
raise RuntimeError(
|
||
f"查询并获取单条缓存发生异常:{str(exception)}"
|
||
) from exception
|
||
|
||
def update(self, guid: str, cache: Dict) -> Optional[bool]:
|
||
"""
|
||
新增或更新单条缓存(若无则新增缓存,若有则更新缓存)
|
||
:param guid: 缓存唯一标识
|
||
:param cache: 缓存
|
||
:return: 成功返回True,失败返回False
|
||
"""
|
||
try:
|
||
with self:
|
||
return self.execute(
|
||
sql="""
|
||
INSERT OR REPLACE INTO caches (guid, cache, timestamp) VALUES (?, ?, ?)
|
||
""",
|
||
parameters=(
|
||
guid,
|
||
json.dumps(cache, ensure_ascii=False),
|
||
time.time(),
|
||
),
|
||
)
|
||
except Exception as exception:
|
||
raise RuntimeError("新增或更新缓存发生异常") from exception
|
||
|
||
|
||
class Request:
|
||
"""
|
||
请求客户端,支持:
|
||
get:GET请求
|
||
post:POST请求
|
||
download:下载
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
default_headers: Optional[Dict[str, str]] = None,
|
||
total: int = 3,
|
||
backoff_factor: float = 0.5,
|
||
timeout: int = 60,
|
||
cache_enabled: bool = False,
|
||
cache_ttl: int = 360,
|
||
):
|
||
"""
|
||
:param default_headers: 默认请求头
|
||
:param total: 最大重试次数,默认 3
|
||
:param backoff_factor: 重试间隔退避因子,默认 0.5
|
||
:param timeout: 超时时间(单位为秒),默认为 60
|
||
:param cache_enabled: 使用缓存,默认 False
|
||
:param cache_ttl: 缓存生存时间(单位为天),默认为 360
|
||
"""
|
||
# 创建请求会话并挂载适配器
|
||
self.session = self._create_session(
|
||
default_headers=default_headers, total=total, backoff_factor=backoff_factor
|
||
)
|
||
# 初始化超时时间
|
||
self.timeout = timeout
|
||
# 实例化缓存
|
||
self.caches = Caches(cache_ttl=cache_ttl * 86400) if cache_enabled else None
|
||
|
||
@staticmethod
|
||
def _create_session(
|
||
total: int,
|
||
backoff_factor: float,
|
||
default_headers: Optional[Dict[str, str]] = None,
|
||
) -> Session:
|
||
"""
|
||
创建请求会话并挂载适配器
|
||
:param default_headers: 默认请求头
|
||
:param total: 最大重试次数
|
||
:param backoff_factor: 重试间隔退避因子
|
||
:return Session: 请求会话实例
|
||
"""
|
||
# 实例化请求会话
|
||
session = Session()
|
||
|
||
# 设置默认请求头
|
||
if default_headers:
|
||
session.headers.update(default_headers)
|
||
|
||
# 设置重试策略并挂载适配器
|
||
adapter = HTTPAdapter(
|
||
max_retries=Retry(
|
||
allowed_methods=["HEAD", "GET", "POST", "PUT", "DELETE", "PATCH"],
|
||
status_forcelist=[
|
||
408,
|
||
502,
|
||
503,
|
||
504,
|
||
], # 408为请求超时,502为网关错误,503为服务不可用,504为网关超时
|
||
total=total,
|
||
respect_retry_after_header=True,
|
||
backoff_factor=backoff_factor,
|
||
)
|
||
)
|
||
session.mount("http://", adapter)
|
||
session.mount("https://", adapter)
|
||
|
||
return session
|
||
|
||
def get(self, **kwargs) -> Any:
|
||
"""
|
||
GET请求
|
||
:param kwargs: 请求参数
|
||
:return: 响应内容
|
||
"""
|
||
return self._request(method="GET", parameters=Parameters(**kwargs))
|
||
|
||
def post(self, **kwargs) -> Any:
|
||
"""
|
||
POST请求
|
||
:param kwargs: 请求参数
|
||
:return: 响应内容
|
||
"""
|
||
return self._request(method="POST", parameters=Parameters(**kwargs))
|
||
|
||
def download(
|
||
self, stream_enabled: bool = False, chunk_size: int = 1024, **kwargs
|
||
) -> Any:
|
||
"""
|
||
下载
|
||
:param stream_enabled: 使用流式传输
|
||
:param chunk_size: 流式传输的分块大小
|
||
:param kwargs: 请求参数
|
||
:return: 响应内容
|
||
"""
|
||
response = self._request(
|
||
method="GET",
|
||
parameters=Parameters(**{"stream_enabled": stream_enabled, **kwargs}),
|
||
)
|
||
# 若使用流式传输则处理流式传输响应
|
||
if stream_enabled:
|
||
return self._process_stream_response(
|
||
response=response, chunk_size=chunk_size
|
||
)
|
||
return response
|
||
|
||
def _request(self, method: Literal["GET", "POST"], parameters: Parameters) -> Any:
|
||
"""
|
||
请求
|
||
:param method: 请求方法
|
||
:param parameters: 请求参数模型
|
||
:return: 响应内容
|
||
"""
|
||
# 将请求参数模型转为请求参数字典
|
||
kwargs = parameters.model_dump(exclude_none=True, by_alias=True)
|
||
|
||
# 将统一资源定位符转为字符串
|
||
url = str(kwargs.pop("url"))
|
||
|
||
# 过滤表单数据中空值
|
||
if kwargs.get("data"):
|
||
kwargs["data"] = {k: v for k, v in kwargs["data"].items() if v}
|
||
|
||
# 过滤JSON数据中空值
|
||
if kwargs.get("json"):
|
||
kwargs["json"] = {k: v for k, v in kwargs["json"].items() if v}
|
||
|
||
# 使用流式传输
|
||
stream_enabled = kwargs.pop("stream_enabled", False)
|
||
|
||
# 缓存全局唯一标识
|
||
guid = kwargs.pop("guid", None)
|
||
# 若缓存非空且缓存全局唯一标识非空则查询并获取单条缓存
|
||
if self.caches and guid:
|
||
cache = self.caches.query(guid)
|
||
if cache:
|
||
return cache
|
||
|
||
# 发送请求并处理响应
|
||
try:
|
||
response = self.session.request(
|
||
method=method, url=url, timeout=self.timeout, **kwargs
|
||
)
|
||
response.raise_for_status() # 若返回非2??状态码则抛出异常
|
||
|
||
# 若使用流式传输则直接返回响应对象(不缓存)
|
||
if stream_enabled:
|
||
return response
|
||
|
||
# 处理响应对象
|
||
response = self._process_response(response=response)
|
||
# 若使用缓存且缓存全局唯一标识非空则新增或更新缓存
|
||
if self.caches and guid:
|
||
self.caches.update(guid, response)
|
||
|
||
return response
|
||
|
||
# 重构异常信息
|
||
except Exception as exception:
|
||
try:
|
||
response = getattr(exception, "response", None)
|
||
status = (
|
||
response.json().get("status", response.status_code)
|
||
if response
|
||
else None
|
||
)
|
||
message = (
|
||
response.json().get("message", response.text)
|
||
if response
|
||
else str(exception).splitlines()[0]
|
||
)
|
||
except Exception:
|
||
status = None
|
||
message = f"{method} {kwargs["url"]} 请求发生异常:{str(exception).splitlines()[0]}"
|
||
return RequestException(status=status, message=message).__dict__
|
||
|
||
@staticmethod
|
||
def _process_response(
|
||
response: Response,
|
||
) -> Any:
|
||
"""
|
||
处理响应对象
|
||
:param response: 响应对象
|
||
:return: 响应内容
|
||
"""
|
||
content = response.content
|
||
if not content:
|
||
return None
|
||
|
||
# 响应类型
|
||
response_type = (
|
||
response.headers.get("Content-Type", "").split(";")[0].strip().lower()
|
||
)
|
||
# 根据响应类型匹配响应内容解析方法并返回
|
||
match response_type:
|
||
# 若为JSON则反序列化
|
||
case "application/json" | "text/json":
|
||
return response.json()
|
||
# 若为XML解析为Element对象
|
||
case "application/xml" | "text/xml":
|
||
return ElementTree.fromstring(content)
|
||
# 若为影像件格式则返回影像件格式和响应内容
|
||
case _ if response_type.startswith("image/"):
|
||
return response_type.split(sep="/", maxsplit=1)[1], content
|
||
case _:
|
||
try:
|
||
return content.decode("utf-8")
|
||
except UnicodeDecodeError:
|
||
return content
|
||
|
||
# 处理流式传输响应
|
||
@staticmethod
|
||
def _process_stream_response(
|
||
response: Response, chunk_size: int
|
||
) -> Generator[bytes, None, None]:
|
||
"""
|
||
处理流式响应
|
||
:param response: 响应对象
|
||
:param chunk_size: 分块大小
|
||
:return: 响应内容迭代器
|
||
"""
|
||
try:
|
||
for chunk in response.iter_content(chunk_size=chunk_size):
|
||
if chunk:
|
||
yield chunk
|
||
finally:
|
||
response.close()
|