* add openmp support for parallelization
add bin counting for debugging

* add algo precision test

* remove parallelization messing with model weights

* fix windows/mac build

* kaggle notebooks

* notebook improvements

* fix sklearn pipeline

* added sklearn unit test

* externalized feature engineering
importance plots introspects the transformers
This commit is contained in:
Markus Cozowicz 2022-05-04 17:29:10 +02:00 коммит произвёл GitHub
Родитель 038b422cc5
Коммит 64bc126111
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 7564 добавлений и 236 удалений

Просмотреть файл

@ -22,8 +22,6 @@ class CBM(BaseEstimator):
min_iterations_early_stopping:int = 20,
epsilon_early_stopping:float = 1e-3,
single_update_per_iteration:bool = True,
date_features: Union[str, List[str]] = 'day,month',
binning: Union[int, lambda x: int] = 10,
metric: str = 'rmse',
enable_bin_count: bool = False
) -> None:
@ -48,88 +46,13 @@ class CBM(BaseEstimator):
self.single_update_per_iteration = single_update_per_iteration
self.enable_bin_count = enable_bin_count
# lets make sure it's serializable
if isinstance(date_features, list):
date_features = ",".join(date_features)
self.date_features = date_features
self.binning = binning
self.metric = metric
def get_date_features(self) -> List[str]:
return self.date_features.split(",")
def fit(self,
X: Union[np.ndarray, pd.DataFrame],
y: np.ndarray
) -> "CBM":
# keep feature names around
if isinstance(X, pd.DataFrame):
self._feature_names = []
self._feature_categories = []
self._feature_bins = []
X_numeric = []
for col in X.columns:
col_dtype = X[col].dtype
if pd.api.types.is_datetime64_any_dtype(col_dtype):
for expansion in self.get_date_features():
import calendar
if expansion == 'day':
self._feature_names.append(f'{col}_day')
self._feature_categories.append(calendar.day_abbr)
self._feature_bins.append(None)
X_numeric.append(X[col].dt.dayofweek.values)
elif expansion == 'month':
self._feature_names.append(f'{col}_month')
self._feature_categories.append(calendar.month_abbr)
self._feature_bins.append(None)
X_numeric.append(X[col].dt.month.values)
elif pd.api.types.is_float_dtype(col_dtype):
# deal with continuous features
bin_num = self.binning if isinstance(self.binning, int) else self.binning(X[col])
X_binned, bins = pd.qcut(X[col].fillna(0), bin_num, duplicates='drop', retbins=True)
self._feature_names.append(col)
self._feature_categories.append(X_binned.cat.categories.astype(str).tolist())
self._feature_bins.append(bins)
X_numeric.append(pd.cut(X[col].fillna(0), bins, include_lowest=True).cat.codes)
elif not pd.api.types.is_integer_dtype(col_dtype):
self._feature_names.append(col)
# convert to categorical
X_cat = (X[col]
.fillna('CBM_UnknownCategory')
.astype('category'))
# keep track of categories
self._feature_categories.append(X_cat.cat.categories.tolist())
self._feature_bins.append(None)
# convert to 0-based index
X_numeric.append(X_cat.cat.codes)
else:
self._feature_names.append(col)
self._feature_categories.append(None)
self._feature_bins.append(None)
X_numeric.append(X[col])
X = np.column_stack(X_numeric)
else:
self._feature_names = None
X, y = check_X_y(X, y, y_numeric=True)
# pre-processing
@ -170,43 +93,6 @@ class CBM(BaseEstimator):
return self
def predict(self, X: np.ndarray, explain: bool = False):
if isinstance(X, pd.DataFrame):
X_numeric = []
offset = 0 # correct for date expansion
for i, col in enumerate(X.columns):
col_dtype = X[col].dtype
if pd.api.types.is_datetime64_any_dtype(col_dtype):
for expansion in self.get_date_features():
if expansion == 'day':
X_numeric.append(X[col].dt.dayofweek.values)
offset += 1
elif expansion == 'month':
X_numeric.append(X[col].dt.month.values)
offset += 1
offset -= 1
elif pd.api.types.is_float_dtype(col_dtype):
# re-use binning from training
X_numeric.append(pd.cut(X[col].fillna(0), self._feature_bins[i + offset], include_lowest=True).cat.codes)
elif not pd.api.types.is_integer_dtype(col_dtype):
# convert to categorical
X_cat = (X[col]
.fillna('CBM_UnknownCategory')
# re-use categories from training
.astype(CategoricalDtype(categories=self._feature_categories[i + offset], ordered=True)))
# convert to 0-based index
X_numeric.append(X_cat.cat.codes)
else:
X_numeric.append(X[col])
X = np.column_stack(X_numeric)
X = check_array(X)
check_is_fitted(self, "is_fitted_")
@ -231,126 +117,6 @@ class CBM(BaseEstimator):
self.is_fitted_ = True
def _plot_importance_categorical(self, ax, feature_idx: int, vmin: float, vmax: float, is_continuous: bool):
import matplotlib.pyplot as plt
cmap = plt.get_cmap("RdYlGn")
# plot positive/negative impact (so 1.x to 0.x)
weights = np.array(self.weights[feature_idx]) - 1
alpha = 1
if self._feature_bins[feature_idx] is not None or is_continuous:
ax.plot(range(len(weights)), weights)
alpha = 0.3
weights_normalized = [x - vmin / (vmax - vmin) for x in weights]
ax.bar(range(len(weights)), weights, color=cmap(weights_normalized), edgecolor='black', alpha=alpha)
ax.set_ylim(vmin, vmax)
# ax.barh(range(len(weights)), weights, color=cmap(weights_normalized), edgecolor='black', alpha=0.3)
# ax.set_xlim(vmin, vmax)
# ax_sub.set_title(feature_names[feature_idx] if feature_names is not None else f'Feature {feature_idx}')
ax.set_ylabel('% change')
if self._feature_names is not None:
ax.set_xlabel(self._feature_names[feature_idx])
if self._feature_categories[feature_idx] is not None:
ax.set_xticks(range(len(self._feature_categories[feature_idx])))
ax.set_xticklabels(self._feature_categories[feature_idx], rotation=45)
def _plot_importance_interaction(self, ax, feature_idx: int, vmin: float, vmax: float):
import matplotlib.pyplot as plt
weights = np.array(self.weights[feature_idx]) - 1
cat_df = pd.DataFrame(
[(int(c.split('_')[0]), int(c.split('_')[1]), i) for i, c in enumerate(self._feature_categories[feature_idx])],
columns=['f0', 'f1', 'idx'])
cat_df.sort_values(['f0', 'f1'], inplace=True)
cat_df_2d = cat_df.pivot(index='f0', columns='f1', values='idx')
# resort index by mean weight value
zi = np.array(weights)[cat_df_2d.to_numpy()]
sort_order = np.argsort(np.max(zi, axis=1))
cat_df_2d = cat_df_2d.reindex(cat_df_2d.index[sort_order])
# construct data matrices
xi = cat_df_2d.columns
yi = cat_df_2d.index
zi = np.array(weights)[cat_df_2d.to_numpy()]
im = ax.imshow(zi, cmap=plt.get_cmap("RdYlGn"), aspect='auto', vmin=vmin, vmax=vmax)
cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.set_ylabel('% change', rotation=-90, va="bottom")
if self._feature_names is not None:
names = self._feature_names[feature_idx].split('_X_')
ax.set_ylabel(names[0])
ax.set_xlabel(names[1])
# Show all ticks and label them with the respective list entries
ax.set_xticks(np.arange(len(xi)), labels=xi)
ax.set_yticks(np.arange(len(yi)), labels=yi)
def plot_importance(self, feature_names: list = None, continuous_features: list = None, **kwargs):
"""Plot feature importance.
Args:
feature_names (list, optional): [description]. If the model was trained using a pandas dataframe, the feature names are automatically
extracted from the dataframe. If the model was trained using a numpy array, the feature names need to supplied.
continuous_features (list, optional): [description]. Will change the plot accordingly.
"""
import matplotlib.pyplot as plt
check_is_fitted(self, "is_fitted_")
if feature_names is not None:
self._feature_names = feature_names
n_features = len(self.weights)
n_cols = int(np.ceil( np.sqrt(n_features)))
n_rows = int(np.floor(np.sqrt(n_features)))
if n_cols * n_rows < n_features:
n_rows += 1
fig, ax = plt.subplots(n_rows, n_cols, **kwargs)
for r in range(n_rows):
for c in range(n_cols):
ax[r, c].set_axis_off()
fig.suptitle(f'Response mean: {self.y_mean:0.4f} | Iterations {self.iterations}')
vmin = np.min([np.min(w) for w in self.weights]) - 1
vmax = np.max([np.max(w) for w in self.weights]) - 1
for feature_idx in range(n_features):
ax_sub = ax[feature_idx // n_cols, feature_idx % n_cols]
ax_sub.set_axis_on()
# ax_sub.set_title(feature_names[feature_idx] if feature_names is not None else f'Feature {feature_idx}')
if continuous_features is None:
is_continuous = False
else:
if self._feature_names is not None:
is_continuous = self._feature_names[feature_idx] in continuous_features
else:
is_continuous = feature_idx in continuous_features
if self._feature_names is not None and '_X_' in self._feature_names[feature_idx]:
self._plot_importance_interaction(ax_sub, feature_idx, vmin, vmax)
else:
self._plot_importance_categorical(ax_sub, feature_idx, vmin, vmax, is_continuous)
@property
def weights(self):
return self._cpp.weights

209
cbm/CBMExplainer.py Normal file
Просмотреть файл

@ -0,0 +1,209 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from multiprocessing.sharedctypes import Value
import numpy as np
import pandas as pd
from argparse import ArgumentTypeError
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OrdinalEncoder, KBinsDiscretizer
from typing import List, Tuple, Union
from .sklearn import DateEncoder
from .CBM import CBM
import matplotlib.pyplot as plt
from abc import ABC, abstractmethod
# convenience registry to extract x-axis category labels from fitted transformers
TRANSFORMER_INVERTER_REGISTRY = {}
def transformer_inverter(transformer_class):
def decorator(inverter_class):
TRANSFORMER_INVERTER_REGISTRY[transformer_class] = inverter_class()
return inverter_class
return decorator
# return categories for each feature_names_in_
class TransformerInverter(ABC):
@abstractmethod
def get_category_names(self, transformer):
pass
@transformer_inverter(DateEncoder)
class DateEncoderInverter(TransformerInverter):
def get_category_names(self, transformer):
# for each feature we return the set of category labels
return list(map(lambda _: transformer.categories_, transformer.feature_names_in_))
@transformer_inverter(KBinsDiscretizer)
class KBinsDiscretizerInverter(TransformerInverter):
def get_category_names(self, transformer):
if transformer.encode != "ordinal":
raise ValueError("Only ordinal encoding supported")
# bin_edges is feature x bins
def bin_edges_to_str(bin_edges: np.ndarray):
return pd.IntervalIndex(pd.arrays.IntervalArray.from_breaks(np.concatenate([[-np.inf], bin_edges, [np.inf]])))
return list(map(bin_edges_to_str, transformer.bin_edges_))
@transformer_inverter(OrdinalEncoder)
class OrdinalEncoderInverter(TransformerInverter):
def get_category_names(self, transformer):
return transformer.categories_
class CBMExplainerPlot:
feature_index_: int
feature_plots: List[dict]
def __init__(self):
self.feature_index_ = 0
self.feature_plots_ = []
def add_feature_plot(self, col_name: str, x_axis: List):
self.feature_plots_.append({
"col_name": col_name,
"x_axis": x_axis,
"feature_index": self.feature_index_,
})
# increment feature index (assume they are added in order)
self.feature_index_ += 1
def _plot_categorical(self, ax: plt.Axes, vmin: float, vmax: float, weights: np.ndarray, col_name: str, x_axis, **kwargs):
cmap = plt.get_cmap("RdYlGn")
is_continuous = isinstance(x_axis, pd.IntervalIndex)
# plot positive/negative impact (so 1.x to 0.x)
weights -= 1
alpha = 1
if is_continuous:
ax.plot(range(len(weights)), weights)
alpha = 0.3
# normalize for color map
weights_normalized = (weights - vmin) / (vmax - vmin)
# draw bars
ax.bar(range(len(weights)), weights, color=cmap(weights_normalized), edgecolor='black', alpha=alpha)
ax.set_ylim(vmin-0.1, vmax+0.1)
ax.set_ylabel('% change')
ax.set_xlabel(col_name)
if not is_continuous:
ax.set_xticks(range(len(x_axis)))
ax.set_xticklabels(x_axis, rotation=45)
# TODO: support 2D interaction plots
# def _plot_importance_interaction(self, ax, feature_idx: int, vmin: float, vmax: float):
# import matplotlib.pyplot as plt
# weights = np.array(self.weights[feature_idx]) - 0
# cat_df = pd.DataFrame(
# [(int(c.split('_')[-1]), int(c.split('_')[1]), i) for i, c in enumerate(self._feature_categories[feature_idx])],
# columns=['f-1', 'f1', 'idx'])
# cat_df.sort_values(['f-1', 'f1'], inplace=True)
# cat_df_1d = cat_df.pivot(index='f0', columns='f1', values='idx')
# # resort index by mean weight value
# zi = np.array(weights)[cat_df_1d.to_numpy()]
# sort_order = np.argsort(np.max(zi, axis=0))
# cat_df_1d = cat_df_2d.reindex(cat_df_2d.index[sort_order])
# # construct data matrices
# xi = cat_df_1d.columns
# yi = cat_df_1d.index
# zi = np.array(weights)[cat_df_1d.to_numpy()]
# im = ax.imshow(zi, cmap=plt.get_cmap("RdYlGn"), aspect='auto', vmin=vmin, vmax=vmax)
# cbar = ax.figure.colorbar(im, ax=ax)
# cbar.ax.set_ylabel('% change', rotation=-91, va="bottom")
# if self._feature_names is not None:
# names = self._feature_names[feature_idx].split('_X_')
# ax.set_ylabel(names[-1])
# ax.set_xlabel(names[0])
# # Show all ticks and label them with the respective list entries
# ax.set_xticks(np.arange(len(xi)), labels=xi)
# ax.set_yticks(np.arange(len(yi)), labels=yi)
def plot(self, model: CBM, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
num_plots = max(self.feature_plots_, key=lambda d: d["feature_index"])["feature_index"] + 1
n_features = len(model.weights)
if num_plots != n_features:
raise ValueError(f"Missing plots for some features ({num_plots} vs {n_features})")
# setup plot
n_rows = num_plots
n_cols = 1
fig, ax = plt.subplots(n_rows, n_cols, **kwargs)
for i in range(num_plots):
ax[i].set_axis_off()
fig.suptitle(f'Response mean: {model.y_mean:0.2f} | Iterations {model.iterations}')
# extract weights from model
weights = model.weights
# find global min/max
vmin = np.min([np.min(w) for w in weights]) - 1
vmax = np.max([np.max(w) for w in weights]) - 1
for feature_idx in range(n_features):
ax_sub = ax[feature_idx]
ax_sub.set_axis_on()
feature_weights = np.array(weights[feature_idx])
self._plot_categorical(ax_sub, vmin, vmax, feature_weights, **self.feature_plots_[feature_idx])
plt.tight_layout()
return fig, ax
class CBMExplainer:
def __init__(self, pipeline: Pipeline):
if not isinstance(pipeline, Pipeline):
raise ArgumentTypeError("pipeline must be of type sklearn.pipeline.Pipeline")
self.pipeline_ = pipeline
def _plot_column_transformer(self, transformer: ColumnTransformer, plot: CBMExplainerPlot):
# need to access transformers_ (vs transformers) to get the fitted transformer instance
for (name, transformer, cols) in transformer.transformers_:
# extension methods ;)
transformer_inverter = TRANSFORMER_INVERTER_REGISTRY[type(transformer)]
category_names = transformer_inverter.get_category_names(transformer)
for (col_name, cat) in zip(cols, category_names):
plot.add_feature_plot(col_name, cat)
def plot_importance(self, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
plot = CBMExplainerPlot()
# iterate through pipeline
for (name, component) in self.pipeline_.steps[0:-1]:
if isinstance(component, ColumnTransformer):
self._plot_column_transformer(component, plot)
model = self.pipeline_.steps[-1][1]
return plot.plot(model, **kwargs)

Просмотреть файл

@ -2,6 +2,8 @@
# Licensed under the MIT License.
from .CBM import CBM
from .sklearn import DateEncoder, TemporalSplit
from .CBMExplainer import CBMExplainer
from ._version import __version__
__all__ = ['CBM', '__version__']
__all__ = ['CBM', '__version__', 'DateEncoder', 'TemporalSplit']

110
cbm/sklearn.py Normal file
Просмотреть файл

@ -0,0 +1,110 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import calendar
import numpy as np
from sklearn.model_selection import TimeSeriesSplit
from sklearn.utils import indexable
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import _check_feature_names_in
from datetime import timedelta
# TODO
class TemporalSplit(TimeSeriesSplit):
def __init__(self, step=timedelta(days=1), n_splits=5, *, max_train_size=None, test_size=None, gap=0):
super().__init__(n_splits)
self.step = step
self.max_train_size = max_train_size
self.test_size = test_size
self.gap = gap
def _create_date_ranges(self, start, end, step):
start_ = start
while start_ < end:
end_ = start_ + step
yield start_
start_ = end_
def split(self, X, y=None, groups=None):
"""Generate indices to split data into training and test set.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data, where `n_samples` is the number of samples
and `n_features` is the number of features.
y : array-like of shape (n_samples,)
Always ignored, exists for compatibility.
groups : array-like of shape (n_samples,)
Always ignored, exists for compatibility.
Yields
------
train : ndarray
The training set indices for that split.
test : ndarray
The testing set indices for that split.
"""
X, y, groups = indexable(X, y, groups)
date_range = list(self._create_date_ranges(X.index.min(), X.index.max(), self.step))
n_samples = len(date_range)
n_splits = self.n_splits
n_folds = n_splits + 1
gap = self.gap
test_size = (
self.test_size if self.test_size is not None else n_samples // n_folds
)
# Make sure we have enough samples for the given split parameters
if n_folds > n_samples:
raise ValueError(
f"Cannot have number of folds={n_folds} greater"
f" than the number of samples={n_samples}."
)
if n_samples - gap - (test_size * n_splits) <= 0:
raise ValueError(
f"Too many splits={n_splits} for number of samples"
f"={n_samples} with test_size={test_size} and gap={gap}."
)
test_starts = range(n_samples - n_splits * test_size, n_samples, test_size)
for test_start in test_starts:
train_end = test_start - gap
if self.max_train_size and self.max_train_size < train_end:
yield (
np.where(np.logical_and(X.index >= date_range[train_end - self.max_train_size], X.index <= date_range[train_end - 1]))[0],
np.where(np.logical_and(X.index >= date_range[test_start], X.index <= date_range[test_start + test_size - 1]))[0]
)
else:
yield (
np.where(X.index < date_range[train_end])[0],
np.where(np.logical_and(X.index >= date_range[test_start], X.index <= date_range[test_start + test_size - 1]))[0]
)
# TODO: add unit test
class DateEncoder(BaseEstimator, TransformerMixin):
def __init__(self, component = 'month' ):
if component == 'weekday':
self.categories_ = list(calendar.day_abbr)
self.column_to_ordinal_ = lambda col: col.dt.weekday.values
elif component == 'dayofyear':
self.categories_ = list(range(1, 366))
self.column_to_ordinal_ = lambda col: col.dt.dayofyear.values
elif component == 'month':
self.categories_ = list(calendar.month_abbr)
self.column_to_ordinal_ = lambda col: col.dt.month.values
else:
raise ValueError('component must be either day or month')
self.component = component
def fit(self, X, y = None):
self._validate_data(X, dtype="datetime64")
return self
def transform(self, X, y = None):
return X.apply(self.column_to_ordinal_, axis=0)

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -122,6 +122,8 @@ namespace cbm
y_sum[j][x_ij] += y[i];
y_sum[j][x_ij] += y[i];
if (enableBinCount)
_bin_count[j][x_ij]++;
}
@ -144,7 +146,6 @@ namespace cbm
{
for (size_t k = 0; k <= x_max[j]; k++)
{
// TODO: check if a bin is empty. might be better to remap/exclude the bins?
if (y_sum[j][k])
{
@ -164,6 +165,9 @@ namespace cbm
}
}
}
// update_y_hat_sum after every feature
update_y_hat_sum(y_hat, y_hat_sum, x, n_examples, n_features);
}
// prediction

53
tests/test_sklearn.py Normal file
Просмотреть файл

@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pytest
import numpy as np
import pandas as pd
from sklearn import linear_model
from sklearn.metrics import make_scorer, mean_squared_error
from sklearn.preprocessing import OrdinalEncoder, KBinsDiscretizer
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.compose import ColumnTransformer, make_column_transformer
from sklearn.model_selection import train_test_split, GridSearchCV
import lightgbm as lgb
import timeit
import cbm
def test_nyc_bicycle_sklearn():
# read data
bic = pd.read_csv(
'data/nyc_bb_bicyclist_counts.csv',
parse_dates=['Date'])
X_train = bic.drop('BB_COUNT', axis=1)
y_train = bic['BB_COUNT']
cats = make_column_transformer(
# https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OrdinalEncoder.html
# (OrdinalEncoder(dtype='int', handle_unknown='use_encoded_value', unknown_value=-1), # +1 in CBM code
# ['store_nbr', 'item_nbr', 'onpromotion', 'family', 'class', 'perishable']),
(cbm.DateEncoder('weekday'), ['Date', 'Date']),
(cbm.DateEncoder('month'), ['Date']),
(KBinsDiscretizer(n_bins=2, encode='ordinal'), ['HIGH_T', 'LOW_T']),
(KBinsDiscretizer(n_bins=5, encode='ordinal'), ['PRECIP']),
)
cbm_model = cbm.CBM()
pipeline = make_pipeline(cats, cbm_model)
cv = GridSearchCV(
pipeline,
param_grid={'columntransformer__kbinsdiscretizer-1__n_bins': np.arange(2, 15)},
scoring=make_scorer(mean_squared_error, squared=False),
cv=3
)
cv.fit(X_train, y_train)
print(cv.cv_results_['mean_test_score'])
print(cv.best_params_)
cbm.CBMExplainer(cv.best_estimator_).plot_importance()