在scikit-learn中可视化交叉验证行为

选择正确的交叉验证对象是正确拟合模型的关键部分。有很多方法可以将数据分为训练集和测试集,从而避免模型过度拟合,例如标准化测试集中的组数等。

本示例将几个常见的scikit学习对象的行为可视化以进行比较。

from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit,
                                     StratifiedKFold, GroupShuffleSplit,
                                     GroupKFold, StratifiedShuffleSplit)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
np.random.seed(1338)
cmap_data = plt.cm.Paired
cmap_cv = plt.cm.coolwarm
n_splits = 4

可视化我们的数据

首先,我们必须了解数据的结构。它包含100个随机生成的输入数据点,数据点之间标签被不均匀地划分为三类,同时我们均匀划分了10个“组”。

正如我们将看到的,一些交叉验证对象对带有标签的数据执行特定的操作,另一些对分组数据的处理方式有所不同,而另一些则不使用此信息。

首先,我们将可视化数据。

# 生成类别/组数据
n_points = 100
X = np.random.randn(10010)

percentiles_classes = [.1.3.6]
y = np.hstack([[ii] * int(100 * perc)
               for ii, perc in enumerate(percentiles_classes)])

# 间隔均匀的组重复一次
groups = np.hstack([[ii] * 10 for ii in range(10)])


def visualize_groups(classes, groups, name):
    # 可视化数据集组
    fig, ax = plt.subplots()
    ax.scatter(range(len(groups)),  [.5] * len(groups), c=groups, marker='_',
               lw=50, cmap=cmap_data)
    ax.scatter(range(len(groups)),  [3.5] * len(groups), c=classes, marker='_',
               lw=50, cmap=cmap_data)
    ax.set(ylim=[-15], yticks=[.53.5],
           yticklabels=['Data\ngroup''Data\nclass'], xlabel="Sample index")


visualize_groups(y, groups, 'no groups')

定义一个函数以可视化交叉验证行为

我们将定义一个函数,使我们可以可视化每个交叉验证对象的行为。 我们将对数据进行4次拆分。在每个分组中,我们将为训练集(蓝色)和测试集(红色)可视化选择的索引。

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """为交叉验证对象的索引创建样本图."""

    # 为每个交叉验证分组生成训练/测试可视化图像
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        # 与训练/测试组一起填写索引
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        # 可视化结果
        ax.scatter(range(len(indices)), [ii + .5] * len(indices),
                   c=indices, marker='_', lw=lw, cmap=cmap_cv,
                   vmin=-.2, vmax=1.2)

    # 将数据的分组情况和标签情况放入图像
    ax.scatter(range(len(X)), [ii + 1.5] * len(X),
               c=y, marker='_', lw=lw, cmap=cmap_data)

    ax.scatter(range(len(X)), [ii + 2.5] * len(X),
               c=group, marker='_', lw=lw, cmap=cmap_data)

    # 调整格式
    yticklabels = list(range(n_splits)) + ['class''group']
    ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels,
           xlabel='Sample index', ylabel="CV iteration",
           ylim=[n_splits+2.2-.2], xlim=[0100])
    ax.set_title('{}'.format(type(cv).__name__), fontsize=15)
    return ax

现在看看K折交叉验证对象可视化后效果如何:

fig, ax = plt.subplots()
cv = KFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)

输出:

<matplotlib.axes._subplots.AxesSubplot object at 0x7f96064f9190>

如您所见,默认情况下,K折交叉验证迭代器不考虑数据点类或组。我们可以像这样使用StratifiedKFold来改变它。

fig, ax = plt.subplots()
cv = StratifiedKFold(n_splits)
plot_cv_indices(cv, X, y, groups, ax, n_splits)
<matplotlib.axes._subplots.AxesSubplot object at 0x7f96042325b0>

在这种情况下,交叉验证在每个CV划分中保留相同的类比例。 接下来,我们将可视化许多CV迭代器的行为。

可视化许多CV对象的交叉验证索引

让我们直观地比较许多scikit-learn交叉验证对象的交叉验证行为。下面,我们将循环浏览几个常见的交叉验证对象,以可视化每个对象的行为。

注意有些交叉验证如何使用组/类信息,而有些交叉验证则不使用。

cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold,
       GroupShuffleSplit, StratifiedShuffleSplit,
       TimeSeriesSplit]

for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(63))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.02))],
              ['Testing set''Training set'], loc=(1.02.8))
    # Make the legend fit
    plt.tight_layout()
    fig.subplots_adjust(right=.7)
plt.show()

脚本的总运行时间:(0分钟0.937秒)