fix(datasets): Only uses essential information for FastHfDatasetProvider fingerprint.

This commit is contained in:
Gustavo Rosa 2023-02-17 09:48:28 -03:00
Родитель 4dcfaf037c
Коммит f631895dff
1 изменённых файлов: 12 добавлений и 3 удалений

Просмотреть файл

@ -3,6 +3,7 @@
from __future__ import annotations
import copy
import json
import pickle
import sys
@ -49,7 +50,7 @@ class FastHfDatasetProvider(DatasetProvider):
tokenizer_name: Optional[str] = None,
tokenizer_max_length: Optional[int] = None,
mapping_column_name: Optional[List[str]] = None,
validation_split: Optional[float] = 0.1,
validation_split: Optional[float] = 0.0,
seed: Optional[int] = 42,
num_workers: Optional[int] = 1,
use_eos_token: Optional[bool] = True,
@ -127,6 +128,7 @@ class FastHfDatasetProvider(DatasetProvider):
"mapping_column_name": self.mapping_column_name,
"validation_split": self.validation_split,
"seed": self.seed,
"num_workers": self.num_workers,
"use_eos_token": self.use_eos_token,
"use_shared_memory": self.use_shared_memory,
"cache_dir": self.cache_dir,
@ -136,7 +138,13 @@ class FastHfDatasetProvider(DatasetProvider):
def fingerprint(self) -> str:
"""Return a unique fingerprint for the dataset provider."""
return sha1(repr(self.config).encode("ascii")).hexdigest()
# Only use keys that affect the dataset for the fingerprint
config = copy.deepcopy(self.config)
config.pop("num_workers")
config.pop("use_shared_memory")
config.pop("cache_dir")
return sha1(repr(config).encode("ascii")).hexdigest()
def _encode_dataset(self) -> None:
dtype = np.uint16 if self.tokenizer.vocab_size < 64 * 1024 else np.int32
@ -151,8 +159,9 @@ class FastHfDatasetProvider(DatasetProvider):
if "validation" not in raw_dataset:
logger.info("Creating validation split ...")
validation_split = self.validation_split or 0.1
tmp_dataset_dict = raw_dataset["train"].train_test_split(
test_size=self.validation_split, shuffle=True, seed=self.seed
test_size=validation_split, shuffle=True, seed=self.seed
)
raw_dataset["train"] = tmp_dataset_dict["train"]
raw_dataset["validation"] = tmp_dataset_dict["test"]