222 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			222 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
| # -*- coding: utf-8 -*-
 | ||
| 
 | ||
| import json
 | ||
| import re
 | ||
| from csv import DictReader, DictWriter
 | ||
| from pathlib import Path
 | ||
| from typing import List, Dict
 | ||
| 
 | ||
| import torch
 | ||
| from transformers import BertTokenizerFast, BertForTokenClassification
 | ||
| 
 | ||
| 
 | ||
| # 命名实体识别
 | ||
| class NER:
 | ||
|     def __init__(self):
 | ||
|         # 实体标签映射
 | ||
|         self.label_map = {
 | ||
|             0: "O",  # 非药品命名实体
 | ||
|             1: "B-DRUG",  # 药品命名实体-开始
 | ||
|             2: "I-DRUG",  # 药品命名实体-中间
 | ||
|         }
 | ||
| 
 | ||
|         # 加载预训练分词器
 | ||
|         self.tokenizer = BertTokenizerFast.from_pretrained(
 | ||
|             pretrained_model_name_or_path=Path("./models/bert-base-chinese").resolve()
 | ||
|         )
 | ||
| 
 | ||
|         # 加载预训练模型
 | ||
|         self.model = BertForTokenClassification.from_pretrained(
 | ||
|             pretrained_model_name_or_path=Path("./models/bert-base-chinese").resolve(),
 | ||
|         )
 | ||
| 
 | ||
|         # 设置模型为预测模式
 | ||
|         self.model.eval()
 | ||
| 
 | ||
|     def recognize_drugs(self, text: str) -> List[Dict]:
 | ||
|         """识别药品命名实体"""
 | ||
| 
 | ||
|         if not text.strip():
 | ||
|             return []
 | ||
| 
 | ||
|         # 分词编码
 | ||
|         inputs = self.tokenizer(
 | ||
|             text,
 | ||
|             return_tensors="pt",
 | ||
|             padding=True,
 | ||
|             truncation=True,
 | ||
|             return_offsets_mapping=True,
 | ||
|         )
 | ||
| 
 | ||
|         # TOKEN于文本中起止位置
 | ||
|         offset_mapping = inputs.pop("offset_mapping")[0].cpu().numpy()
 | ||
| 
 | ||
|         with torch.no_grad():
 | ||
|             # 模型预测
 | ||
|             outputs = self.model(**inputs)
 | ||
|             # 获取TOKEN预测标签
 | ||
|             predictions = torch.argmax(outputs.logits, dim=2)
 | ||
| 
 | ||
|         entities = []
 | ||
|         current_entity = None
 | ||
| 
 | ||
|         # 遍历所有TOKEN、预测标签索引和起止索引
 | ||
|         for token, offset, label_id in zip(
 | ||
|             self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
 | ||
|             offset_mapping,
 | ||
|             predictions[0].cpu().numpy(),
 | ||
|         ):
 | ||
|             print(label_id)
 | ||
|             continue
 | ||
| 
 | ||
|             # 映射TOKEN标签
 | ||
|             label = self.label_map.get(label_id, "O")
 | ||
| 
 | ||
|             # 若遇到特殊TOKEN则跳过
 | ||
|             if (
 | ||
|                 token in ["[CLS]", "[SEP]", "[PAD]"]
 | ||
|                 or offset[0] == 0
 | ||
|                 and offset[1] == 0
 | ||
|             ):
 | ||
|                 continue
 | ||
| 
 | ||
|             if label == "B-DRUG":
 | ||
|                 if current_entity:
 | ||
|                     self._combine_tokens(current_entity, text)
 | ||
|                     entities.append(current_entity)
 | ||
| 
 | ||
|                 current_entity = {
 | ||
|                     "start": offset[0],
 | ||
|                     "end": offset[1],
 | ||
|                     "tokens": [token],
 | ||
|                     "offsets": [offset],
 | ||
|                     "type": label,
 | ||
|                 }
 | ||
| 
 | ||
|             elif label == "I-DRUG":
 | ||
|                 if current_entity:
 | ||
|                     if offset[0] == current_entity["end"]:
 | ||
|                         current_entity["end"] = offset[1]
 | ||
|                         current_entity["tokens"].append(token)
 | ||
|                         current_entity["offsets"].append(offset)
 | ||
|                     else:
 | ||
|                         self._combine_tokens(current_entity, text)
 | ||
|                         entities.append(current_entity)
 | ||
|                         current_entity = {
 | ||
|                             "start": offset[0],
 | ||
|                             "end": offset[1],
 | ||
|                             "tokens": [token],
 | ||
|                             "offsets": [offset],
 | ||
|                             "type": label,
 | ||
|                         }
 | ||
| 
 | ||
|             else:
 | ||
|                 if current_entity:
 | ||
|                     self._combine_tokens(current_entity, text)
 | ||
|                     entities.append(current_entity)
 | ||
|                     current_entity = None
 | ||
| 
 | ||
|         if current_entity:
 | ||
|             self._combine_tokens(current_entity, text)
 | ||
|             entities.append(current_entity)
 | ||
| 
 | ||
|         return entities
 | ||
| 
 | ||
|     @staticmethod
 | ||
|     def _combine_tokens(current_entity: Dict, text: str):
 | ||
|         """合并TOKEN"""
 | ||
| 
 | ||
|         # 从文本中提取命名实体文本
 | ||
|         current_entity["text"] = text[current_entity["start"] : current_entity["end"]]
 | ||
| 
 | ||
| 
 | ||
| """
 | ||
| 
 | ||
| # 使用示例(需要训练好的模型)
 | ||
| dl_ner = NER()
 | ||
| text = "患者需要硫酸吗啡缓释片治疗癌症疼痛"
 | ||
| entities = dl_ner.recognize_drugs(text)
 | ||
| print(entities)
 | ||
| 
 | ||
| exit()
 | ||
| 
 | ||
| """
 | ||
| 
 | ||
| 
 | ||
| def drug_extraction(text) -> tuple[str, str | None]:
 | ||
|     """药品数据提取"""
 | ||
| 
 | ||
|     # 正则匹配两个“*”之间内容作为药品类别,第二个“*”之后内容作为药品名称。
 | ||
|     if match := re.match(
 | ||
|         pattern=r"\*(?P<drug_type>.*?)\*(?P<drug_name>.*)",
 | ||
|         string=(text := text.strip()),
 | ||
|     ):
 | ||
|         # 药品类别
 | ||
|         drug_type = match.group("drug_type").strip()
 | ||
| 
 | ||
|         # 药品名称
 | ||
|         drug_name = (
 | ||
|             match.group("drug_name")
 | ||
|             .upper()  # 小写转大写
 | ||
|             .replace("(", " ")
 | ||
|             .replace(")", " ")
 | ||
|             .replace("(", " ")
 | ||
|             .replace(")", " ")
 | ||
|             .replace("[", " ")
 | ||
|             .replace("]", " ")
 | ||
|             .replace("【", " ")
 | ||
|             .replace("】", " ")
 | ||
|             .replace(":", " ")
 | ||
|             .replace(":", " ")
 | ||
|             .replace(",", " ")
 | ||
|             .replace(",", " ")
 | ||
|             .replace("·", " ")
 | ||
|             .replace("`", " ")
 | ||
|             .replace("@", " ")
 | ||
|             .replace("#", " ")
 | ||
|             .replace("*", " ")
 | ||
|             .replace("/", " ")  # 就指定符号替换为空格
 | ||
|             .strip()
 | ||
|         )
 | ||
| 
 | ||
|         # 就药品名称中多个空格替换为一个空格
 | ||
|         drug_name = re.sub(pattern=r"\s+", repl=" ", string=drug_name)
 | ||
| 
 | ||
|         for section in drug_name.split(" "):
 | ||
|             print(section)
 | ||
| 
 | ||
|     # 若匹配失败则药品类型默认为文本、药品名称默认为None
 | ||
|     else:
 | ||
|         drug_type, drug_name = text, None
 | ||
| 
 | ||
|     return drug_type, drug_name
 | ||
| 
 | ||
| 
 | ||
| dataframe = []
 | ||
| 
 | ||
| # 就票据查验结果和疾病对应关系进行数据清洗(暂仅考虑增值税发票且为真票)
 | ||
| with open("票据查验结果和疾病对应关系.csv", "r", encoding="utf-8") as file:
 | ||
|     for row in DictReader(file):
 | ||
|         try:
 | ||
|             disease = row["疾病"]
 | ||
| 
 | ||
|             response = json.loads(row["票据查验结果"])
 | ||
| 
 | ||
|             # 遍历项目
 | ||
|             for item in response["data"]["details"]["items"]:
 | ||
| 
 | ||
|                 name = item["name"]
 | ||
| 
 | ||
|                 drug_extraction(name)
 | ||
| 
 | ||
|                 exit()
 | ||
| 
 | ||
|         except Exception as e:
 | ||
|             print(e)
 | ||
|             exit()
 | ||
| 
 | ||
| with open("1.csv", "w", newline="", encoding="utf-8") as file:
 | ||
|     writer = DictWriter(file, fieldnames=dataframe[0].keys())
 | ||
|     writer.writeheader()
 | ||
|     writer.writerows(dataframe)
 |