在scikit-learn中可视化交叉验证行为¶
选择正确的交叉验证对象是正确拟合模型的关键部分。有很多方法可以将数据分为训练集和测试集,从而避免模型过度拟合,例如标准化测试集中的组数等。
本示例将几个常见的scikit学习对象的行为可视化以进行比较。
from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit,
StratifiedKFold, GroupShuffleSplit,
GroupKFold, StratifiedShuffleSplit)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
np.random.seed(1338)
cmap_data = plt.cm.Paired
cmap_cv = plt.cm.coolwarm
n_splits = 4
可视化我们的数据
首先,我们必须了解数据的结构。它包含100个随机生成的输入数据点,数据点之间标签被不均匀地划分为三类,同时我们均匀划分了10个“组”。
正如我们将看到的,一些交叉验证对象对带有标签的数据执行特定的操作,另一些对分组数据的处理方式有所不同,而另一些则不使用此信息。
首先,我们将可视化数据。
# 生成类别/组数据
n_points = 100
X = np.random.randn(100, 10)
percentiles_classes = [.1, .3, .6]
y = np.hstack([[ii] * int(100 * perc)
for ii, perc in enumerate(percentiles_classes)])
# 间隔均匀的组重复一次
groups = np.hstack([[ii] * 10 for ii in range(10)])
def visualize_groups(classes, groups, name):
# 可视化数据集组
fig, ax = plt.subplots()
ax.scatter(range(len(groups)), [.5] * len(groups), c=groups, marker='_',
lw=50, cmap=cmap_data)
ax.scatter(range(len(groups)), [3.5] * len(groups), c=classes, marker='_',
lw=50, cmap=cmap_data)
ax.set(ylim=[-1, 5], yticks=[.5, 3.5],
yticklabels=['Data\ngroup', 'Data\nclass'], xlabel="Sample index")
visualize_groups(y, groups, 'no groups')
定义一个函数以可视化交叉验证行为
我们将定义一个函数,使我们可以可视化每个交叉验证对象的行为。 我们将对数据进行4次拆分。在每个分组中,我们将为训练集(蓝色)和测试集(红色)可视化选择的索引。
def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
"""为交叉验证对象的索引创建样本图."""
# 为每个交叉验证分组生成训练/测试可视化图像
for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
# 与训练/测试组一起填写索引
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
# 可视化结果
ax.scatter(range(len(indices)), [ii + .5] * len(indices),
c=indices, marker='_', lw=lw, cmap=cmap_cv,
vmin=-.2, vmax=1.2)
# 将数据的分组情况和标签情况放入图像
ax.scatter(range(len(X)), [ii + 1.5] * len(X),
c=y, marker='_', lw=lw, cmap=cmap_data)
ax.scatter(range(len(X)), [ii + 2.5] * len(X),
c=group, marker='_', lw=lw, cmap=cmap_data)
# 调整格式
yticklabels = list(range(n_splits)) + ['class', 'group']
ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels,
xlabel='Sample index', ylabel="CV iteration",
ylim=[n_splits+2.2, -.2], xlim=[0, 100])
ax.set_title('{}'.format(type(cv).__name__), fontsize=15)
return ax
现在看看K折交叉验证对象可视化后效果如何:
fig, ax = plt.subplots()
cv = KFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)
输出:
<matplotlib.axes._subplots.AxesSubplot object at 0x7f96064f9190>
如您所见,默认情况下,K折交叉验证迭代器不考虑数据点类或组。我们可以像这样使用StratifiedKFold来改变它。
fig, ax = plt.subplots()
cv = StratifiedKFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)
<matplotlib.axes._subplots.AxesSubplot object at 0x7f96042325b0>
在这种情况下,交叉验证在每个CV划分中保留相同的类比例。 接下来,我们将可视化许多CV迭代器的行为。
可视化许多CV对象的交叉验证索引
让我们直观地比较许多scikit-learn交叉验证对象的交叉验证行为。下面,我们将循环浏览几个常见的交叉验证对象,以可视化每个对象的行为。
注意有些交叉验证如何使用组/类信息,而有些交叉验证则不使用。
cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold,
GroupShuffleSplit, StratifiedShuffleSplit,
TimeSeriesSplit]
for cv in cvs:
this_cv = cv(n_splits=n_splits)
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(this_cv, X, y, groups, ax, n_splits)
ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.02))],
['Testing set', 'Training set'], loc=(1.02, .8))
# Make the legend fit
plt.tight_layout()
fig.subplots_adjust(right=.7)
plt.show()
脚本的总运行时间:(0分钟0.937秒)