Add user feedback to compute() function (#697)

* Add user feedback to compute()

* ADd user feedback to compute(

* Add log wrapper

* Fix explanations

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>

* Fix imports

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>

---------

Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>
This commit is contained in:
Gaurav Gupta 2023-05-30 15:29:39 -07:00 коммит произвёл GitHub
Родитель 0bbd92f885
Коммит b2d98f82a2
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 50 добавлений и 0 удалений

Просмотреть файл

@ -22,6 +22,7 @@ from responsibleai._tools.shared.state_directory_management import \
DirectoryManager
from responsibleai.feature_metadata import FeatureMetadata
from responsibleai.managers.base_manager import BaseManager
from responsibleai.utils import _measure_time
class CausalManager(BaseManager):
@ -353,9 +354,12 @@ class CausalManager(BaseManager):
result = filtered[0]
return result._global_cohort_policy(X_test)
@_measure_time
def compute(self):
"""Computes the causal effects by running the causal
configuration."""
print("Causal Effects")
print('Current Status: Generating Causal Effects.')
is_classification = self._task_type == ModelTask.CLASSIFICATION
for result in self._results:
causal_config = result.config
@ -431,6 +435,7 @@ class CausalManager(BaseManager):
result.policies.append(policy)
result._validate_schema()
print('Current Status: Finished generating causal effects.')
def get(self):
"""Get the computed causal insights."""

Просмотреть файл

@ -29,6 +29,7 @@ from responsibleai._tools.shared.state_directory_management import \
from responsibleai.exceptions import (DuplicateManagerConfigException,
SchemaErrorException)
from responsibleai.managers.base_manager import BaseManager
from responsibleai.utils import _measure_time
class CounterfactualConstants:
@ -539,13 +540,19 @@ class CounterfactualManager(BaseManager):
self._add_counterfactual_config(counterfactual_config)
@_measure_time
def compute(self):
"""Computes the counterfactual examples by running the counterfactual
configuration."""
print("Counterfactual")
for cf_config in self._counterfactual_config_list:
if not cf_config.is_computed:
cf_config.is_computed = True
try:
print("Current Status: Generating {0} counterfactuals"
" for {1} samples".format(
cf_config.total_CFs, len(self._test)))
cf_config.explainer = self._create_diceml_explainer(
method=cf_config.method,
continuous_features=cf_config.continuous_features)
@ -578,6 +585,9 @@ class CounterfactualManager(BaseManager):
cf_config.counterfactual_obj = counterfactual_obj
print('Current Status: Generated {0} counterfactuals'
' for {1} samples.'.format(
cf_config.total_CFs, len(self._test)))
except Exception as e:
cf_config.has_computation_failed = True
cf_config.failure_reason = str(e)

Просмотреть файл

@ -26,6 +26,7 @@ from responsibleai._tools.shared.state_directory_management import \
from responsibleai.exceptions import (ConfigAndResultMismatchException,
DuplicateManagerConfigException)
from responsibleai.managers.base_manager import BaseManager
from responsibleai.utils import _measure_time
REPORTS = 'reports'
CONFIG = 'config'
@ -303,9 +304,12 @@ class ErrorAnalysisManager(BaseManager):
else:
self._ea_config_list.append(ea_config)
@_measure_time
def compute(self):
"""Creates an ErrorReport by running the error analyzer on the model.
"""
print("Error Analysis")
print('Current Status: Generating error analysis reports.')
for config in self._ea_config_list:
if config.is_computed:
continue
@ -327,6 +331,7 @@ class ErrorAnalysisManager(BaseManager):
json.loads(report.to_json()), schema)
self._ea_report_list.append(report)
print('Current Status: Finished generating error analysis reports.')
def get(self):
"""Get the computed error reports.

Просмотреть файл

@ -31,6 +31,7 @@ from responsibleai._internal.constants import (ExplanationKeys, ListProperties,
from responsibleai._tools.shared.state_directory_management import \
DirectoryManager
from responsibleai.managers.base_manager import BaseManager
from responsibleai.utils import _measure_time
SPARSE_NUM_FEATURES_THRESHOLD = 1000
IS_RUN = 'is_run'
@ -153,8 +154,13 @@ class ExplainerManager(BaseManager):
return explainer.explain_global(data[
0:MAXIMUM_ROWS_FOR_GLOBAL_EXPLANATIONS], include_local=local)
@_measure_time
def compute(self):
"""Creates an explanation by running the explainer on the model."""
print("Explanations")
print('Current Status: Explaining {0} features'.format(
len(self._features)))
if not self._is_added:
return
if self._is_run:
@ -165,6 +171,9 @@ class ExplainerManager(BaseManager):
)
self._is_run = True
print('Current Status: Explained {0} features.'.format(
len(self._features)))
def get(self):
"""Get the computed explanation.

Просмотреть файл

@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.
import timeit
def _measure_time(manager_compute_func):
def compute_wrapper(*args, **kwargs):
_separator(80)
start_time = timeit.default_timer()
manager_compute_func(*args, **kwargs)
elapsed = timeit.default_timer() - start_time
m, s = divmod(elapsed, 60)
print('Time taken: {0} min {1} sec'.format(
m, s))
_separator(80)
return compute_wrapper
def _separator(max_len):
print('=' * max_len)