1、多分类指标 前面已经深入讨论了二分类任务的评估,下面来看一下对多分类问题的评估指标。 多分类问题的所有指标基本上都来自二分类指标,但要对所有类别进行平均。 除了精度
前面已经深入讨论了二分类任务的评估,下面来看一下对多分类问题的评估指标。
多分类问题的所有指标基本上都来自二分类指标,但要对所有类别进行平均。
除了精度,常用的工具有混淆矩阵和分类报告
sklearn.metrics.confusion_metrix
sklearn.metrics.classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.datasets import load_digits
from matplotlib import pyplot as plt
from sklearn.linear_model import LogisticRegression
#建立不平衡数据集
digits = load_digits()
#y = digits.target==9
#划分数据集
X_train,X_test,y_train,y_test = train_test_split(digits.data,digits.target,random_state=0)
#构建逻辑回归模型
lgr = LogisticRegression().fit(X_train,y_train)
pred = lgr.predict(X_test)
print("Confusion matrix:\n{}".format(confusion_matrix(y_test,pred)))
'''
```
Confusion matrix:
[[37 0 0 0 0 0 0 0 0 0]
[ 0 40 0 0 0 0 0 0 2 1]
[ 0 1 40 3 0 0 0 0 0 0]
[ 0 0 0 43 0 0 0 0 1 1]
[ 0 0 0 0 37 0 0 1 0 0]
[ 0 0 0 0 0 46 0 0 0 2]
[ 0 1 0 0 0 0 51 0 0 0]
[ 0 0 0 1 1 0 0 46 0 0]
[ 0 3 1 0 0 0 0 0 43 1]
[ 0 0 0 0 0 1 0 0 1 45]]
```
'''
#将混淆矩阵转换为热图
import mglearn
images = mglearn.tools.heatmap(confusion_matrix(y_test,pred),xlabel="predict label",ylabel="true label",
xticklabels=digits.target_names,yticklabels=digits.target_names,
cmap=plt.cm.gray_r,fmt="%d")
plt.title("Confusion metrix")
plt.gca().invert_yaxis()