зеркало из https://github.com/py-why/EconML.git
CATE uplift validation methods (#836)
Added additional functionality to the DRTester validation class to include AUTOC validation metric, with associated inference methods. Also included ability to plot uplift curve methods (both QINI and TOC curves) and cleaned up handling of multiple treatments.
This commit is contained in:
Родитель
7793184ed3
Коммит
67eef1e191
|
@ -88,13 +88,15 @@ class TestDRTester(unittest.TestCase):
|
|||
res = my_dr_tester.evaluate_all(Xval, Xtrain)
|
||||
res_df = res.summary()
|
||||
|
||||
for k in range(3):
|
||||
if k == 0:
|
||||
with self.assertRaises(Exception) as exc:
|
||||
res.plot_cal(k)
|
||||
self.assertTrue(str(exc.exception) == 'Plotting only supported for treated units (not controls)')
|
||||
else:
|
||||
for k in range(4):
|
||||
if k in [0, 3]:
|
||||
self.assertRaises(ValueError, res.plot_cal, k)
|
||||
self.assertRaises(ValueError, res.plot_qini, k)
|
||||
self.assertRaises(ValueError, res.plot_toc, k)
|
||||
else: # real treatments, k = 1 or 2
|
||||
self.assertTrue(res.plot_cal(k) is not None)
|
||||
self.assertTrue(res.plot_qini(k) is not None)
|
||||
self.assertTrue(res.plot_toc(k) is not None)
|
||||
|
||||
self.assertGreater(res_df.blp_pval.values[0], 0.1) # no heterogeneity
|
||||
self.assertLess(res_df.blp_pval.values[1], 0.05) # heterogeneity
|
||||
|
@ -103,6 +105,7 @@ class TestDRTester(unittest.TestCase):
|
|||
self.assertGreater(res_df.cal_r_squared.values[1], 0) # good R2
|
||||
|
||||
self.assertLess(res_df.qini_pval.values[1], res_df.qini_pval.values[0])
|
||||
self.assertLess(res_df.autoc_pval.values[1], res_df.autoc_pval.values[0])
|
||||
|
||||
def test_binary(self):
|
||||
Xtrain, Dtrain, Ytrain, Xval, Dval, Yval = self._get_data(num_treatments=1)
|
||||
|
@ -136,17 +139,20 @@ class TestDRTester(unittest.TestCase):
|
|||
res = my_dr_tester.evaluate_all(Xval, Xtrain)
|
||||
res_df = res.summary()
|
||||
|
||||
for k in range(2):
|
||||
if k == 0:
|
||||
with self.assertRaises(Exception) as exc:
|
||||
res.plot_cal(k)
|
||||
self.assertTrue(str(exc.exception) == 'Plotting only supported for treated units (not controls)')
|
||||
else:
|
||||
for k in range(3):
|
||||
if k in [0, 2]:
|
||||
self.assertRaises(ValueError, res.plot_cal, k)
|
||||
self.assertRaises(ValueError, res.plot_qini, k)
|
||||
self.assertRaises(ValueError, res.plot_toc, k)
|
||||
else: # real treatment, k = 1
|
||||
self.assertTrue(res.plot_cal(k) is not None)
|
||||
self.assertTrue(res.plot_qini(k) is not None)
|
||||
self.assertTrue(res.plot_toc(k) is not None)
|
||||
|
||||
self.assertLess(res_df.blp_pval.values[0], 0.05) # heterogeneity
|
||||
self.assertGreater(res_df.cal_r_squared.values[0], 0) # good R2
|
||||
self.assertLess(res_df.qini_pval.values[0], 0.05) # heterogeneity
|
||||
self.assertLess(res_df.autoc_pval.values[0], 0.05) # heterogeneity
|
||||
|
||||
def test_nuisance_val_fit(self):
|
||||
Xtrain, Dtrain, Ytrain, Xval, Dval, Yval = self._get_data(num_treatments=1)
|
||||
|
@ -209,7 +215,7 @@ class TestDRTester(unittest.TestCase):
|
|||
)
|
||||
|
||||
# fit nothing
|
||||
for func in [my_dr_tester.evaluate_blp, my_dr_tester.evaluate_cal, my_dr_tester.evaluate_qini]:
|
||||
for func in [my_dr_tester.evaluate_blp, my_dr_tester.evaluate_cal, my_dr_tester.evaluate_uplift]:
|
||||
with self.assertRaises(Exception) as exc:
|
||||
func()
|
||||
if func.__name__ == 'evaluate_cal':
|
||||
|
@ -226,7 +232,7 @@ class TestDRTester(unittest.TestCase):
|
|||
for func in [
|
||||
my_dr_tester.evaluate_blp,
|
||||
my_dr_tester.evaluate_cal,
|
||||
my_dr_tester.evaluate_qini,
|
||||
my_dr_tester.evaluate_uplift,
|
||||
my_dr_tester.evaluate_all
|
||||
]:
|
||||
with self.assertRaises(Exception) as exc:
|
||||
|
@ -241,7 +247,7 @@ class TestDRTester(unittest.TestCase):
|
|||
|
||||
for func in [
|
||||
my_dr_tester.evaluate_cal,
|
||||
my_dr_tester.evaluate_qini,
|
||||
my_dr_tester.evaluate_uplift,
|
||||
my_dr_tester.evaluate_all
|
||||
]:
|
||||
with self.assertRaises(Exception) as exc:
|
||||
|
@ -252,6 +258,12 @@ class TestDRTester(unittest.TestCase):
|
|||
cal_res = my_dr_tester.evaluate_cal(Xval, Xtrain)
|
||||
self.assertGreater(cal_res.cal_r_squared[0], 0) # good R2
|
||||
|
||||
with self.assertRaises(Exception) as exc:
|
||||
my_dr_tester.evaluate_uplift(metric='blah')
|
||||
self.assertTrue(
|
||||
str(exc.exception) == "Unsupported metric - must be one of ['toc', 'qini']"
|
||||
)
|
||||
|
||||
my_dr_tester = DRtester(
|
||||
model_regression=reg_y,
|
||||
model_propensity=reg_t,
|
||||
|
@ -259,5 +271,8 @@ class TestDRTester(unittest.TestCase):
|
|||
).fit_nuisance(
|
||||
Xval, Dval, Yval, Xtrain, Dtrain, Ytrain
|
||||
)
|
||||
qini_res = my_dr_tester.evaluate_qini(Xval, Xtrain)
|
||||
qini_res = my_dr_tester.evaluate_uplift(Xval, Xtrain)
|
||||
self.assertLess(qini_res.pvals[0], 0.05)
|
||||
|
||||
autoc_res = my_dr_tester.evaluate_uplift(Xval, Xtrain, metric='toc')
|
||||
self.assertLess(autoc_res.pvals[0], 0.05)
|
||||
|
|
|
@ -8,8 +8,8 @@ from sklearn.model_selection import cross_val_predict, StratifiedKFold, KFold
|
|||
from statsmodels.api import OLS
|
||||
from statsmodels.tools import add_constant
|
||||
|
||||
from .results import CalibrationEvaluationResults, BLPEvaluationResults, QiniEvaluationResults, EvaluationResults
|
||||
from .utils import calculate_dr_outcomes, calc_qini_coeff
|
||||
from .results import CalibrationEvaluationResults, BLPEvaluationResults, UpliftEvaluationResults, EvaluationResults
|
||||
from .utils import calculate_dr_outcomes, calc_uplift
|
||||
|
||||
|
||||
class DRtester:
|
||||
|
@ -382,7 +382,7 @@ class DRtester:
|
|||
self.get_cate_preds(Xval, Xtrain)
|
||||
|
||||
cal_r_squared = np.zeros(self.n_treat)
|
||||
df_plot = pd.DataFrame()
|
||||
plot_data_dict = dict()
|
||||
for k in range(self.n_treat):
|
||||
cuts = np.quantile(self.cate_preds_train_[:, k], np.linspace(0, 1, n_groups + 1))
|
||||
probs = np.zeros(n_groups)
|
||||
|
@ -409,15 +409,19 @@ class DRtester:
|
|||
# Calculate R-square calibration score
|
||||
cal_r_squared[k] = 1 - (cal_score_g / cal_score_o)
|
||||
|
||||
df_plot1 = pd.DataFrame({'ind': np.array(range(n_groups)),
|
||||
'gate': gate, 'se_gate': se_gate,
|
||||
'g_cate': g_cate, 'se_g_cate': se_g_cate})
|
||||
df_plot1['tmt'] = self.treatments[k + 1]
|
||||
df_plot = pd.concat((df_plot, df_plot1))
|
||||
df_plot = pd.DataFrame({
|
||||
'ind': np.array(range(n_groups)),
|
||||
'gate': gate,
|
||||
'se_gate': se_gate,
|
||||
'g_cate': g_cate,
|
||||
'se_g_cate': se_g_cate
|
||||
})
|
||||
|
||||
plot_data_dict[self.treatments[k + 1]] = df_plot
|
||||
|
||||
self.cal_res = CalibrationEvaluationResults(
|
||||
cal_r_squared=cal_r_squared,
|
||||
df_plot=df_plot,
|
||||
plot_data_dict=plot_data_dict,
|
||||
treatments=self.treatments
|
||||
)
|
||||
|
||||
|
@ -480,12 +484,13 @@ class DRtester:
|
|||
|
||||
return self.blp_res
|
||||
|
||||
def evaluate_qini(
|
||||
def evaluate_uplift(
|
||||
self,
|
||||
Xval: np.array = None,
|
||||
Xtrain: np.array = None,
|
||||
percentiles: np.array = np.linspace(5, 95, 50)
|
||||
) -> QiniEvaluationResults:
|
||||
percentiles: np.array = np.linspace(5, 95, 50),
|
||||
metric: str = 'qini'
|
||||
) -> UpliftEvaluationResults:
|
||||
"""
|
||||
Calculates QINI coefficient for the given model as in Radcliffe (2007), where units are ordered by predicted
|
||||
CATE values and a running measure of the average treatment effect in each cohort is kept as we progress
|
||||
|
@ -505,10 +510,12 @@ class DRtester:
|
|||
percentiles: one-dimensional array, default ``np.linspace(5, 95, 50)''
|
||||
Array of percentiles over which the QINI curve should be constructed. Defaults to 5%-95% in intervals of
|
||||
5%.
|
||||
metric: string, default 'qini'
|
||||
Which type of uplift curve to evaluate. Must be one of ['toc', 'qini']
|
||||
|
||||
Returns
|
||||
-------
|
||||
QiniEvaluationResults object showing the results of the QINI fit
|
||||
UpliftEvaluationResults object showing the fitted results
|
||||
"""
|
||||
if not hasattr(self, 'dr_val_'):
|
||||
raise Exception("Must fit nuisances before evaluating")
|
||||
|
@ -518,39 +525,44 @@ class DRtester:
|
|||
raise Exception('CATE predictions not yet calculated - must provide both Xval, Xtrain')
|
||||
self.get_cate_preds(Xval, Xtrain)
|
||||
|
||||
curve_data_dict = dict()
|
||||
if self.n_treat == 1:
|
||||
qini, qini_err = calc_qini_coeff(
|
||||
coeff, err, curve_df = calc_uplift(
|
||||
self.cate_preds_train_,
|
||||
self.cate_preds_val_,
|
||||
self.dr_val_,
|
||||
percentiles
|
||||
percentiles,
|
||||
metric
|
||||
)
|
||||
qinis = [qini]
|
||||
errs = [qini_err]
|
||||
coeffs = [coeff]
|
||||
errs = [err]
|
||||
curve_data_dict[self.treatments[1]] = curve_df
|
||||
else:
|
||||
qinis = []
|
||||
coeffs = []
|
||||
errs = []
|
||||
for k in range(self.n_treat):
|
||||
qini, qini_err = calc_qini_coeff(
|
||||
coeff, err, curve_df = calc_uplift(
|
||||
self.cate_preds_train_[:, k],
|
||||
self.cate_preds_val_[:, k],
|
||||
self.dr_val_[:, k],
|
||||
percentiles
|
||||
percentiles,
|
||||
metric
|
||||
)
|
||||
coeffs.append(coeff)
|
||||
errs.append(err)
|
||||
curve_data_dict[self.treatments[k + 1]] = curve_df
|
||||
|
||||
qinis.append(qini)
|
||||
errs.append(qini_err)
|
||||
pvals = [st.norm.sf(abs(q / e)) for q, e in zip(coeffs, errs)]
|
||||
|
||||
pvals = [st.norm.sf(abs(q / e)) for q, e in zip(qinis, errs)]
|
||||
|
||||
self.qini_res = QiniEvaluationResults(
|
||||
params=qinis,
|
||||
self.uplift_res = UpliftEvaluationResults(
|
||||
params=coeffs,
|
||||
errs=errs,
|
||||
pvals=pvals,
|
||||
treatments=self.treatments
|
||||
treatments=self.treatments,
|
||||
curve_data_dict=curve_data_dict
|
||||
)
|
||||
|
||||
return self.qini_res
|
||||
return self.uplift_res
|
||||
|
||||
def evaluate_all(
|
||||
self,
|
||||
|
@ -559,8 +571,8 @@ class DRtester:
|
|||
n_groups: int = 4
|
||||
) -> EvaluationResults:
|
||||
"""
|
||||
Implements the best linear prediction (`evaluate_blp'), calibration (`evaluate_cal') and QINI coefficient
|
||||
(`evaluate_qini') methods.
|
||||
Implements the best linear prediction (`evaluate_blp'), calibration (`evaluate_cal'), uplift curve
|
||||
('evaluate_uplift') methods
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -583,12 +595,14 @@ class DRtester:
|
|||
|
||||
blp_res = self.evaluate_blp()
|
||||
cal_res = self.evaluate_cal(n_groups=n_groups)
|
||||
qini_res = self.evaluate_qini()
|
||||
qini_res = self.evaluate_uplift(metric='qini')
|
||||
toc_res = self.evaluate_uplift(metric='toc')
|
||||
|
||||
self.res = EvaluationResults(
|
||||
blp_res=blp_res,
|
||||
cal_res=cal_res,
|
||||
qini_res=qini_res
|
||||
qini_res=qini_res,
|
||||
toc_res=toc_res
|
||||
)
|
||||
|
||||
return self.res
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import List
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
class CalibrationEvaluationResults:
|
||||
|
@ -13,8 +13,9 @@ class CalibrationEvaluationResults:
|
|||
cal_r_squared: list or numpy array of floats
|
||||
Sequence of calibration R^2 values
|
||||
|
||||
df_plot: pandas dataframe
|
||||
Dataframe containing necessary data for plotting calibration test GATE results
|
||||
plot_data_dict: dict
|
||||
Dictionary mapping treatment levels to dataframes containing necessary
|
||||
data for plotting calibration test GATE results
|
||||
|
||||
treatments: list or numpy array of floats
|
||||
Sequence of treatment labels
|
||||
|
@ -22,11 +23,11 @@ class CalibrationEvaluationResults:
|
|||
def __init__(
|
||||
self,
|
||||
cal_r_squared: np.array,
|
||||
df_plot: pd.DataFrame,
|
||||
plot_data_dict: Dict[Any, pd.DataFrame],
|
||||
treatments: np.array
|
||||
):
|
||||
self.cal_r_squared = cal_r_squared
|
||||
self.df_plot = df_plot
|
||||
self.plot_data_dict = plot_data_dict
|
||||
self.treatments = treatments
|
||||
|
||||
def summary(self) -> pd.DataFrame:
|
||||
|
@ -48,24 +49,23 @@ class CalibrationEvaluationResults:
|
|||
}).round(3)
|
||||
return res
|
||||
|
||||
def plot_cal(self, tmt: int):
|
||||
def plot_cal(self, tmt: Any):
|
||||
"""
|
||||
Plots group average treatment effects (GATEs) and predicted GATEs by quantile-based group in validation sample.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tmt: integer
|
||||
Treatment level to plot
|
||||
tmt: Any
|
||||
Name of treatment to plot
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib plot with predicted GATE on x-axis and GATE (and 95% CI) on y-axis
|
||||
"""
|
||||
if tmt == 0:
|
||||
raise Exception('Plotting only supported for treated units (not controls)')
|
||||
if tmt not in self.treatments[1:]:
|
||||
raise ValueError(f'Invalid treatment; must be one of {self.treatments[1:]}')
|
||||
|
||||
df = self.df_plot
|
||||
df = df[df.tmt == tmt].copy()
|
||||
df = self.plot_data_dict[tmt].copy()
|
||||
rsq = round(self.cal_r_squared[np.where(self.treatments == tmt)[0][0] - 1], 3)
|
||||
df['95_err'] = 1.96 * df['se_gate']
|
||||
fig = df.plot(
|
||||
|
@ -132,9 +132,9 @@ class BLPEvaluationResults:
|
|||
return res
|
||||
|
||||
|
||||
class QiniEvaluationResults:
|
||||
class UpliftEvaluationResults:
|
||||
"""
|
||||
Results class for QINI test.
|
||||
Results class for uplift curve-based tests.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -149,18 +149,24 @@ class QiniEvaluationResults:
|
|||
|
||||
treatments: list or numpy array of floats
|
||||
Sequence of treatment labels
|
||||
|
||||
curve_data_dict: dict
|
||||
Dictionary mapping treatment levels to dataframes containing
|
||||
necessary data for plotting uplift curves
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
params: List[float],
|
||||
errs: List[float],
|
||||
pvals: List[float],
|
||||
treatments: np.array
|
||||
treatments: np.array,
|
||||
curve_data_dict: Dict[Any, pd.DataFrame]
|
||||
):
|
||||
self.params = params
|
||||
self.errs = errs
|
||||
self.pvals = pvals
|
||||
self.treatments = treatments
|
||||
self.curves = curve_data_dict
|
||||
|
||||
def summary(self):
|
||||
"""
|
||||
|
@ -176,12 +182,44 @@ class QiniEvaluationResults:
|
|||
"""
|
||||
res = pd.DataFrame({
|
||||
'treatment': self.treatments[1:],
|
||||
'qini_est': self.params,
|
||||
'qini_se': self.errs,
|
||||
'qini_pval': self.pvals
|
||||
'est': self.params,
|
||||
'se': self.errs,
|
||||
'pval': self.pvals
|
||||
}).round(3)
|
||||
return res
|
||||
|
||||
def plot_uplift(self, tmt: Any):
|
||||
"""
|
||||
Plots uplift curves.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tmt: any (sortable)
|
||||
Name of treatment to plot.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib plot with percentage treated on x-axis and uplift metric (and 95% CI) on y-axis
|
||||
"""
|
||||
if tmt not in self.treatments[1:]:
|
||||
raise ValueError(f'Invalid treatment; must be one of {self.treatments[1:]}')
|
||||
|
||||
df = self.curves[tmt].copy()
|
||||
df['95_err'] = 1.96 * df['err']
|
||||
res = self.summary()
|
||||
coeff = round(res.loc[res['treatment'] == tmt]['est'].values[0], 3)
|
||||
err = round(res.loc[res['treatment'] == tmt]['se'].values[0], 3)
|
||||
fig = df.plot(
|
||||
kind='scatter',
|
||||
x='Percentage treated',
|
||||
y='value',
|
||||
yerr='95_err',
|
||||
ylabel='Gain over Random',
|
||||
title=f"Treatment = {tmt}, Integral = {coeff} +/- {err}"
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
class EvaluationResults:
|
||||
"""
|
||||
|
@ -195,18 +233,23 @@ class EvaluationResults:
|
|||
blp_res: BLPEvaluationResults object
|
||||
Results object for BLP test
|
||||
|
||||
qini_res: QiniEvaluationResults object
|
||||
qini_res: UpliftEvaluationResults object
|
||||
Results object for QINI test
|
||||
|
||||
toc_res: UpliftEvaluationResults object
|
||||
Results object for TOC test
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
cal_res: CalibrationEvaluationResults,
|
||||
blp_res: BLPEvaluationResults,
|
||||
qini_res: QiniEvaluationResults
|
||||
qini_res: UpliftEvaluationResults,
|
||||
toc_res: UpliftEvaluationResults
|
||||
):
|
||||
self.cal = cal_res
|
||||
self.blp = blp_res
|
||||
self.qini = qini_res
|
||||
self.toc = toc_res
|
||||
|
||||
def summary(self):
|
||||
"""
|
||||
|
@ -221,7 +264,10 @@ class EvaluationResults:
|
|||
pandas dataframe containing summary of all test results
|
||||
"""
|
||||
res = self.blp.summary().merge(
|
||||
self.qini.summary(),
|
||||
self.qini.summary().rename({'est': 'qini_est', 'se': 'qini_se', 'pval': 'qini_pval'}, axis=1),
|
||||
on='treatment'
|
||||
).merge(
|
||||
self.toc.summary().rename({'est': 'autoc_est', 'se': 'autoc_se', 'pval': 'autoc_pval'}, axis=1),
|
||||
on='treatment'
|
||||
).merge(
|
||||
self.cal.summary(),
|
||||
|
@ -243,3 +289,33 @@ class EvaluationResults:
|
|||
matplotlib plot with predicted GATE on x-axis and GATE (and 95% CI) on y-axis
|
||||
"""
|
||||
return self.cal.plot_cal(tmt)
|
||||
|
||||
def plot_qini(self, tmt: int):
|
||||
"""
|
||||
Plots QINI curves.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tmt: integer
|
||||
Treatment level to plot
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib plot with percentage treated on x-axis and QINI value (and 95% CI) on y-axis
|
||||
"""
|
||||
return self.qini.plot_uplift(tmt)
|
||||
|
||||
def plot_toc(self, tmt: int):
|
||||
"""
|
||||
Plots TOC curves.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tmt: integer
|
||||
Treatment level to plot
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib plot with percentage treated on x-axis and TOC value (and 95% CI) on y-axis
|
||||
"""
|
||||
return self.toc.plot_uplift(tmt)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def calculate_dr_outcomes(
|
||||
|
@ -47,15 +48,16 @@ def calculate_dr_outcomes(
|
|||
return dr
|
||||
|
||||
|
||||
def calc_qini_coeff(
|
||||
def calc_uplift(
|
||||
cate_preds_train: np.array,
|
||||
cate_preds_val: np.array,
|
||||
dr_val: np.array,
|
||||
percentiles: np.array
|
||||
) -> Tuple[float, float]:
|
||||
percentiles: np.array,
|
||||
metric: str
|
||||
) -> Tuple[float, float, pd.DataFrame]:
|
||||
"""
|
||||
Helper function for QINI coefficient calculation. See documentation for "evaluate_qini" method
|
||||
for more details.
|
||||
Helper function for QINI curve generation and QINI coefficient calculation.
|
||||
See documentation for "evaluate_qini" method for more details.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -68,10 +70,12 @@ def calc_qini_coeff(
|
|||
control, e.g. for treatment k the value is Y(k) - Y(0), where 0 signifies no treatment.
|
||||
percentiles: one-dimensional array
|
||||
Array of percentiles over which the QINI curve should be constructed. Defaults to 5%-95% in intervals of 5%.
|
||||
metric: string
|
||||
String indicating whether to calculate TOC or QINI; should be one of ['toc', 'qini']
|
||||
|
||||
Returns
|
||||
-------
|
||||
QINI coefficient and associated standard error.
|
||||
Uplift coefficient and associated standard error, as well as associated curve.
|
||||
"""
|
||||
qs = np.percentile(cate_preds_train, percentiles)
|
||||
toc, toc_std, group_prob = np.zeros(len(qs)), np.zeros(len(qs)), np.zeros(len(qs))
|
||||
|
@ -81,14 +85,27 @@ def calc_qini_coeff(
|
|||
for it in range(len(qs)):
|
||||
inds = (qs[it] <= cate_preds_val) # group with larger CATE prediction than the q-th quantile
|
||||
group_prob = np.sum(inds) / n # fraction of population in this group
|
||||
toc[it] = group_prob * (
|
||||
np.mean(dr_val[inds]) - ate) # tau(q) = q * E[Y(1) - Y(0) | tau(X) >= q[it]] - E[Y(1) - Y(0)]
|
||||
toc_psi[it, :] = np.squeeze(
|
||||
(dr_val - ate) * (inds - group_prob) - toc[it]) # influence function for the tau(q)
|
||||
if metric == 'qini':
|
||||
toc[it] = group_prob * (
|
||||
np.mean(dr_val[inds]) - ate) # tau(q) = q * E[Y(1) - Y(0) | tau(X) >= q[it]] - E[Y(1) - Y(0)]
|
||||
toc_psi[it, :] = np.squeeze(
|
||||
(dr_val - ate) * (inds - group_prob) - toc[it]) # influence function for the tau(q)
|
||||
elif metric == 'toc':
|
||||
toc[it] = np.mean(dr_val[inds]) - ate # tau(q) := E[Y(1) - Y(0) | tau(X) >= q[it]] - E[Y(1) - Y(0)]
|
||||
toc_psi[it, :] = np.squeeze((dr_val - ate) * (inds / group_prob - 1) - toc[it])
|
||||
else:
|
||||
raise ValueError("Unsupported metric - must be one of ['toc', 'qini']")
|
||||
|
||||
toc_std[it] = np.sqrt(np.mean(toc_psi[it] ** 2) / n) # standard error of tau(q)
|
||||
|
||||
qini_psi = np.sum(toc_psi[:-1] * np.diff(percentiles).reshape(-1, 1) / 100, 0)
|
||||
qini = np.sum(toc[:-1] * np.diff(percentiles) / 100)
|
||||
qini_stderr = np.sqrt(np.mean(qini_psi ** 2) / n)
|
||||
coeff_psi = np.sum(toc_psi[:-1] * np.diff(percentiles).reshape(-1, 1) / 100, 0)
|
||||
coeff = np.sum(toc[:-1] * np.diff(percentiles) / 100)
|
||||
coeff_stderr = np.sqrt(np.mean(coeff_psi ** 2) / n)
|
||||
|
||||
return qini, qini_stderr
|
||||
curve_df = pd.DataFrame({
|
||||
'Percentage treated': 100 - percentiles,
|
||||
'value': toc,
|
||||
'err': toc_std
|
||||
})
|
||||
|
||||
return coeff, coeff_stderr, curve_df
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Загрузка…
Ссылка в новой задаче