Check value of metrics over a long time (#930)

Fixes #852

Also skip analysis on the standard deviation, as it's unstable, fixing #946
This commit is contained in:
Boris Feld 2019-09-27 12:40:36 +02:00 коммит произвёл Marco
Родитель 9fe82c1a6d
Коммит 099f9dba89
1 изменённых файлов: 112 добавлений и 62 удалений

Просмотреть файл

@ -12,10 +12,11 @@ import sys
from collections import defaultdict
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict
from typing import Any, Dict, Tuple
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
from pandas import DataFrame
LOGGER = logging.getLogger(__name__)
@ -23,7 +24,8 @@ logging.basicConfig(level=logging.INFO)
# By default, if the latest metric point is 5% lower than the previous one, show a warning and exit
# with 1.
WARNING_THRESHOLD = 0.95
RELATIVE_THRESHOLD = 0.95
ABSOLUTE_THRESHOLD = 0.1
REPORT_METRICS = ["accuracy", "precision", "recall"]
@ -31,33 +33,26 @@ REPORT_METRICS = ["accuracy", "precision", "recall"]
def plot_graph(
model_name: str,
metric_name: str,
values_dict: Dict[datetime, float],
df: DataFrame,
title: str,
output_directory: Path,
warning_threshold: float,
) -> bool:
sorted_metrics = sorted(values_dict.items())
x, y = zip(*sorted_metrics)
# Compute the threshold
if len(y) >= 2:
before_last_value = y[-2]
else:
before_last_value = y[-1]
metric_threshold = before_last_value * warning_threshold
file_path: str,
metric_threshold: float,
) -> None:
figure = plt.figure()
axes = plt.axes()
axes = df.plot(y="value")
# Formatting of the figure
figure.autofmt_xdate()
axes.fmt_xdata = mdates.DateFormatter("%Y-%m-%d-%H-%M")
axes.set_title(f"{model_name} {metric_name}")
axes.set_title(title)
# Display threshold
axes.axhline(y=metric_threshold, linestyle="--", color="red")
plt.annotate(
"{:.4f}".format(metric_threshold),
(x[-1], metric_threshold),
(df.index[-1], metric_threshold),
textcoords="offset points", # how to position the text
xytext=(-10, 10), # distance from text to points (x,y)
ha="center",
@ -65,7 +60,7 @@ def plot_graph(
)
# Display point values
for single_x, single_y in zip(x, y):
for single_x, single_y in zip(df.index, df.value):
label = "{:.4f}".format(single_y)
plt.annotate(
@ -76,20 +71,45 @@ def plot_graph(
ha="center",
)
axes.plot_date(x, y, marker=".", fmt="-")
output_file_path = output_directory.resolve() / f"{model_name}_{metric_name}.svg"
output_file_path = output_directory.resolve() / file_path
LOGGER.info("Saving %s figure", output_file_path)
plt.savefig(output_file_path)
plt.close(figure)
# Check if the threshold has been crossed
return y[-1] < metric_threshold
def parse_metric_file(metric_file_path: Path) -> Tuple[datetime, str, Dict[str, Any]]:
# Load the metric
with open(metric_file_path, "r") as metric_file:
metric = json.load(metric_file)
# Get the model, date and version from the file
# TODO: Might be better storing it in the file
file_path_parts = metric_file_path.stem.split("_")
assert file_path_parts[:5] == ["metric", "project", "relman", "bugbug", "train"]
model_name = file_path_parts[5]
assert file_path_parts[6:8] == ["per", "date"]
date_parts = list(map(int, file_path_parts[8:14]))
date = datetime(
date_parts[0],
date_parts[1],
date_parts[2],
date_parts[3],
date_parts[4],
date_parts[5],
tzinfo=timezone.utc,
)
# version = file_path_parts[14:] # TODO: Use version
return (date, model_name, metric)
def analyze_metrics(
metrics_directory: str, output_directory: str, warning_threshold: float
metrics_directory: str,
output_directory: str,
relative_threshold: float,
absolute_threshold: float,
):
root = Path(metrics_directory)
@ -97,32 +117,12 @@ def analyze_metrics(
lambda: defaultdict(dict)
)
threshold_ever_crossed = False
clean = True
# First process the metrics JSON files
for metric_file_path in root.glob("metric*.json"):
# Load the metric
with open(metric_file_path, "r") as metric_file:
metric = json.load(metric_file)
# Get the model, date and version from the file
# TODO: Might be better storing it in the file
file_path_parts = metric_file_path.stem.split("_")
assert file_path_parts[:5] == ["metric", "project", "relman", "bugbug", "train"]
model_name = file_path_parts[5]
assert file_path_parts[6:8] == ["per", "date"]
date_parts = list(map(int, file_path_parts[8:14]))
date = datetime(
date_parts[0],
date_parts[1],
date_parts[2],
date_parts[3],
date_parts[4],
date_parts[5],
tzinfo=timezone.utc,
)
# version = file_path_parts[14:] # TODO: Use version
date, model_name, metric = parse_metric_file(metric_file_path)
# Then process the report
for key, value in metric["report"]["average"].items():
@ -139,19 +139,46 @@ def analyze_metrics(
metrics[model_name][f"{key}_mean"][date] = value["mean"]
metrics[model_name][f"{key}_std"][date] = value["std"]
# Then analyze them
for model_name in metrics:
for metric_name, values in metrics[model_name].items():
threshold_crossed = plot_graph(
model_name,
metric_name,
values,
Path(output_directory),
warning_threshold,
)
diff = (1 - warning_threshold) * 100
if metric_name.endswith("_std"):
LOGGER.info(
"Skipping analysis of %r, analysis is not efficient on standard deviation",
metric_name,
)
continue
df = DataFrame.from_dict(values, orient="index", columns=["value"])
df = df.sort_index()
# Compute the absolute threshold for the metric
max_value = max(df["value"])
metric_threshold = max_value - absolute_threshold
threshold_crossed = df.value[-1] < metric_threshold
if threshold_crossed:
LOGGER.warning(
"Last metric %r for model %s is at least %f less than the max",
metric_name,
model_name,
ABSOLUTE_THRESHOLD,
)
clean = False
# Compute the relative threshold for the metric
if len(df["value"]) >= 2:
before_last_value = df["value"][-2]
else:
before_last_value = df["value"][-1]
relative_metric_threshold = before_last_value * relative_threshold
relative_threshold_crossed = df.value[-1] < relative_metric_threshold
if relative_threshold_crossed:
diff = (1 - relative_threshold) * 100
LOGGER.warning(
"Last metric %r for model %s is %f%% worse than the previous one",
metric_name,
@ -159,9 +186,23 @@ def analyze_metrics(
diff,
)
threshold_ever_crossed = threshold_ever_crossed or threshold_crossed
clean = False
if threshold_ever_crossed:
# Plot the non-smoothed graph
title = f"{model_name} {metric_name}"
file_path = f"{model_name}_{metric_name}.svg"
plot_graph(
model_name,
metric_name,
df,
title,
Path(output_directory),
file_path,
metric_threshold,
)
if not clean:
sys.exit(1)
@ -179,16 +220,25 @@ def main():
help="In which directory the script will save the generated graphs",
)
parser.add_argument(
"--warning_threshold",
default=WARNING_THRESHOLD,
"--relative_threshold",
default=RELATIVE_THRESHOLD,
type=float,
help="If the last metric value is below the previous one*warning_threshold, fails. Default to 0.95",
help="If the last metric value is below the previous_one * relative_threshold, fails. Default to 0.95",
)
parser.add_argument(
"--absolute_threshold",
default=ABSOLUTE_THRESHOLD,
type=float,
help="If the last metric value is below the max value - absolute_threshod, fails. Default to 0.1",
)
args = parser.parse_args()
analyze_metrics(
args.metrics_directory, args.output_directory, args.warning_threshold
args.metrics_directory,
args.output_directory,
args.relative_threshold,
args.absolute_threshold,
)