Python——通过plt库绘制可视化的混淆矩阵
字数统计:308 阅读时长 ≈ 1分钟前言
混淆矩阵(confusion matrix)也称误差矩阵,是表示精度评价的一种标准格式,用n行n列的方阵。为了让混淆矩阵更加直观,可以考虑可视化混淆矩阵。本文介绍训练后的模型通过sklearn.metrics.confusion_matrix()方法获取混淆矩阵,然后通过plt绘图,使混淆矩阵可视化。
获取混淆矩阵
混淆矩阵可以通过sklearn库获取,具体方法如下:
from sklearn import metrics
confusion_mat = metrics.confusion_matrix(y_true, y_predict, labels=None, sample_weight=None)
参数说明
- y_true:真实标签
- y_predict:预测标签
- labels:是所给出的类别,通过这个可对类别进行选择
- sample_weight : 样本权重
混淆矩阵可视化
matploylib.pyplot是python中数据可视化的常用库,可以通过plt对混淆矩阵可视化,具体方法如下:
def plot_confusion_matrix(confusion_mat):
# 画混淆矩阵图,配色风格使用cm.Greens
plt.imshow(confusion_mat,interpolation='nearest',cmap=plt.cm.Greens)
# 显示colorbar
plt.colorbar()
# 使用annotate在图中显示混淆矩阵的数据
for x in range(len(confusion_mat)):
for y in range(len(confusion_mat)):
plt.annotate(confusion_mat[x, y], xy=(x, y), horizontalalignment='center', verticalalignment='center')
# 第一个参数是注释的内容
# xy设置箭头尖的坐标
# horizontalalignment水平对齐
# verticalalignment垂直对齐
# 其余常用参数如下:
# xytext设置注释内容显示的起始位置
# arrowprops 用来设置箭头
# facecolor 设置箭头的颜色
# headlength 箭头的头的长度
# headwidth 箭头的宽度
# width 箭身的宽度
plt.title('Confusion Matrix') # 图标title
plt.ylabel('True label') # 坐标轴标签
plt.xlabel('Predicted label') # 坐标轴标签
tick_marks = np.arange(2)
plt.xticks(tick_marks, tick_marks)
plt.yticks(tick_marks, tick_marks)
plt.show()
cm为配色风格,默认有:
本文由simyng创作,
采用知识共享署名4.0 国际许可协议进行许可,转载前请务必署名
文章最后更新时间为:May 5th , 2020 at 02:20 pm