带有可视化API的ROC曲线

Scikit-learn定义了一个简单的API,用于创建用于机器学习的可视化。 该API的主要功能是无需重新计算即可进行快速绘图和视觉调整。 在此示例中,我们将通过比较ROC曲线来演示如何使用可视化API。

导入数据并训练一个支持向量机

首先,我们加载红酒数据集并将其转换为二分类分类问题。 然后,我们在训练数据集上训练支持向量分类器。

import matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import plot_roc_curve
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split

X, y = load_wine(return_X_y=True)
y = y == 2

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)

输出:

绘制ROC曲线

接下来,我们通过一次调用sklearn.metrics.plot_roc_curve绘制ROC曲线。 返回的svc_disp对象使我们可以在以后的图中继续使用已经计算出的SVC ROC曲线。

svc_disp = plot_roc_curve(svc, X_test, y_test)
plt.show()

输出:

训练随机森林并绘制ROC曲线

我们训练一个随机森林分类器,并创建一个将其与SVC ROC曲线进行比较的图。 注意svc_disp如何使用plot来绘制SVC ROC曲线,而无需重新计算roc曲线本身的值。 此外,我们将alpha = 0.8传递给绘图函数以调整曲线的alpha值。

rfc = RandomForestClassifier(n_estimators=10, random_state=42)
rfc.fit(X_train, y_train)
ax = plt.gca()
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
svc_disp.plot(ax=ax, alpha=0.8)
plt.show()

输出:

脚本的总运行时间:0分钟0.233秒