common updates
This commit is contained in:
Родитель
e418fc9b75
Коммит
c1aaf2592e
|
@ -85,7 +85,9 @@ def qa_test_data(qa_test_df, tmp_module):
|
|||
)
|
||||
|
||||
# xlnet
|
||||
qa_processor_xlnet = QAProcessor(model_name="xlnet-base-cased", cache_dir=tmp_module)
|
||||
qa_processor_xlnet = QAProcessor(
|
||||
model_name="xlnet-base-cased", cache_dir=tmp_module
|
||||
)
|
||||
train_features_xlnet = qa_processor_xlnet.preprocess(
|
||||
train_dataset,
|
||||
is_training=True,
|
||||
|
@ -148,13 +150,19 @@ def test_QAProcessor(qa_test_data, tmp_module):
|
|||
]:
|
||||
qa_processor = QAProcessor(model_name=model_name, cache_dir=tmp_module)
|
||||
qa_processor.preprocess(
|
||||
qa_test_data["train_dataset"], is_training=True, feature_cache_dir=tmp_module,
|
||||
qa_test_data["train_dataset"],
|
||||
is_training=True,
|
||||
feature_cache_dir=tmp_module,
|
||||
)
|
||||
qa_processor.preprocess(
|
||||
qa_test_data["train_dataset_list"], is_training=True, feature_cache_dir=tmp_module,
|
||||
qa_test_data["train_dataset_list"],
|
||||
is_training=True,
|
||||
feature_cache_dir=tmp_module,
|
||||
)
|
||||
qa_processor.preprocess(
|
||||
qa_test_data["test_dataset"], is_training=False, feature_cache_dir=tmp_module,
|
||||
qa_test_data["test_dataset"],
|
||||
is_training=False,
|
||||
feature_cache_dir=tmp_module,
|
||||
)
|
||||
|
||||
# test unsupported model type
|
||||
|
@ -188,7 +196,9 @@ def test_AnswerExtractor(qa_test_data, tmp_module):
|
|||
# bert
|
||||
qa_extractor_bert = AnswerExtractor(cache_dir=tmp_module)
|
||||
train_loader_bert = dataloader_from_dataset(qa_test_data["train_features_bert"])
|
||||
test_loader_bert = dataloader_from_dataset(qa_test_data["test_features_bert"], shuffle=False)
|
||||
test_loader_bert = dataloader_from_dataset(
|
||||
qa_test_data["test_features_bert"], shuffle=False
|
||||
)
|
||||
qa_extractor_bert.fit(train_loader_bert, verbose=False, cache_model=True)
|
||||
|
||||
# test saving fine-tuned model
|
||||
|
@ -203,13 +213,19 @@ def test_AnswerExtractor(qa_test_data, tmp_module):
|
|||
|
||||
# xlnet
|
||||
train_loader_xlnet = dataloader_from_dataset(qa_test_data["train_features_xlnet"])
|
||||
test_loader_xlnet = dataloader_from_dataset(qa_test_data["test_features_xlnet"], shuffle=False)
|
||||
qa_extractor_xlnet = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp_module)
|
||||
test_loader_xlnet = dataloader_from_dataset(
|
||||
qa_test_data["test_features_xlnet"], shuffle=False
|
||||
)
|
||||
qa_extractor_xlnet = AnswerExtractor(
|
||||
model_name="xlnet-base-cased", cache_dir=tmp_module
|
||||
)
|
||||
qa_extractor_xlnet.fit(train_loader_xlnet, verbose=False, cache_model=False)
|
||||
qa_extractor_xlnet.predict(test_loader_xlnet, verbose=False)
|
||||
|
||||
# distilbert
|
||||
train_loader_xlnet = dataloader_from_dataset(qa_test_data["train_features_distilbert"])
|
||||
train_loader_xlnet = dataloader_from_dataset(
|
||||
qa_test_data["train_features_distilbert"]
|
||||
)
|
||||
test_loader_xlnet = dataloader_from_dataset(
|
||||
qa_test_data["test_features_distilbert"], shuffle=False
|
||||
)
|
||||
|
|
|
@ -81,7 +81,7 @@ PIP_BASE = {
|
|||
"https://github.com/explosion/spacy-models/releases/download/"
|
||||
"en_core_web_sm-2.1.0/en_core_web_sm-2.1.0.tar.gz"
|
||||
),
|
||||
"transformers": "transformers==2.5.0",
|
||||
"transformers": "transformers==2.9.0",
|
||||
"gensim": "gensim>=3.7.0",
|
||||
"nltk": "nltk>=3.4",
|
||||
"seqeval": "seqeval>=0.0.12",
|
||||
|
|
|
@ -13,28 +13,7 @@ import time
|
|||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
AdamW,
|
||||
AlbertTokenizer,
|
||||
BartTokenizer,
|
||||
BertTokenizer,
|
||||
CamembertTokenizer,
|
||||
DistilBertTokenizer,
|
||||
RobertaTokenizer,
|
||||
XLMRobertaTokenizer,
|
||||
XLMTokenizer,
|
||||
XLNetTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
from utils_nlp.common.pytorch_utils import (
|
||||
get_amp,
|
||||
|
@ -43,28 +22,6 @@ from utils_nlp.common.pytorch_utils import (
|
|||
parallelize_model,
|
||||
)
|
||||
|
||||
TOKENIZER_CLASS = {}
|
||||
TOKENIZER_CLASS.update({k: BertTokenizer for k in BERT_PRETRAINED_MODEL_ARCHIVE_MAP})
|
||||
TOKENIZER_CLASS.update(
|
||||
{k: RobertaTokenizer for k in ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||
)
|
||||
TOKENIZER_CLASS.update({k: XLNetTokenizer for k in XLNET_PRETRAINED_MODEL_ARCHIVE_MAP})
|
||||
TOKENIZER_CLASS.update(
|
||||
{k: DistilBertTokenizer for k in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||
)
|
||||
TOKENIZER_CLASS.update({k: BartTokenizer for k in BART_PRETRAINED_MODEL_ARCHIVE_MAP})
|
||||
TOKENIZER_CLASS.update({k: XLMTokenizer for k in XLM_PRETRAINED_MODEL_ARCHIVE_MAP})
|
||||
TOKENIZER_CLASS.update(
|
||||
{k: AlbertTokenizer for k in ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||
)
|
||||
TOKENIZER_CLASS.update(
|
||||
{k: XLMRobertaTokenizer for k in XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||
)
|
||||
TOKENIZER_CLASS.update(
|
||||
{k: CamembertTokenizer for k in CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||
)
|
||||
|
||||
|
||||
MAX_SEQ_LEN = 512
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -73,34 +30,14 @@ logger = logging.getLogger(__name__)
|
|||
class Transformer:
|
||||
def __init__(
|
||||
self,
|
||||
model_class,
|
||||
model_name="bert-base-cased",
|
||||
num_labels=2,
|
||||
cache_dir=".",
|
||||
load_model_from_dir=None,
|
||||
model_name,
|
||||
model,
|
||||
cache_dir,
|
||||
):
|
||||
if model_name not in self.list_supported_models():
|
||||
raise ValueError(
|
||||
"Model name {0} is not supported by {1}. "
|
||||
"Call '{1}.list_supported_models()' to get all supported model "
|
||||
"names.".format(model_name, self.__class__.__name__)
|
||||
)
|
||||
self._model_name = model_name
|
||||
self._model_type = model_name.split("-")[0]
|
||||
self.model = model
|
||||
self.cache_dir = cache_dir
|
||||
self.load_model_from_dir = load_model_from_dir
|
||||
if load_model_from_dir is None:
|
||||
self.model = model_class[model_name].from_pretrained(
|
||||
model_name,
|
||||
cache_dir=cache_dir,
|
||||
num_labels=num_labels,
|
||||
output_loading_info=False,
|
||||
)
|
||||
else:
|
||||
logger.info("Loading cached model from {}".format(load_model_from_dir))
|
||||
self.model = model_class[model_name].from_pretrained(
|
||||
load_model_from_dir, num_labels=num_labels, output_loading_info=False
|
||||
)
|
||||
|
||||
@property
|
||||
def model_name(self):
|
||||
|
@ -342,7 +279,7 @@ class Transformer:
|
|||
saved_model_path = os.path.join(
|
||||
self.cache_dir, f"{self.model_name}_step_{global_step}.pt"
|
||||
)
|
||||
self.save_model(global_step, saved_model_path)
|
||||
self.save_model(saved_model_path)
|
||||
if validation_function:
|
||||
validation_log = validation_function(self)
|
||||
logger.info(validation_log)
|
||||
|
|
|
@ -289,7 +289,7 @@ class TokenClassifier(Transformer):
|
|||
model = AutoModelForTokenClassification.from_pretrained(
|
||||
model_name, cache_dir=cache_dir, config=config, output_loading_info=False
|
||||
)
|
||||
super().__init__(model_name=model_name, model=model)
|
||||
super().__init__(model_name=model_name, model=model, cache_dir=cache_dir)
|
||||
|
||||
@staticmethod
|
||||
def list_supported_models():
|
||||
|
|
|
@ -215,7 +215,7 @@ class SequenceClassifier(Transformer):
|
|||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, cache_dir=cache_dir, config=config, output_loading_info=False
|
||||
)
|
||||
super().__init__(model_name=model_name, model=model)
|
||||
super().__init__(model_name=model_name, model=model, cache_dir=cache_dir)
|
||||
|
||||
@staticmethod
|
||||
def list_supported_models():
|
||||
|
|
Загрузка…
Ссылка в новой задаче