firefox-translations-training/utils/config_generator.py

532 строки
19 KiB
Python

import argparse
import re
import subprocess
import sys
from io import StringIO
from pathlib import Path
from typing import Literal, Union
import ruamel.yaml
from pipeline.common.downloads import get_download_size, location_exists
from pipeline.data.importers.mono.hplt import language_has_hplt_support
from utils.find_corpus import (
fetch_mtdata,
fetch_news_crawl,
fetch_opus,
fetch_sacrebleu,
get_remote_file_size,
)
"""
Generate a training config for a language pair based on the latest production
training config, taskcluster/configs/config.prod.yml.
"""
root_dir = Path(__file__).parent.parent
prod_config_path = root_dir / "taskcluster/configs/config.prod.yml"
pretrained_student_models = {
("ru", "en"): "https://storage.googleapis.com/releng-translations-dev/models/ru-en/better-teacher/student"
} # fmt: skip
skip_datasets = [
# The NLLB dataset is based off of the CCMatrix dataset, and is mostly duplicated.
"CCMatrix",
# Skip Multi* datasets as they are generally multilingual versions of the original datasets.
"MultiMaCoCu",
"MultiHPLT",
# In Russian, the WikiTitles data had its direction reversed. The `LinguaTools-WikiTitles`
# version is fine.
"WikiTitles",
# This mtdata dataset fails in a task, and is a duplicate to OPUS.
"swedish_work_environment",
# Fails to load from mtdata.
"lithuanian_legislation_seimas_lithuania",
# Fails to load from OPUS.
"SPC",
]
# Do not include small datasets. This works around #508, and minimizes dataset tasks that
# won't bring a lot more data.
minimum_dataset_sentences = 200
# If a task name is too long, it will fail.
max_dataset_name_size = 80
flores_101_languages = {
"af", "amh", "ar", "as", "ast", "az", "be", "bn", "bs", "bg", "ca", "ceb", "cs", "ckb", "cy",
"da", "de", "el", "en", "et", "fa", "fi", "fr", "ful", "ga", "gl", "gu", "ha", "he", "hi",
"hr", "hu", "hy", "ig", "id", "is", "it", "jv", "ja", "kam", "kn", "ka", "kk", "kea", "km",
"ky", "ko", "lo", "lv", "ln", "lt", "lb", "lg", "luo", "ml", "mr", "mk", "mt", "mn", "mi",
"ms", "my", "nl", "nb", "npi", "nso", "ny", "oc", "om", "or", "pa", "pl", "pt", "pus", "ro",
"ru", "sk", "sl", "sna", "snd", "so", "es", "sr", "sv", "sw", "ta", "te", "tg", "tl", "th",
"tr", "uk", "umb", "ur", "uz", "vi", "wo", "xh", "yo", "zh", "zh", "zu"
} # fmt: skip
# mtdata points to raw downloads, and does some processing to normalize the data. This means
# that if we measure the download size, it may be inaccurate.
bad_mtdata_sizes = {
# These are stored in a big archive with train/test/dev. Keep "train" estimates as they are
# the largest, but ignore test/dev.
"tedtalks_test",
"tedtalks_dev",
}
def get_git_revision_hash(remote_branch: str) -> str:
"""
The git hash should be something that will always be around. Check the main branch for the
most common ancestor to the local changes. The prod config locally could be different than
remote, but it's better
"""
return (
subprocess.check_output(["git", "merge-base", remote_branch, "HEAD"])
.decode("ascii")
.strip()
)
def update_config(
prod_config: any, name: str, source: str, target: str, fast: bool
) -> dict[str, str]:
experiment = prod_config["experiment"]
# Update the prod config for this language pair.
experiment["name"] = name
experiment["src"] = source
experiment["trg"] = target
experiment["bicleaner"]["dataset-thresholds"] = {}
pretrained_model = pretrained_student_models.get((source, target))
if pretrained_model:
# Switch to the one stage teacher mode, as the higher quality backtranslations lead
# to issues with early stopping when switching between stages.
experiment["teacher-mode"] = "one-stage"
experiment["pretrained-models"]["train-backwards"]["urls"] = [pretrained_model]
else:
experiment["pretrained-models"] = {}
datasets = prod_config["datasets"]
# Clear out the base config.
datasets["train"].clear()
datasets["devtest"].clear()
datasets["test"].clear()
datasets["mono-src"].clear()
datasets["mono-trg"].clear()
# ruamel.yaml only supports inline comments. This dict will do string matching to apply
# comments too the top of a section.
comment_section = {}
add_train_data(source, target, datasets, comment_section, fast)
add_test_data(
source,
target,
datasets["test"],
datasets["devtest"],
comment_section,
)
add_mono_data(
source,
"src",
datasets,
experiment,
comment_section,
)
add_mono_data(
target,
"trg",
datasets,
experiment,
comment_section,
)
return comment_section
def add_train_data(
source: str, target: str, datasets: list[str], comment_section: dict[str, str], fast: bool
):
opus_datasets = fetch_opus(source, target)
total_sentences = 0
skipped_datasets = []
visited_corpora = set()
for dataset in opus_datasets:
sentences = dataset.alignment_pairs or 0
visited_corpora.add(normalize_corpus_name(dataset.corpus))
# Some datasets are ignored or too small to be included.
if dataset.corpus in skip_datasets:
skipped_datasets.append(
f"{dataset.corpus_key()} - ignored datasets ({sentences:,} sentences)"
)
continue
if (dataset.alignment_pairs or 0) < minimum_dataset_sentences:
skipped_datasets.append(
f"{dataset.corpus_key()} - not enough data ({sentences:,} sentences)"
)
continue
if len(dataset.corpus) > max_dataset_name_size:
skipped_datasets.append(f"{dataset.corpus_key()} - corpus name is too long for tasks")
continue
total_sentences += sentences
corpus_key = dataset.corpus_key()
datasets["train"].append(corpus_key)
datasets["train"].yaml_add_eol_comment(
f"{sentences:,} sentences".rjust(70 - len(corpus_key), " "),
len(datasets["train"]) - 1,
)
print("Fetching mtdata")
entries = fetch_mtdata(source, target)
for corpus_key, entry in entries.items():
# mtdata can have test and devtest data as well.
if entry.did.name.endswith("test"):
dataset = datasets["test"]
elif entry.did.name.endswith("dev"):
dataset = datasets["devtest"]
else:
dataset = datasets["train"]
corpus_name = normalize_corpus_name(entry.did.name)
group_corpus_name = normalize_corpus_name(entry.did.group + entry.did.name)
if corpus_name in visited_corpora or group_corpus_name in visited_corpora:
skipped_datasets.append(f"{corpus_key} - duplicate with opus")
continue
if entry.did.name in skip_datasets:
skipped_datasets.append(f"{entry.did.name} - ignored datasets")
continue
if len(entry.did.name) > max_dataset_name_size:
skipped_datasets.append(f"{entry.did.name} - corpus name is too long for tasks")
continue
if fast:
# Just add the dataset when in fast mode.
dataset.append(corpus_key)
else:
byte_size, display_size = get_remote_file_size(entry.url)
if byte_size is None:
# There was a network error, skip the dataset.
skipped_datasets.append(f"{corpus_key} - Error fetching ({entry.url})")
else:
# Don't add the sentences to the total_sentences, as mtdata is less reliable
# compared to opus.
sentences = estimate_sentence_size(byte_size)
dataset.append(corpus_key)
if byte_size:
dataset.yaml_add_eol_comment(
f"~{sentences:,} sentences ".rjust(70 - len(corpus_key), " ")
+ f"({display_size})",
len(datasets["train"]) - 1,
)
else:
dataset.yaml_add_eol_comment(
"No Content-Length reported ".rjust(70 - len(corpus_key), " ")
+ f"({entry.url})",
len(datasets["train"]) - 1,
)
comments = [
"The training data contains:",
f" {total_sentences:,} sentences",
]
if skipped_datasets:
comments.append("")
comments.append("Skipped datasets:")
for d in skipped_datasets:
comments.append(f" - {d}")
train_comment = "\n".join(comments)
comment_section[" train:"] = train_comment
def normalize_corpus_name(corpus_name: str):
"""Normalize the corpus name so that it's easy to deduplicate between opus and mtdata."""
# Remove the language tags at the end.
# mtdata_ELRC-vnk.fi-1-eng-fin
# ^^^^^^^^
corpus_name = re.sub(r"-\w{3}-\w{3}$", "", corpus_name)
corpus_name = corpus_name.lower()
# Remove numbers anything that is not a letter. This is a little aggressive, but should help
# deduplicate more datasets. For example:
# opus: 725-Hallituskausi_2011_2
# mtdata: hallituskausi_2011_2015-1-eng-fin
corpus_name = re.sub(r"[^a-z]", "", corpus_name.lower())
# Datasets could be split by train/test/dev. Remove the "train" word so that it will match
# between Opus and mtdata.
# opus: NeuLab-TedTalks/v1
# mtdata: Neulab-tedtalks_train-1-eng-fin
# mtdata: Neulab-tedtalks_test-1-eng-fin
# mtdata: Neulab-tedtalks_dev-1-eng-fin
corpus_name = re.sub(r"train$", "", corpus_name)
return corpus_name
def add_test_data(
source: str,
target: str,
test_datasets: list[str],
devtest_datasets: list[str],
comment_section: dict[str, str],
):
skipped_datasets = []
print("Fetching flores")
if source in flores_101_languages and target in flores_101_languages:
test_datasets.append("flores_devtest")
# Add augmented datasets to check performance for the specific cases
test_datasets.append("flores_aug-mix_devtest")
test_datasets.append("flores_aug-title_devtest")
test_datasets.append("flores_aug-upper_devtest")
test_datasets.append("flores_aug-typos_devtest")
test_datasets.append("flores_aug-noise_devtest")
test_datasets.append("flores_aug-inline-noise_devtest")
devtest_datasets.append("flores_aug-mix_dev")
is_test = True # Flip between devtest and test.
print("Fetching sacrebleu")
for d in fetch_sacrebleu(source, target):
# Work around: PLW2901 `for` loop variable `dataset_name` overwritten by assignment target
dataset_name = d
if dataset_name in skip_datasets:
# This could be a dataset with a variant design.
skipped_datasets.append(f"{dataset_name} - variant dataset")
elif len(dataset_name) > max_dataset_name_size:
skipped_datasets.append(f"{dataset_name} - corpus name is too long for tasks")
else:
dataset_name = dataset_name.replace("sacrebleu_", "")
if is_test:
test_datasets.append(f"sacrebleu_{dataset_name}")
else:
devtest_datasets.append(f"sacrebleu_aug-mix_{dataset_name}")
is_test = not is_test
if skipped_datasets:
test_comment = "\n".join(
[
"Skipped test/devtest datasets:",
*[f" - {d}" for d in skipped_datasets],
]
)
comment_section[" devtest:"] = test_comment
def estimate_sentence_size(bytes: int) -> int:
"""Estimate the sentences based on the compressed byte size"""
# One dataset measured 113 bytes per sentence, use that as a rough estimate.
bytes_per_sentence = 113
return bytes // bytes_per_sentence
def add_mono_data(
lang: str,
direction: Union[Literal["src"], Literal["trg"]],
datasets: dict[str, list[str]],
experiment: any,
comment_section: dict[str, str],
):
mono_datasets = datasets[f"mono-{direction}"]
max_per_dataset: int = experiment[f"mono-max-sentences-{direction}"]["per-dataset"]
def add_comment(dataset_name: str, comment: str):
"""Add a right justified comment to a dataset."""
mono_datasets.yaml_add_eol_comment(
comment.rjust(50 - len(dataset_name), " "),
len(mono_datasets) - 1,
)
extra_comments: list[str] = []
skipped_datasets = []
print("Fetching newscrawl for", lang)
sentence_count = 0
for dataset in fetch_news_crawl(lang):
mono_datasets.append(dataset.name)
if dataset.size:
sentences = estimate_sentence_size(dataset.size)
sentence_count += sentences
add_comment(dataset.name, f"~{sentences:,} sentences")
print("Fetching HPLT mono for", lang)
if language_has_hplt_support(lang):
dataset_name = "hplt_mono/v1.2"
mono_datasets.append(dataset_name)
add_comment(dataset_name, f"Up to {max_per_dataset:,} sentences")
extra_comments.append(f" Up to {max_per_dataset:,} sentences from HPLT")
print("Fetching NLLB mono for", lang)
opus_nllb_url = f"https://object.pouta.csc.fi/OPUS-NLLB/v1/mono/{lang}.txt.gz"
if location_exists(opus_nllb_url):
dataset_name = "opus_NLLB/v1"
lines_num = estimate_sentence_size(get_download_size(opus_nllb_url))
if direction == "trg":
skipped_datasets.append(
f"{dataset_name} - data may have lower quality, disable for back-translations ({lines_num:,} sentences)"
)
else:
mono_datasets.append(dataset_name)
sentence_count += lines_num
add_comment(dataset_name, f"~{lines_num:,} sentences")
skipped_datasets_final = []
if skipped_datasets:
skipped_datasets_final.append("")
skipped_datasets_final.append("Skipped datasets:")
for d in skipped_datasets:
skipped_datasets_final.append(f" - {d}")
comment = "\n".join(
[
"The monolingual data contains:",
f" ~{sentence_count:,} sentences",
# Append any additional information.
*extra_comments,
*skipped_datasets_final,
]
)
comment_section[f" mono-{direction}:"] = comment
def strip_comments(yaml_text: str) -> list[str]:
"""
ruamel.yaml preserves key ordering and comments. This function strips out the comments
"""
result = ""
for l in yaml_text.splitlines():
# Work around: PLW2901 `for` loop variable `line` overwritten by assignment target
line = l
if line.strip().startswith("#"):
continue
# Remove any comments at the end.
line = re.sub(r"#[\s\w\-.]*$", "", line)
# Don't add any empty lines.
if line.strip():
result += line.rstrip() + "\n"
return result
def apply_comments_to_yaml_string(yaml, prod_config, comment_section, remote_branch: str) -> str:
"""
ruamel.yaml only supports inline comments, so do direct string manipulation to apply
all the comments needed.
"""
# Dump out the yaml to a string so that it can be manipulated.
output_stream = StringIO()
yaml.dump(prod_config, output_stream)
yaml_string: str = output_stream.getvalue()
yaml_string = apply_comment_section(comment_section, yaml_string)
script_args = " ".join(sys.argv[1:])
return "\n".join(
[
"# The initial configuration was generated using:",
f"# task config-generator -- {script_args}",
"#",
"# The documentation for this config can be found here:",
f"# https://github.com/mozilla/firefox-translations-training/blob/{get_git_revision_hash(remote_branch)}/taskcluster/configs/config.prod.yml",
yaml_string,
]
)
def apply_comment_section(comment_section: dict[str, str], yaml_string: str) -> str:
for key, raw_comment in comment_section.items():
# Find the indent amount for the key.
match = re.search(r"^(?P<indent>\s*)", key)
if not match:
raise Exception("Could not find regex match")
indent = match.group("indent")
# Indent the lines, and add the # comment.
comment = "\n".join([f"{indent}# {line}" for line in raw_comment.splitlines()])
yaml_string = yaml_string.replace(f"\n{key}", f"\n\n{comment}\n{key}")
return yaml_string
def main() -> None:
parser = argparse.ArgumentParser(
description=__doc__,
# Preserves whitespace in the help text.
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("source", metavar="SOURCE", type=str, help="The source language tag")
parser.add_argument("target", metavar="TARGET", type=str, help="The target language tag")
parser.add_argument(
"--name",
metavar="name",
type=str,
required=True,
help="The name of the config, which gets constructed like so: configs/autogenerated/{source}-{target}-{name}.yml",
)
parser.add_argument(
"--remote_branch",
metavar="REF",
type=str,
default="origin/main",
help="The remote branch that contains the config.prod.yml. Typically origin/main, or origin/release",
)
parser.add_argument(
"--fast",
action="store_true",
help="Skip slow network requests like looking up dataset size",
)
args = parser.parse_args()
# Validate the inputs.
langtag_re = r"[a-z]{2,3}"
if not re.fullmatch(langtag_re, args.source):
print("The source language should be a 2 or 3 letter lang tag.")
if not re.fullmatch(langtag_re, args.target):
print("The target language should be a 2 or 3 letter lang tag.")
if not re.fullmatch(r"[\w\d-]+", args.name):
print(
"The name of the training config should only contain alphanumeric, underscores, and dashes.",
file=sys.stderr,
)
sys.exit(1)
# ruamel.yaml preserves comments and ordering unlink PyYAML
yaml = ruamel.yaml.YAML()
# Load the prod yaml.
with prod_config_path.open() as f:
yaml_string = f.read()
yaml_string = strip_comments(yaml_string)
prod_config = yaml.load(StringIO(yaml_string))
comment_section = update_config(prod_config, args.name, args.source, args.target, args.fast)
final_config = apply_comments_to_yaml_string(
yaml, prod_config, comment_section, args.remote_branch
)
final_config_path = (
root_dir / "configs/autogenerated" / f"{args.source}-{args.target}-{args.name}.yml"
)
print("Writing config to:", str(final_config_path))
final_config_path.write_text(final_config)
if __name__ == "__main__":
main()