layer idx and pooling strategy at initialization

This commit is contained in:
Casey Hong 2019-07-24 13:33:45 -04:00
Родитель 90a6242da5
Коммит a02b95c72e
2 изменённых файлов: 31 добавлений и 23 удалений

Просмотреть файл

@ -31,7 +31,7 @@
"\n",
"sys.path.append(\"../../\")\n",
"from utils_nlp.models.bert.common import Language, Tokenizer\n",
"from utils_nlp.models.bert.extract_features import BERTSentenceEncoder, PoolingStrategy"
"from utils_nlp.models.bert.sequence_encoding import BERTSentenceEncoder, PoolingStrategy"
]
},
{
@ -51,6 +51,8 @@
"LANGUAGE = Language.ENGLISH\n",
"TO_LOWER = True\n",
"MAX_SEQ_LENGTH = 128\n",
"LAYER_INDEX = -2\n",
"POOLING_STRATEGY = PoolingStrategy.MEAN\n",
"\n",
"# path config\n",
"CACHE_DIR = \"./temp\""
@ -92,6 +94,8 @@
" cache_dir=CACHE_DIR,\n",
" to_lower=TO_LOWER,\n",
" max_len=MAX_SEQ_LENGTH,\n",
" layer_index=LAYER_INDEX,\n",
" pooling_strategy=POOLING_STRATEGY,\n",
")"
]
},
@ -118,7 +122,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 2/2 [00:00<00:00, 4082.05it/s]\n"
"100%|██████████| 2/2 [00:00<00:00, 2663.05it/s]\n"
]
},
{
@ -178,8 +182,6 @@
"source": [
"se.encode(\n",
" [\"Coffee is good\", \"The moose is across the street\"],\n",
" layer_indices=[-2],\n",
" pooling_strategy=PoolingStrategy.MEAN,\n",
" as_numpy=False\n",
")"
]

Просмотреть файл

@ -42,6 +42,8 @@ class BERTSentenceEncoder:
cache_dir=".",
to_lower=True,
max_len=512,
layer_index=-1,
pooling_strategy=PoolingStrategy.MEAN,
):
"""Initialize the encoder's underlying model and tokenizer
@ -53,6 +55,9 @@ class BERTSentenceEncoder:
cache_dir: Location of BERT's cache directory. Defaults to "."
to_lower: True to lowercase before tokenization. Defaults to False.
max_len: Maximum number of tokens.
layer_index: The layer from which to extract features.
Defaults to the last layer; can also be a list of integers for experimentation.
pooling_strategy: Pooling strategy to aggregate token embeddings into sentence embedding.
"""
self.model = (
bert_model.model.bert
@ -66,13 +71,17 @@ class BERTSentenceEncoder:
)
self.num_gpus = num_gpus
self.max_len = max_len
if isinstance(layer_index, int):
self.layer_index = [layer_index]
else:
self.layer_index = layer_index
self.pooling_strategy = pooling_strategy
def get_hidden_states(self, text, layer_indices=[-2], batch_size=32):
def get_hidden_states(self, text, batch_size=32):
"""Extract the hidden states from the pretrained model
Args:
text: List of documents to extract features from.
layer_indices: List of indices of the layers to extract features from. Defaults to the second-to-last layer.
batch_size: Batch size, defaults to 32.
Returns:
@ -117,11 +126,12 @@ class BERTSentenceEncoder:
input_ids_tensor,
token_type_ids=None,
attention_mask=input_mask_tensor,
)
)
self.embedding_dim = all_encoder_layers[0].size()[-1]
for b, example_index in enumerate(example_indices_tensor):
for (i, token) in enumerate(tokens[example_index.item()]):
for (j, layer_index) in enumerate(layer_indices):
for (j, layer_index) in enumerate(self.layer_index):
layer_output = (
all_encoder_layers[int(layer_index)]
.detach()
@ -148,7 +158,7 @@ class BERTSentenceEncoder:
return pd.DataFrame.from_dict(hidden_states)
def pool(self, df, pooling_strategy=PoolingStrategy.MEAN):
def pool(self, df):
"""Pooling to aggregate token-wise embeddings to sentence embeddings
Args:
@ -161,7 +171,7 @@ class BERTSentenceEncoder:
def max_pool(x):
values = np.array(
[
np.reshape(np.array(x.values[i]), 768)
np.reshape(np.array(x.values[i]), self.embedding_dim)
for i in range(x.values.shape[0])
]
)
@ -171,7 +181,7 @@ class BERTSentenceEncoder:
def mean_pool(x):
values = np.array(
[
np.reshape(np.array(x.values[i]), 768)
np.reshape(np.array(x.values[i]), self.embedding_dim)
for i in range(x.values.shape[0])
]
)
@ -182,22 +192,22 @@ class BERTSentenceEncoder:
def cls_pool(x):
values = np.array(
[
np.reshape(np.array(x.values[i]), 768)
np.reshape(np.array(x.values[i]), self.embedding_dim)
for i in range(x.values.shape[0])
]
)
return values[0]
try:
if pooling_strategy == "max":
if self.pooling_strategy == "max":
pool_func = max_pool
elif pooling_strategy == "mean":
elif self.pooling_strategy == "mean":
pool_func = mean_pool
elif pooling_strategy == "cls":
elif self.pooling_strategy == "cls":
pool_func = cls_pool
else:
raise ValuerError("Please enter valid pooling strategy")
except ValuerError as ve:
raise ValueError("Please enter valid pooling strategy")
except ValueError as ve:
print(ve)
return df.groupby(["text_index", "layer_index"])["values"].apply(lambda x: pool_func(x)).reset_index()
@ -205,21 +215,17 @@ class BERTSentenceEncoder:
def encode(
self,
text,
layer_indices=[-2],
batch_size=32,
pooling_strategy=PoolingStrategy.MEAN,
as_numpy=False
):
"""Computes sentence encodings
Args:
text: List of documents to encode.
layer_indices: List of indexes of the layers to extract features from. Defaults to the second-to-last layer.
batch_size: Batch size, defaults to 32.
pooling_strategy: Pooling strategy to aggregate token embeddings into sentence embedding.
"""
df = self.get_hidden_states(text, layer_indices, batch_size)
pooled = self.pool(df, pooling_strategy=pooling_strategy)
df = self.get_hidden_states(text, batch_size)
pooled = self.pool(df)
if as_numpy:
return np.array(pooled["values"].tolist())