update the CUDA to accelerate model inference if there is any.
This commit is contained in:
Родитель
6407435432
Коммит
c23dd46ed1
|
@ -4,6 +4,7 @@
|
|||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import torch.cuda
|
||||
from fairseq.models.bart import BARTModel
|
||||
|
||||
from tapex.processor import get_default_processor
|
||||
|
@ -19,6 +20,8 @@ class TAPEXModelInterface:
|
|||
def __init__(self, resource_dir, checkpoint_name, table_processor=None):
|
||||
self.model = BARTModel.from_pretrained(model_name_or_path=resource_dir,
|
||||
checkpoint_file=checkpoint_name)
|
||||
if torch.cuda.is_available():
|
||||
self.model.cuda()
|
||||
self.model.eval()
|
||||
if table_processor is not None:
|
||||
self.tab_processor = table_processor
|
||||
|
|
Загрузка…
Ссылка в новой задаче