diff --git a/poetry.lock b/poetry.lock index 78ccc0de..3a4959cf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -910,6 +910,22 @@ dev = ["autoflake (>=1.4.0,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "passlib[bcrypt] doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer (>=0.4.1,<0.5.0)"] test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==22.3.0)", "databases[sqlite] (>=0.3.2,<0.6.0)", "email_validator (>=1.1.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.14.0,<0.19.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "orjson (>=3.2.1,<4.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=6.2.4,<7.0.0)", "pytest-cov (>=2.12.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.6)", "requests (>=2.24.0,<3.0.0)", "sqlalchemy (>=1.3.18,<1.5.0)", "types-dataclasses (==0.6.5)", "types-orjson (==3.6.2)", "types-ujson (==4.2.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"] +[[package]] +name = "filelock" +version = "3.13.1" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, + {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +typing = ["typing-extensions (>=4.8)"] + [[package]] name = "frozenlist" version = "1.4.1" @@ -996,6 +1012,41 @@ files = [ {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, ] +[[package]] +name = "fsspec" +version = "2023.12.2" +description = "File-system specification" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fsspec-2023.12.2-py3-none-any.whl", hash = "sha256:d800d87f72189a745fa3d6b033b9dc4a34ad069f60ca60b943a63599f5501960"}, + {file = "fsspec-2023.12.2.tar.gz", hash = "sha256:8548d39e8810b59c38014934f6b31e57f40c1b20f911f4cc2b85389c7e9bf0cb"}, +] + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +devel = ["pytest", "pytest-cov"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +tqdm = ["tqdm"] + [[package]] name = "gitdb" version = "4.0.11" @@ -1156,6 +1207,38 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "huggingface-hub" +version = "0.20.3" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "huggingface_hub-0.20.3-py3-none-any.whl", hash = "sha256:d988ae4f00d3e307b0c80c6a05ca6dbb7edba8bba3079f74cda7d9c2e562a7b6"}, + {file = "huggingface_hub-0.20.3.tar.gz", hash = "sha256:94e7f8e074475fbc67d6a71957b678e1b4a74ff1b64a644fd6cbb83da962d05d"}, +] + +[package.dependencies] +filelock = "*" +fsspec = ">=2023.5.0" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] +quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["torch"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] + [[package]] name = "humanfriendly" version = "10.0" @@ -3680,4 +3763,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "bcc1fa16feb20ed991eec7dd21bc6705ed3cfabd66ba9ef4a5434f6820aff39c" +content-hash = "f40b2383ba105036dca33b317913ea6bff866d92ce3c7fff1ee28989df6d2060" diff --git a/pyproject.toml b/pyproject.toml index 4576cd93..49854d72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ mtdata="0.3.2" requests="2.26.0" humanize = "^4.9.0" blessed = "^1.20.0" +huggingface-hub = "^0.20.3" [tool.poetry.group.tests.dependencies] sacrebleu="2.0.0" diff --git a/utils/find_corpus.py b/utils/find_corpus.py index 6e4e9fc0..c40ae440 100755 --- a/utils/find_corpus.py +++ b/utils/find_corpus.py @@ -128,6 +128,183 @@ def get_sacrebleu(source: str, target: str): print_yaml(names) +def get_size(tags: list[str]) -> str: + size = next( + filter( + lambda tag: tag.startswith("size_categories:"), + tags, + ), + None, + ) + + if not size or size == "unknown": + return "" + + # Lowercase the text since it's not consistent. + return size.replace("size_categories:", "").lower() + + +def get_language_count(tags: list[str]): + count = 0 + for tag in tags: + if tag.startswith("language:"): + count = count + 1 + return count + + +HF_DATASET_SIZES = { + "": 0, + "unknown": 0, + "n<1k": 1, + "1k bool: + """Determines if a dataset is useful or not.""" + return "task_categories:automatic-speech-recognition" not in dataset.tags + + +def get_huggingface_any(language: str): + """ + Returns parallel datasets ordered by size. Datasets with few downloads are ignored + as they are probably low quality and not trustworthy. + """ + from huggingface_hub import DatasetFilter, HfApi + + api = HfApi() + + datasets = list( + api.list_datasets( + filter=DatasetFilter( + # + language=language, + ) + ) + ) + + datasets.sort(key=lambda dataset: -dataset.downloads) + datasets.sort(key=lambda dataset: -HF_DATASET_SIZES.get(get_size(dataset.tags), 0)) + + print("") + print("┌─────────────────────────────────────────────────────────────────────────────┐") + print(f"│ huggingface any data https://huggingface.co/datasets?language=language:{language}") + print("└─────────────────────────────────────────────────────────────────────────────┘") + print_table( + [ + ["ID", "Size", "Downloads"], + *[ + [ + # + f"https://huggingface.co/datasets/{dataset.id}", + get_size(dataset.tags), + dataset.downloads, + ] + for dataset in datasets + if is_useful_dataset(dataset) + ], + ] + ) + + def get_remote_file_size(url: str) -> Optional[int]: try: response = requests.head(url, timeout=1) @@ -230,17 +407,6 @@ def print_yaml(names: list[str], exclude: list[str] = []): print("\n".join(sorted([f" - {name}" for name in cleaned]))) -def run(source: str, target: str, importer: Optional[str]): - if importer == "opus" or not type: - get_opus(source, target) - - if importer == "sacrebleu" or not type: - get_sacrebleu(source, target) - - if importer == "mtdata" or not type: - get_mtdata(source, target) - - def print_table(table: list[list[any]]): """ Nicely print a table, the first row is the header @@ -268,6 +434,14 @@ def print_table(table: list[list[any]]): def main(args: Optional[list[str]] = None) -> None: + importers = [ + "opus", + "sacrebleu", + "mtdata", + "huggingface_mono", + "huggingface_parallel", + "huggingface_any", + ] parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawTextHelpFormatter, # Preserves whitespace in the help text. @@ -275,7 +449,9 @@ def main(args: Optional[list[str]] = None) -> None: parser.add_argument("source", type=str, nargs="?", help="Source language code") parser.add_argument("target", type=str, nargs="?", help="Target language code") parser.add_argument( - "--importer", type=str, help="The importer to use: mtdata, opus, sacrebleu" + "--importer", + type=str, + help=f"The importer to use: {', '.join(importers)}", ) parser.add_argument( "--download_url", @@ -290,7 +466,7 @@ def main(args: Optional[list[str]] = None) -> None: parser.print_help() sys.exit(1) - if args.importer and args.importer not in ["opus", "sacrebleu", "mtdata"]: + if args.importer and args.importer not in importers: print(f'"{args.importer}" is not a valid importer.') sys.exit(1) @@ -303,6 +479,15 @@ def main(args: Optional[list[str]] = None) -> None: if args.importer == "mtdata" or not args.importer: get_mtdata(args.source, args.target) + if args.importer == "huggingface_mono" or not args.importer: + get_huggingface_monolingual(args.target if args.source == "en" else args.source) + + if args.importer == "huggingface_parallel" or not args.importer: + get_huggingface_parallel(args.source, args.target) + + if args.importer == "huggingface_any" or not args.importer: + get_huggingface_any(args.target if args.source == "en" else args.source) + if __name__ == "__main__": main()