472 lines
14 KiB
Python
472 lines
14 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
# 导入模块
|
||
|
||
import warnings
|
||
|
||
# 过滤使用提醒
|
||
warnings.filterwarnings(
|
||
"ignore",
|
||
category=UserWarning,
|
||
)
|
||
|
||
from fuzzywuzzy import fuzz
|
||
|
||
import re
|
||
|
||
import numpy
|
||
|
||
import cv2
|
||
|
||
from decimal import Decimal, ROUND_HALF_UP
|
||
|
||
from paddleocr import PaddleOCR
|
||
|
||
"""
|
||
|
||
封装百度飞桨PADDLEOCR
|
||
|
||
"""
|
||
|
||
|
||
def fuzzy_match(
|
||
target: str, components: list, specify_key: str, return_key: str
|
||
) -> str:
|
||
"""
|
||
根据目标在组成部分列表模糊匹指定键名的键值,并返回匹配的组成部分的返回键名的键值
|
||
需要匹配的键名的键值
|
||
"""
|
||
|
||
def _get_value(component, keys):
|
||
"""根据键名递归获取键值,支持嵌套结构"""
|
||
key = keys[0]
|
||
if isinstance(component, dict) and key in component:
|
||
return (
|
||
_get_value(component[key], keys[1:])
|
||
if len(keys) > 1
|
||
else component[key]
|
||
)
|
||
return None
|
||
|
||
results = []
|
||
|
||
for component in components:
|
||
# 在组成部分根据指定键名获取对应键值
|
||
specify_value = _get_value(component, specify_key.split("."))
|
||
if specify_value is None:
|
||
continue
|
||
|
||
# 在组成部分根据返回键名获取对应键值
|
||
return_value = _get_value(component, return_key.split("."))
|
||
if return_value is not None:
|
||
results.append(
|
||
(return_value, fuzz.WRatio(target, specify_value))
|
||
) # 基于加权补偿莱文斯坦相似度算法
|
||
|
||
return max(results, key=lambda x: x[1])[0] if results else None
|
||
|
||
|
||
class PPOCR:
|
||
"""OCR客户端"""
|
||
|
||
def __init__(self):
|
||
|
||
# 初始化PADDLEOCR
|
||
self.ocr_engine = PaddleOCR(
|
||
ocr_version="PP-OCRv4",
|
||
use_doc_orientation_classify=True,
|
||
use_doc_unwarping=True,
|
||
use_textline_orientation=True,
|
||
)
|
||
|
||
@staticmethod
|
||
def _texts_sort(texts):
|
||
"""文本排序"""
|
||
|
||
texts_merged = []
|
||
|
||
for texts, coordinates in zip(
|
||
texts[0]["rec_texts"], texts[0]["rec_polys"]
|
||
): # 默认识别结果仅包含一张影像件
|
||
|
||
# 合并文本框的X/Y坐标、高度和文本
|
||
texts_merged.append(
|
||
[
|
||
# X坐标
|
||
numpy.min(coordinates[:, 0]),
|
||
# Y坐标
|
||
numpy.min(coordinates[:, 1]),
|
||
# 高度
|
||
numpy.max(coordinates[:, 1]) - numpy.min(coordinates[:, 1]),
|
||
texts,
|
||
]
|
||
)
|
||
|
||
# 按照文本框Y坐标升序(使用空间坐标算法)
|
||
texts_merged.sort(key=lambda x: x[1])
|
||
|
||
texts_sorted = []
|
||
|
||
for index, text in enumerate(texts_merged[1:]):
|
||
|
||
if index == 0:
|
||
|
||
# 初始化当前行
|
||
row = [texts_merged[0]]
|
||
|
||
continue
|
||
|
||
# 若文本框Y坐标与当前行中最后一个文本框的Y坐标差值小于阈值,则归为同一行
|
||
# noinspection PyUnboundLocalVariable
|
||
# noinspection PyTypeChecker
|
||
if (
|
||
text[1] - row[-1][1] < numpy.mean([text[2] for text in row]) * 0.5
|
||
): # 注意NUMPY.NDARRAY和LIST区别,ROW[:, 1]仅适用于NUMPY.NDARRAY,故使用列表推导式计算当前行文本框Y坐标和高度
|
||
|
||
row.append(text)
|
||
|
||
# 否则按照文本框X坐标就当前行中文本框升序
|
||
else:
|
||
|
||
row_sorted = sorted(row, key=lambda x: x[0])
|
||
|
||
texts_sorted.extend(row_sorted)
|
||
|
||
row = [text]
|
||
|
||
# 按照文本框X坐标就最后一行中文本框升序
|
||
row_sorted = sorted(row, key=lambda x: x[0])
|
||
|
||
texts_sorted.extend(row_sorted)
|
||
|
||
# 返回排序后文本
|
||
return [text_sorted[3] for text_sorted in texts_sorted]
|
||
|
||
def identity_card_recognition(self, image_path: str) -> dict:
|
||
"""居民身份证识别"""
|
||
|
||
# 读取影像件(数据类型为NUMPY.NDARRAY)
|
||
image = cv2.imread(image_path)
|
||
|
||
texts = self.ocr_engine.predict(
|
||
image,
|
||
use_doc_orientation_classify=False,
|
||
use_doc_unwarping=False,
|
||
use_textline_orientation=True,
|
||
text_rec_score_thresh=0.5,
|
||
)
|
||
|
||
# 文本排序
|
||
texts = self._texts_sort(texts)
|
||
|
||
# 居民身份证模版
|
||
result = {
|
||
"姓名": "",
|
||
"性别": "",
|
||
"民族": "",
|
||
"出生": "",
|
||
"住址": "",
|
||
"公民身份号码": "",
|
||
"有效期限": "",
|
||
"签发机关": "",
|
||
}
|
||
|
||
for text in texts: # 默认只包含一套居民身份证正反面
|
||
|
||
# 姓名
|
||
if not result["姓名"] and "姓名" in text:
|
||
|
||
result["姓名"] = text.replace("姓名", "").strip()
|
||
|
||
elif "性别" in text or "民族" in text: # 姓名和民族常同时返回
|
||
|
||
# 性别
|
||
if not result["性别"] and "性别" in text:
|
||
|
||
result["性别"] = (
|
||
text.split("性别")[-1].strip().split("民族")[0].strip()
|
||
)
|
||
|
||
# 民族
|
||
if not result["民族"] and "民族" in text:
|
||
|
||
result["民族"] = text.split("民族")[-1].strip()
|
||
|
||
# 出生
|
||
elif not result["出生"] and "出生" in text:
|
||
|
||
result["出生"] = text.replace("出生", "").strip()
|
||
|
||
# 住址
|
||
elif "住址" in text or (
|
||
(
|
||
not any(
|
||
keyword in text
|
||
for keyword in [
|
||
"姓名",
|
||
"性别",
|
||
"民族",
|
||
"出生",
|
||
"公民身份号码",
|
||
"中华人民共和国",
|
||
"居民身份证",
|
||
"签发机关",
|
||
"有效期限",
|
||
]
|
||
)
|
||
)
|
||
and not re.fullmatch(
|
||
r"^(\d{4}[.]\d{2}[.]\d{2})$", text.split("-")[0].strip()
|
||
)
|
||
):
|
||
|
||
if not result["住址"] and "住址" in text:
|
||
|
||
result["住址"] = text.replace("住址", "").strip()
|
||
|
||
if result["住址"] and not "住址" in text:
|
||
|
||
result["住址"] += text.strip()
|
||
|
||
# 公民身份号码
|
||
elif not result["公民身份号码"] and ("公民身份号码" in text):
|
||
|
||
result["公民身份号码"] = text.replace("公民身份号码", "").strip()
|
||
|
||
# 有效期限
|
||
elif not result["有效期限"] and (
|
||
"有效期限" in text
|
||
or re.fullmatch(
|
||
r"^(\d{4}[.]\d{2}[.]\d{2})$", text.split("-")[0].strip()
|
||
)
|
||
):
|
||
|
||
result["有效期限"] = text.replace("有效期限", "").strip()
|
||
|
||
# 签发机关
|
||
elif not result["签发机关"] and "签发机关" in text:
|
||
|
||
result["签发机关"] = text.replace("签发机关", "").strip()
|
||
|
||
return result
|
||
|
||
def invoice_recognition(self, image_path: str) -> dict:
|
||
"""增值税发票识别"""
|
||
|
||
# 读取影像件(数据类型为NUMPY.NDARRAY)
|
||
image = cv2.imread(image_path)
|
||
|
||
texts = self.ocr_engine.predict(
|
||
image,
|
||
use_doc_orientation_classify=False,
|
||
use_doc_unwarping=False,
|
||
use_textline_orientation=False,
|
||
text_rec_score_thresh=0.5,
|
||
)
|
||
|
||
# 文本排序
|
||
texts = self._texts_sort(texts)
|
||
|
||
print(texts)
|
||
|
||
# 增值税发票模版
|
||
result = {
|
||
"票据类型": "",
|
||
"票据号码": "",
|
||
"票据代码": "",
|
||
"开票日期": "",
|
||
"票据金额": "",
|
||
"校验码": "",
|
||
"收款方": "",
|
||
"付款方": "",
|
||
"项目": [],
|
||
}
|
||
|
||
for i, text in enumerate(texts):
|
||
|
||
if not result["票据类型"] and "电子发票" in text:
|
||
|
||
result["票据类型"] = "数电发票"
|
||
|
||
elif not result["票据号码"] and "发票号码" in text:
|
||
|
||
result["票据号码"] = (
|
||
text.replace("发票号码", "")
|
||
.replace(":", "")
|
||
.replace(":", "")
|
||
.strip()
|
||
)
|
||
|
||
elif not result["开票日期"] and "开票日期" in text:
|
||
|
||
result["开票日期"] = (
|
||
text.replace("开票日期", "")
|
||
.replace(":", "")
|
||
.replace(":", "")
|
||
.strip()
|
||
)
|
||
|
||
elif not result["票据金额"] and "小写" in text:
|
||
|
||
if re.match(
|
||
r"^-?\d+(\.\d+)?$", text.replace("¥", "¥").split("¥")[-1].strip()
|
||
):
|
||
|
||
result["票据金额"] = text.replace("¥", "¥").split("¥")[-1].strip()
|
||
|
||
elif re.match(
|
||
r"^-?\d+(\.\d+)?$",
|
||
texts[i + 1].replace("¥", "¥").split("¥")[-1].strip(),
|
||
):
|
||
|
||
result["票据金额"] = (
|
||
texts[i + 1].replace("¥", "¥").split("¥")[-1].strip()
|
||
)
|
||
|
||
elif "名称" in text and not "项目名称" in text:
|
||
|
||
if not result["付款方"]:
|
||
|
||
result["付款方"] = (
|
||
text.replace("名称", "")
|
||
.replace(":", "")
|
||
.replace(":", "")
|
||
.strip()
|
||
)
|
||
|
||
else:
|
||
|
||
result["收款方"] = (
|
||
text.replace("名称", "")
|
||
.replace(":", "")
|
||
.replace(":", "")
|
||
.strip()
|
||
)
|
||
|
||
# 项目
|
||
items = []
|
||
|
||
for i, text in enumerate(texts):
|
||
|
||
# 通过首位为星号定位名称、规格和单位
|
||
if text.startswith("*"):
|
||
|
||
# 项目模版
|
||
# noinspection PyDictCreation
|
||
item = {
|
||
"名称": "",
|
||
"规格": "",
|
||
"单位": "",
|
||
"数量": "",
|
||
"单价": "",
|
||
"金额": "",
|
||
"税率": "",
|
||
"税额": "",
|
||
}
|
||
|
||
item["名称"] = text.strip("")
|
||
|
||
# 若非数值则名称后一项为规格
|
||
if not re.match(
|
||
r"^-?\d+(\.\d+)?$",
|
||
texts[i + 1].replace("%", "").strip(),
|
||
):
|
||
|
||
item["规格"] = texts[i + 1].strip()
|
||
|
||
# 若非数值则名称后二项为单位
|
||
if not re.match(
|
||
r"^-?\d+(\.\d+)?$",
|
||
texts[i + 2].replace("%", "").strip(),
|
||
):
|
||
|
||
item["单位"] = texts[i + 2].strip()
|
||
|
||
for j, text_ in enumerate(texts):
|
||
|
||
# 若内循环索引小于等于外循环索引则跳过
|
||
if j <= i:
|
||
|
||
continue
|
||
|
||
# 若内循环首位为星号或为小计则将识别结果添加至项目并停止内循环
|
||
if j > i and (
|
||
text_.startswith("*") or text_ in "小计" or text_ in "合计"
|
||
):
|
||
|
||
items.append(item)
|
||
|
||
break
|
||
|
||
# 通过包含百分号定位税率、税额、数量、单价和金额
|
||
if "%" in text_ and re.match(
|
||
r"^\d+(\.\d+)?$",
|
||
texts[j].replace("%", "").strip(),
|
||
):
|
||
|
||
item["税率"] = texts[j].replace("%", "").strip() + "%"
|
||
|
||
# 税率后一项为税额
|
||
if re.match(
|
||
r"^-?\d+(\.\d+)?$",
|
||
texts[j + 1].strip(),
|
||
):
|
||
|
||
item["税额"] = texts[j + 1].strip()
|
||
|
||
# 税率前一项为金额
|
||
if re.match(
|
||
r"^-?\d+(\.\d+)?$",
|
||
texts[j - 1].strip(),
|
||
):
|
||
|
||
item["金额"] = texts[j - 1].strip()
|
||
|
||
# 若金额包含负号,税率前二项为单价、前三项为数量
|
||
if not "-" in item["金额"]:
|
||
|
||
if re.match(
|
||
r"^\d+(\.\d+)?$",
|
||
texts[j - 2].strip(),
|
||
):
|
||
|
||
item["单价"] = texts[j - 2].strip()
|
||
|
||
if texts[j - 3].strip().isdigit():
|
||
|
||
item["数量"] = texts[j - 3].strip()
|
||
|
||
elif j > i + 2 and not re.match(
|
||
r"^-?\d+(\.\d+)?$",
|
||
text_.replace("%", "").strip(),
|
||
):
|
||
|
||
item["名称"] += texts[j].strip()
|
||
|
||
# 数值修正
|
||
for item in items:
|
||
|
||
if (
|
||
not item["数量"]
|
||
and item["金额"]
|
||
and not "-" in item["金额"]
|
||
and item["单价"]
|
||
):
|
||
|
||
item["数量"] = (
|
||
""
|
||
if (
|
||
quantity := int(
|
||
(Decimal(item["金额"]) / Decimal(item["单价"])).quantize(
|
||
Decimal("0"), rounding=ROUND_HALF_UP
|
||
)
|
||
)
|
||
)
|
||
== 0
|
||
else str(quantity)
|
||
)
|
||
|
||
result["项目"] = items
|
||
|
||
return result
|