使用XGBoost的算法往往能在Kaggle等数据科学竞赛中取得好成绩,因此广受欢迎。本文使用特定的数据集来分析XGBoost机器学习模型的预测过程。通过使用可视化手段展示结果,可以更好地理解模型的预测过程。随着机器学习的工业应用不断发展,理解、解释和定义机器学习模型的工作原理似乎是一个越来越明显的趋势。对于非深度学习类型的机器学习分类问题,XGBoost是最受欢迎的库。由于XGBoost可以很好地扩展到大型数据集并支持多种语言,因此它在商业环境中特别有用。例如,使用XGBoost可以很容易地在Python中训练模型并将模型部署到Java生产环境。虽然XGBoost可以达到很高的准确率,但对于XGBoost如何做出决策以达到如此高的准确率,它仍然不够透明。在将结果直接交给客户时,这种不透明性可能是一个严重的缺陷。了解事情发生的原因很有用。转向应用机器学习来理解数据的公司也需要理解模型的预测。这变得越来越重要。例如,没有人希望征信机构使用机器学习模型来预测用户的信用度,而无法解释做出这些预测的过程。另一个例子是如果我们的机器学习模型说婚姻档案和出生档案与同一个人相关(文件关联任务),但是档案上的日期暗示婚姻的双方都是一个很老的人和一个非常年轻的人,我们可能会质疑为什么模特会联想到他们。在这种情况下,理解模型为什么做出这样的预测是非常有价值的。结果可能是该模型考虑了名称和位置的唯一性并做出了正确的预测。但也可能是模特的特征没有正确考虑个人资料上的年龄差距。在这种情况下,了解模型预测可以帮助我们找到提高模型性能的方法。在本文中,我们将介绍一些技术以更好地理解XGBoost的预测过程。这使我们能够利用梯度提升的力量,同时仍然了解模型的决策过程。为了解释这些技术,我们将使用泰坦尼克号数据集。该数据集包含有关每位泰坦尼克号乘客的信息(包括乘客是否幸存)。我们的目标是预测乘客是否会幸存,并了解做出该预测的过程。即使有了这些数据,我们也可以看出理解模型决策的重要性。想象一下,如果我们有最近沉船事故的乘客数据集。建立这样一个预测模型的目的实际上并不是预测结果本身,而是了解预测过程可以帮助我们了解如何在事故中最大化幸存者的数量。从xgboost导入XGBClassifier从sklearn.model_selection导入train_test_split从sklearn.metrics导入accuracy_scoreimportoperatorimportmatplotlib.pyplot作为pltimportseaborn作为snsimportlime.lime_tabular从sklearn.pipeline导入Pipelinefromsklearn.preprocessing我们要做的第一件事是查看我们的数据,您可以在Kaggle(https://www.kaggle.com/c/titanic/data)上找到这些数据。拿到数据集后,我们就简单的清理一下数据。即:清洗namesandpassengerIDsconvertingcategoricalvariablestodummyvariablespaddingandremovingdatawithmedians这些清洗技术很简单,本文的目的不是讨论数据清洗,而是解释XGBoost,所以这些都是快速合理的清洗用于训练模型。data=pd.read_csv("./data/titantic/train.csv")y=data.SurvivedX=data.drop(["Survived","Name","PassengerId"],1)X=pd.get_dummies(X)现在让我们将数据集分成训练集和测试集。X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.33,random_state=42)并构建具有少量超参数测试的训练管道。pipeline=Pipeline([('imputer',Imputer(strategy='median')),('model',XGBClassifier())])parameters=dict(model__max_depth=[3,5,7],model__learning_rate=[.01,.1],model__n_estimators=[100,500])cv=GridSearchCV(pipeline,param_grid=parameters)cv.fit(X_train,y_train)然后查看测试结果。为简单起见,我们将使用与Kaggle相同的指标:准确性。test_predictions=cv.predict(X_test)print("TestAccuracy:{}".format(accuracy_score(y_test,test_predictions)))TestAccuracy:0.8101694915254237目前为止我们得到了相当不错的准确率,大约9000场比赛在Kaggle中排名前500名。所以还有进一步改进的空间,但这留给读者作为练习。我们继续讨论了解模型学到了什么。一种常见的方法是使用XGBoost提供的特征重要性。特征重要性的级别越高,特征对改进模型预测的贡献就越大。接下来我们将使用重要性参数对特征进行排序并比较相对重要性。fi=list(zip(X.columns,cv.best_estimator_.named_steps['model'].feature_importances_))fi.sort(key=operator.itemgetter(1),reverse=True)top_10=fi[:10]x=[x[0]forxintop_10]y=[x[1]forxintop_10]top_10_chart=sns.barplot(x,y)plt.setp(top_10_chart.get_xticklabels(),rotation=90)从上图可以看出票价和年龄是重要的特征。我们可以进一步看survival/death和fare的相对分布:sns.barplot(y_train,X_train['Fare'])我们可以清楚的看到,那些幸存者的平均fare比受害者高很多,所以它将票价视为一个重要特征可能是合理的。特征重要性通常是理解特征重要性的好方法。如果存在这样一种特殊情况,即模型预测票价高的乘客将无法生存,那么我们可以得出结论,票价高并不一定会导致生存。接下来,我们将分析其他可能导致模型得出乘客无法幸存的结论的因素。特征。这种个人层面的分析对于生产机器学习系统非常有用。考虑另一个示例,其中模型用于预测某人是否有资格获得贷款。我们知道信用评分将是模型的一个重要特征,但是信用评分高的客户被模型拒绝了,我们将如何向客户解释?以及如何向管理者解释?幸运的是,华盛顿大学最近有一项关于解释任意分类器预测过程的研究。他们的方法称为LIME,已在GitHub(https://github.com/marcotcr/lime)上开源。本文不打算讨论这个,可以参考论文(https://arxiv.org/pdf/1602.04938.pdf)接下来我们尝试在模型中应用LIME。基本上,我们首先需要定义一个解释器来处理训练数据(我们需要确保传递给解释器的估计训练数据集正是将要训练的数据集):X_train_imputed=cv.best_estimator_.named_steps['imputer'].transform(X_train)explainer=lime.lime_tabular.LimeTabularExplainer(X_train_imputed,feature_names=X_train.columns.tolist(),class_names=["NotSurvived","Survived"],discretize_continuous=True)那么你必须定义一个函数将特征数组作为变量,并返回一个数组和每个类别的概率:model=cv.best_estimator_.named_steps['model']defxgb_prediction(X_array_in):iflen(X_array_in.shape)<2:X_array_in=np.expand_dims(X_array_in,0)returnmodel.predict_proba(X_array_in)最后,我们通过一个例子让解释器使用你的函数输出特征和标签的数量:X_test_imputed=cv.best_estimator_.named_steps['imputer'].transform(X_test)exp=解释器。explain_instance(X_test_imputed[1],xgb_prediction,num_features=5,top_labels=1)exp.show_in_notebook(show_table=True,show_all=False)这里我们有一个实例,有76%的几率不存在。我们还想看看哪个特征对哪个类别的贡献最大,以及它的重要性。例如,当Sex=Female时,存活几率更大。让我们看看条形图:sns.barplot(X_train['Sex_female'],y_train)所以这似乎是有道理的。如果您是女性,这会大大提高您在训练数据中幸存下来的机会。那么为什么预测结果是“没活下来”呢?Pclass=2.0似乎大大降低了生存率。看一下:sns.barplot(X_train['Pclass'],y_train)好像Pclass等于2的存活率比较低,所以对我们的预测结果有了更多的了解。看LIME上显示的top5特征,好像这个人还活着,我们看它的标签:y_test.values[0]>>>1这个人确实活下来了,所以我们的模型是错误的!感谢LIME,我们可以深入了解问题的原因:看起来Pclass可能需要删除。这种方法可以帮助我们并希望找到一些改进模型的方法。本文为读者提供了一种简单有效的方式来理解XGBoost。希望这些方法可以帮助您合理使用XGBoost,让您的模型做出更好的推理。
