Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
This commit is contained in:
Keith Battocchi 2024-08-10 13:04:18 -04:00 коммит произвёл Keith Battocchi
Родитель 1fbeb76f82
Коммит 8b29a8dafc
7 изменённых файлов: 18 добавлений и 14 удалений

Просмотреть файл

@ -694,7 +694,7 @@ class _SingleTreeExporterMixin(metaclass=abc.ABCMeta):
own_file = False
try:
if isinstance(out_file, str):
out_file = open(out_file, "w", encoding="utf-8")
out_file = open(out_file, "w", encoding="utf-8") # noqa: SIM115, we close explicitly by design
own_file = True
return_string = out_file is None

Просмотреть файл

@ -763,7 +763,7 @@ class CausalForestDML(_BaseDML):
else:
# If custom param grid, check that only estimator parameters are being altered
estimator_param_names = self.tunable_params
for key in params.keys():
for key in params:
if key not in estimator_param_names:
raise ValueError(f"Parameter `{key}` is not an tunable causal forest parameter.")

Просмотреть файл

@ -721,9 +721,9 @@ class _DMLOrthoForest_nuisance_estimator_generator:
def __call__(self, Y, T, X, W, sample_weight=None, split_indices=None):
if self.global_residualization:
return 0
if self.discrete_treatment:
# Check that all discrete treatments are represented
if len(np.unique(T @ np.arange(1, T.shape[1] + 1))) < T.shape[1] + 1:
# Check that all discrete treatments are represented
if (self.discrete_treatment and
len(np.unique(T @ np.arange(1, T.shape[1] + 1))) < T.shape[1] + 1):
return None
# Nuissance estimates evaluated with cross-fitting
this_random_state = check_random_state(self.random_state)

Просмотреть файл

@ -525,10 +525,10 @@ class SklearnCVSelector(SingleModelSelector):
best_model, score = SklearnCVSelector._convert_model(inner_model, args, kwargs)
return Pipeline(steps=[*model.steps[:-1], (name, best_model)]), score
if isinstance(model, GridSearchCV) or isinstance(model, RandomizedSearchCV):
if isinstance(model, (GridSearchCV, RandomizedSearchCV)):
return model.best_estimator_, model.best_score_
for known_type in SklearnCVSelector._model_mapping().keys():
for known_type in SklearnCVSelector._model_mapping():
if isinstance(model, known_type):
converter = SklearnCVSelector._model_mapping()[known_type]
return converter(model, args, kwargs)

Просмотреть файл

@ -669,7 +669,7 @@ class TestDRLearner(unittest.TestCase):
[2, 3, len(feature_names) +
(W.shape[1] if W is not None else 0)])
if isinstance(est, LinearDRLearner) or isinstance(est, SparseLinearDRLearner):
if isinstance(est, (LinearDRLearner, SparseLinearDRLearner)):
if X is not None:
for t in [1, 2]:
true_coef = np.zeros(

Просмотреть файл

@ -117,7 +117,7 @@ class TestPolicyForest(unittest.TestCase):
for sample_weight in [None, 'rand']:
for n_outcomes in n_outcome_list:
config = self._get_base_config()
config['honest'] = True if not dr else False
config['honest'] = not dr
config['criterion'] = criterion
config['max_depth'] = 2
config['min_samples_leaf'] = 5

Просмотреть файл

@ -150,14 +150,18 @@ ignore = [
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D301", # Use r""" if any backslashes in a docstring
]
extend-select = [
"D301", # Use r""" if any backslashes in a docstring,
"SIM108", # Use ternary instead of if-else (looks ugly for some of our long expressions)
"SIM300", # Yoda condition detected (these are often easier to understand in array expressions)
]
select = [
"D", # Docstring
"E501", # Line too long
"W", # Pycodestyle warnings
"E", # All Pycodestyle erros, not just the default ones
"F", # All pyflakes rules
"SIM", # Simplifification
]
extend-per-file-ignores = { "econml/tests" = ["D"] } # ignore docstring rules for tests
[tool.ruff.lint.pydocstyle]
convention = "numpy"
convention = "numpy"