fix(archai): Fixes functions that call get_full_path with create_folder argument.

This commit is contained in:
Gustavo Rosa 2023-01-27 16:21:25 -03:00
Родитель 4a23e4c40b
Коммит 5d9e88b512
5 изменённых файлов: 10 добавлений и 10 удалений

Просмотреть файл

@ -59,7 +59,7 @@ class Corpus:
# Corpus cache is created using dataset/vocab_type/vocab_size path # Corpus cache is created using dataset/vocab_type/vocab_size path
self.corpus_cache_dir = get_full_path( self.corpus_cache_dir = get_full_path(
os.path.join(cache_dir, str(dataset_name), str(vocab_type), str(vocab_size)), create=True os.path.join(cache_dir, str(dataset_name), str(vocab_type), str(vocab_size)), create_folder=True
) )
# Encoded dataset (.npy files) cache paths # Encoded dataset (.npy files) cache paths

Просмотреть файл

@ -72,7 +72,7 @@ class BbpeTokenizer(TokenizerBase):
) )
self._tokenizer = None self._tokenizer = None
self._tokenizer_filepath = os.path.join(get_full_path(save_path, create=True), "bbpe_tokenizer.json") self._tokenizer_filepath = os.path.join(get_full_path(save_path, create_folder=True), "bbpe_tokenizer.json")
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.sorted_vocab = sorted_vocab self.sorted_vocab = sorted_vocab

Просмотреть файл

@ -204,7 +204,7 @@ class WordTokenizer(TokenizerBase):
self.sym2idx = OrderedDict() self.sym2idx = OrderedDict()
def _vocab_filepath(self) -> str: def _vocab_filepath(self) -> str:
vocab_dir = get_full_path(os.path.join(self.save_path), create=True) vocab_dir = get_full_path(os.path.join(self.save_path), create_folder=True)
return os.path.join(vocab_dir, "vocab.txt") return os.path.join(vocab_dir, "vocab.txt")

Просмотреть файл

@ -192,14 +192,14 @@ class NvidiaTrainer(TrainerBase):
else: else:
raise RuntimeError(f"Split: {split} is not supported yet.") raise RuntimeError(f"Split: {split} is not supported yet.")
if self.dataset_name in ["wt2", "wt103"] or self.dataset_name.startswith("olx_"): if self.args.dataset_name in ["wt2", "wt103"] or self.args.dataset_name.startswith("olx_"):
return LMOrderedIterator( return LMOrderedIterator(
input_ids, input_ids,
self.args.global_batch_size, self.args.global_batch_size,
self.args.seq_len, self.args.seq_len,
device=self.args.device, device=self.args.device,
) )
elif self.dataset_name == "lm1b": elif self.args.dataset_name == "lm1b":
return LMMultiFileIterator( return LMMultiFileIterator(
input_ids, input_ids,
self.vocab, self.vocab,
@ -208,7 +208,7 @@ class NvidiaTrainer(TrainerBase):
device=self.args.device, device=self.args.device,
) )
else: else:
raise RuntimeError(f"Dataset: {self.dataset_name} is not supported yet.") raise RuntimeError(f"Dataset: {self.args.dataset_name} is not supported yet.")
def _create_optimizer(self) -> None: def _create_optimizer(self) -> None:
optimizer_name = self.args.optim.lower() optimizer_name = self.args.optim.lower()

Просмотреть файл

@ -20,9 +20,9 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("-es", "--eval_steps", type=int, default=100, help="Number of steps between evaluations.") parser.add_argument("-es", "--eval_steps", type=int, default=100, help="Number of steps between evaluations.")
parser.add_argument("-d", "--dataset", type=str, default="wt103", help="Name of the dataset.") parser.add_argument("-dn", "--dataset_name", type=str, default="wt103", help="Name of the dataset.")
parser.add_argument("-v", "--vocab", type=str, default="gpt2", help="Name of the vocabulary/tokenizer.") parser.add_argument("-vt", "--vocab_type", type=str, default="gpt2", help="Name of the vocabulary/tokenizer.")
parser.add_argument("-vs", "--vocab_size", type=int, default=10000, help="Size of the vocabulary.") parser.add_argument("-vs", "--vocab_size", type=int, default=10000, help="Size of the vocabulary.")
@ -54,8 +54,8 @@ if __name__ == "__main__":
no_cuda=args.no_cuda, no_cuda=args.no_cuda,
logging_steps=args.logging_steps, logging_steps=args.logging_steps,
eval_steps=args.eval_steps, eval_steps=args.eval_steps,
dataset=args.dataset, dataset_name=args.dataset_name,
vocab=args.vocab, vocab_type=args.vocab_type,
vocab_size=args.vocab_size, vocab_size=args.vocab_size,
global_batch_size=args.global_batch_size, global_batch_size=args.global_batch_size,
seq_len=args.seq_len, seq_len=args.seq_len,