From c611ee9c968df9d8dbbc09f124f22e974cca61e8 Mon Sep 17 00:00:00 2001 From: Casey Hong Date: Wed, 24 Jul 2019 11:12:58 -0400 Subject: [PATCH] make enums json serializable --- utils_nlp/dataset/__init__.py | 8 ++++---- utils_nlp/models/bert/common.py | 18 +++++++++--------- utils_nlp/models/bert/extract_features.py | 16 ++++++++-------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/utils_nlp/dataset/__init__.py b/utils_nlp/dataset/__init__.py index 6e59904..b1b50c4 100644 --- a/utils_nlp/dataset/__init__.py +++ b/utils_nlp/dataset/__init__.py @@ -8,7 +8,7 @@ nltk.download("punkt", quiet=True) nltk.download("stopwords", quiet=True) -class Split(Enum): - TRAIN = "train" - DEV = "dev" - TEST = "test" \ No newline at end of file +class Split(str, Enum): + TRAIN : str = "train" + DEV : str = "dev" + TEST : str = "test" \ No newline at end of file diff --git a/utils_nlp/models/bert/common.py b/utils_nlp/models/bert/common.py index dd48409..ee5b339 100644 --- a/utils_nlp/models/bert/common.py +++ b/utils_nlp/models/bert/common.py @@ -26,17 +26,17 @@ from torch.utils.data import ( BERT_MAX_LEN = 512 -class Language(Enum): +class Language(str, Enum): """An enumeration of the supported pretrained models and languages.""" - ENGLISH = "bert-base-uncased" - ENGLISHCASED = "bert-base-cased" - ENGLISHLARGE = "bert-large-uncased" - ENGLISHLARGECASED = "bert-large-cased" - ENGLISHLARGEWWM = "bert-large-uncased-whole-word-masking" - ENGLISHLARGECASEDWWM = "bert-large-cased-whole-word-masking" - CHINESE = "bert-base-chinese" - MULTILINGUAL = "bert-base-multilingual-cased" + ENGLISH : str = "bert-base-uncased" + ENGLISHCASED : str = "bert-base-cased" + ENGLISHLARGE : str = "bert-large-uncased" + ENGLISHLARGECASED : str = "bert-large-cased" + ENGLISHLARGEWWM : str = "bert-large-uncased-whole-word-masking" + ENGLISHLARGECASEDWWM : str = "bert-large-cased-whole-word-masking" + CHINESE : str = "bert-base-chinese" + MULTILINGUAL : str = "bert-base-multilingual-cased" class Tokenizer: diff --git a/utils_nlp/models/bert/extract_features.py b/utils_nlp/models/bert/extract_features.py index e77f8de..5cf6174 100644 --- a/utils_nlp/models/bert/extract_features.py +++ b/utils_nlp/models/bert/extract_features.py @@ -23,11 +23,11 @@ from torch.utils.data import ( from utils_nlp.models.bert.common import Language, Tokenizer -class PoolingStrategy(Enum): +class PoolingStrategy(str, Enum): """Enumerate pooling strategies""" - MAX = "max" - MEAN = "mean" - CLS = "cls" + MAX : str = "max" + MEAN : str = "mean" + CLS : str = "cls" class BERTSentenceEncoder: @@ -57,7 +57,7 @@ class BERTSentenceEncoder: self.model = ( bert_model.model.bert if bert_model - else BertModel.from_pretrained(language.value, cache_dir=cache_dir) + else BertModel.from_pretrained(language, cache_dir=cache_dir) ) self.tokenizer = ( tokenizer @@ -190,11 +190,11 @@ class BERTSentenceEncoder: return values[0] try: - if pooling_strategy.value == "max": + if pooling_strategy == "max": pool_func = max_pool - elif pooling_strategy.value == "mean": + elif pooling_strategy == "mean": pool_func = mean_pool - elif pooling_strategy.value == "cls": + elif pooling_strategy == "cls": pool_func = cls_pool else: raise ValuerError("Please enter valid pooling strategy")