ROC、PR曲线绘制

PR曲线绘制

https://blog.csdn.net/xiang_gina/article/details/149962713 

三分钟带你理解ROC曲线和PR曲线

https://blog.csdn.net/Guo_Python/article/details/105820358

https://developers.google.com/machine-learning/crash-course/overfitting/interpreting-loss-curves?hl=zh-cn


 Cursor编写的示例程序:

提示:写一个Demo,计算ROC、AUC、和PR曲线,要求,计算过程中通过阈值按照步长逐渐增长的方式(步长采用0.01)

  1 #!/usr/bin/env python3
  2 # -*- coding: utf-8 -*-
  3 
  4 import numpy as np
  5 import matplotlib.pyplot as plt
  6 
  7 
  8 def confusion_from_threshold(y_true, y_scores, threshold):
  9     """Compute TP/FP/TN/FN for rule: predict positive if score >= threshold."""
 10     y_true = np.asarray(y_true).astype(int)
 11     y_scores = np.asarray(y_scores).astype(float)
 12 
 13     y_pred = (y_scores >= threshold).astype(int)
 14 
 15     tp = int(np.sum((y_pred == 1) & (y_true == 1)))
 16     fp = int(np.sum((y_pred == 1) & (y_true == 0)))
 17     tn = int(np.sum((y_pred == 0) & (y_true == 0)))
 18     fn = int(np.sum((y_pred == 0) & (y_true == 1)))
 19 
 20     return tp, fp, tn, fn
 21 
 22 
 23 def sweep_thresholds(y_true, y_scores, step=0.1):
 24     """Sweep threshold from 0.0 to 1.0 (inclusive) with given step."""
 25     thresholds = np.round(np.arange(0.0, 1.0 + 1e-9, step), 10)
 26 
 27     total_p = int(np.sum(y_true))
 28     total_n = int(len(y_true) - total_p)
 29 
 30     rows = []
 31     for thr in thresholds:
 32         tp, fp, tn, fn = confusion_from_threshold(y_true, y_scores, thr)
 33 
 34         tpr = tp / total_p if total_p > 0 else 0.0
 35         fpr = fp / total_n if total_n > 0 else 0.0
 36 
 37         precision = tp / (tp + fp) if (tp + fp) > 0 else 1.0
 38         recall = tpr
 39 
 40         rows.append(
 41             {
 42                 'thr': float(thr),
 43                 'tp': tp,
 44                 'fp': fp,
 45                 'tn': tn,
 46                 'fn': fn,
 47                 'tpr': float(tpr),
 48                 'fpr': float(fpr),
 49                 'precision': float(precision),
 50                 'recall': float(recall),
 51             }
 52         )
 53 
 54     return rows
 55 
 56 
 57 def trapezoid_auc(x, y):
 58     """Area under curve using trapezoidal rule after sorting by x."""
 59     x = np.asarray(x, dtype=float)
 60     y = np.asarray(y, dtype=float)
 61 
 62     order = np.argsort(x)
 63     x = x[order]
 64     y = y[order]
 65 
 66     return float(np.trapz(y, x))
 67 
 68 
 69 def main():
 70     # Built-in example (10 samples)
 71     scores = np.array([0.95, 0.85, 0.78, 0.66, 0.60, 0.55, 0.53, 0.52, 0.51, 0.40], dtype=float)
 72     y_true = np.array([1, 1, 0, 1, 0, 1, 0, 0, 1, 0], dtype=int)  # P=1, N=0
 73 
 74     step = 0.01
 75     rows = sweep_thresholds(y_true, scores, step=step)
 76 
 77     # Collect curve points
 78     fprs = [r['fpr'] for r in rows]
 79     tprs = [r['tpr'] for r in rows]
 80     recalls = [r['recall'] for r in rows]
 81     precisions = [r['precision'] for r in rows]
 82 
 83     roc_auc = trapezoid_auc(fprs, tprs)
 84 
 85     # Print table
 86     print(f"Threshold sweep step={step}")
 87     print("thr\tTP\tFP\tTN\tFN\tTPR\tFPR\tPrec\tRecall")
 88     for r in rows:
 89         print(
 90             f"{r['thr']:.1f}\t{r['tp']}\t{r['fp']}\t{r['tn']}\t{r['fn']}\t"
 91             f"{r['tpr']:.3f}\t{r['fpr']:.3f}\t{r['precision']:.3f}\t{r['recall']:.3f}"
 92         )
 93 
 94     print('\nROC AUC (threshold-sweep, trapezoid) =', roc_auc)
 95 
 96     # Plot
 97     fig, axes = plt.subplots(1, 2, figsize=(12, 5))
 98 
 99     axes[0].plot(fprs, tprs, marker='o', label=f'ROC (AUC={roc_auc:.4f})', color='tab:blue')
100     axes[0].plot([0, 1], [0, 1], linestyle='--', linewidth=1, color='gray', label='Random')
101     axes[0].set_xlabel('False Positive Rate')
102     axes[0].set_ylabel('True Positive Rate')
103     axes[0].set_title('ROC Curve (threshold sweep)')
104     axes[0].set_xlim(0, 1)
105     axes[0].set_ylim(0, 1)
106     axes[0].grid(True)
107     axes[0].legend(loc='lower right')
108 
109     axes[1].plot(recalls, precisions, marker='o', label='PR (threshold sweep)', color='tab:red')
110     axes[1].set_xlabel('Recall')
111     axes[1].set_ylabel('Precision')
112     axes[1].set_title('PR Curve (threshold sweep)')
113     axes[1].set_xlim(0, 1)
114     axes[1].set_ylim(0, 1)
115     axes[1].grid(True)
116     axes[1].legend(loc='lower left')
117 
118     plt.tight_layout()
119     plt.show()
120 
121 
122 if __name__ == '__main__':
123     main()

结果:

image

 

posted @ 2025-09-23 16:15  太一吾鱼水  阅读(21)  评论(0)    收藏  举报