Python/普康健康自动化录入/test.py

222 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
import json
import re
from csv import DictReader, DictWriter
from pathlib import Path
from typing import List, Dict
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
# 命名实体识别
class NER:
def __init__(self):
# 实体标签映射
self.label_map = {
0: "O", # 非药品命名实体
1: "B-DRUG", # 药品命名实体-开始
2: "I-DRUG", # 药品命名实体-中间
}
# 加载预训练分词器
self.tokenizer = BertTokenizerFast.from_pretrained(
pretrained_model_name_or_path=Path("./models/bert-base-chinese").resolve()
)
# 加载预训练模型
self.model = BertForTokenClassification.from_pretrained(
pretrained_model_name_or_path=Path("./models/bert-base-chinese").resolve(),
)
# 设置模型为预测模式
self.model.eval()
def recognize_drugs(self, text: str) -> List[Dict]:
"""识别药品命名实体"""
if not text.strip():
return []
# 分词编码
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
return_offsets_mapping=True,
)
# TOKEN于文本中起止位置
offset_mapping = inputs.pop("offset_mapping")[0].cpu().numpy()
with torch.no_grad():
# 模型预测
outputs = self.model(**inputs)
# 获取TOKEN预测标签
predictions = torch.argmax(outputs.logits, dim=2)
entities = []
current_entity = None
# 遍历所有TOKEN、预测标签索引和起止索引
for token, offset, label_id in zip(
self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
offset_mapping,
predictions[0].cpu().numpy(),
):
print(label_id)
continue
# 映射TOKEN标签
label = self.label_map.get(label_id, "O")
# 若遇到特殊TOKEN则跳过
if (
token in ["[CLS]", "[SEP]", "[PAD]"]
or offset[0] == 0
and offset[1] == 0
):
continue
if label == "B-DRUG":
if current_entity:
self._combine_tokens(current_entity, text)
entities.append(current_entity)
current_entity = {
"start": offset[0],
"end": offset[1],
"tokens": [token],
"offsets": [offset],
"type": label,
}
elif label == "I-DRUG":
if current_entity:
if offset[0] == current_entity["end"]:
current_entity["end"] = offset[1]
current_entity["tokens"].append(token)
current_entity["offsets"].append(offset)
else:
self._combine_tokens(current_entity, text)
entities.append(current_entity)
current_entity = {
"start": offset[0],
"end": offset[1],
"tokens": [token],
"offsets": [offset],
"type": label,
}
else:
if current_entity:
self._combine_tokens(current_entity, text)
entities.append(current_entity)
current_entity = None
if current_entity:
self._combine_tokens(current_entity, text)
entities.append(current_entity)
return entities
@staticmethod
def _combine_tokens(current_entity: Dict, text: str):
"""合并TOKEN"""
# 从文本中提取命名实体文本
current_entity["text"] = text[current_entity["start"] : current_entity["end"]]
"""
# 使用示例(需要训练好的模型)
dl_ner = NER()
text = "患者需要硫酸吗啡缓释片治疗癌症疼痛"
entities = dl_ner.recognize_drugs(text)
print(entities)
exit()
"""
def drug_extraction(text) -> tuple[str, str | None]:
"""药品数据提取"""
# 正则匹配两个“*”之间内容作为药品类别,第二个“*”之后内容作为药品名称。
if match := re.match(
pattern=r"\*(?P<drug_type>.*?)\*(?P<drug_name>.*)",
string=(text := text.strip()),
):
# 药品类别
drug_type = match.group("drug_type").strip()
# 药品名称
drug_name = (
match.group("drug_name")
.upper() # 小写转大写
.replace("(", " ")
.replace(")", " ")
.replace("", " ")
.replace("", " ")
.replace("[", " ")
.replace("]", " ")
.replace("", " ")
.replace("", " ")
.replace(":", " ")
.replace("", " ")
.replace(",", " ")
.replace("", " ")
.replace("·", " ")
.replace("`", " ")
.replace("@", " ")
.replace("#", " ")
.replace("*", " ")
.replace("/", " ") # 就指定符号替换为空格
.strip()
)
# 就药品名称中多个空格替换为一个空格
drug_name = re.sub(pattern=r"\s+", repl=" ", string=drug_name)
for section in drug_name.split(" "):
print(section)
# 若匹配失败则药品类型默认为文本、药品名称默认为None
else:
drug_type, drug_name = text, None
return drug_type, drug_name
dataframe = []
# 就票据查验结果和疾病对应关系进行数据清洗(暂仅考虑增值税发票且为真票)
with open("票据查验结果和疾病对应关系.csv", "r", encoding="utf-8") as file:
for row in DictReader(file):
try:
disease = row["疾病"]
response = json.loads(row["票据查验结果"])
# 遍历项目
for item in response["data"]["details"]["items"]:
name = item["name"]
drug_extraction(name)
exit()
except Exception as e:
print(e)
exit()
with open("1.csv", "w", newline="", encoding="utf-8") as file:
writer = DictWriter(file, fieldnames=dataframe[0].keys())
writer.writeheader()
writer.writerows(dataframe)