当前位置 : 主页 > 编程语言 > 其它开发 >

评估指标与评分(下):多分类指标及其他

来源:互联网 收集:自由互联 发布时间:2022-05-30
1、多分类指标 前面已经深入讨论了二分类任务的评估,下面来看一下对多分类问题的评估指标。 多分类问题的所有指标基本上都来自二分类指标,但要对所有类别进行平均。 除了精度
1、多分类指标

前面已经深入讨论了二分类任务的评估,下面来看一下对多分类问题的评估指标。

多分类问题的所有指标基本上都来自二分类指标,但要对所有类别进行平均。

除了精度,常用的工具有混淆矩阵和分类报告

sklearn.metrics.confusion_metrix

sklearn.metrics.classification_report

  from sklearn.metrics import confusion_matrix
  from sklearn.metrics import classification_report
  from sklearn.datasets import load_digits
  from matplotlib import pyplot as plt
  from sklearn.linear_model import LogisticRegression

  #建立不平衡数据集
  digits = load_digits()

  #y = digits.target==9

  #划分数据集
  X_train,X_test,y_train,y_test = train_test_split(digits.data,digits.target,random_state=0)

  #构建逻辑回归模型
  lgr = LogisticRegression().fit(X_train,y_train)
  pred = lgr.predict(X_test)

  print("Confusion matrix:\n{}".format(confusion_matrix(y_test,pred)))


  '''
  ```
  Confusion matrix:
  [[37  0  0  0  0  0  0  0  0  0]
   [ 0 40  0  0  0  0  0  0  2  1]
   [ 0  1 40  3  0  0  0  0  0  0]
   [ 0  0  0 43  0  0  0  0  1  1]
   [ 0  0  0  0 37  0  0  1  0  0]
   [ 0  0  0  0  0 46  0  0  0  2]
   [ 0  1  0  0  0  0 51  0  0  0]
   [ 0  0  0  1  1  0  0 46  0  0]
   [ 0  3  1  0  0  0  0  0 43  1]
   [ 0  0  0  0  0  1  0  0  1 45]]
  ```
  '''

  #将混淆矩阵转换为热图

  import mglearn

  images = mglearn.tools.heatmap(confusion_matrix(y_test,pred),xlabel="predict label",ylabel="true label",
                                 xticklabels=digits.target_names,yticklabels=digits.target_names,
                                cmap=plt.cm.gray_r,fmt="%d")

  plt.title("Confusion metrix")
  plt.gca().invert_yaxis()

上一篇:【Golang】关于Go中的struct{}
下一篇:没有了
网友评论