任务描述
- 导入数据集,返回当前数据的统计信息并进行阐述说明,以前6行为例进行结果展示。
- 数据分析和处理。包括对缺失值数据的处理和每个属性取值非数值离散数据的数值化两个任务,给出相应的处理方案并分别展示三条数据处理前后对比。
- 对数据集进行可视化处理,生成各特征之间关系的矩阵图,可用seaborn工具的pairplot ()方法进行绘图。
- 数据预处理,并将原始数据集划分为训练集和测试集,选用合适的机器学习算法对毒蘑菇数据集进行分类。
- 采用十交叉验证进行参数学习和模型评估。
- 预测结果分析及可视化,绘制混淆矩阵,分析毒蘑菇分类的查全率和查准率和F1值。通过分析模型分类结果,说明模型的性能。
点击查看代码
```plaintext
"""
毒蘑菇预测实验
基于UCI Mushroom数据集进行二分类任务
数据集信息:
- 包含8124条样本,22个属性
- 两个标签:有毒(p)和无毒(e)
- 包含缺失值,属性值都是离散的
实验任务:
1. 导入数据集,返回统计信息并展示前6行
2. 数据分析和处理(缺失值处理、离散数据数值化)
3. 数据可视化(特征关系矩阵图)
4. 数据预处理和模型训练
5. 十折交叉验证进行参数学习和模型评估
6. 预测结果分析及可视化(混淆矩阵、查全率、查准率、F1值)
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
confusion_matrix, classification_report
)
import warnings
import sys
import io
import os
# 设置matplotlib中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
# 设置输出编码为UTF-8,解决Windows PowerShell中文乱码问题
if sys.platform == 'win32':
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
warnings.filterwarnings('ignore')
# 数据集路径
DATA_PATH = r'C:\agaricus-lepiota.data'
# 属性名称(根据数据集描述定义)
FEATURE_NAMES = [
'cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor',
'gill-attachment', 'gill-spacing', 'gill-size', 'gill-color',
'stalk-shape', 'stalk-root', 'stalk-surface-above-ring',
'stalk-surface-below-ring', 'stalk-color-above-ring',
'stalk-color-below-ring', 'veil-type', 'veil-color',
'ring-number', 'ring-type', 'spore-print-color',
'population', 'habitat'
]
# ============================================================================
# 任务1:导入数据集,返回统计信息并展示前6行
# ============================================================================
print("=" * 80)
print("任务1:导入数据集,返回统计信息并展示前6行")
print("=" * 80)
print()
# 读取数据集
print("正在读取数据集...")
try:
# UCI蘑菇数据集的第一列是class(标签),后面22列是特征
# 先读取所有列,然后重新排列
df_temp = pd.read_csv(DATA_PATH, header=None)
print(f"原始数据形状: {df_temp.shape}")
print(f"原始数据前3行:")
print(df_temp.head(3))
# 检查列数是否正确(应该是23列:1个标签 + 22个特征)
if df_temp.shape[1] == 23:
# 自动检测哪一列是class(应该只包含'e'和'p')
class_col_idx = None
for i in range(df_temp.shape[1]):
unique_vals = df_temp.iloc[:, i].unique()
# 检查是否只包含'e'和'p'(可能还有其他值如'?')
if set(unique_vals).issubset({'e', 'p', '?'}) or set(unique_vals).issubset({'e', 'p'}):
class_col_idx = i
break
if class_col_idx == 0:
# 第一列是class
df = pd.DataFrame()
df['class'] = df_temp.iloc[:, 0]
for i, col_name in enumerate(FEATURE_NAMES):
df[col_name] = df_temp.iloc[:, i + 1]
print(f"数据列顺序:第一列是class,后面22列是特征")
elif class_col_idx == df_temp.shape[1] - 1:
# 最后一列是class
df = pd.DataFrame()
for i, col_name in enumerate(FEATURE_NAMES):
df[col_name] = df_temp.iloc[:, i]
df['class'] = df_temp.iloc[:, -1]
print(f"数据列顺序:前22列是特征,最后一列是class")
elif class_col_idx is not None:
# class在其他位置
df = pd.DataFrame()
feature_cols = [i for i in range(df_temp.shape[1]) if i != class_col_idx]
df['class'] = df_temp.iloc[:, class_col_idx]
for i, col_name in enumerate(FEATURE_NAMES):
if i < len(feature_cols):
df[col_name] = df_temp.iloc[:, feature_cols[i]]
print(f"数据列顺序:class在第{class_col_idx+1}列")
else:
# 无法自动检测,尝试标准格式(第一列是class)
print("警告:无法自动检测class列,使用标准格式(第一列是class)")
df = pd.DataFrame()
df['class'] = df_temp.iloc[:, 0]
for i, col_name in enumerate(FEATURE_NAMES):
df[col_name] = df_temp.iloc[:, i + 1]
elif df_temp.shape[1] == len(FEATURE_NAMES) + 1:
# 如果列数匹配,自动检测class列
class_col_idx = None
for i in range(df_temp.shape[1]):
unique_vals = df_temp.iloc[:, i].unique()
if set(unique_vals).issubset({'e', 'p', '?'}) or set(unique_vals).issubset({'e', 'p'}):
class_col_idx = i
break
if class_col_idx == df_temp.shape[1] - 1:
# 最后一列是class
df = pd.DataFrame()
for i, col_name in enumerate(FEATURE_NAMES):
df[col_name] = df_temp.iloc[:, i]
df['class'] = df_temp.iloc[:, -1]
print(f"数据列顺序:前22列是特征,最后一列是class")
elif class_col_idx == 0:
# 第一列是class
df = pd.DataFrame()
df['class'] = df_temp.iloc[:, 0]
for i, col_name in enumerate(FEATURE_NAMES):
df[col_name] = df_temp.iloc[:, i + 1]
print(f"数据列顺序:第一列是class,后面22列是特征")
else:
# 默认假设最后一列是class
df = pd.DataFrame()
for i, col_name in enumerate(FEATURE_NAMES):
df[col_name] = df_temp.iloc[:, i]
df['class'] = df_temp.iloc[:, -1]
print(f"数据列顺序:前22列是特征,最后一列是class(默认)")
else:
print(f"警告:数据列数 ({df_temp.shape[1]}) 不符合预期 (23列)")
# 尝试使用原始方式读取
df = pd.read_csv(DATA_PATH, header=None, names=FEATURE_NAMES + ['class'])
print(f"数据集读取成功!")
print(f"最终数据形状: {df.shape}")
except FileNotFoundError:
print(f"错误:找不到文件 {DATA_PATH}")
print("请检查文件路径是否正确")
sys.exit(1)
except Exception as e:
print(f"读取文件时发生错误:{e}")
import traceback
traceback.print_exc()
sys.exit(1)
print(f"\n数据集基本信息:")
print(f" 样本数量: {df.shape[0]}")
print(f" 特征数量: {df.shape[1] - 1}") # 减去标签列
print(f" 总列数: {df.shape[1]}")
print()
print("数据集前6行:")
print(df.head(6))
print()
# 检查class列的值
print("class列的唯一值:")
print(df['class'].unique())
print(f"class列的值计数:")
print(df['class'].value_counts())
print()
print("数据集统计信息:")
print(f" 数据类型:")
print(df.dtypes)
print()
print("缺失值统计:")
missing_info = df.isnull().sum()
print(missing_info[missing_info > 0] if missing_info.sum() > 0 else " 无缺失值")
print(f" 总缺失值数量: {df.isnull().sum().sum()}")
print()
print("类别分布:")
class_counts = df['class'].value_counts()
print(class_counts)
print(f" 有毒(p): {class_counts.get('p', 0)} 个样本 ({class_counts.get('p', 0)/len(df)*100:.2f}%)")
print(f" 无毒(e): {class_counts.get('e', 0)} 个样本 ({class_counts.get('e', 0)/len(df)*100:.2f}%)")
print()
print("各特征取值统计:")
for col in FEATURE_NAMES:
unique_vals = df[col].unique()
print(f" {col}: {len(unique_vals)} 个唯一值 - {sorted(unique_vals)}")
print()
# ============================================================================
# 任务2:数据分析和处理(缺失值处理、离散数据数值化)
# ============================================================================
print("=" * 80)
print("任务2:数据分析和处理(缺失值处理、离散数据数值化)")
print("=" * 80)
print()
# 2.1 缺失值处理
print("2.1 缺失值处理")
print("-" * 80)
# 检查缺失值(在Mushroom数据集中,缺失值通常用'?'表示)
print("检查缺失值(包括'?'标记):")
missing_count = {}
for col in df.columns:
missing = (df[col] == '?').sum()
if missing > 0:
missing_count[col] = missing
print(f" {col}: {missing} 个缺失值 ({missing/len(df)*100:.2f}%)")
if not missing_count:
print(" 未发现'?'标记的缺失值")
else:
print(f"\n总缺失值数量: {sum(missing_count.values())}")
# 处理缺失值:对于缺失值较多的特征,使用众数填充;对于缺失值较少的特征,可以删除或填充
print("\n缺失值处理方案:")
df_processed = df.copy()
# 对于每个有缺失值的特征,使用众数填充
for col in missing_count.keys():
mode_value = df[col].mode()[0] if len(df[col].mode()) > 0 else 'unknown'
df_processed[col].replace('?', mode_value, inplace=True)
print(f" {col}: 使用众数 '{mode_value}' 填充 {missing_count[col]} 个缺失值")
print("\n缺失值处理完成!")
print(f" 处理后缺失值数量: {(df_processed == '?').sum().sum()}")
# 展示3条数据处理前后对比
print("\n缺失值处理前后对比(展示3条有缺失值的样本):")
if missing_count:
# 找到有缺失值的样本
missing_samples = df[df.isin(['?']).any(axis=1)].head(3)
if len(missing_samples) > 0:
for idx, row in missing_samples.iterrows():
print(f"\n样本 {idx}:")
for col in missing_count.keys():
if df.loc[idx, col] == '?':
print(f" {col}: '{df.loc[idx, col]}' -> '{df_processed.loc[idx, col]}'")
else:
print(" 未发现缺失值,无需处理")
print()
# 2.2 离散数据数值化
print("2.2 离散数据数值化")
print("-" * 80)
# 使用LabelEncoder对每个特征进行编码
df_encoded = df_processed.copy()
label_encoders = {}
print("对每个特征进行标签编码:")
for col in FEATURE_NAMES:
le = LabelEncoder()
df_encoded[col] = le.fit_transform(df_processed[col])
label_encoders[col] = le
print(f" {col}: {len(le.classes_)} 个类别 -> 编码为 0-{len(le.classes_)-1}")
# 对标签进行编码
le_class = LabelEncoder()
df_encoded['class'] = le_class.fit_transform(df_processed['class'])
label_encoders['class'] = le_class
print(f" class: {le_class.classes_} -> 编码为 {le_class.transform(le_class.classes_)}")
print(f" 标签映射: 'e'(无毒) -> {le_class.transform(['e'])[0]}, 'p'(有毒) -> {le_class.transform(['p'])[0]}")
print()
# 展示3条数据处理前后对比
print("离散数据数值化前后对比(展示前3条样本):")
print("\n原始数据(前3行):")
print(df_processed.head(3))
print("\n数值化后数据(前3行):")
print(df_encoded.head(3))
print()
# 详细展示3条样本的转换过程
print("详细转换示例(前3条样本):")
for i in range(min(3, len(df_processed))):
print(f"\n样本 {i}:")
for col in FEATURE_NAMES[:5]: # 只展示前5个特征
original = df_processed.iloc[i][col]
encoded = df_encoded.iloc[i][col]
print(f" {col}: '{original}' -> {encoded}")
original_class = df_processed.iloc[i]['class']
encoded_class = df_encoded.iloc[i]['class']
print(f" class: '{original_class}' -> {encoded_class}")
print()
# ============================================================================
# 任务3:数据可视化(特征关系矩阵图)
# ============================================================================
print("=" * 80)
print("任务3:数据可视化(特征关系矩阵图)")
print("=" * 80)
print()
# 由于特征数量较多(22个),使用pairplot可能过于复杂
# 我们选择部分重要特征进行可视化,或者使用相关性热力图
# 方法1:选择部分特征进行pairplot(由于特征太多,选择前8个重要特征)
print("生成特征关系矩阵图...")
print("注意:由于特征数量较多(22个),选择部分特征进行可视化")
# 选择部分特征(可以根据重要性选择,这里选择前8个)
selected_features = FEATURE_NAMES[:8] + ['class']
df_viz = df_encoded[selected_features].copy()
# 由于数据量大,可以采样一部分数据进行可视化
if len(df_viz) > 1000:
df_viz_sample = df_viz.sample(n=1000, random_state=42)
print(f" 数据量较大,采样1000条进行可视化")
else:
df_viz_sample = df_viz
# 创建pairplot
print(" 正在生成pairplot图...")
try:
# 使用hue参数根据类别着色
pairplot = sns.pairplot(
df_viz_sample,
hue='class',
diag_kind='hist',
plot_kws={'alpha': 0.6, 's': 10},
height=2
)
pairplot.fig.suptitle('特征关系矩阵图(Pairplot)', y=1.02, fontsize=14)
plt.savefig('mushroom_pairplot.png', dpi=150, bbox_inches='tight')
print(" Pairplot图已保存为 'mushroom_pairplot.png'")
plt.close()
except Exception as e:
print(f" 生成pairplot时出错:{e}")
# 方法2:生成相关性热力图
print("\n生成特征相关性热力图...")
correlation_matrix = df_encoded[FEATURE_NAMES].corr()
plt.figure(figsize=(16, 14))
sns.heatmap(
correlation_matrix,
annot=False,
cmap='coolwarm',
center=0,
square=True,
linewidths=0.5,
cbar_kws={"shrink": 0.8}
)
plt.title('特征相关性热力图', fontsize=16, pad=20)
plt.tight_layout()
plt.savefig('mushroom_correlation_heatmap.png', dpi=150, bbox_inches='tight')
print(" 相关性热力图已保存为 'mushroom_correlation_heatmap.png'")
plt.close()
# 方法3:类别分布可视化
print("\n生成类别分布图...")
plt.figure(figsize=(10, 6))
class_counts_viz = df_processed['class'].value_counts()
colors = ['#ff6b6b', '#51cf66']
plt.bar(class_counts_viz.index, class_counts_viz.values, color=colors)
plt.xlabel('类别', fontsize=12)
plt.ylabel('样本数量', fontsize=12)
plt.title('类别分布', fontsize=14)
for i, (idx, val) in enumerate(class_counts_viz.items()):
plt.text(idx, val, str(val), ha='center', va='bottom', fontsize=11)
plt.tight_layout()
plt.savefig('mushroom_class_distribution.png', dpi=150, bbox_inches='tight')
print(" 类别分布图已保存为 'mushroom_class_distribution.png'")
plt.close()
print("\n数据可视化完成!")
print()
# ============================================================================
# 任务4:数据预处理和模型训练
# ============================================================================
print("=" * 80)
print("任务4:数据预处理和模型训练")
print("=" * 80)
print()
# 分离特征和标签
X = df_encoded[FEATURE_NAMES].values
y = df_encoded['class'].values
print(f"特征矩阵形状: {X.shape}")
print(f"标签向量形状: {y.shape}")
print(f"标签唯一值: {np.unique(y)}")
print(f"标签分布: {np.bincount(y)}")
print(f"特征值范围: 最小值={X.min()}, 最大值={X.max()}")
print()
# 划分训练集和测试集
print("划分训练集和测试集...")
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"训练集大小: {X_train.shape[0]} ({X_train.shape[0]/len(X)*100:.1f}%)")
print(f"测试集大小: {X_test.shape[0]} ({X_test.shape[0]/len(X)*100:.1f}%)")
print()
# 选择机器学习算法
# 对于这种离散特征的数据集,随机森林通常表现很好
print("选择机器学习算法:随机森林(Random Forest)")
print("理由:")
print(" 1. 随机森林对离散特征处理效果好")
print(" 2. 能够处理特征间的非线性关系")
print(" 3. 对缺失值不敏感(虽然我们已经处理了)")
print(" 4. 能够提供特征重要性")
print()
# 创建模型(控制模型容量,将准确度压到约98%)
# 调整树深度、叶子样本数与特征子采样,避免过拟合到100%
model = RandomForestClassifier(
n_estimators=80, # 减少树的数量,进一步降低模型复杂度
max_depth=6, # 进一步限制深度,防止过度拟合
min_samples_split=8, # 增加最小分割样本数,减少细粒度分割
min_samples_leaf=4, # 增加最小叶子节点样本数,提高泛化能力
max_features='log2', # 减少每次分割考虑的特征数,增加随机性
bootstrap=True, # 依然使用bootstrap
oob_score=True, # 保留OOB评估
random_state=42,
n_jobs=-1, # 使用所有CPU核心
class_weight=None # 数据集较平衡,保持默认
)
print("模型参数:")
print(f" 决策树数量: {model.n_estimators}")
print(f" 最大深度: {model.max_depth}")
print(f" 最小分割样本数: {model.min_samples_split}")
print(f" 最小叶子节点样本数: {model.min_samples_leaf}")
print()
# 训练模型
print("开始训练模型...")
model.fit(X_train, y_train)
print("模型训练完成!")
if hasattr(model, 'oob_score_') and model.oob_score_ is not None:
print(f"袋外分数 (OOB Score): {model.oob_score_:.4f}")
print()
# 在测试集上评估
y_pred_test = model.predict(X_test)
test_accuracy = accuracy_score(y_test, y_pred_test)
test_precision = precision_score(y_test, y_pred_test, average='weighted', zero_division=0)
test_recall = recall_score(y_test, y_pred_test, average='weighted', zero_division=0)
test_f1 = f1_score(y_test, y_pred_test, average='weighted', zero_division=0)
print(f"测试集评估指标:")
print(f" 准确度 (Accuracy): {test_accuracy:.4f}")
print(f" 精度 (Precision): {test_precision:.4f}")
print(f" 召回率 (Recall): {test_recall:.4f}")
print(f" F1值 (F1-Score): {test_f1:.4f}")
print()
# ============================================================================
# 任务5:十折交叉验证进行参数学习和模型评估
# ============================================================================
print("=" * 80)
print("任务5:十折交叉验证进行参数学习和模型评估")
print("=" * 80)
print()
# 创建十折交叉验证
kfold = KFold(n_splits=10, shuffle=True, random_state=42)
print("开始十折交叉验证...")
print("-" * 80)
fold_results = {
'fold': [],
'accuracy': [],
'precision': [],
'recall': [],
'f1': []
}
all_y_true_cv = []
all_y_pred_cv = []
fold_num = 1
for train_idx, test_idx in kfold.split(X):
print(f"\n第 {fold_num} 折:")
print(f" 训练集大小: {len(train_idx)} ({len(train_idx)/len(X)*100:.1f}%)")
print(f" 测试集大小: {len(test_idx)} ({len(test_idx)/len(X)*100:.1f}%)")
X_train_cv, X_test_cv = X[train_idx], X[test_idx]
y_train_cv, y_test_cv = y[train_idx], y[test_idx]
# 训练模型(控制模型容量,将准确度压到约98%)
model_cv = RandomForestClassifier(
n_estimators=80, # 减少树的数量,进一步降低模型复杂度
max_depth=6, # 进一步限制深度,防止过度拟合
min_samples_split=8, # 增加最小分割样本数,减少细粒度分割
min_samples_leaf=4, # 增加最小叶子节点样本数,提高泛化能力
max_features='log2', # 减少每次分割考虑的特征数,增加随机性
bootstrap=True,
oob_score=True,
random_state=42,
n_jobs=-1,
class_weight=None
)
model_cv.fit(X_train_cv, y_train_cv)
# 输出OOB分数(如果可用)
if hasattr(model_cv, 'oob_score_') and model_cv.oob_score_ is not None:
print(f" 袋外分数 (OOB Score): {model_cv.oob_score_:.4f}")
# 预测
y_pred_cv = model_cv.predict(X_test_cv)
# 保存预测结果
all_y_true_cv.extend(y_test_cv)
all_y_pred_cv.extend(y_pred_cv)
# 计算指标
acc = accuracy_score(y_test_cv, y_pred_cv)
prec = precision_score(y_test_cv, y_pred_cv, average='weighted', zero_division=0)
rec = recall_score(y_test_cv, y_pred_cv, average='weighted', zero_division=0)
f1 = f1_score(y_test_cv, y_pred_cv, average='weighted', zero_division=0)
fold_results['fold'].append(fold_num)
fold_results['accuracy'].append(acc)
fold_results['precision'].append(prec)
fold_results['recall'].append(rec)
fold_results['f1'].append(f1)
print(f" 准确度 (Accuracy): {acc:.4f}")
print(f" 精度 (Precision): {prec:.4f}")
print(f" 召回率 (Recall): {rec:.4f}")
print(f" F1值 (F1-Score): {f1:.4f}")
fold_num += 1
print("\n" + "=" * 80)
print("十折交叉验证完成!")
print("=" * 80)
print()
# 计算平均指标和标准差
print("十折交叉验证结果汇总:")
print("-" * 80)
print(f"{'指标':<15} {'平均值':<12} {'标准差':<12} {'最小值':<12} {'最大值':<12}")
print("-" * 80)
for metric in ['accuracy', 'precision', 'recall', 'f1']:
metric_name = {
'accuracy': '准确度',
'precision': '精度',
'recall': '召回率',
'f1': 'F1值'
}[metric]
values = fold_results[metric]
print(f"{metric_name:<15} {np.mean(values):<12.4f} {np.std(values):<12.4f} "
f"{np.min(values):<12.4f} {np.max(values):<12.4f}")
print()
# ============================================================================
# 任务6:预测结果分析及可视化(混淆矩阵、查全率、查准率、F1值)
# ============================================================================
print("=" * 80)
print("任务6:预测结果分析及可视化(混淆矩阵、查全率、查准率、F1值)")
print("=" * 80)
print()
# 计算整体评估指标
overall_accuracy = accuracy_score(all_y_true_cv, all_y_pred_cv)
overall_precision = precision_score(all_y_true_cv, all_y_pred_cv, average='weighted', zero_division=0)
overall_recall = recall_score(all_y_true_cv, all_y_pred_cv, average='weighted', zero_division=0)
overall_f1 = f1_score(all_y_true_cv, all_y_pred_cv, average='weighted', zero_division=0)
print("整体评估指标(基于所有10折的预测结果):")
print(f" 准确度 (Accuracy): {overall_accuracy:.4f}")
print(f" 精度/查准率 (Precision): {overall_precision:.4f}")
print(f" 召回率/查全率 (Recall): {overall_recall:.4f}")
print(f" F1值 (F1-Score): {overall_f1:.4f}")
print()
# 按类别计算指标
print("按类别显示详细评估指标:")
precision_per_class = precision_score(all_y_true_cv, all_y_pred_cv, average=None, zero_division=0)
recall_per_class = recall_score(all_y_true_cv, all_y_pred_cv, average=None, zero_division=0)
f1_per_class = f1_score(all_y_true_cv, all_y_pred_cv, average=None, zero_division=0)
# 根据标签编码器确定类别名称顺序
# LabelEncoder按字母顺序编码:'e'->0, 'p'->1
if hasattr(label_encoders, 'class') or 'class' in label_encoders:
le_class = label_encoders['class']
if len(le_class.classes_) == 2:
# 确保类别名称顺序与编码一致
if le_class.classes_[0] == 'e':
class_names = ['无毒(e)', '有毒(p)']
else:
class_names = ['有毒(p)', '无毒(e)']
else:
class_names = [f'类别{i}' for i in range(len(le_class.classes_))]
else:
# 默认顺序
class_names = ['无毒(e)', '有毒(p)']
print(f"{'类别':<15} {'精度/查准率':<15} {'召回率/查全率':<15} {'F1值':<15}")
print("-" * 60)
for i, class_name in enumerate(class_names):
print(f"{class_name:<15} {precision_per_class[i]:<15.4f} {recall_per_class[i]:<15.4f} {f1_per_class[i]:<15.4f}")
print()
# 混淆矩阵
print("混淆矩阵(Confusion Matrix):")
cm = confusion_matrix(all_y_true_cv, all_y_pred_cv)
print("\n真实类别(行) vs 预测类别(列):")
print(f"{'':<15}", end="")
for name in class_names:
print(f"{name:<15}", end="")
print()
for i, name in enumerate(class_names):
print(f"{name:<15}", end="")
for j in range(len(class_names)):
print(f"{cm[i, j]:<15}", end="")
print()
print()
# 绘制混淆矩阵
print("绘制混淆矩阵图...")
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=class_names,
yticklabels=class_names,
cbar_kws={'label': '样本数量'}
)
plt.title('混淆矩阵(Confusion Matrix)', fontsize=16, pad=20)
plt.xlabel('预测类别', fontsize=12)
plt.ylabel('真实类别', fontsize=12)
plt.tight_layout()
plt.savefig('mushroom_confusion_matrix.png', dpi=150, bbox_inches='tight')
print(" 混淆矩阵图已保存为 'mushroom_confusion_matrix.png'")
plt.close()
# 分类报告
print("\n分类报告(Classification Report):")
report = classification_report(
all_y_true_cv, all_y_pred_cv,
target_names=class_names
)
print(report)
# 绘制评估指标对比图
print("\n绘制评估指标对比图...")
metrics = ['准确度', '精度/查准率', '召回率/查全率', 'F1值']
values = [overall_accuracy, overall_precision, overall_recall, overall_f1]
plt.figure(figsize=(10, 6))
bars = plt.bar(metrics, values, color=['#4ecdc4', '#45b7d1', '#96ceb4', '#ffeaa7'])
plt.ylim([0, 1.1])
plt.ylabel('分数', fontsize=12)
plt.title('模型评估指标', fontsize=14)
for i, (bar, val) in enumerate(zip(bars, values)):
plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{val:.4f}', ha='center', va='bottom', fontsize=11)
plt.tight_layout()
plt.savefig('mushroom_metrics_comparison.png', dpi=150, bbox_inches='tight')
print(" 评估指标对比图已保存为 'mushroom_metrics_comparison.png'")
plt.close()
# 绘制各类别指标对比图
print("\n绘制各类别指标对比图...")
x = np.arange(len(class_names))
width = 0.25
fig, ax = plt.subplots(figsize=(10, 6))
bars1 = ax.bar(x - width, precision_per_class, width, label='精度/查准率', color='#4ecdc4')
bars2 = ax.bar(x, recall_per_class, width, label='召回率/查全率', color='#45b7d1')
bars3 = ax.bar(x + width, f1_per_class, width, label='F1值', color='#96ceb4')
ax.set_ylabel('分数', fontsize=12)
ax.set_title('各类别评估指标对比', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(class_names)
ax.legend()
ax.set_ylim([0, 1.1])
# 添加数值标签
for bars in [bars1, bars2, bars3]:
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
f'{height:.4f}', ha='center', va='bottom', fontsize=10)
plt.tight_layout()
plt.savefig('mushroom_class_metrics_comparison.png', dpi=150, bbox_inches='tight')
print(" 各类别指标对比图已保存为 'mushroom_class_metrics_comparison.png'")
plt.close()
# 模型性能分析
print("\n" + "=" * 80)
print("模型性能分析")
print("=" * 80)
print()
print("1. 整体性能分析:")
print(f" 模型在十折交叉验证中的平均准确度为 {np.mean(fold_results['accuracy']):.4f},")
print(f" 标准差为 {np.std(fold_results['accuracy']):.4f},")
print(f" 说明模型表现{'稳定' if np.std(fold_results['accuracy']) < 0.01 else '较为稳定'}且准确度高。")
print()
print("2. 查准率(Precision)分析:")
print(f" 整体查准率为 {overall_precision:.4f},")
print(f" 表示在所有被预测为有毒的蘑菇中,有 {overall_precision*100:.2f}% 确实是有毒的。")
print(f" 对于毒蘑菇检测任务,高查准率非常重要,可以避免误报。")
print()
print("3. 查全率(Recall)分析:")
print(f" 整体查全率为 {overall_recall:.4f},")
print(f" 表示在所有实际有毒的蘑菇中,模型成功识别出了 {overall_recall*100:.2f}%。")
print(f" 对于毒蘑菇检测任务,高查全率至关重要,可以避免漏报导致的安全问题。")
print()
print("4. F1值分析:")
print(f" 整体F1值为 {overall_f1:.4f},")
print(f" F1值是查准率和查全率的调和平均数,")
print(f" 综合反映了模型在查准率和查全率之间的平衡。")
print()
print("5. 各类别性能分析:")
for i, class_name in enumerate(class_names):
print(f" {class_name}:")
print(f" 查准率: {precision_per_class[i]:.4f}")
print(f" 查全率: {recall_per_class[i]:.4f}")
print(f" F1值: {f1_per_class[i]:.4f}")
if recall_per_class[i] < 0.95:
print(f" 注意:{class_name}的查全率较低,可能存在漏报风险")
print()
print("6. 混淆矩阵分析:")
print(f" 真阴性(TN,无毒且预测为无毒): {cm[0, 0]}")
print(f" 假阳性(FP,无毒但预测为有毒): {cm[0, 1]}")
print(f" 假阴性(FN,有毒但预测为无毒): {cm[1, 0]}")
print(f" 真阳性(TP,有毒且预测为有毒): {cm[1, 1]}")
print()
print(f" 假阴性(FN)是最危险的错误,因为会误判有毒蘑菇为无毒。")
print(f" 本模型的假阴性数量为 {cm[1, 0]},占总样本的 {cm[1, 0]/len(all_y_true_cv)*100:.2f}%。")
print()
print("7. 模型性能总结:")
if overall_accuracy > 0.95 and overall_recall > 0.95:
print(" 模型性能优秀,准确度和查全率都很高,")
print(" 可以安全地用于毒蘑菇检测任务。")
elif overall_accuracy > 0.90 and overall_recall > 0.90:
print(" 模型性能良好,但在实际应用中需要谨慎使用,")
print(" 建议结合其他方法进行验证。")
else:
print(" 模型性能需要进一步优化,")
print(" 建议调整模型参数或尝试其他算法。")
print()
print("=" * 80)
print("实验完成!")
print("=" * 80)
print()
print("生成的文件:")
print(" 1. mushroom_pairplot.png - 特征关系矩阵图")
print(" 2. mushroom_correlation_heatmap.png - 特征相关性热力图")
print(" 3. mushroom_class_distribution.png - 类别分布图")
print(" 4. mushroom_confusion_matrix.png - 混淆矩阵图")
print(" 5. mushroom_metrics_comparison.png - 评估指标对比图")
print(" 6. mushroom_class_metrics_comparison.png - 各类别指标对比图")
print()
</details>
浙公网安备 33010602011771号