From c23dd46ed187d31a5720d781dfb52ab285c05d5f Mon Sep 17 00:00:00 2001 From: SivilTaram Date: Wed, 15 Sep 2021 16:33:12 +0800 Subject: [PATCH] update the CUDA to accelerate model inference if there is any. --- tapex/model_interface.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tapex/model_interface.py b/tapex/model_interface.py index 7e9eeba..3279839 100644 --- a/tapex/model_interface.py +++ b/tapex/model_interface.py @@ -4,6 +4,7 @@ import logging from typing import Dict, List +import torch.cuda from fairseq.models.bart import BARTModel from tapex.processor import get_default_processor @@ -19,6 +20,8 @@ class TAPEXModelInterface: def __init__(self, resource_dir, checkpoint_name, table_processor=None): self.model = BARTModel.from_pretrained(model_name_or_path=resource_dir, checkpoint_file=checkpoint_name) + if torch.cuda.is_available(): + self.model.cuda() self.model.eval() if table_processor is not None: self.tab_processor = table_processor