110 строки
4.1 KiB
Python
110 строки
4.1 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
|
|
"""A base module for rule-related module."""
|
|
|
|
import re
|
|
|
|
from superbench.common.utils import logger
|
|
from superbench.analyzer import file_handler
|
|
|
|
|
|
class RuleBase():
|
|
"""RuleBase class."""
|
|
def __init__(self):
|
|
"""Init function."""
|
|
self._sb_rules = {}
|
|
self._benchmark_metrics_dict = {}
|
|
self._enable_metrics = set()
|
|
|
|
def _get_metrics_by_benchmarks(self, metrics_list):
|
|
"""Get mappings of benchmarks:metrics from metrics_list.
|
|
|
|
Args:
|
|
metrics_list (list): list of metrics
|
|
|
|
Returns:
|
|
dict: metrics organized by benchmarks
|
|
"""
|
|
benchmarks_metrics = {}
|
|
for metric in metrics_list:
|
|
if '/' not in metric:
|
|
logger.warning('RuleBase: get_metrics_by_benchmarks - {} does not have benchmark_name'.format(metric))
|
|
else:
|
|
benchmark = metric.split('/')[0]
|
|
# support annotations in benchmark naming
|
|
if ':' in benchmark:
|
|
benchmark = metric.split(':')[0]
|
|
if benchmark not in benchmarks_metrics:
|
|
benchmarks_metrics[benchmark] = set()
|
|
benchmarks_metrics[benchmark].add(metric)
|
|
return benchmarks_metrics
|
|
|
|
def _check_and_format_rules(self, rule, name):
|
|
"""Check the rule of the metric whether the format is valid.
|
|
|
|
Args:
|
|
rule (dict): the rule
|
|
name (str): the rule name
|
|
|
|
Returns:
|
|
dict: the rule for the metric
|
|
"""
|
|
# check if rule is supported
|
|
if 'categories' not in rule:
|
|
logger.log_and_raise(exception=Exception, msg='{} lack of category'.format(name))
|
|
if 'metrics' in rule:
|
|
if isinstance(rule['metrics'], str):
|
|
rule['metrics'] = [rule['metrics']]
|
|
return rule
|
|
|
|
def _get_metrics(self, rule, benchmark_rules):
|
|
"""Get metrics in the rule.
|
|
|
|
Parse metric regex in the rule, and store the (metric, -1) pair
|
|
in _sb_rules[rule]['metrics']
|
|
|
|
Args:
|
|
rule (str): the name of the rule
|
|
benchmark_rules (dict): the dict of rules
|
|
"""
|
|
metrics_in_rule = benchmark_rules[rule]['metrics']
|
|
benchmark_metrics_dict_in_rule = self._get_metrics_by_benchmarks(metrics_in_rule)
|
|
for benchmark_name in benchmark_metrics_dict_in_rule:
|
|
if benchmark_name not in self._benchmark_metrics_dict:
|
|
logger.warning('RuleBase: get metrics failed - {}'.format(benchmark_name))
|
|
continue
|
|
# get rules and criteria for each metric
|
|
for metric in self._benchmark_metrics_dict[benchmark_name]:
|
|
# metric full name in baseline
|
|
if metric in metrics_in_rule:
|
|
self._sb_rules[rule]['metrics'][metric] = -1
|
|
self._enable_metrics.add(metric)
|
|
continue
|
|
# metric full name not in baseline, use regex to match
|
|
for metric_regex in benchmark_metrics_dict_in_rule[benchmark_name]:
|
|
if re.search(metric_regex, metric):
|
|
self._sb_rules[rule]['metrics'][metric] = -1
|
|
self._enable_metrics.add(metric)
|
|
|
|
def _preprocess(self, raw_data_file, rule_file):
|
|
"""Preprocess/preparation operations for the rules.
|
|
|
|
Args:
|
|
raw_data_file (str): the path of raw data file
|
|
rule_file (str): the path of rule file
|
|
|
|
Returns:
|
|
dict: dict of rules
|
|
"""
|
|
# read raw data from file
|
|
self._raw_data_df = file_handler.read_raw_data(raw_data_file)
|
|
# re-organize metrics by benchmark names
|
|
self._benchmark_metrics_dict = self._get_metrics_by_benchmarks(list(self._raw_data_df.columns))
|
|
# check raw data whether empty
|
|
if len(self._raw_data_df) == 0:
|
|
logger.log_and_raise(exception=Exception, msg='RuleBase: empty raw data')
|
|
# read rules
|
|
rules = file_handler.read_rules(rule_file)
|
|
return rules
|