зеркало из https://github.com/microsoft/archai.git
fix(archai): Fixes functions that call get_full_path with create_folder argument.
This commit is contained in:
Родитель
4a23e4c40b
Коммит
5d9e88b512
|
@ -59,7 +59,7 @@ class Corpus:
|
||||||
|
|
||||||
# Corpus cache is created using dataset/vocab_type/vocab_size path
|
# Corpus cache is created using dataset/vocab_type/vocab_size path
|
||||||
self.corpus_cache_dir = get_full_path(
|
self.corpus_cache_dir = get_full_path(
|
||||||
os.path.join(cache_dir, str(dataset_name), str(vocab_type), str(vocab_size)), create=True
|
os.path.join(cache_dir, str(dataset_name), str(vocab_type), str(vocab_size)), create_folder=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Encoded dataset (.npy files) cache paths
|
# Encoded dataset (.npy files) cache paths
|
||||||
|
|
|
@ -72,7 +72,7 @@ class BbpeTokenizer(TokenizerBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self._tokenizer = None
|
self._tokenizer = None
|
||||||
self._tokenizer_filepath = os.path.join(get_full_path(save_path, create=True), "bbpe_tokenizer.json")
|
self._tokenizer_filepath = os.path.join(get_full_path(save_path, create_folder=True), "bbpe_tokenizer.json")
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.sorted_vocab = sorted_vocab
|
self.sorted_vocab = sorted_vocab
|
||||||
|
|
|
@ -204,7 +204,7 @@ class WordTokenizer(TokenizerBase):
|
||||||
self.sym2idx = OrderedDict()
|
self.sym2idx = OrderedDict()
|
||||||
|
|
||||||
def _vocab_filepath(self) -> str:
|
def _vocab_filepath(self) -> str:
|
||||||
vocab_dir = get_full_path(os.path.join(self.save_path), create=True)
|
vocab_dir = get_full_path(os.path.join(self.save_path), create_folder=True)
|
||||||
|
|
||||||
return os.path.join(vocab_dir, "vocab.txt")
|
return os.path.join(vocab_dir, "vocab.txt")
|
||||||
|
|
||||||
|
|
|
@ -192,14 +192,14 @@ class NvidiaTrainer(TrainerBase):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Split: {split} is not supported yet.")
|
raise RuntimeError(f"Split: {split} is not supported yet.")
|
||||||
|
|
||||||
if self.dataset_name in ["wt2", "wt103"] or self.dataset_name.startswith("olx_"):
|
if self.args.dataset_name in ["wt2", "wt103"] or self.args.dataset_name.startswith("olx_"):
|
||||||
return LMOrderedIterator(
|
return LMOrderedIterator(
|
||||||
input_ids,
|
input_ids,
|
||||||
self.args.global_batch_size,
|
self.args.global_batch_size,
|
||||||
self.args.seq_len,
|
self.args.seq_len,
|
||||||
device=self.args.device,
|
device=self.args.device,
|
||||||
)
|
)
|
||||||
elif self.dataset_name == "lm1b":
|
elif self.args.dataset_name == "lm1b":
|
||||||
return LMMultiFileIterator(
|
return LMMultiFileIterator(
|
||||||
input_ids,
|
input_ids,
|
||||||
self.vocab,
|
self.vocab,
|
||||||
|
@ -208,7 +208,7 @@ class NvidiaTrainer(TrainerBase):
|
||||||
device=self.args.device,
|
device=self.args.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Dataset: {self.dataset_name} is not supported yet.")
|
raise RuntimeError(f"Dataset: {self.args.dataset_name} is not supported yet.")
|
||||||
|
|
||||||
def _create_optimizer(self) -> None:
|
def _create_optimizer(self) -> None:
|
||||||
optimizer_name = self.args.optim.lower()
|
optimizer_name = self.args.optim.lower()
|
||||||
|
|
|
@ -20,9 +20,9 @@ def parse_args() -> argparse.Namespace:
|
||||||
|
|
||||||
parser.add_argument("-es", "--eval_steps", type=int, default=100, help="Number of steps between evaluations.")
|
parser.add_argument("-es", "--eval_steps", type=int, default=100, help="Number of steps between evaluations.")
|
||||||
|
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="wt103", help="Name of the dataset.")
|
parser.add_argument("-dn", "--dataset_name", type=str, default="wt103", help="Name of the dataset.")
|
||||||
|
|
||||||
parser.add_argument("-v", "--vocab", type=str, default="gpt2", help="Name of the vocabulary/tokenizer.")
|
parser.add_argument("-vt", "--vocab_type", type=str, default="gpt2", help="Name of the vocabulary/tokenizer.")
|
||||||
|
|
||||||
parser.add_argument("-vs", "--vocab_size", type=int, default=10000, help="Size of the vocabulary.")
|
parser.add_argument("-vs", "--vocab_size", type=int, default=10000, help="Size of the vocabulary.")
|
||||||
|
|
||||||
|
@ -54,8 +54,8 @@ if __name__ == "__main__":
|
||||||
no_cuda=args.no_cuda,
|
no_cuda=args.no_cuda,
|
||||||
logging_steps=args.logging_steps,
|
logging_steps=args.logging_steps,
|
||||||
eval_steps=args.eval_steps,
|
eval_steps=args.eval_steps,
|
||||||
dataset=args.dataset,
|
dataset_name=args.dataset_name,
|
||||||
vocab=args.vocab,
|
vocab_type=args.vocab_type,
|
||||||
vocab_size=args.vocab_size,
|
vocab_size=args.vocab_size,
|
||||||
global_batch_size=args.global_batch_size,
|
global_batch_size=args.global_batch_size,
|
||||||
seq_len=args.seq_len,
|
seq_len=args.seq_len,
|
||||||
|
|
Загрузка…
Ссылка в новой задаче