diff --git a/econml/_ortho_learner.py b/econml/_ortho_learner.py index fdc9e769..39a300ad 100644 --- a/econml/_ortho_learner.py +++ b/econml/_ortho_learner.py @@ -899,7 +899,7 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator): nuisances = [np.zeros((n_iters * n_splits,) + nuis.shape) for nuis in nuisance_temp] for it, nuis in enumerate(nuisance_temp): - nuisances[it][i * n_iters + j] = nuis + nuisances[it][j * n_iters + i] = nuis for it in range(len(nuisances)): nuisances[it] = np.mean(nuisances[it], axis=0) diff --git a/econml/tests/test_dml.py b/econml/tests/test_dml.py index 8105f7ec..57b5c3ec 100644 --- a/econml/tests/test_dml.py +++ b/econml/tests/test_dml.py @@ -1095,6 +1095,7 @@ class TestDML(unittest.TestCase): est.fit(y, T, X=X, W=W) assert len(est.nuisance_scores_t) == len(est.nuisance_scores_y) == mc_iters assert len(est.nuisance_scores_t[0]) == len(est.nuisance_scores_y[0]) == cv + est.score(y, T, X=X, W=W) def test_categories(self): dmls = [LinearDML, SparseLinearDML]