* Upgrade group_logs metrics from online evaluation tasks

* Support incrementing group_logs metrics table

* Use real run name

* Remove useless indent & fixes

* Nits

* Support disabled publication through WANDB_PUBLICATION

* Fix linting

---------

Co-authored-by: Evgeny Pavlov <epavlov@mozilla.com>
This commit is contained in:
Valentin Rigal 2024-08-08 20:19:55 +02:00 коммит произвёл GitHub
Родитель 747821a796
Коммит d5b94fe422
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 86 добавлений и 12 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -146,3 +146,4 @@ dmypy.json
# Tracking
output
wandb
media

Просмотреть файл

@ -55,8 +55,14 @@ from pipeline.common.logging import get_logger
logger = get_logger("eval")
try:
import wandb
from translations_parser.publishers import METRIC_KEYS, WandB
from translations_parser.utils import metric_from_tc_context
from translations_parser.wandb import add_wandb_arguments, get_wandb_publisher
from translations_parser.wandb import (
add_wandb_arguments,
get_wandb_publisher,
list_existing_group_logs_metrics,
)
WANDB_AVAILABLE = True
except ImportError as e:
@ -343,7 +349,11 @@ def main(args_list: Optional[list[str]] = None) -> None:
file.write(f"{bleu_details['score']}\n" f"{chrf_details['score']}\n" f"{comet_score}\n")
if WANDB_AVAILABLE:
wandb = get_wandb_publisher( # noqa
metric = metric_from_tc_context(
chrf=chrf_details["score"], bleu=bleu_details["score"], comet=comet_score
)
run_client = get_wandb_publisher( # noqa
project_name=args.wandb_project,
group_name=args.wandb_group,
run_name=args.wandb_run_name,
@ -351,16 +361,52 @@ def main(args_list: Optional[list[str]] = None) -> None:
artifacts=args.wandb_artifacts,
publication=args.wandb_publication,
)
if wandb:
logger.info("Initializing Weight & Biases client")
# Allow publishing metrics as a table on existing runs (i.e. previous trainings)
wandb.open(resume=True)
logger.info(f"Publishing metrics to Weight & Biases ({wandb.extra_kwargs})")
metric = metric_from_tc_context(
chrf=chrf_details["score"], bleu=bleu_details["score"], comet=comet_score
)
wandb.handle_metrics(metrics=[metric])
wandb.close()
if run_client is None:
# W&B publication may be direclty disabled through WANDB_PUBLICATION
return
logger.info(f"Publishing metrics to Weight & Biases ({run_client.extra_kwargs})")
run_client.open(resume=True)
run_client.handle_metrics(metrics=[metric])
run_client.close()
# Publish an extra row on the group_logs summary run
group_logs_client = WandB( # noqa
project=run_client.wandb.project,
group=run_client.wandb.group,
name="group_logs",
)
logger.info("Adding metric row to the 'group_logs' run")
group_logs_client.open(resume=True)
# Restore existing metrics data
data = list_existing_group_logs_metrics(group_logs_client.wandb)
data.append(
[
run_client.wandb.group,
run_client.wandb.name,
metric.importer,
metric.dataset,
metric.augmentation,
]
+ [getattr(metric, attr) for attr in METRIC_KEYS]
)
group_logs_client.wandb.log(
{
"metrics": wandb.Table(
columns=[
"Group",
"Model",
"Importer",
"Dataset",
"Augmenation",
*METRIC_KEYS,
],
data=data,
)
}
)
group_logs_client.close()
if __name__ == "__main__":

Просмотреть файл

@ -1,5 +1,9 @@
import json
import os
from pathlib import Path
from typing import List
import wandb
import taskcluster
from translations_parser.parser import logger
@ -157,3 +161,26 @@ def get_wandb_publisher(
tags=tags,
config=config,
)
def list_existing_group_logs_metrics(
wandb_run: wandb.sdk.wandb_run.Run,
) -> List[List[str | float]]:
"""Retrieve the data from groups_logs metric table"""
if wandb_run.resumed is False:
return []
logger.info(f"Retrieving existing group logs metrics from group_logs ({wandb_run.id})")
api = wandb.Api()
run = api.run(f"{wandb_run.project}/{wandb_run.id}")
last = next(
(
artifact
for artifact in list(run.files())[::-1]
if artifact.name.startswith("media/table/metrics")
),
None,
)
if not last:
return []
data = json.load(last.download(replace=True))
return data.get("data", [])