This commit is contained in:
Casey Hong 2019-07-24 11:12:58 -04:00
Родитель 46f2336c65
Коммит c611ee9c96
3 изменённых файлов: 21 добавлений и 21 удалений

Просмотреть файл

@ -8,7 +8,7 @@ nltk.download("punkt", quiet=True)
nltk.download("stopwords", quiet=True)
class Split(Enum):
TRAIN = "train"
DEV = "dev"
TEST = "test"
class Split(str, Enum):
TRAIN : str = "train"
DEV : str = "dev"
TEST : str = "test"

Просмотреть файл

@ -26,17 +26,17 @@ from torch.utils.data import (
BERT_MAX_LEN = 512
class Language(Enum):
class Language(str, Enum):
"""An enumeration of the supported pretrained models and languages."""
ENGLISH = "bert-base-uncased"
ENGLISHCASED = "bert-base-cased"
ENGLISHLARGE = "bert-large-uncased"
ENGLISHLARGECASED = "bert-large-cased"
ENGLISHLARGEWWM = "bert-large-uncased-whole-word-masking"
ENGLISHLARGECASEDWWM = "bert-large-cased-whole-word-masking"
CHINESE = "bert-base-chinese"
MULTILINGUAL = "bert-base-multilingual-cased"
ENGLISH : str = "bert-base-uncased"
ENGLISHCASED : str = "bert-base-cased"
ENGLISHLARGE : str = "bert-large-uncased"
ENGLISHLARGECASED : str = "bert-large-cased"
ENGLISHLARGEWWM : str = "bert-large-uncased-whole-word-masking"
ENGLISHLARGECASEDWWM : str = "bert-large-cased-whole-word-masking"
CHINESE : str = "bert-base-chinese"
MULTILINGUAL : str = "bert-base-multilingual-cased"
class Tokenizer:

Просмотреть файл

@ -23,11 +23,11 @@ from torch.utils.data import (
from utils_nlp.models.bert.common import Language, Tokenizer
class PoolingStrategy(Enum):
class PoolingStrategy(str, Enum):
"""Enumerate pooling strategies"""
MAX = "max"
MEAN = "mean"
CLS = "cls"
MAX : str = "max"
MEAN : str = "mean"
CLS : str = "cls"
class BERTSentenceEncoder:
@ -57,7 +57,7 @@ class BERTSentenceEncoder:
self.model = (
bert_model.model.bert
if bert_model
else BertModel.from_pretrained(language.value, cache_dir=cache_dir)
else BertModel.from_pretrained(language, cache_dir=cache_dir)
)
self.tokenizer = (
tokenizer
@ -190,11 +190,11 @@ class BERTSentenceEncoder:
return values[0]
try:
if pooling_strategy.value == "max":
if pooling_strategy == "max":
pool_func = max_pool
elif pooling_strategy.value == "mean":
elif pooling_strategy == "mean":
pool_func = mean_pool
elif pooling_strategy.value == "cls":
elif pooling_strategy == "cls":
pool_func = cls_pool
else:
raise ValuerError("Please enter valid pooling strategy")