跟着AI学AI - 诊断结论信息抽取 - 模型评估与调试

# 设置镜像源的环境变量
(vippython) PS D:\OpenSource\Python\VipPython> $env:HF_ENDPOINT = "https://hf-mirror.com"
# 添加依赖
(vippython) PS D:\OpenSource\Python\VipPython\information_extraction> uv add pandas
# 切换下目录,否则会报文件不存在
(vippython) PS D:\OpenSource\Python\VipPython> cd D:\OpenSource\Python\VipPython\information_extraction
(vippython) PS D:\OpenSource\Python\VipPython\information_extraction> uv run .\evaluate_model.py

evaluate_model.py

# 模型评估和测试
import json
import os
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from seqeval.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score
import pandas as pd


def evaluate_model(model_dir, test_file):
    """评估模型性能"""
    print("=" * 60)
    print("模型评估")
    print("=" * 60)

    # 1. 加载模型和tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForTokenClassification.from_pretrained(model_dir)
    model.eval()

    # 2. 加载测试数据
    with open(test_file, 'r', encoding='utf-8') as f:
        test_data = json.load(f)

    # 3. 加载标签映射
    with open(os.path.join(model_dir, 'id2label.json'), 'r', encoding='utf-8') as f:
        id2label = json.load(f)
        id2label = {int(k): v for k, v in id2label.items()}

    label2id = {v: k for k, v in id2label.items()}

    # 4. 预测
    predictions = []
    references = []

    with torch.no_grad():
        for item in test_data[:50]:  # 测试前50个
            input_ids = torch.tensor(item['input_ids']).unsqueeze(0)
            attention_mask = torch.tensor(item['attention_mask']).unsqueeze(0)

            outputs = model(input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=2)

            # 转换预测
            pred_labels = [id2label[p.item()] for p in preds[0]]
            true_labels = [id2label[l] for l in item['labels'] if id2label[l] != 'O'][:len(pred_labels)]

            # 过滤特殊token
            pred_filtered = []
            true_filtered = []
            for p, l, attn in zip(pred_labels, item['labels'], attention_mask[0]):
                if attn == 1 and l != label2id['O']:
                    pred_filtered.append(p)
                    true_filtered.append(id2label[l])
                    break

            predictions.append(pred_filtered)
            references.append(true_filtered)

    # 5. 计算指标
    print("\n分类报告:")
    print(classification_report(references, predictions))

    print(f"准确率: {accuracy_score(references, predictions):.4f}")
    print(f"F1分数: {f1_score(references, predictions):.4f}")
    print(f"精确率: {precision_score(references, predictions):.4f}")
    print(f"召回率: {recall_score(references, predictions):.4f}")

    return predictions, references

if __name__ == "__main__":
    evaluate_model('./ecg_ner_model', 'data/out/bert_training_data.json')

image

posted @ 2026-05-14 13:52  VipSoft  阅读(13)  评论(0)    收藏  举报