sklearn.model_selection.PredefinedSplit

class sklearn.model_selection.PredefinedSplit(test_fold)

[源码]

预定义的切分交叉验证器。

提供训练集或测试集的索引,使用用户通过test_fold参数指定的预定义方案将数据分为训练集或测试集。

用户指南中阅读更多内容。

版本0.16中的新功能。

参数 说明
test_fold array-like of shape (n_samples,)
输入test_fold[i]表示样本i所属的测试集的索引。通过将test_fold[i]值设置为-1,可以从任何测试集中排除样本i(即在每个训练集中包括样本i)。

示例

>>> import numpy as np
>>> from sklearn.model_selection import PredefinedSplit
>>> X = np.array([[12], [34], [12], [34]])
>>> y = np.array([0011])
>>> test_fold = [01-11]
>>> ps = PredefinedSplit(test_fold)
>>> ps.get_n_splits()
2
>>> print(ps)
PredefinedSplit(test_fold=array([ 0,  1-1,  1]))
>>> for train_index, test_index in ps.split():
...     print("TRAIN:", train_index, "TEST:", test_index)
...     X_train, X_test = X[train_index], X[test_index]
...     y_train, y_test = y[train_index], y[test_index]
TRAIN: [1 2 3] TEST: [0]
TRAIN: [0 2] TEST: [1 3]

方法

方法 说明
get_n_splits(self[, X, y, groups]) 返回交叉验证器中的拆分迭代次数。
split(self[, X, y, groups]) 生成索引以将数据分为训练集和测试集。
__init__(self, test_fold)

[源码]

初始化self。详情可参阅 type(self)的帮助。

get_n_splits(self,X = None,y = None,groups = None )

[源码]

返回交叉验证器中的切分迭代次数。

参数 说明
X object
始终被忽略,为了兼容性而存在。
y object
始终被忽略,为了兼容性而存在。
groups object
始终被忽略,为了兼容性而存在。
返回值 说明
n_splits int
返回交叉验证器中拆分迭代的次数。
split(self,X = None,y = None,groups = None )

[源码]

生成索引以将数据分为训练和测试集。

参数 说明
X object
始终被忽略,为了兼容性而存在。
y object
始终被忽略,为了兼容性而存在。
groups object
始终被忽略,为了兼容性而存在。
输出 说明
train ndarray
切分的训练集索引。
test ndarray
切分的测试集索引。