跟着AI学AI - 诊断结论信息抽取 - 模型评估与调试
# 设置镜像源的环境变量
(vippython) PS D:\OpenSource\Python\VipPython> $env:HF_ENDPOINT = "https://hf-mirror.com"
# 添加依赖
(vippython) PS D:\OpenSource\Python\VipPython\information_extraction> uv add pandas
# 切换下目录,否则会报文件不存在
(vippython) PS D:\OpenSource\Python\VipPython> cd D:\OpenSource\Python\VipPython\information_extraction
(vippython) PS D:\OpenSource\Python\VipPython\information_extraction> uv run .\evaluate_model.py
evaluate_model.py
# 模型评估和测试
import json
import os
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from seqeval.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score
import pandas as pd
def evaluate_model(model_dir, test_file):
"""评估模型性能"""
print("=" * 60)
print("模型评估")
print("=" * 60)
# 1. 加载模型和tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForTokenClassification.from_pretrained(model_dir)
model.eval()
# 2. 加载测试数据
with open(test_file, 'r', encoding='utf-8') as f:
test_data = json.load(f)
# 3. 加载标签映射
with open(os.path.join(model_dir, 'id2label.json'), 'r', encoding='utf-8') as f:
id2label = json.load(f)
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
# 4. 预测
predictions = []
references = []
with torch.no_grad():
for item in test_data[:50]: # 测试前50个
input_ids = torch.tensor(item['input_ids']).unsqueeze(0)
attention_mask = torch.tensor(item['attention_mask']).unsqueeze(0)
outputs = model(input_ids, attention_mask=attention_mask)
preds = torch.argmax(outputs.logits, dim=2)
# 转换预测
pred_labels = [id2label[p.item()] for p in preds[0]]
true_labels = [id2label[l] for l in item['labels'] if id2label[l] != 'O'][:len(pred_labels)]
# 过滤特殊token
pred_filtered = []
true_filtered = []
for p, l, attn in zip(pred_labels, item['labels'], attention_mask[0]):
if attn == 1 and l != label2id['O']:
pred_filtered.append(p)
true_filtered.append(id2label[l])
break
predictions.append(pred_filtered)
references.append(true_filtered)
# 5. 计算指标
print("\n分类报告:")
print(classification_report(references, predictions))
print(f"准确率: {accuracy_score(references, predictions):.4f}")
print(f"F1分数: {f1_score(references, predictions):.4f}")
print(f"精确率: {precision_score(references, predictions):.4f}")
print(f"召回率: {recall_score(references, predictions):.4f}")
return predictions, references
if __name__ == "__main__":
evaluate_model('./ecg_ner_model', 'data/out/bert_training_data.json')

本文来自博客园,作者:VipSoft 转载请注明原文链接:https://chuna2.787528.xyz/vipsoft/p/20012736
浙公网安备 33010602011771号