fix(archai): Fixes functions that call get_full_path with create_folder argument.

This commit is contained in:
Gustavo Rosa 2023-01-27 16:21:25 -03:00
Родитель 4a23e4c40b
Коммит 5d9e88b512
5 изменённых файлов: 10 добавлений и 10 удалений

Просмотреть файл

@ -59,7 +59,7 @@ class Corpus:
# Corpus cache is created using dataset/vocab_type/vocab_size path
self.corpus_cache_dir = get_full_path(
os.path.join(cache_dir, str(dataset_name), str(vocab_type), str(vocab_size)), create=True
os.path.join(cache_dir, str(dataset_name), str(vocab_type), str(vocab_size)), create_folder=True
)
# Encoded dataset (.npy files) cache paths

Просмотреть файл

@ -72,7 +72,7 @@ class BbpeTokenizer(TokenizerBase):
)
self._tokenizer = None
self._tokenizer_filepath = os.path.join(get_full_path(save_path, create=True), "bbpe_tokenizer.json")
self._tokenizer_filepath = os.path.join(get_full_path(save_path, create_folder=True), "bbpe_tokenizer.json")
self.vocab_size = vocab_size
self.sorted_vocab = sorted_vocab

Просмотреть файл

@ -204,7 +204,7 @@ class WordTokenizer(TokenizerBase):
self.sym2idx = OrderedDict()
def _vocab_filepath(self) -> str:
vocab_dir = get_full_path(os.path.join(self.save_path), create=True)
vocab_dir = get_full_path(os.path.join(self.save_path), create_folder=True)
return os.path.join(vocab_dir, "vocab.txt")

Просмотреть файл

@ -192,14 +192,14 @@ class NvidiaTrainer(TrainerBase):
else:
raise RuntimeError(f"Split: {split} is not supported yet.")
if self.dataset_name in ["wt2", "wt103"] or self.dataset_name.startswith("olx_"):
if self.args.dataset_name in ["wt2", "wt103"] or self.args.dataset_name.startswith("olx_"):
return LMOrderedIterator(
input_ids,
self.args.global_batch_size,
self.args.seq_len,
device=self.args.device,
)
elif self.dataset_name == "lm1b":
elif self.args.dataset_name == "lm1b":
return LMMultiFileIterator(
input_ids,
self.vocab,
@ -208,7 +208,7 @@ class NvidiaTrainer(TrainerBase):
device=self.args.device,
)
else:
raise RuntimeError(f"Dataset: {self.dataset_name} is not supported yet.")
raise RuntimeError(f"Dataset: {self.args.dataset_name} is not supported yet.")
def _create_optimizer(self) -> None:
optimizer_name = self.args.optim.lower()

Просмотреть файл

@ -20,9 +20,9 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("-es", "--eval_steps", type=int, default=100, help="Number of steps between evaluations.")
parser.add_argument("-d", "--dataset", type=str, default="wt103", help="Name of the dataset.")
parser.add_argument("-dn", "--dataset_name", type=str, default="wt103", help="Name of the dataset.")
parser.add_argument("-v", "--vocab", type=str, default="gpt2", help="Name of the vocabulary/tokenizer.")
parser.add_argument("-vt", "--vocab_type", type=str, default="gpt2", help="Name of the vocabulary/tokenizer.")
parser.add_argument("-vs", "--vocab_size", type=int, default=10000, help="Size of the vocabulary.")
@ -54,8 +54,8 @@ if __name__ == "__main__":
no_cuda=args.no_cuda,
logging_steps=args.logging_steps,
eval_steps=args.eval_steps,
dataset=args.dataset,
vocab=args.vocab,
dataset_name=args.dataset_name,
vocab_type=args.vocab_type,
vocab_size=args.vocab_size,
global_batch_size=args.global_batch_size,
seq_len=args.seq_len,