Group logs online evals (#708)
* Upgrade group_logs metrics from online evaluation tasks * Support incrementing group_logs metrics table * Use real run name * Remove useless indent & fixes * Nits * Support disabled publication through WANDB_PUBLICATION * Fix linting --------- Co-authored-by: Evgeny Pavlov <epavlov@mozilla.com>
This commit is contained in:
Родитель
747821a796
Коммит
d5b94fe422
|
@ -146,3 +146,4 @@ dmypy.json
|
|||
# Tracking
|
||||
output
|
||||
wandb
|
||||
media
|
||||
|
|
|
@ -55,8 +55,14 @@ from pipeline.common.logging import get_logger
|
|||
|
||||
logger = get_logger("eval")
|
||||
try:
|
||||
import wandb
|
||||
from translations_parser.publishers import METRIC_KEYS, WandB
|
||||
from translations_parser.utils import metric_from_tc_context
|
||||
from translations_parser.wandb import add_wandb_arguments, get_wandb_publisher
|
||||
from translations_parser.wandb import (
|
||||
add_wandb_arguments,
|
||||
get_wandb_publisher,
|
||||
list_existing_group_logs_metrics,
|
||||
)
|
||||
|
||||
WANDB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
|
@ -343,7 +349,11 @@ def main(args_list: Optional[list[str]] = None) -> None:
|
|||
file.write(f"{bleu_details['score']}\n" f"{chrf_details['score']}\n" f"{comet_score}\n")
|
||||
|
||||
if WANDB_AVAILABLE:
|
||||
wandb = get_wandb_publisher( # noqa
|
||||
metric = metric_from_tc_context(
|
||||
chrf=chrf_details["score"], bleu=bleu_details["score"], comet=comet_score
|
||||
)
|
||||
|
||||
run_client = get_wandb_publisher( # noqa
|
||||
project_name=args.wandb_project,
|
||||
group_name=args.wandb_group,
|
||||
run_name=args.wandb_run_name,
|
||||
|
@ -351,16 +361,52 @@ def main(args_list: Optional[list[str]] = None) -> None:
|
|||
artifacts=args.wandb_artifacts,
|
||||
publication=args.wandb_publication,
|
||||
)
|
||||
if wandb:
|
||||
logger.info("Initializing Weight & Biases client")
|
||||
# Allow publishing metrics as a table on existing runs (i.e. previous trainings)
|
||||
wandb.open(resume=True)
|
||||
logger.info(f"Publishing metrics to Weight & Biases ({wandb.extra_kwargs})")
|
||||
metric = metric_from_tc_context(
|
||||
chrf=chrf_details["score"], bleu=bleu_details["score"], comet=comet_score
|
||||
)
|
||||
wandb.handle_metrics(metrics=[metric])
|
||||
wandb.close()
|
||||
if run_client is None:
|
||||
# W&B publication may be direclty disabled through WANDB_PUBLICATION
|
||||
return
|
||||
|
||||
logger.info(f"Publishing metrics to Weight & Biases ({run_client.extra_kwargs})")
|
||||
run_client.open(resume=True)
|
||||
run_client.handle_metrics(metrics=[metric])
|
||||
run_client.close()
|
||||
|
||||
# Publish an extra row on the group_logs summary run
|
||||
group_logs_client = WandB( # noqa
|
||||
project=run_client.wandb.project,
|
||||
group=run_client.wandb.group,
|
||||
name="group_logs",
|
||||
)
|
||||
logger.info("Adding metric row to the 'group_logs' run")
|
||||
group_logs_client.open(resume=True)
|
||||
|
||||
# Restore existing metrics data
|
||||
data = list_existing_group_logs_metrics(group_logs_client.wandb)
|
||||
data.append(
|
||||
[
|
||||
run_client.wandb.group,
|
||||
run_client.wandb.name,
|
||||
metric.importer,
|
||||
metric.dataset,
|
||||
metric.augmentation,
|
||||
]
|
||||
+ [getattr(metric, attr) for attr in METRIC_KEYS]
|
||||
)
|
||||
group_logs_client.wandb.log(
|
||||
{
|
||||
"metrics": wandb.Table(
|
||||
columns=[
|
||||
"Group",
|
||||
"Model",
|
||||
"Importer",
|
||||
"Dataset",
|
||||
"Augmenation",
|
||||
*METRIC_KEYS,
|
||||
],
|
||||
data=data,
|
||||
)
|
||||
}
|
||||
)
|
||||
group_logs_client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import wandb
|
||||
|
||||
import taskcluster
|
||||
from translations_parser.parser import logger
|
||||
|
@ -157,3 +161,26 @@ def get_wandb_publisher(
|
|||
tags=tags,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def list_existing_group_logs_metrics(
|
||||
wandb_run: wandb.sdk.wandb_run.Run,
|
||||
) -> List[List[str | float]]:
|
||||
"""Retrieve the data from groups_logs metric table"""
|
||||
if wandb_run.resumed is False:
|
||||
return []
|
||||
logger.info(f"Retrieving existing group logs metrics from group_logs ({wandb_run.id})")
|
||||
api = wandb.Api()
|
||||
run = api.run(f"{wandb_run.project}/{wandb_run.id}")
|
||||
last = next(
|
||||
(
|
||||
artifact
|
||||
for artifact in list(run.files())[::-1]
|
||||
if artifact.name.startswith("media/table/metrics")
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not last:
|
||||
return []
|
||||
data = json.load(last.download(replace=True))
|
||||
return data.get("data", [])
|
||||
|
|
Загрузка…
Ссылка в новой задаче