Analyzer: Revise - Abstract RuleBase from DataDiagnosis (#321)
**Description** Abstract RuleBase from DataDiagnosis.
This commit is contained in:
Родитель
9759527111
Коммит
1ec055e1c2
|
@ -3,7 +3,8 @@
|
|||
|
||||
"""Exposes interfaces of SuperBench Analyzer."""
|
||||
|
||||
from superbench.analyzer.rule_base import RuleBase
|
||||
from superbench.analyzer.data_diagnosis import DataDiagnosis
|
||||
from superbench.analyzer.diagnosis_rule_op import RuleOp, DiagnosisRuleType
|
||||
|
||||
__all__ = ['DataDiagnosis', 'DiagnosisRuleType', 'RuleOp']
|
||||
__all__ = ['DataDiagnosis', 'DiagnosisRuleType', 'RuleOp', 'RuleBase']
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
"""A module for baseline-based data diagnosis."""
|
||||
|
||||
import re
|
||||
from typing import Callable
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -12,38 +10,16 @@ import pandas as pd
|
|||
from superbench.common.utils import logger
|
||||
from superbench.analyzer.diagnosis_rule_op import RuleOp, DiagnosisRuleType
|
||||
from superbench.analyzer import file_handler
|
||||
from superbench.analyzer import RuleBase
|
||||
|
||||
|
||||
class DataDiagnosis():
|
||||
class DataDiagnosis(RuleBase):
|
||||
"""The DataDiagnosis class to do the baseline-based data diagnosis."""
|
||||
def __init__(self):
|
||||
"""Init function."""
|
||||
self._sb_rules = {}
|
||||
self._benchmark_metrics_dict = {}
|
||||
super().__init__()
|
||||
|
||||
def _get_metrics_by_benchmarks(self, metrics_list):
|
||||
"""Get mappings of benchmarks:metrics of metrics_list.
|
||||
|
||||
Args:
|
||||
metrics_list (list): list of metrics
|
||||
|
||||
Returns:
|
||||
dict: metrics organized by benchmarks
|
||||
"""
|
||||
benchmarks_metrics = {}
|
||||
for metric in metrics_list:
|
||||
if '/' not in metric:
|
||||
logger.warning(
|
||||
'DataDiagnosis: get_metrics_by_benchmarks - {} does not have benchmark_name'.format(metric)
|
||||
)
|
||||
else:
|
||||
benchmark = metric.split('/')[0]
|
||||
if benchmark not in benchmarks_metrics:
|
||||
benchmarks_metrics[benchmark] = set()
|
||||
benchmarks_metrics[benchmark].add(metric)
|
||||
return benchmarks_metrics
|
||||
|
||||
def _check_rules(self, rule, name):
|
||||
def _check_and_format_rules(self, rule, name):
|
||||
"""Check the rule of the metric whether the formart is valid.
|
||||
|
||||
Args:
|
||||
|
@ -54,6 +30,7 @@ class DataDiagnosis():
|
|||
dict: the rule for the metric
|
||||
"""
|
||||
# check if rule is supported
|
||||
super()._check_and_format_rules(rule, name)
|
||||
if 'function' not in rule:
|
||||
logger.log_and_raise(exception=Exception, msg='{} lack of function'.format(name))
|
||||
if not isinstance(DiagnosisRuleType(rule['function']), DiagnosisRuleType):
|
||||
|
@ -63,13 +40,9 @@ class DataDiagnosis():
|
|||
logger.log_and_raise(exception=Exception, msg='{} lack of criteria'.format(name))
|
||||
if not isinstance(eval(rule['criteria']), Callable):
|
||||
logger.log_and_raise(exception=Exception, msg='invalid criteria format')
|
||||
if 'categories' not in rule:
|
||||
logger.log_and_raise(exception=Exception, msg='{} lack of category'.format(name))
|
||||
if rule['function'] != 'multi_rules':
|
||||
if 'metrics' not in rule:
|
||||
logger.log_and_raise(exception=Exception, msg='{} lack of metrics'.format(name))
|
||||
if isinstance(rule['metrics'], str):
|
||||
rule['metrics'] = [rule['metrics']]
|
||||
if 'store' in rule and not isinstance(rule['store'], bool):
|
||||
logger.log_and_raise(exception=Exception, msg='{} store must be bool type'.format(name))
|
||||
return rule
|
||||
|
@ -107,26 +80,11 @@ class DataDiagnosis():
|
|||
benchmark_rules (dict): the dict of rules
|
||||
baseline (dict): the dict of baseline of metrics
|
||||
"""
|
||||
if self._sb_rules[rule]['function'] == 'multi_rules':
|
||||
if 'function' in self._sb_rules[rule] and self._sb_rules[rule]['function'] == 'multi_rules':
|
||||
return
|
||||
metrics_in_rule = benchmark_rules[rule]['metrics']
|
||||
benchmark_metrics_dict_in_rule = self._get_metrics_by_benchmarks(metrics_in_rule)
|
||||
for benchmark_name in benchmark_metrics_dict_in_rule:
|
||||
if benchmark_name not in self._benchmark_metrics_dict:
|
||||
logger.warning('DataDiagnosis: get criteria failed - {}'.format(benchmark_name))
|
||||
continue
|
||||
# get rules and criteria for each metric
|
||||
for metric in self._benchmark_metrics_dict[benchmark_name]:
|
||||
# metric full name in baseline
|
||||
if metric in metrics_in_rule:
|
||||
self._sb_rules[rule]['metrics'][metric] = self._get_baseline_of_metric(baseline, metric)
|
||||
self._enable_metrics.add(metric)
|
||||
continue
|
||||
# metric full name not in baseline, use regex to match
|
||||
for metric_regex in benchmark_metrics_dict_in_rule[benchmark_name]:
|
||||
if re.search(metric_regex, metric):
|
||||
self._sb_rules[rule]['metrics'][metric] = self._get_baseline_of_metric(baseline, metric)
|
||||
self._enable_metrics.add(metric)
|
||||
self._get_metrics(rule, benchmark_rules)
|
||||
for metric in self._sb_rules[rule]['metrics']:
|
||||
self._sb_rules[rule]['metrics'][metric] = self._get_baseline_of_metric(baseline, metric)
|
||||
|
||||
def _parse_rules_and_baseline(self, rules, baseline):
|
||||
"""Parse and merge rules and baseline read from file.
|
||||
|
@ -146,7 +104,7 @@ class DataDiagnosis():
|
|||
self._enable_metrics = set()
|
||||
benchmark_rules = rules['superbench']['rules']
|
||||
for rule in benchmark_rules:
|
||||
benchmark_rules[rule] = self._check_rules(benchmark_rules[rule], rule)
|
||||
benchmark_rules[rule] = self._check_and_format_rules(benchmark_rules[rule], rule)
|
||||
self._sb_rules[rule] = {}
|
||||
self._sb_rules[rule]['name'] = rule
|
||||
self._sb_rules[rule]['function'] = benchmark_rules[rule]['function']
|
||||
|
@ -209,16 +167,16 @@ class DataDiagnosis():
|
|||
|
||||
return None, None
|
||||
|
||||
def run_diagnosis_rules(self, rule_file, baseline_file):
|
||||
def run_diagnosis_rules(self, rules, baseline):
|
||||
"""Rule-based data diagnosis for multiple nodes' raw data.
|
||||
|
||||
Use the rules defined in rule_file to diagnose the raw data of each node,
|
||||
Use the rules defined in rules to diagnose the raw data of each node,
|
||||
if the node violate any rule, label as defective node and save
|
||||
the 'Category', 'Defective Details' and processed data of defective node.
|
||||
|
||||
Args:
|
||||
rule_file (str): The path of rule yaml file
|
||||
baseline_file (str): The path of baseline json file
|
||||
rules (dict): rules from rule yaml file
|
||||
baseline (dict): baseline of metrics from baseline json file
|
||||
|
||||
Returns:
|
||||
data_not_accept_df (DataFrame): defective nodes's detailed information
|
||||
|
@ -229,13 +187,6 @@ class DataDiagnosis():
|
|||
data_not_accept_df = pd.DataFrame(columns=summary_columns)
|
||||
summary_details_df = pd.DataFrame()
|
||||
label_df = pd.DataFrame(columns=['label'])
|
||||
# check raw data whether empty
|
||||
if len(self._raw_data_df) == 0:
|
||||
logger.error('DataDiagnosis: empty raw data')
|
||||
return data_not_accept_df, label_df
|
||||
# get criteria
|
||||
rules = file_handler.read_rules(rule_file)
|
||||
baseline = file_handler.read_baseline(baseline_file)
|
||||
if not self._parse_rules_and_baseline(rules, baseline):
|
||||
return data_not_accept_df, label_df
|
||||
# run diagnosis rules for each node
|
||||
|
@ -267,10 +218,11 @@ class DataDiagnosis():
|
|||
output_format (str): the format of the output, 'excel' or 'json'
|
||||
"""
|
||||
try:
|
||||
self._raw_data_df = file_handler.read_raw_data(raw_data_file)
|
||||
self._benchmark_metrics_dict = self._get_metrics_by_benchmarks(list(self._raw_data_df.columns))
|
||||
rules = self._preprocess(raw_data_file, rule_file)
|
||||
# read baseline
|
||||
baseline = file_handler.read_baseline(baseline_file)
|
||||
logger.info('DataDiagnosis: Begin to process {} nodes'.format(len(self._raw_data_df)))
|
||||
data_not_accept_df, label_df = self.run_diagnosis_rules(rule_file, baseline_file)
|
||||
data_not_accept_df, label_df = self.run_diagnosis_rules(rules, baseline)
|
||||
logger.info('DataDiagnosis: Processed finished')
|
||||
output_path = ''
|
||||
if output_format == 'excel':
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""A base module for rule-related module."""
|
||||
|
||||
import re
|
||||
|
||||
from superbench.common.utils import logger
|
||||
from superbench.analyzer import file_handler
|
||||
|
||||
|
||||
class RuleBase():
|
||||
"""RuleBase class."""
|
||||
def __init__(self):
|
||||
"""Init function."""
|
||||
self._sb_rules = {}
|
||||
self._benchmark_metrics_dict = {}
|
||||
self._enable_metrics = set()
|
||||
|
||||
def _get_metrics_by_benchmarks(self, metrics_list):
|
||||
"""Get mappings of benchmarks:metrics from metrics_list.
|
||||
|
||||
Args:
|
||||
metrics_list (list): list of metrics
|
||||
|
||||
Returns:
|
||||
dict: metrics organized by benchmarks
|
||||
"""
|
||||
benchmarks_metrics = {}
|
||||
for metric in metrics_list:
|
||||
if '/' not in metric:
|
||||
logger.warning('RuleBase: get_metrics_by_benchmarks - {} does not have benchmark_name'.format(metric))
|
||||
else:
|
||||
benchmark = metric.split('/')[0]
|
||||
if benchmark not in benchmarks_metrics:
|
||||
benchmarks_metrics[benchmark] = set()
|
||||
benchmarks_metrics[benchmark].add(metric)
|
||||
return benchmarks_metrics
|
||||
|
||||
def _check_and_format_rules(self, rule, name):
|
||||
"""Check the rule of the metric whether the format is valid.
|
||||
|
||||
Args:
|
||||
rule (dict): the rule
|
||||
name (str): the rule name
|
||||
|
||||
Returns:
|
||||
dict: the rule for the metric
|
||||
"""
|
||||
# check if rule is supported
|
||||
if 'categories' not in rule:
|
||||
logger.log_and_raise(exception=Exception, msg='{} lack of category'.format(name))
|
||||
if 'metrics' in rule:
|
||||
if isinstance(rule['metrics'], str):
|
||||
rule['metrics'] = [rule['metrics']]
|
||||
return rule
|
||||
|
||||
def _get_metrics(self, rule, benchmark_rules):
|
||||
"""Get metrics in the rule.
|
||||
|
||||
Parse metric regex in the rule, and store the (metric, -1) pair
|
||||
in _sb_rules[rule]['metrics']
|
||||
|
||||
Args:
|
||||
rule (str): the name of the rule
|
||||
benchmark_rules (dict): the dict of rules
|
||||
"""
|
||||
metrics_in_rule = benchmark_rules[rule]['metrics']
|
||||
benchmark_metrics_dict_in_rule = self._get_metrics_by_benchmarks(metrics_in_rule)
|
||||
for benchmark_name in benchmark_metrics_dict_in_rule:
|
||||
if benchmark_name not in self._benchmark_metrics_dict:
|
||||
logger.warning('RuleBase: get metrics failed - {}'.format(benchmark_name))
|
||||
continue
|
||||
# get rules and criteria for each metric
|
||||
for metric in self._benchmark_metrics_dict[benchmark_name]:
|
||||
# metric full name in baseline
|
||||
if metric in metrics_in_rule:
|
||||
self._sb_rules[rule]['metrics'][metric] = -1
|
||||
self._enable_metrics.add(metric)
|
||||
continue
|
||||
# metric full name not in baseline, use regex to match
|
||||
for metric_regex in benchmark_metrics_dict_in_rule[benchmark_name]:
|
||||
if re.search(metric_regex, metric):
|
||||
self._sb_rules[rule]['metrics'][metric] = -1
|
||||
self._enable_metrics.add(metric)
|
||||
|
||||
def _preprocess(self, raw_data_file, rule_file):
|
||||
"""Preprocess/preparation operations for the rules.
|
||||
|
||||
Args:
|
||||
raw_data_file (str): the path of raw data file
|
||||
rule_file (str): the path of rule file
|
||||
|
||||
Returns:
|
||||
dict: dict of rules
|
||||
"""
|
||||
# read raw data from file
|
||||
self._raw_data_df = file_handler.read_raw_data(raw_data_file)
|
||||
# re-organize metrics by benchmark names
|
||||
self._benchmark_metrics_dict = self._get_metrics_by_benchmarks(list(self._raw_data_df.columns))
|
||||
# check raw data whether empty
|
||||
if len(self._raw_data_df) == 0:
|
||||
logger.error('RuleBase: empty raw data')
|
||||
return None
|
||||
# read rules
|
||||
rules = file_handler.read_rules(rule_file)
|
||||
return rules
|
|
@ -64,7 +64,7 @@ class TestDataDiagnosis(unittest.TestCase):
|
|||
assert (not rules)
|
||||
rules = file_handler.read_rules(test_rule_file)
|
||||
assert (rules)
|
||||
# Test - _check_rules
|
||||
# Test - _check_and_format_rules
|
||||
# Negative case
|
||||
false_rules = [
|
||||
{
|
||||
|
@ -97,7 +97,7 @@ class TestDataDiagnosis(unittest.TestCase):
|
|||
]
|
||||
metric = 'kernel-launch/event_overhead:0'
|
||||
for rules in false_rules:
|
||||
self.assertRaises(Exception, diag1._check_rules, rules, metric)
|
||||
self.assertRaises(Exception, diag1._check_and_format_rules, rules, metric)
|
||||
# Positive case
|
||||
true_rules = [
|
||||
{
|
||||
|
@ -118,7 +118,7 @@ class TestDataDiagnosis(unittest.TestCase):
|
|||
}
|
||||
]
|
||||
for rules in true_rules:
|
||||
assert (diag1._check_rules(rules, metric))
|
||||
assert (diag1._check_and_format_rules(rules, metric))
|
||||
# Test - _get_baseline_of_metric
|
||||
baseline = file_handler.read_baseline(test_baseline_file)
|
||||
assert (diag1._get_baseline_of_metric(baseline, 'kernel-launch/event_overhead:0') == 0.00596)
|
||||
|
@ -148,7 +148,8 @@ class TestDataDiagnosis(unittest.TestCase):
|
|||
(details_row, summary_data_row) = diag1._run_diagnosis_rules_for_single_node('sb-validation-02')
|
||||
assert (not details_row)
|
||||
# Test - _run_diagnosis_rules
|
||||
data_not_accept_df, label_df = diag1.run_diagnosis_rules(test_rule_file, test_baseline_file)
|
||||
baseline = file_handler.read_baseline(test_baseline_file)
|
||||
data_not_accept_df, label_df = diag1.run_diagnosis_rules(rules, baseline)
|
||||
assert (len(label_df) == 3)
|
||||
assert (label_df.loc['sb-validation-01']['label'] == 1)
|
||||
assert (label_df.loc['sb-validation-02']['label'] == 0)
|
||||
|
@ -204,7 +205,7 @@ class TestDataDiagnosis(unittest.TestCase):
|
|||
data_not_accept_read_from_excel = excel_file.parse(data_sheet_name)
|
||||
expect_result_file = pd.ExcelFile(str(self.parent_path / '../data/diagnosis_summary.xlsx'), engine='openpyxl')
|
||||
expect_result = expect_result_file.parse(data_sheet_name)
|
||||
pd.util.testing.assert_frame_equal(data_not_accept_read_from_excel, expect_result)
|
||||
pd.testing.assert_frame_equal(data_not_accept_read_from_excel, expect_result)
|
||||
# Test - output in json
|
||||
DataDiagnosis().run(test_raw_data, test_rule_file, test_baseline_file, str(self.parent_path), 'json')
|
||||
assert (Path(self.output_json_file).is_file())
|
||||
|
@ -218,7 +219,7 @@ class TestDataDiagnosis(unittest.TestCase):
|
|||
def test_mutli_rules(self):
|
||||
"""Test multi rules check feature."""
|
||||
diag1 = DataDiagnosis()
|
||||
# test _check_rules
|
||||
# test _check_and_format_rules
|
||||
false_rules = [
|
||||
{
|
||||
'criteria': 'lambda x:x>0',
|
||||
|
@ -229,7 +230,7 @@ class TestDataDiagnosis(unittest.TestCase):
|
|||
]
|
||||
metric = 'kernel-launch/event_overhead:0'
|
||||
for rules in false_rules:
|
||||
self.assertRaises(Exception, diag1._check_rules, rules, metric)
|
||||
self.assertRaises(Exception, diag1._check_and_format_rules, rules, metric)
|
||||
# Positive case
|
||||
true_rules = [
|
||||
{
|
||||
|
@ -245,7 +246,7 @@ class TestDataDiagnosis(unittest.TestCase):
|
|||
}
|
||||
]
|
||||
for rules in true_rules:
|
||||
assert (diag1._check_rules(rules, metric))
|
||||
assert (diag1._check_and_format_rules(rules, metric))
|
||||
# test _run_diagnosis_rules_for_single_node
|
||||
rules = {
|
||||
'superbench': {
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Tests for RuleBase module."""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from superbench.analyzer import RuleBase
|
||||
import superbench.analyzer.file_handler as file_handler
|
||||
|
||||
|
||||
class TestRuleBase(unittest.TestCase):
|
||||
"""Test for RuleBase class."""
|
||||
def setUp(self):
|
||||
"""Method called to prepare the test fixture."""
|
||||
self.parent_path = Path(__file__).parent
|
||||
|
||||
def test_rule_base(self):
|
||||
"""Test for rule-based functions."""
|
||||
# Test - read_raw_data and get_metrics_from_raw_data
|
||||
# Positive case
|
||||
test_raw_data = str(self.parent_path / 'test_results.jsonl')
|
||||
test_rule_file = str(self.parent_path / 'test_rules.yaml')
|
||||
rulebase1 = RuleBase()
|
||||
rulebase1._raw_data_df = file_handler.read_raw_data(test_raw_data)
|
||||
rulebase1._benchmark_metrics_dict = rulebase1._get_metrics_by_benchmarks(list(rulebase1._raw_data_df))
|
||||
assert (len(rulebase1._raw_data_df) == 3)
|
||||
# Negative case
|
||||
test_rule_file_fake = str(self.parent_path / 'test_rules_fake.yaml')
|
||||
test_raw_data_fake = str(self.parent_path / 'test_results_fake.jsonl')
|
||||
rulebase2 = RuleBase()
|
||||
rulebase2._raw_data_df = file_handler.read_raw_data(test_raw_data_fake)
|
||||
rulebase2._benchmark_metrics_dict = rulebase2._get_metrics_by_benchmarks(list(rulebase2._raw_data_df))
|
||||
assert (len(rulebase2._raw_data_df) == 0)
|
||||
assert (len(rulebase2._benchmark_metrics_dict) == 0)
|
||||
metric_list = [
|
||||
'gpu_temperature', 'gpu_power_limit', 'gemm-flops/FP64',
|
||||
'bert_models/pytorch-bert-base/steptime_train_float32'
|
||||
]
|
||||
self.assertDictEqual(
|
||||
rulebase2._get_metrics_by_benchmarks(metric_list), {
|
||||
'gemm-flops': {'gemm-flops/FP64'},
|
||||
'bert_models': {'bert_models/pytorch-bert-base/steptime_train_float32'}
|
||||
}
|
||||
)
|
||||
|
||||
# Test - _preprocess
|
||||
rules = rulebase1._preprocess(test_raw_data_fake, test_rule_file)
|
||||
assert (not rules)
|
||||
rules = rulebase1._preprocess(test_raw_data, test_rule_file_fake)
|
||||
assert (not rules)
|
||||
rules = rulebase1._preprocess(test_raw_data, test_rule_file)
|
||||
assert (rules)
|
||||
|
||||
# Test - _check_and_format_rules
|
||||
# Negative case
|
||||
false_rule = {
|
||||
'criteria': 'lambda x:x>0',
|
||||
'function': 'variance',
|
||||
'metrics': ['kernel-launch/event_overhead:\\d+']
|
||||
}
|
||||
metric = 'kernel-launch/event_overhead:0'
|
||||
self.assertRaises(Exception, rulebase1._check_and_format_rules, false_rule, metric)
|
||||
# Positive case
|
||||
true_rule = {
|
||||
'categories': 'KernelLaunch',
|
||||
'criteria': 'lambda x:x<-0.05',
|
||||
'function': 'variance',
|
||||
'metrics': 'kernel-launch/event_overhead:\\d+'
|
||||
}
|
||||
true_rule = rulebase1._check_and_format_rules(true_rule, metric)
|
||||
assert (true_rule)
|
||||
assert (true_rule['metrics'] == ['kernel-launch/event_overhead:\\d+'])
|
||||
|
||||
# Test - _get_metrics
|
||||
rules = rules['superbench']['rules']
|
||||
for rule in ['rule0', 'rule1']:
|
||||
rulebase1._sb_rules[rule] = {}
|
||||
rulebase1._sb_rules[rule]['metrics'] = {}
|
||||
rulebase1._get_metrics(rule, rules)
|
||||
assert (len(rulebase1._sb_rules[rule]['metrics']) == 16)
|
Загрузка…
Ссылка в новой задаче