在训练深度学习模型的时候,通常将数据集切分为训练集和验证集.Keras提供了两种评估模型性能的方法:
使用自动切分的验证集
使用手动切分的验证集
一.自动切分
在Keras中,可以从数据集中切分出一部分作为验证集,并且在每次迭代(epoch)时在验证集中评估模型的性能.
具体地,调用model.fit()训练模型时,可通过validation_split参数来指定从数据集中切分出验证集的比例.
# MLP with automatic validation set from keras.models import Sequential from keras.layers import Dense import numpy # fix random seed for reproducibility numpy.random.seed(7) # load pima indians dataset dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",") # split into input (X) and output (Y) variables X = dataset[:,0:8] Y = dataset[:,8] # create model model = Sequential() model.add(Dense(12, input_dim=8, activation='relu')) model.add(Dense(8, activation='relu')) model.add(Dense(1, activation='sigmoid')) # Compile model model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # Fit the model model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10)
validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。
注意,validation_split的划分在shuffle之前,因此如果你的数据本身是有序的,需要先手工打乱再指定validation_split,否则可能会出现验证集样本不均匀。
二.手动切分
Keras允许在训练模型的时候手动指定验证集.
例如,用sklearn库中的train_test_split()函数将数据集进行切分,然后在keras的model.fit()的时候通过validation_data参数指定前面切分出来的验证集.
# MLP with manual validation set from keras.models import Sequential from keras.layers import Dense from sklearn.model_selection import train_test_split import numpy # fix random seed for reproducibility seed = 7 numpy.random.seed(seed) # load pima indians dataset dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",") # split into input (X) and output (Y) variables X = dataset[:,0:8] Y = dataset[:,8] # split into 67% for train and 33% for test X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=seed) # create model model = Sequential() model.add(Dense(12, input_dim=8, activation='relu')) model.add(Dense(8, activation='relu')) model.add(Dense(1, activation='sigmoid')) # Compile model model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # Fit the model model.fit(X_train, y_train, validation_data=(X_test,y_test), epochs=150, batch_size=10)
三.K折交叉验证(k-fold cross validation)
将数据集分成k份,每一轮用其中(k-1)份做训练而剩余1份做验证,以这种方式执行k轮,得到k个模型.将k次的性能取平均,作为该算法的整体性能.k一般取值为5或者10.
优点:能比较鲁棒性地评估模型在未知数据上的性能.
缺点:计算复杂度较大.因此,在数据集较大,模型复杂度较高,或者计算资源不是很充沛的情况下,可能不适用,尤其是在训练深度学习模型的时候.
sklearn.model_selection提供了KFold以及RepeatedKFold, LeaveOneOut, LeavePOut, ShuffleSplit, StratifiedKFold, GroupKFold, TimeSeriesSplit等变体.
下面的例子中用的StratifiedKFold采用的是分层抽样,它保证各类别的样本在切割后每一份小数据集中的比例都与原数据集中的比例相同.
# MLP for Pima Indians Dataset with 10-fold cross validation from keras.models import Sequential from keras.layers import Dense from sklearn.model_selection import StratifiedKFold import numpy # fix random seed for reproducibility seed = 7 numpy.random.seed(seed) # load pima indians dataset dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",") # split into input (X) and output (Y) variables X = dataset[:,0:8] Y = dataset[:,8] # define 10-fold cross validation test harness kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed) cvscores = [] for train, test in kfold.split(X, Y): # create model model = Sequential() model.add(Dense(12, input_dim=8, activation='relu')) model.add(Dense(8, activation='relu')) model.add(Dense(1, activation='sigmoid')) # Compile model model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # Fit the model model.fit(X[train], Y[train], epochs=150, batch_size=10, verbose=0) # evaluate the model scores = model.evaluate(X[test], Y[test], verbose=0) print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100)) cvscores.append(scores[1] * 100) print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores), numpy.std(cvscores)))
补充知识:训练集,验证集和测试集
训练集:通过最小化目标函数(损失函数 + 正则项),用来训练模型的参数。当目标函数最小化时,完成对模型的训练。
验证集:用来选择模型的阶数。目标函数最小的模型对应的阶数,为模型的最终选择的阶数。
注:
1. 验证集会在训练过程中,反复使用,机器学习中作为选择不同模型的评判标准,深度学习中作为选择网络层数和每层节点数的评判标准。
2. 验证集的使用并非必不可少,如果网络的层数和节点数已经确定,则不需要这一步操作。
测试集:评估模型的泛化能力。根据选择的已经训练好的模型,评估它的泛化能力。
注:
测试集评判的是最终训练好的模型的泛化能力,只进行一次评判。
以上这篇sklearn和keras的数据切分与交叉验证的实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持易盾网络。