зеркало из https://github.com/microsoft/LightGBM.git
Fix feature attributions for regression models and add Python bindings (#861)
* Fix feature attributions for regression models and add Python bindings * Address pylint issue * Lazy fix missing tree depth info
This commit is contained in:
Родитель
5543979b0e
Коммит
67c2bdf905
|
@ -111,7 +111,7 @@ public:
|
|||
|
||||
inline int PredictLeafIndex(const double* feature_values) const;
|
||||
|
||||
inline void PredictContrib(const double* feature_values, int num_features, double* output) const;
|
||||
inline void PredictContrib(const double* feature_values, int num_features, double* output);
|
||||
|
||||
/*! \brief Get Number of leaves*/
|
||||
inline int num_leaves() const { return num_leaves_; }
|
||||
|
@ -299,9 +299,12 @@ private:
|
|||
/*! \brief Serialize one node to if-else statement*/
|
||||
std::string NodeToIfElse(int index, bool is_predict_leaf_index) const;
|
||||
|
||||
double ExpectedValue(int node) const;
|
||||
double ExpectedValue() const;
|
||||
|
||||
int MaxDepth() const;
|
||||
int MaxDepth();
|
||||
|
||||
/*! \brief This is used fill in leaf_depth_ after reloading a model*/
|
||||
inline void RecomputeLeafDepths(int node = 0, int depth = 0);
|
||||
|
||||
/*!
|
||||
* \brief Used by TreeSHAP for data we keep about our decision path
|
||||
|
@ -431,12 +434,25 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const {
|
|||
}
|
||||
}
|
||||
|
||||
inline void Tree::PredictContrib(const double* feature_values, int num_features, double* output) const {
|
||||
output[num_features] += ExpectedValue(0);
|
||||
inline void Tree::PredictContrib(const double* feature_values, int num_features, double* output) {
|
||||
output[num_features] += ExpectedValue();
|
||||
// Run the recursion with preallocated space for the unique path data
|
||||
const int max_path_len = MaxDepth() + 1;
|
||||
std::vector<PathElement> unique_path_data((max_path_len*(max_path_len + 1)) / 2);
|
||||
TreeSHAP(feature_values, output, 0, 0, unique_path_data.data(), 1, 1, -1);
|
||||
if (num_leaves_ > 1) {
|
||||
const int max_path_len = MaxDepth()+1;
|
||||
PathElement *unique_path_data = new PathElement[(max_path_len*(max_path_len+1))/2];
|
||||
TreeSHAP(feature_values, output, 0, 0, unique_path_data, 1, 1, -1);
|
||||
delete[] unique_path_data;
|
||||
}
|
||||
}
|
||||
|
||||
inline void Tree::RecomputeLeafDepths(int node, int depth) {
|
||||
if (node == 0) leaf_depth_.resize(num_leaves());
|
||||
if (node < 0) {
|
||||
leaf_depth_[~node] = depth;
|
||||
} else {
|
||||
RecomputeLeafDepths(left_child_[node], depth+1);
|
||||
RecomputeLeafDepths(right_child_[node], depth+1);
|
||||
}
|
||||
}
|
||||
|
||||
inline int Tree::GetLeaf(const double* feature_values) const {
|
||||
|
|
|
@ -170,6 +170,7 @@ C_API_IS_ROW_MAJOR = 1
|
|||
C_API_PREDICT_NORMAL = 0
|
||||
C_API_PREDICT_RAW_SCORE = 1
|
||||
C_API_PREDICT_LEAF_INDEX = 2
|
||||
C_API_PREDICT_CONTRIB = 3
|
||||
|
||||
"""data type of data field"""
|
||||
FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32,
|
||||
|
@ -351,7 +352,7 @@ class _InnerPredictor(object):
|
|||
return this
|
||||
|
||||
def predict(self, data, num_iteration=-1,
|
||||
raw_score=False, pred_leaf=False, data_has_header=False,
|
||||
raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False,
|
||||
is_reshape=True):
|
||||
"""
|
||||
Predict logic
|
||||
|
@ -367,6 +368,8 @@ class _InnerPredictor(object):
|
|||
True for predict raw score
|
||||
pred_leaf : bool
|
||||
True for predict leaf index
|
||||
pred_contrib : bool
|
||||
True for predict feature contributions
|
||||
data_has_header : bool
|
||||
Used for txt data, True if txt data has header
|
||||
is_reshape : bool
|
||||
|
@ -384,6 +387,8 @@ class _InnerPredictor(object):
|
|||
predict_type = C_API_PREDICT_RAW_SCORE
|
||||
if pred_leaf:
|
||||
predict_type = C_API_PREDICT_LEAF_INDEX
|
||||
if pred_contrib:
|
||||
predict_type = C_API_PREDICT_CONTRIB
|
||||
int_data_has_header = 1 if data_has_header else 0
|
||||
if num_iteration > self.num_total_iteration:
|
||||
num_iteration = self.num_total_iteration
|
||||
|
@ -1653,7 +1658,7 @@ class Booster(object):
|
|||
ptr_string_buffer))
|
||||
return json.loads(string_buffer.value.decode())
|
||||
|
||||
def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False,
|
||||
def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, pred_contrib=False,
|
||||
data_has_header=False, is_reshape=True, pred_parameter=None):
|
||||
"""Make a prediction.
|
||||
|
||||
|
@ -1669,6 +1674,8 @@ class Booster(object):
|
|||
Whether to predict raw scores.
|
||||
pred_leaf : bool, optional (default=False)
|
||||
Whether to predict leaf index.
|
||||
pred_contrib : bool, optional (default=False)
|
||||
Whether to predict feature contributions.
|
||||
data_has_header : bool, optional (default=False)
|
||||
Whether the data has header.
|
||||
Used only if data is string.
|
||||
|
@ -1685,7 +1692,7 @@ class Booster(object):
|
|||
predictor = self._to_predictor(pred_parameter)
|
||||
if num_iteration <= 0:
|
||||
num_iteration = self.best_iteration
|
||||
return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape)
|
||||
return predictor.predict(data, num_iteration, raw_score, pred_leaf, pred_contrib, data_has_header, is_reshape)
|
||||
|
||||
def get_leaf_output(self, tree_id, leaf_id):
|
||||
"""Get the output of a leaf.
|
||||
|
|
|
@ -589,7 +589,6 @@ void Tree::TreeSHAP(const double *feature_values, double *phi,
|
|||
if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
|
||||
ExtendPath(unique_path, unique_depth, parent_zero_fraction,
|
||||
parent_one_fraction, parent_feature_index);
|
||||
const int split_index = split_feature_[node];
|
||||
|
||||
// leaf node
|
||||
if (node < 0) {
|
||||
|
@ -601,7 +600,7 @@ void Tree::TreeSHAP(const double *feature_values, double *phi,
|
|||
|
||||
// internal node
|
||||
} else {
|
||||
const int hot_index = Decision(feature_values[split_index], node);
|
||||
const int hot_index = Decision(feature_values[split_feature_[node]], node);
|
||||
const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
|
||||
const double w = data_count(node);
|
||||
const double hot_zero_fraction = data_count(hot_index) / w;
|
||||
|
@ -613,7 +612,7 @@ void Tree::TreeSHAP(const double *feature_values, double *phi,
|
|||
// if so we undo that split so we can redo it for this node
|
||||
int path_index = 0;
|
||||
for (; path_index <= unique_depth; ++path_index) {
|
||||
if (unique_path[path_index].feature_index == split_index) break;
|
||||
if (unique_path[path_index].feature_index == split_feature_[node]) break;
|
||||
}
|
||||
if (path_index != unique_depth + 1) {
|
||||
incoming_zero_fraction = unique_path[path_index].zero_fraction;
|
||||
|
@ -623,25 +622,26 @@ void Tree::TreeSHAP(const double *feature_values, double *phi,
|
|||
}
|
||||
|
||||
TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
|
||||
hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_index);
|
||||
hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
|
||||
|
||||
TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
|
||||
cold_zero_fraction*incoming_zero_fraction, 0, split_index);
|
||||
cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
double Tree::ExpectedValue(int node) const {
|
||||
if (node >= 0) {
|
||||
const int l = left_child_[node];
|
||||
const int r = right_child_[node];
|
||||
return (data_count(l)*ExpectedValue(l) + data_count(r)*ExpectedValue(r)) / data_count(node);
|
||||
} else {
|
||||
return LeafOutput(~node);
|
||||
double Tree::ExpectedValue() const {
|
||||
if (num_leaves_ == 1) return LeafOutput(0);
|
||||
const double total_count = internal_count_[0];
|
||||
double exp_value = 0.0;
|
||||
for (int i = 0; i < num_leaves(); ++i) {
|
||||
exp_value += (leaf_count_[i]/total_count)*LeafOutput(i);
|
||||
}
|
||||
return exp_value;
|
||||
}
|
||||
|
||||
int Tree::MaxDepth() const {
|
||||
int Tree::MaxDepth() {
|
||||
if (leaf_depth_.size() == 0) RecomputeLeafDepths();
|
||||
if (num_leaves_ == 1) return 0;
|
||||
int max_depth = 0;
|
||||
for (int i = 0; i < num_leaves(); ++i) {
|
||||
if (max_depth < leaf_depth_[i]) max_depth = leaf_depth_[i];
|
||||
|
|
|
@ -462,3 +462,23 @@ class TestEngine(unittest.TestCase):
|
|||
tmp_dat_val = tmp_dat.subset(np.arange(80, 100)).subset(np.arange(18))
|
||||
params = {'objective': 'regression_l2', 'metric': 'rmse'}
|
||||
gbm = lgb.train(params, tmp_dat_train, num_boost_round=20, valid_sets=[tmp_dat_train, tmp_dat_val])
|
||||
|
||||
def test_contribs(self):
|
||||
X, y = load_breast_cancer(True)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
|
||||
params = {
|
||||
'objective': 'binary',
|
||||
'metric': 'binary_logloss',
|
||||
'verbose': -1,
|
||||
'num_iteration': 50 # test num_iteration in dict here
|
||||
}
|
||||
lgb_train = lgb.Dataset(X_train, y_train)
|
||||
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
|
||||
evals_result = {}
|
||||
gbm = lgb.train(params, lgb_train,
|
||||
num_boost_round=20,
|
||||
valid_sets=lgb_eval,
|
||||
verbose_eval=False,
|
||||
evals_result=evals_result)
|
||||
|
||||
self.assertLess(np.linalg.norm(gbm.predict(X_test, raw_score=True) - np.sum(gbm.predict(X_test, pred_contrib=True), axis=1)), 1e-4)
|
||||
|
|
Загрузка…
Ссылка в новой задаче