The `add_space_before_punct_symbol` is only for TransfoXL (#5549)

This commit is contained in:
Lysandre Debut 2020-07-06 12:17:05 -04:00 коммит произвёл GitHub
Родитель d6b0b9d451
Коммит 9d9b872b66
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 7 добавлений и 1 удалений

Просмотреть файл

@ -214,8 +214,14 @@ def main():
if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
else:
tokenizer_kwargs = {}
encoded_prompt = tokenizer.encode(
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
)
else:
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")