222 lines
6.5 KiB
Python
222 lines
6.5 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
import json
|
||
import re
|
||
from csv import DictReader, DictWriter
|
||
from pathlib import Path
|
||
from typing import List, Dict
|
||
|
||
import torch
|
||
from transformers import BertTokenizerFast, BertForTokenClassification
|
||
|
||
|
||
# 命名实体识别
|
||
class NER:
|
||
def __init__(self):
|
||
# 实体标签映射
|
||
self.label_map = {
|
||
0: "O", # 非药品命名实体
|
||
1: "B-DRUG", # 药品命名实体-开始
|
||
2: "I-DRUG", # 药品命名实体-中间
|
||
}
|
||
|
||
# 加载预训练分词器
|
||
self.tokenizer = BertTokenizerFast.from_pretrained(
|
||
pretrained_model_name_or_path=Path("./models/bert-base-chinese").resolve()
|
||
)
|
||
|
||
# 加载预训练模型
|
||
self.model = BertForTokenClassification.from_pretrained(
|
||
pretrained_model_name_or_path=Path("./models/bert-base-chinese").resolve(),
|
||
)
|
||
|
||
# 设置模型为预测模式
|
||
self.model.eval()
|
||
|
||
def recognize_drugs(self, text: str) -> List[Dict]:
|
||
"""识别药品命名实体"""
|
||
|
||
if not text.strip():
|
||
return []
|
||
|
||
# 分词编码
|
||
inputs = self.tokenizer(
|
||
text,
|
||
return_tensors="pt",
|
||
padding=True,
|
||
truncation=True,
|
||
return_offsets_mapping=True,
|
||
)
|
||
|
||
# TOKEN于文本中起止位置
|
||
offset_mapping = inputs.pop("offset_mapping")[0].cpu().numpy()
|
||
|
||
with torch.no_grad():
|
||
# 模型预测
|
||
outputs = self.model(**inputs)
|
||
# 获取TOKEN预测标签
|
||
predictions = torch.argmax(outputs.logits, dim=2)
|
||
|
||
entities = []
|
||
current_entity = None
|
||
|
||
# 遍历所有TOKEN、预测标签索引和起止索引
|
||
for token, offset, label_id in zip(
|
||
self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
|
||
offset_mapping,
|
||
predictions[0].cpu().numpy(),
|
||
):
|
||
print(label_id)
|
||
continue
|
||
|
||
# 映射TOKEN标签
|
||
label = self.label_map.get(label_id, "O")
|
||
|
||
# 若遇到特殊TOKEN则跳过
|
||
if (
|
||
token in ["[CLS]", "[SEP]", "[PAD]"]
|
||
or offset[0] == 0
|
||
and offset[1] == 0
|
||
):
|
||
continue
|
||
|
||
if label == "B-DRUG":
|
||
if current_entity:
|
||
self._combine_tokens(current_entity, text)
|
||
entities.append(current_entity)
|
||
|
||
current_entity = {
|
||
"start": offset[0],
|
||
"end": offset[1],
|
||
"tokens": [token],
|
||
"offsets": [offset],
|
||
"type": label,
|
||
}
|
||
|
||
elif label == "I-DRUG":
|
||
if current_entity:
|
||
if offset[0] == current_entity["end"]:
|
||
current_entity["end"] = offset[1]
|
||
current_entity["tokens"].append(token)
|
||
current_entity["offsets"].append(offset)
|
||
else:
|
||
self._combine_tokens(current_entity, text)
|
||
entities.append(current_entity)
|
||
current_entity = {
|
||
"start": offset[0],
|
||
"end": offset[1],
|
||
"tokens": [token],
|
||
"offsets": [offset],
|
||
"type": label,
|
||
}
|
||
|
||
else:
|
||
if current_entity:
|
||
self._combine_tokens(current_entity, text)
|
||
entities.append(current_entity)
|
||
current_entity = None
|
||
|
||
if current_entity:
|
||
self._combine_tokens(current_entity, text)
|
||
entities.append(current_entity)
|
||
|
||
return entities
|
||
|
||
@staticmethod
|
||
def _combine_tokens(current_entity: Dict, text: str):
|
||
"""合并TOKEN"""
|
||
|
||
# 从文本中提取命名实体文本
|
||
current_entity["text"] = text[current_entity["start"] : current_entity["end"]]
|
||
|
||
|
||
"""
|
||
|
||
# 使用示例(需要训练好的模型)
|
||
dl_ner = NER()
|
||
text = "患者需要硫酸吗啡缓释片治疗癌症疼痛"
|
||
entities = dl_ner.recognize_drugs(text)
|
||
print(entities)
|
||
|
||
exit()
|
||
|
||
"""
|
||
|
||
|
||
def drug_extraction(text) -> tuple[str, str | None]:
|
||
"""药品数据提取"""
|
||
|
||
# 正则匹配两个“*”之间内容作为药品类别,第二个“*”之后内容作为药品名称。
|
||
if match := re.match(
|
||
pattern=r"\*(?P<drug_type>.*?)\*(?P<drug_name>.*)",
|
||
string=(text := text.strip()),
|
||
):
|
||
# 药品类别
|
||
drug_type = match.group("drug_type").strip()
|
||
|
||
# 药品名称
|
||
drug_name = (
|
||
match.group("drug_name")
|
||
.upper() # 小写转大写
|
||
.replace("(", " ")
|
||
.replace(")", " ")
|
||
.replace("(", " ")
|
||
.replace(")", " ")
|
||
.replace("[", " ")
|
||
.replace("]", " ")
|
||
.replace("【", " ")
|
||
.replace("】", " ")
|
||
.replace(":", " ")
|
||
.replace(":", " ")
|
||
.replace(",", " ")
|
||
.replace(",", " ")
|
||
.replace("·", " ")
|
||
.replace("`", " ")
|
||
.replace("@", " ")
|
||
.replace("#", " ")
|
||
.replace("*", " ")
|
||
.replace("/", " ") # 就指定符号替换为空格
|
||
.strip()
|
||
)
|
||
|
||
# 就药品名称中多个空格替换为一个空格
|
||
drug_name = re.sub(pattern=r"\s+", repl=" ", string=drug_name)
|
||
|
||
for section in drug_name.split(" "):
|
||
print(section)
|
||
|
||
# 若匹配失败则药品类型默认为文本、药品名称默认为None
|
||
else:
|
||
drug_type, drug_name = text, None
|
||
|
||
return drug_type, drug_name
|
||
|
||
|
||
dataframe = []
|
||
|
||
# 就票据查验结果和疾病对应关系进行数据清洗(暂仅考虑增值税发票且为真票)
|
||
with open("票据查验结果和疾病对应关系.csv", "r", encoding="utf-8") as file:
|
||
for row in DictReader(file):
|
||
try:
|
||
disease = row["疾病"]
|
||
|
||
response = json.loads(row["票据查验结果"])
|
||
|
||
# 遍历项目
|
||
for item in response["data"]["details"]["items"]:
|
||
|
||
name = item["name"]
|
||
|
||
drug_extraction(name)
|
||
|
||
exit()
|
||
|
||
except Exception as e:
|
||
print(e)
|
||
exit()
|
||
|
||
with open("1.csv", "w", newline="", encoding="utf-8") as file:
|
||
writer = DictWriter(file, fieldnames=dataframe[0].keys())
|
||
writer.writeheader()
|
||
writer.writerows(dataframe)
|