Add user feedback to compute() function (#697)
* Add user feedback to compute() * ADd user feedback to compute( * Add log wrapper * Fix explanations Signed-off-by: Gaurav Gupta <gaugup@microsoft.com> * Fix imports Signed-off-by: Gaurav Gupta <gaugup@microsoft.com> --------- Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>
This commit is contained in:
Родитель
0bbd92f885
Коммит
b2d98f82a2
|
@ -22,6 +22,7 @@ from responsibleai._tools.shared.state_directory_management import \
|
|||
DirectoryManager
|
||||
from responsibleai.feature_metadata import FeatureMetadata
|
||||
from responsibleai.managers.base_manager import BaseManager
|
||||
from responsibleai.utils import _measure_time
|
||||
|
||||
|
||||
class CausalManager(BaseManager):
|
||||
|
@ -353,9 +354,12 @@ class CausalManager(BaseManager):
|
|||
result = filtered[0]
|
||||
return result._global_cohort_policy(X_test)
|
||||
|
||||
@_measure_time
|
||||
def compute(self):
|
||||
"""Computes the causal effects by running the causal
|
||||
configuration."""
|
||||
print("Causal Effects")
|
||||
print('Current Status: Generating Causal Effects.')
|
||||
is_classification = self._task_type == ModelTask.CLASSIFICATION
|
||||
for result in self._results:
|
||||
causal_config = result.config
|
||||
|
@ -431,6 +435,7 @@ class CausalManager(BaseManager):
|
|||
result.policies.append(policy)
|
||||
|
||||
result._validate_schema()
|
||||
print('Current Status: Finished generating causal effects.')
|
||||
|
||||
def get(self):
|
||||
"""Get the computed causal insights."""
|
||||
|
|
|
@ -29,6 +29,7 @@ from responsibleai._tools.shared.state_directory_management import \
|
|||
from responsibleai.exceptions import (DuplicateManagerConfigException,
|
||||
SchemaErrorException)
|
||||
from responsibleai.managers.base_manager import BaseManager
|
||||
from responsibleai.utils import _measure_time
|
||||
|
||||
|
||||
class CounterfactualConstants:
|
||||
|
@ -539,13 +540,19 @@ class CounterfactualManager(BaseManager):
|
|||
|
||||
self._add_counterfactual_config(counterfactual_config)
|
||||
|
||||
@_measure_time
|
||||
def compute(self):
|
||||
"""Computes the counterfactual examples by running the counterfactual
|
||||
configuration."""
|
||||
print("Counterfactual")
|
||||
for cf_config in self._counterfactual_config_list:
|
||||
if not cf_config.is_computed:
|
||||
cf_config.is_computed = True
|
||||
try:
|
||||
print("Current Status: Generating {0} counterfactuals"
|
||||
" for {1} samples".format(
|
||||
cf_config.total_CFs, len(self._test)))
|
||||
|
||||
cf_config.explainer = self._create_diceml_explainer(
|
||||
method=cf_config.method,
|
||||
continuous_features=cf_config.continuous_features)
|
||||
|
@ -578,6 +585,9 @@ class CounterfactualManager(BaseManager):
|
|||
|
||||
cf_config.counterfactual_obj = counterfactual_obj
|
||||
|
||||
print('Current Status: Generated {0} counterfactuals'
|
||||
' for {1} samples.'.format(
|
||||
cf_config.total_CFs, len(self._test)))
|
||||
except Exception as e:
|
||||
cf_config.has_computation_failed = True
|
||||
cf_config.failure_reason = str(e)
|
||||
|
|
|
@ -26,6 +26,7 @@ from responsibleai._tools.shared.state_directory_management import \
|
|||
from responsibleai.exceptions import (ConfigAndResultMismatchException,
|
||||
DuplicateManagerConfigException)
|
||||
from responsibleai.managers.base_manager import BaseManager
|
||||
from responsibleai.utils import _measure_time
|
||||
|
||||
REPORTS = 'reports'
|
||||
CONFIG = 'config'
|
||||
|
@ -303,9 +304,12 @@ class ErrorAnalysisManager(BaseManager):
|
|||
else:
|
||||
self._ea_config_list.append(ea_config)
|
||||
|
||||
@_measure_time
|
||||
def compute(self):
|
||||
"""Creates an ErrorReport by running the error analyzer on the model.
|
||||
"""
|
||||
print("Error Analysis")
|
||||
print('Current Status: Generating error analysis reports.')
|
||||
for config in self._ea_config_list:
|
||||
if config.is_computed:
|
||||
continue
|
||||
|
@ -327,6 +331,7 @@ class ErrorAnalysisManager(BaseManager):
|
|||
json.loads(report.to_json()), schema)
|
||||
|
||||
self._ea_report_list.append(report)
|
||||
print('Current Status: Finished generating error analysis reports.')
|
||||
|
||||
def get(self):
|
||||
"""Get the computed error reports.
|
||||
|
|
|
@ -31,6 +31,7 @@ from responsibleai._internal.constants import (ExplanationKeys, ListProperties,
|
|||
from responsibleai._tools.shared.state_directory_management import \
|
||||
DirectoryManager
|
||||
from responsibleai.managers.base_manager import BaseManager
|
||||
from responsibleai.utils import _measure_time
|
||||
|
||||
SPARSE_NUM_FEATURES_THRESHOLD = 1000
|
||||
IS_RUN = 'is_run'
|
||||
|
@ -153,8 +154,13 @@ class ExplainerManager(BaseManager):
|
|||
return explainer.explain_global(data[
|
||||
0:MAXIMUM_ROWS_FOR_GLOBAL_EXPLANATIONS], include_local=local)
|
||||
|
||||
@_measure_time
|
||||
def compute(self):
|
||||
"""Creates an explanation by running the explainer on the model."""
|
||||
|
||||
print("Explanations")
|
||||
print('Current Status: Explaining {0} features'.format(
|
||||
len(self._features)))
|
||||
if not self._is_added:
|
||||
return
|
||||
if self._is_run:
|
||||
|
@ -165,6 +171,9 @@ class ExplainerManager(BaseManager):
|
|||
)
|
||||
self._is_run = True
|
||||
|
||||
print('Current Status: Explained {0} features.'.format(
|
||||
len(self._features)))
|
||||
|
||||
def get(self):
|
||||
"""Get the computed explanation.
|
||||
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Microsoft Corporation
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import timeit
|
||||
|
||||
|
||||
def _measure_time(manager_compute_func):
|
||||
def compute_wrapper(*args, **kwargs):
|
||||
_separator(80)
|
||||
start_time = timeit.default_timer()
|
||||
manager_compute_func(*args, **kwargs)
|
||||
elapsed = timeit.default_timer() - start_time
|
||||
m, s = divmod(elapsed, 60)
|
||||
print('Time taken: {0} min {1} sec'.format(
|
||||
m, s))
|
||||
_separator(80)
|
||||
return compute_wrapper
|
||||
|
||||
|
||||
def _separator(max_len):
|
||||
print('=' * max_len)
|
Загрузка…
Ссылка в новой задаче