make enums json serializable
This commit is contained in:
Родитель
46f2336c65
Коммит
c611ee9c96
|
@ -8,7 +8,7 @@ nltk.download("punkt", quiet=True)
|
|||
nltk.download("stopwords", quiet=True)
|
||||
|
||||
|
||||
class Split(Enum):
|
||||
TRAIN = "train"
|
||||
DEV = "dev"
|
||||
TEST = "test"
|
||||
class Split(str, Enum):
|
||||
TRAIN : str = "train"
|
||||
DEV : str = "dev"
|
||||
TEST : str = "test"
|
|
@ -26,17 +26,17 @@ from torch.utils.data import (
|
|||
BERT_MAX_LEN = 512
|
||||
|
||||
|
||||
class Language(Enum):
|
||||
class Language(str, Enum):
|
||||
"""An enumeration of the supported pretrained models and languages."""
|
||||
|
||||
ENGLISH = "bert-base-uncased"
|
||||
ENGLISHCASED = "bert-base-cased"
|
||||
ENGLISHLARGE = "bert-large-uncased"
|
||||
ENGLISHLARGECASED = "bert-large-cased"
|
||||
ENGLISHLARGEWWM = "bert-large-uncased-whole-word-masking"
|
||||
ENGLISHLARGECASEDWWM = "bert-large-cased-whole-word-masking"
|
||||
CHINESE = "bert-base-chinese"
|
||||
MULTILINGUAL = "bert-base-multilingual-cased"
|
||||
ENGLISH : str = "bert-base-uncased"
|
||||
ENGLISHCASED : str = "bert-base-cased"
|
||||
ENGLISHLARGE : str = "bert-large-uncased"
|
||||
ENGLISHLARGECASED : str = "bert-large-cased"
|
||||
ENGLISHLARGEWWM : str = "bert-large-uncased-whole-word-masking"
|
||||
ENGLISHLARGECASEDWWM : str = "bert-large-cased-whole-word-masking"
|
||||
CHINESE : str = "bert-base-chinese"
|
||||
MULTILINGUAL : str = "bert-base-multilingual-cased"
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
|
|
|
@ -23,11 +23,11 @@ from torch.utils.data import (
|
|||
from utils_nlp.models.bert.common import Language, Tokenizer
|
||||
|
||||
|
||||
class PoolingStrategy(Enum):
|
||||
class PoolingStrategy(str, Enum):
|
||||
"""Enumerate pooling strategies"""
|
||||
MAX = "max"
|
||||
MEAN = "mean"
|
||||
CLS = "cls"
|
||||
MAX : str = "max"
|
||||
MEAN : str = "mean"
|
||||
CLS : str = "cls"
|
||||
|
||||
|
||||
class BERTSentenceEncoder:
|
||||
|
@ -57,7 +57,7 @@ class BERTSentenceEncoder:
|
|||
self.model = (
|
||||
bert_model.model.bert
|
||||
if bert_model
|
||||
else BertModel.from_pretrained(language.value, cache_dir=cache_dir)
|
||||
else BertModel.from_pretrained(language, cache_dir=cache_dir)
|
||||
)
|
||||
self.tokenizer = (
|
||||
tokenizer
|
||||
|
@ -190,11 +190,11 @@ class BERTSentenceEncoder:
|
|||
return values[0]
|
||||
|
||||
try:
|
||||
if pooling_strategy.value == "max":
|
||||
if pooling_strategy == "max":
|
||||
pool_func = max_pool
|
||||
elif pooling_strategy.value == "mean":
|
||||
elif pooling_strategy == "mean":
|
||||
pool_func = mean_pool
|
||||
elif pooling_strategy.value == "cls":
|
||||
elif pooling_strategy == "cls":
|
||||
pool_func = cls_pool
|
||||
else:
|
||||
raise ValuerError("Please enter valid pooling strategy")
|
||||
|
|
Загрузка…
Ссылка в новой задаче