зеркало из https://github.com/microsoft/archai.git
fix(datasets): Only uses essential information for FastHfDatasetProvider fingerprint.
This commit is contained in:
Родитель
4dcfaf037c
Коммит
f631895dff
|
@ -3,6 +3,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import pickle
|
||||
import sys
|
||||
|
@ -49,7 +50,7 @@ class FastHfDatasetProvider(DatasetProvider):
|
|||
tokenizer_name: Optional[str] = None,
|
||||
tokenizer_max_length: Optional[int] = None,
|
||||
mapping_column_name: Optional[List[str]] = None,
|
||||
validation_split: Optional[float] = 0.1,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
seed: Optional[int] = 42,
|
||||
num_workers: Optional[int] = 1,
|
||||
use_eos_token: Optional[bool] = True,
|
||||
|
@ -127,6 +128,7 @@ class FastHfDatasetProvider(DatasetProvider):
|
|||
"mapping_column_name": self.mapping_column_name,
|
||||
"validation_split": self.validation_split,
|
||||
"seed": self.seed,
|
||||
"num_workers": self.num_workers,
|
||||
"use_eos_token": self.use_eos_token,
|
||||
"use_shared_memory": self.use_shared_memory,
|
||||
"cache_dir": self.cache_dir,
|
||||
|
@ -136,7 +138,13 @@ class FastHfDatasetProvider(DatasetProvider):
|
|||
def fingerprint(self) -> str:
|
||||
"""Return a unique fingerprint for the dataset provider."""
|
||||
|
||||
return sha1(repr(self.config).encode("ascii")).hexdigest()
|
||||
# Only use keys that affect the dataset for the fingerprint
|
||||
config = copy.deepcopy(self.config)
|
||||
config.pop("num_workers")
|
||||
config.pop("use_shared_memory")
|
||||
config.pop("cache_dir")
|
||||
|
||||
return sha1(repr(config).encode("ascii")).hexdigest()
|
||||
|
||||
def _encode_dataset(self) -> None:
|
||||
dtype = np.uint16 if self.tokenizer.vocab_size < 64 * 1024 else np.int32
|
||||
|
@ -151,8 +159,9 @@ class FastHfDatasetProvider(DatasetProvider):
|
|||
if "validation" not in raw_dataset:
|
||||
logger.info("Creating validation split ...")
|
||||
|
||||
validation_split = self.validation_split or 0.1
|
||||
tmp_dataset_dict = raw_dataset["train"].train_test_split(
|
||||
test_size=self.validation_split, shuffle=True, seed=self.seed
|
||||
test_size=validation_split, shuffle=True, seed=self.seed
|
||||
)
|
||||
raw_dataset["train"] = tmp_dataset_dict["train"]
|
||||
raw_dataset["validation"] = tmp_dataset_dict["test"]
|
||||
|
|
Загрузка…
Ссылка в новой задаче