Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
This commit is contained in:
Keith Battocchi 2023-04-26 12:38:18 -04:00 коммит произвёл Keith Battocchi
Родитель c39cc10116
Коммит bc289f1e0d
2 изменённых файлов: 2 добавлений и 1 удалений

Просмотреть файл

@ -899,7 +899,7 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
nuisances = [np.zeros((n_iters * n_splits,) + nuis.shape) for nuis in nuisance_temp]
for it, nuis in enumerate(nuisance_temp):
nuisances[it][i * n_iters + j] = nuis
nuisances[it][j * n_iters + i] = nuis
for it in range(len(nuisances)):
nuisances[it] = np.mean(nuisances[it], axis=0)

Просмотреть файл

@ -1095,6 +1095,7 @@ class TestDML(unittest.TestCase):
est.fit(y, T, X=X, W=W)
assert len(est.nuisance_scores_t) == len(est.nuisance_scores_y) == mc_iters
assert len(est.nuisance_scores_t[0]) == len(est.nuisance_scores_y[0]) == cv
est.score(y, T, X=X, W=W)
def test_categories(self):
dmls = [LinearDML, SparseLinearDML]