update the CUDA to accelerate model inference if there is any.

This commit is contained in:
SivilTaram 2021-09-15 16:33:12 +08:00
Родитель 6407435432
Коммит c23dd46ed1
1 изменённых файлов: 3 добавлений и 0 удалений

Просмотреть файл

@ -4,6 +4,7 @@
import logging import logging
from typing import Dict, List from typing import Dict, List
import torch.cuda
from fairseq.models.bart import BARTModel from fairseq.models.bart import BARTModel
from tapex.processor import get_default_processor from tapex.processor import get_default_processor
@ -19,6 +20,8 @@ class TAPEXModelInterface:
def __init__(self, resource_dir, checkpoint_name, table_processor=None): def __init__(self, resource_dir, checkpoint_name, table_processor=None):
self.model = BARTModel.from_pretrained(model_name_or_path=resource_dir, self.model = BARTModel.from_pretrained(model_name_or_path=resource_dir,
checkpoint_file=checkpoint_name) checkpoint_file=checkpoint_name)
if torch.cuda.is_available():
self.model.cuda()
self.model.eval() self.model.eval()
if table_processor is not None: if table_processor is not None:
self.tab_processor = table_processor self.tab_processor = table_processor