NimbusML-Samples/samples/4.1 [Model Explainability] ...

1081 строка
45 KiB
Plaintext
Исходник Постоянная ссылка Ответственный История

Этот файл содержит неоднозначные символы Юникода!

Этот файл содержит неоднозначные символы Юникода, которые могут быть перепутаны с другими в текущей локали. Если это намеренно, можете спокойно проигнорировать это предупреждение. Используйте кнопку Экранировать, чтобы подсветить эти символы.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Feature Importance in NimbusML\n",
"In many cases, it is often desirable to not only obtain the predictions from the machine learning model, but also get some sort of 'explanations': why did the model make this prediction? What were the features that affected the predictions the most?\n",
"\n",
"This might be especially relevant in cases with business or regulatory requirements to have explainable decisions, for example explaining the most important factors for a credit application being denied.\n",
"\n",
"In addition, this information helps the experimenter to understand the model better, check for overfitting, and verify the quality of features. NimbusML provides mechanisms for model analysis that provide both model-wide and example-level feature importances.\n",
"\n",
"\n",
"### Model-wide Analysis: Permutaiton Feature Importance (PFI)\n",
"Permutation Feature Importance is a technique that calculates how much each feature 'matters' to the predictions. Namely, how much the model's predictions will change if we randomly permute the values of one feature across the evaluation set? If the quality doesn't change much, this feature is not very important. If the quality drops drastically, this was a really important feature. NimbusML provides an implementation of PFI with the `permutation_feature_importance()` method in the `Pipeline()` object and individual prediction estimators.\n",
"\n",
"\n",
"### Example-level Analysis: Feature Contributions\n",
"Observation level feature importances explain which features were most important when making a *specific* prediction. When predictions are made on a dataset, a score is produced for each example. For classification, this scores gets converted to a probability to make a prediction, and for regression, the score is the prediction itself. To understand and explain these predictions it can be useful to inspect which features influenced them most significantly.\n",
"\n",
"The `get_feature_contributions()` method in the NimbusML `Pipeline()` object and individual prediction extimators computes per-feature contributions to the score for each example. These contributions can be positive (they make the score higher) or negative (they make the score lower). Feature contributions are implemented for **linear and tree models** in NimbusML."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tutorial\n",
"The following tutorial will show how to use the model level and example level feature importances in NimbusML, using the UCI Adult Income dataset as an example. The dataset is used for a binary classification problem where the label is whether or not an indivisual's income is over $50,000.\n",
"\n",
"#### Loading Data"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from nimbusml import Pipeline, FileDataStream\n",
"from nimbusml.datasets import get_dataset\n",
"from nimbusml.ensemble import FastTreesBinaryClassifier\n",
"from nimbusml.feature_extraction.categorical import OneHotVectorizer\n",
"from nimbusml.linear_model import LogisticRegressionBinaryClassifier\n",
"from nimbusml.preprocessing.schema import ColumnSelector"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train data file path: train-500.uciadult.sample.csv\n",
"Test data file path: test-100.uciadult.sample.csv\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>workclass</th>\n",
" <th>education</th>\n",
" <th>marital-status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>ethnicity</th>\n",
" <th>sex</th>\n",
" <th>native-country-region</th>\n",
" <th>age</th>\n",
" <th>fnlwgt</th>\n",
" <th>education-num</th>\n",
" <th>capital-gain</th>\n",
" <th>capital-loss</th>\n",
" <th>hours-per-week</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>Private</td>\n",
" <td>11th</td>\n",
" <td>Never-married</td>\n",
" <td>Machine-op-inspct</td>\n",
" <td>Own-child</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>United-States</td>\n",
" <td>25</td>\n",
" <td>226802</td>\n",
" <td>7</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>Private</td>\n",
" <td>HS-grad</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Farming-fishing</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>United-States</td>\n",
" <td>38</td>\n",
" <td>89814</td>\n",
" <td>9</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>Local-gov</td>\n",
" <td>Assoc-acdm</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Protective-serv</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>United-States</td>\n",
" <td>28</td>\n",
" <td>336951</td>\n",
" <td>12</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>Private</td>\n",
" <td>Some-college</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Machine-op-inspct</td>\n",
" <td>Husband</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>United-States</td>\n",
" <td>44</td>\n",
" <td>160323</td>\n",
" <td>10</td>\n",
" <td>7688</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>?</td>\n",
" <td>Some-college</td>\n",
" <td>Never-married</td>\n",
" <td>?</td>\n",
" <td>Own-child</td>\n",
" <td>White</td>\n",
" <td>Female</td>\n",
" <td>United-States</td>\n",
" <td>18</td>\n",
" <td>103497</td>\n",
" <td>10</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>30</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" label workclass education marital-status occupation \\\n",
"0 0 Private 11th Never-married Machine-op-inspct \n",
"1 0 Private HS-grad Married-civ-spouse Farming-fishing \n",
"2 1 Local-gov Assoc-acdm Married-civ-spouse Protective-serv \n",
"3 1 Private Some-college Married-civ-spouse Machine-op-inspct \n",
"4 0 ? Some-college Never-married ? \n",
"\n",
" relationship ethnicity sex native-country-region age fnlwgt \\\n",
"0 Own-child Black Male United-States 25 226802 \n",
"1 Husband White Male United-States 38 89814 \n",
"2 Husband White Male United-States 28 336951 \n",
"3 Husband Black Male United-States 44 160323 \n",
"4 Own-child White Female United-States 18 103497 \n",
"\n",
" education-num capital-gain capital-loss hours-per-week \n",
"0 7 0 0 40 \n",
"1 9 0 0 50 \n",
"2 12 0 0 40 \n",
"3 10 7688 0 40 \n",
"4 10 0 0 30 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_path = get_dataset('uciadult_train').as_filepath()\n",
"test_path = get_dataset('uciadult_test').as_filepath()\n",
"print(\"Train data file path: \" + str(os.path.basename(train_path)))\n",
"print(\"Test data file path: \" + str(os.path.basename(test_path)))\n",
"\n",
"train_data = FileDataStream.read_csv(train_path)\n",
"test_data = FileDataStream.read_csv(test_path)\n",
"\n",
"train_data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Train linear and tree binary classifiers"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"# supress training output as it not relevant to a discusion of feature importances\n",
"\n",
"feature_columns = ['age', 'capital-gain', 'hours-per-week',\n",
" 'education', 'marital-status', 'ethnicity', 'sex']\n",
"\n",
"cat = OneHotVectorizer(columns=['education', 'marital-status', 'ethnicity', 'sex'])\n",
"\n",
"linear_clf = LogisticRegressionBinaryClassifier(feature=feature_columns, label='label')\n",
"linear_model = Pipeline(steps=[cat, linear_clf])\n",
"linear_model.fit(train_data)\n",
"\n",
"tree_clf = FastTreesBinaryClassifier(feature=feature_columns, label='label')\n",
"tree_model = Pipeline(steps=[cat, tree_clf])\n",
"tree_model.fit(train_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Permutation Feature Importance (PFI)\n",
"Evaluate PFI for the linear model on the test data to get feature importance when making predictions. The training data can be used similarly to analyze important features during training.\n",
"\n",
"Here, we permute each of the `Features.*` columns 5 times and report the mean change in each metric, along with the statndard error of the mean. Note that the most important features will be different for each metric of interest. It is up to the user to determine which metric(s) they care about most, and look at the PFI for that metric.\n",
"\n",
"Let's look at the most important features with respect to Area Under ROC Curve (AUC). Since AUC is an increasing metric, the features that decreased AUC the most are the most important."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>FeatureName</th>\n",
" <th>AreaUnderRocCurve</th>\n",
" <th>AreaUnderRocCurve.StdErr</th>\n",
" <th>Accuracy</th>\n",
" <th>Accuracy.StdErr</th>\n",
" <th>PositivePrecision</th>\n",
" <th>PositivePrecision.StdErr</th>\n",
" <th>PositiveRecall</th>\n",
" <th>PositiveRecall.StdErr</th>\n",
" <th>NegativePrecision</th>\n",
" <th>NegativePrecision.StdErr</th>\n",
" <th>NegativeRecall</th>\n",
" <th>NegativeRecall.StdErr</th>\n",
" <th>F1Score</th>\n",
" <th>F1Score.StdErr</th>\n",
" <th>AreaUnderPrecisionRecallCurve</th>\n",
" <th>AreaUnderPrecisionRecallCurve.StdErr</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>marital-status.Married-civ-spouse</td>\n",
" <td>-0.153399</td>\n",
" <td>0.019996</td>\n",
" <td>-0.042</td>\n",
" <td>0.005831</td>\n",
" <td>-0.060563</td>\n",
" <td>0.048983</td>\n",
" <td>-0.200000</td>\n",
" <td>0.015590</td>\n",
" <td>-0.041178</td>\n",
" <td>0.003361</td>\n",
" <td>0.007895</td>\n",
" <td>0.005263</td>\n",
" <td>-0.231310</td>\n",
" <td>0.022871</td>\n",
" <td>-0.239532</td>\n",
" <td>0.043835</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>marital-status.Never-married</td>\n",
" <td>-0.047752</td>\n",
" <td>0.011941</td>\n",
" <td>-0.022</td>\n",
" <td>0.008000</td>\n",
" <td>-0.024848</td>\n",
" <td>0.039564</td>\n",
" <td>-0.108333</td>\n",
" <td>0.028260</td>\n",
" <td>-0.022613</td>\n",
" <td>0.006113</td>\n",
" <td>0.005263</td>\n",
" <td>0.003223</td>\n",
" <td>-0.116835</td>\n",
" <td>0.036170</td>\n",
" <td>-0.107423</td>\n",
" <td>0.028333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>capital-gain</td>\n",
" <td>-0.022643</td>\n",
" <td>0.002877</td>\n",
" <td>-0.016</td>\n",
" <td>0.002449</td>\n",
" <td>-0.041616</td>\n",
" <td>0.013287</td>\n",
" <td>-0.058333</td>\n",
" <td>0.010206</td>\n",
" <td>-0.013252</td>\n",
" <td>0.002082</td>\n",
" <td>-0.002632</td>\n",
" <td>0.002632</td>\n",
" <td>-0.064925</td>\n",
" <td>0.010544</td>\n",
" <td>-0.078153</td>\n",
" <td>0.022393</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>education.Masters</td>\n",
" <td>-0.016941</td>\n",
" <td>0.003446</td>\n",
" <td>-0.020</td>\n",
" <td>0.003162</td>\n",
" <td>-0.061616</td>\n",
" <td>0.018499</td>\n",
" <td>-0.066667</td>\n",
" <td>0.010206</td>\n",
" <td>-0.015474</td>\n",
" <td>0.002194</td>\n",
" <td>-0.005263</td>\n",
" <td>0.003223</td>\n",
" <td>-0.076690</td>\n",
" <td>0.011168</td>\n",
" <td>-0.053875</td>\n",
" <td>0.007807</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>education.Doctorate</td>\n",
" <td>-0.013268</td>\n",
" <td>0.001894</td>\n",
" <td>-0.012</td>\n",
" <td>0.002000</td>\n",
" <td>-0.032727</td>\n",
" <td>0.014545</td>\n",
" <td>-0.041667</td>\n",
" <td>0.000000</td>\n",
" <td>-0.009638</td>\n",
" <td>0.000400</td>\n",
" <td>-0.002632</td>\n",
" <td>0.002632</td>\n",
" <td>-0.046387</td>\n",
" <td>0.002689</td>\n",
" <td>-0.044245</td>\n",
" <td>0.016611</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" FeatureName AreaUnderRocCurve \\\n",
"19 marital-status.Married-civ-spouse -0.153399 \n",
"18 marital-status.Never-married -0.047752 \n",
"1 capital-gain -0.022643 \n",
"11 education.Masters -0.016941 \n",
"12 education.Doctorate -0.013268 \n",
"\n",
" AreaUnderRocCurve.StdErr Accuracy Accuracy.StdErr PositivePrecision \\\n",
"19 0.019996 -0.042 0.005831 -0.060563 \n",
"18 0.011941 -0.022 0.008000 -0.024848 \n",
"1 0.002877 -0.016 0.002449 -0.041616 \n",
"11 0.003446 -0.020 0.003162 -0.061616 \n",
"12 0.001894 -0.012 0.002000 -0.032727 \n",
"\n",
" PositivePrecision.StdErr PositiveRecall PositiveRecall.StdErr \\\n",
"19 0.048983 -0.200000 0.015590 \n",
"18 0.039564 -0.108333 0.028260 \n",
"1 0.013287 -0.058333 0.010206 \n",
"11 0.018499 -0.066667 0.010206 \n",
"12 0.014545 -0.041667 0.000000 \n",
"\n",
" NegativePrecision NegativePrecision.StdErr NegativeRecall \\\n",
"19 -0.041178 0.003361 0.007895 \n",
"18 -0.022613 0.006113 0.005263 \n",
"1 -0.013252 0.002082 -0.002632 \n",
"11 -0.015474 0.002194 -0.005263 \n",
"12 -0.009638 0.000400 -0.002632 \n",
"\n",
" NegativeRecall.StdErr F1Score F1Score.StdErr \\\n",
"19 0.005263 -0.231310 0.022871 \n",
"18 0.003223 -0.116835 0.036170 \n",
"1 0.002632 -0.064925 0.010544 \n",
"11 0.003223 -0.076690 0.011168 \n",
"12 0.002632 -0.046387 0.002689 \n",
"\n",
" AreaUnderPrecisionRecallCurve AreaUnderPrecisionRecallCurve.StdErr \n",
"19 -0.239532 0.043835 \n",
"18 -0.107423 0.028333 \n",
"1 -0.078153 0.022393 \n",
"11 -0.053875 0.007807 \n",
"12 -0.044245 0.016611 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pfi = linear_model.permutation_feature_importance(test_data, permutation_count=5)\n",
"pfi.sort_values('AreaUnderRocCurve').head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Example-level Feature Contributions (Linear Models)\n",
"Let's look at feature contributions for individual predictions on the test data using the linear model. For linear models, each feature's contribution to the score is equal to the product of the feature times the corresponding weight."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>PredictedLabel</th>\n",
" <th>Score</th>\n",
" <th>Probability</th>\n",
" <th>FeatureContributions.age</th>\n",
" <th>FeatureContributions.capital-gain</th>\n",
" <th>FeatureContributions.hours-per-week</th>\n",
" <th>FeatureContributions.education.11th</th>\n",
" <th>FeatureContributions.education.HS-grad</th>\n",
" <th>FeatureContributions.education.Assoc-acdm</th>\n",
" <th>...</th>\n",
" <th>FeatureContributions.marital-status.Separated</th>\n",
" <th>FeatureContributions.marital-status.Married-spouse-absent</th>\n",
" <th>FeatureContributions.marital-status.Married-AF-spouse</th>\n",
" <th>FeatureContributions.ethnicity.Black</th>\n",
" <th>FeatureContributions.ethnicity.White</th>\n",
" <th>FeatureContributions.ethnicity.Asian-Pac-Islander</th>\n",
" <th>FeatureContributions.ethnicity.Other</th>\n",
" <th>FeatureContributions.ethnicity.Amer-Indian-Inuit</th>\n",
" <th>FeatureContributions.sex.Male</th>\n",
" <th>FeatureContributions.sex.Female</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>-4.047609</td>\n",
" <td>0.017164</td>\n",
" <td>0.030594</td>\n",
" <td>0.000000</td>\n",
" <td>0.360155</td>\n",
" <td>-0.59735</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.059523</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.053660</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>-0.463503</td>\n",
" <td>0.386155</td>\n",
" <td>0.029750</td>\n",
" <td>0.000000</td>\n",
" <td>0.288005</td>\n",
" <td>0.00000</td>\n",
" <td>-0.181008</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.148458</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.034328</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>-0.059127</td>\n",
" <td>0.485223</td>\n",
" <td>0.021921</td>\n",
" <td>0.000000</td>\n",
" <td>0.230404</td>\n",
" <td>0.00000</td>\n",
" <td>0.000000</td>\n",
" <td>0.112218</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.148458</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.034328</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>-0.618038</td>\n",
" <td>0.350228</td>\n",
" <td>0.034447</td>\n",
" <td>0.072675</td>\n",
" <td>0.230404</td>\n",
" <td>0.00000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.038079</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.034328</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>-3.723567</td>\n",
" <td>0.023578</td>\n",
" <td>0.022028</td>\n",
" <td>0.000000</td>\n",
" <td>0.270116</td>\n",
" <td>0.00000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.232061</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>-0.054896</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 36 columns</p>\n",
"</div>"
],
"text/plain": [
" label PredictedLabel Score Probability FeatureContributions.age \\\n",
"0 0 0 -4.047609 0.017164 0.030594 \n",
"1 0 0 -0.463503 0.386155 0.029750 \n",
"2 1 0 -0.059127 0.485223 0.021921 \n",
"3 1 0 -0.618038 0.350228 0.034447 \n",
"4 0 0 -3.723567 0.023578 0.022028 \n",
"\n",
" FeatureContributions.capital-gain FeatureContributions.hours-per-week \\\n",
"0 0.000000 0.360155 \n",
"1 0.000000 0.288005 \n",
"2 0.000000 0.230404 \n",
"3 0.072675 0.230404 \n",
"4 0.000000 0.270116 \n",
"\n",
" FeatureContributions.education.11th \\\n",
"0 -0.59735 \n",
"1 0.00000 \n",
"2 0.00000 \n",
"3 0.00000 \n",
"4 0.00000 \n",
"\n",
" FeatureContributions.education.HS-grad \\\n",
"0 0.000000 \n",
"1 -0.181008 \n",
"2 0.000000 \n",
"3 0.000000 \n",
"4 0.000000 \n",
"\n",
" FeatureContributions.education.Assoc-acdm ... \\\n",
"0 0.000000 ... \n",
"1 0.000000 ... \n",
"2 0.112218 ... \n",
"3 0.000000 ... \n",
"4 0.000000 ... \n",
"\n",
" FeatureContributions.marital-status.Separated \\\n",
"0 0.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"\n",
" FeatureContributions.marital-status.Married-spouse-absent \\\n",
"0 0.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"\n",
" FeatureContributions.marital-status.Married-AF-spouse \\\n",
"0 0.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"\n",
" FeatureContributions.ethnicity.Black FeatureContributions.ethnicity.White \\\n",
"0 0.059523 0.000000 \n",
"1 0.000000 0.148458 \n",
"2 0.000000 0.148458 \n",
"3 0.038079 0.000000 \n",
"4 0.000000 0.232061 \n",
"\n",
" FeatureContributions.ethnicity.Asian-Pac-Islander \\\n",
"0 0.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"\n",
" FeatureContributions.ethnicity.Other \\\n",
"0 0.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"\n",
" FeatureContributions.ethnicity.Amer-Indian-Inuit \\\n",
"0 0.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"\n",
" FeatureContributions.sex.Male FeatureContributions.sex.Female \n",
"0 0.053660 0.000000 \n",
"1 0.034328 0.000000 \n",
"2 0.034328 0.000000 \n",
"3 0.034328 0.000000 \n",
"4 0.000000 -0.054896 \n",
"\n",
"[5 rows x 36 columns]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"linear_fc = linear_model.get_feature_contributions(test_data)\n",
"linear_fc.filter(regex='label|PredictedLabel|Score|Probability|FeatureContributions').head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Example-level Feature Contributions (Tree Models)\n",
"Feature contributions for tree models are determined based on which splits in the decision trees have the most impact on the final score. The calculation is done by evaluating the score we would have gotten *if we had taken the opposite split* everytime we encountered a given feature. The importance for this feature is then given by the difference between this score and the original score. "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>PredictedLabel</th>\n",
" <th>Score</th>\n",
" <th>Probability</th>\n",
" <th>FeatureContributions.age</th>\n",
" <th>FeatureContributions.capital-gain</th>\n",
" <th>FeatureContributions.hours-per-week</th>\n",
" <th>FeatureContributions.education.11th</th>\n",
" <th>FeatureContributions.education.HS-grad</th>\n",
" <th>FeatureContributions.education.Assoc-acdm</th>\n",
" <th>...</th>\n",
" <th>FeatureContributions.marital-status.Separated</th>\n",
" <th>FeatureContributions.marital-status.Married-spouse-absent</th>\n",
" <th>FeatureContributions.marital-status.Married-AF-spouse</th>\n",
" <th>FeatureContributions.ethnicity.Black</th>\n",
" <th>FeatureContributions.ethnicity.White</th>\n",
" <th>FeatureContributions.ethnicity.Asian-Pac-Islander</th>\n",
" <th>FeatureContributions.ethnicity.Other</th>\n",
" <th>FeatureContributions.ethnicity.Amer-Indian-Inuit</th>\n",
" <th>FeatureContributions.sex.Male</th>\n",
" <th>FeatureContributions.sex.Female</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>-25.577993</td>\n",
" <td>0.000036</td>\n",
" <td>-0.728820</td>\n",
" <td>-1.000000</td>\n",
" <td>-0.367933</td>\n",
" <td>-0.162649</td>\n",
" <td>-0.027093</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>-0.154559</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.111279</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>-0.035480</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>-10.135474</td>\n",
" <td>0.017054</td>\n",
" <td>-1.000000</td>\n",
" <td>-0.668069</td>\n",
" <td>-0.490714</td>\n",
" <td>0.225786</td>\n",
" <td>-0.101992</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>-0.111768</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>-0.096452</td>\n",
" <td>0.069662</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.009991</td>\n",
" <td>-0.013131</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2.394207</td>\n",
" <td>0.722658</td>\n",
" <td>0.892277</td>\n",
" <td>-0.850473</td>\n",
" <td>-0.321354</td>\n",
" <td>0.745953</td>\n",
" <td>0.117656</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>-0.172352</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>-0.150833</td>\n",
" <td>0.756565</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>18.900896</td>\n",
" <td>0.999480</td>\n",
" <td>0.236660</td>\n",
" <td>1.000000</td>\n",
" <td>0.417579</td>\n",
" <td>0.437051</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.335541</td>\n",
" <td>0.000000</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>-0.005970</td>\n",
" <td>-0.007641</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>-26.024494</td>\n",
" <td>0.000030</td>\n",
" <td>-1.000000</td>\n",
" <td>-0.902150</td>\n",
" <td>-0.912701</td>\n",
" <td>0.162933</td>\n",
" <td>0.114411</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>-0.138418</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>-0.031655</td>\n",
" <td>0.040154</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 36 columns</p>\n",
"</div>"
],
"text/plain": [
" label PredictedLabel Score Probability FeatureContributions.age \\\n",
"0 0 0 -25.577993 0.000036 -0.728820 \n",
"1 0 0 -10.135474 0.017054 -1.000000 \n",
"2 1 1 2.394207 0.722658 0.892277 \n",
"3 1 1 18.900896 0.999480 0.236660 \n",
"4 0 0 -26.024494 0.000030 -1.000000 \n",
"\n",
" FeatureContributions.capital-gain FeatureContributions.hours-per-week \\\n",
"0 -1.000000 -0.367933 \n",
"1 -0.668069 -0.490714 \n",
"2 -0.850473 -0.321354 \n",
"3 1.000000 0.417579 \n",
"4 -0.902150 -0.912701 \n",
"\n",
" FeatureContributions.education.11th \\\n",
"0 -0.162649 \n",
"1 0.225786 \n",
"2 0.745953 \n",
"3 0.437051 \n",
"4 0.162933 \n",
"\n",
" FeatureContributions.education.HS-grad \\\n",
"0 -0.027093 \n",
"1 -0.101992 \n",
"2 0.117656 \n",
"3 0.000000 \n",
"4 0.114411 \n",
"\n",
" FeatureContributions.education.Assoc-acdm ... \\\n",
"0 0.0 ... \n",
"1 0.0 ... \n",
"2 0.0 ... \n",
"3 0.0 ... \n",
"4 0.0 ... \n",
"\n",
" FeatureContributions.marital-status.Separated \\\n",
"0 -0.154559 \n",
"1 -0.111768 \n",
"2 -0.172352 \n",
"3 0.000000 \n",
"4 -0.138418 \n",
"\n",
" FeatureContributions.marital-status.Married-spouse-absent \\\n",
"0 0.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"\n",
" FeatureContributions.marital-status.Married-AF-spouse \\\n",
"0 0.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"\n",
" FeatureContributions.ethnicity.Black FeatureContributions.ethnicity.White \\\n",
"0 0.111279 0.000000 \n",
"1 -0.096452 0.069662 \n",
"2 -0.150833 0.756565 \n",
"3 0.335541 0.000000 \n",
"4 -0.031655 0.040154 \n",
"\n",
" FeatureContributions.ethnicity.Asian-Pac-Islander \\\n",
"0 0.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"\n",
" FeatureContributions.ethnicity.Other \\\n",
"0 0.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"\n",
" FeatureContributions.ethnicity.Amer-Indian-Inuit \\\n",
"0 0.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"4 0.0 \n",
"\n",
" FeatureContributions.sex.Male FeatureContributions.sex.Female \n",
"0 0.000000 -0.035480 \n",
"1 0.009991 -0.013131 \n",
"2 0.000000 0.000000 \n",
"3 -0.005970 -0.007641 \n",
"4 0.000000 0.000000 \n",
"\n",
"[5 rows x 36 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree_fc = tree_model.get_feature_contributions(test_data)\n",
"tree_fc.filter(regex='label|PredictedLabel|Score|Probability|FeatureContributions').head()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}