Python——通过plt库绘制可视化的混淆矩阵

under Python  机器学习  tag     Published on May 5th , 2020 at 10:20 pm

前言

混淆矩阵(confusion matrix)也称误差矩阵,是表示精度评价的一种标准格式,用n行n列的方阵。为了让混淆矩阵更加直观,可以考虑可视化混淆矩阵。本文介绍训练后的模型通过sklearn.metrics.confusion_matrix()方法获取混淆矩阵,然后通过plt绘图,使混淆矩阵可视化。

获取混淆矩阵

混淆矩阵可以通过sklearn库获取,具体方法如下:

from sklearn import metrics

confusion_mat = metrics.confusion_matrix(y_true, y_predict, labels=None, sample_weight=None)

参数说明

  • y_true:真实标签
  • y_predict:预测标签
  • labels:是所给出的类别,通过这个可对类别进行选择
  • sample_weight : 样本权重

混淆矩阵可视化

matploylib.pyplot是python中数据可视化的常用库,可以通过plt对混淆矩阵可视化,具体方法如下:

def plot_confusion_matrix(confusion_mat):
    # 画混淆矩阵图,配色风格使用cm.Greens
    plt.imshow(confusion_mat,interpolation='nearest',cmap=plt.cm.Greens)
    
    # 显示colorbar
    plt.colorbar()
    
    # 使用annotate在图中显示混淆矩阵的数据
    for x in range(len(confusion_mat)):
        for y in range(len(confusion_mat)):
            plt.annotate(confusion_mat[x, y], xy=(x, y), horizontalalignment='center', verticalalignment='center')
            # 第一个参数是注释的内容
            # xy设置箭头尖的坐标
            # horizontalalignment水平对齐
            # verticalalignment垂直对齐
            # 其余常用参数如下:
            # xytext设置注释内容显示的起始位置
            # arrowprops 用来设置箭头
            # facecolor 设置箭头的颜色
            # headlength 箭头的头的长度
            # headwidth 箭头的宽度
            # width 箭身的宽度

    plt.title('Confusion Matrix')     # 图标title
    plt.ylabel('True label')          # 坐标轴标签
    plt.xlabel('Predicted label')     # 坐标轴标签
    
    tick_marks = np.arange(2)
    plt.xticks(tick_marks, tick_marks)
    plt.yticks(tick_marks, tick_marks)
    plt.show()

cm为配色风格,默认有:


本文由simyng创作, 采用知识共享署名4.0 国际许可协议进行许可,转载前请务必署名
  文章最后更新时间为:May 5th , 2020 at 02:20 pm