update the CUDA to accelerate model inference if there is any.
This commit is contained in:
Родитель
6407435432
Коммит
c23dd46ed1
|
@ -4,6 +4,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch.cuda
|
||||||
from fairseq.models.bart import BARTModel
|
from fairseq.models.bart import BARTModel
|
||||||
|
|
||||||
from tapex.processor import get_default_processor
|
from tapex.processor import get_default_processor
|
||||||
|
@ -19,6 +20,8 @@ class TAPEXModelInterface:
|
||||||
def __init__(self, resource_dir, checkpoint_name, table_processor=None):
|
def __init__(self, resource_dir, checkpoint_name, table_processor=None):
|
||||||
self.model = BARTModel.from_pretrained(model_name_or_path=resource_dir,
|
self.model = BARTModel.from_pretrained(model_name_or_path=resource_dir,
|
||||||
checkpoint_file=checkpoint_name)
|
checkpoint_file=checkpoint_name)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.model.cuda()
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
if table_processor is not None:
|
if table_processor is not None:
|
||||||
self.tab_processor = table_processor
|
self.tab_processor = table_processor
|
||||||
|
|
Загрузка…
Ссылка в новой задаче