layer idx and pooling strategy at initialization
This commit is contained in:
Родитель
90a6242da5
Коммит
a02b95c72e
|
@ -31,7 +31,7 @@
|
|||
"\n",
|
||||
"sys.path.append(\"../../\")\n",
|
||||
"from utils_nlp.models.bert.common import Language, Tokenizer\n",
|
||||
"from utils_nlp.models.bert.extract_features import BERTSentenceEncoder, PoolingStrategy"
|
||||
"from utils_nlp.models.bert.sequence_encoding import BERTSentenceEncoder, PoolingStrategy"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -51,6 +51,8 @@
|
|||
"LANGUAGE = Language.ENGLISH\n",
|
||||
"TO_LOWER = True\n",
|
||||
"MAX_SEQ_LENGTH = 128\n",
|
||||
"LAYER_INDEX = -2\n",
|
||||
"POOLING_STRATEGY = PoolingStrategy.MEAN\n",
|
||||
"\n",
|
||||
"# path config\n",
|
||||
"CACHE_DIR = \"./temp\""
|
||||
|
@ -92,6 +94,8 @@
|
|||
" cache_dir=CACHE_DIR,\n",
|
||||
" to_lower=TO_LOWER,\n",
|
||||
" max_len=MAX_SEQ_LENGTH,\n",
|
||||
" layer_index=LAYER_INDEX,\n",
|
||||
" pooling_strategy=POOLING_STRATEGY,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
@ -118,7 +122,7 @@
|
|||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 2/2 [00:00<00:00, 4082.05it/s]\n"
|
||||
"100%|██████████| 2/2 [00:00<00:00, 2663.05it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -178,8 +182,6 @@
|
|||
"source": [
|
||||
"se.encode(\n",
|
||||
" [\"Coffee is good\", \"The moose is across the street\"],\n",
|
||||
" layer_indices=[-2],\n",
|
||||
" pooling_strategy=PoolingStrategy.MEAN,\n",
|
||||
" as_numpy=False\n",
|
||||
")"
|
||||
]
|
||||
|
|
|
@ -42,6 +42,8 @@ class BERTSentenceEncoder:
|
|||
cache_dir=".",
|
||||
to_lower=True,
|
||||
max_len=512,
|
||||
layer_index=-1,
|
||||
pooling_strategy=PoolingStrategy.MEAN,
|
||||
):
|
||||
"""Initialize the encoder's underlying model and tokenizer
|
||||
|
||||
|
@ -53,6 +55,9 @@ class BERTSentenceEncoder:
|
|||
cache_dir: Location of BERT's cache directory. Defaults to "."
|
||||
to_lower: True to lowercase before tokenization. Defaults to False.
|
||||
max_len: Maximum number of tokens.
|
||||
layer_index: The layer from which to extract features.
|
||||
Defaults to the last layer; can also be a list of integers for experimentation.
|
||||
pooling_strategy: Pooling strategy to aggregate token embeddings into sentence embedding.
|
||||
"""
|
||||
self.model = (
|
||||
bert_model.model.bert
|
||||
|
@ -66,13 +71,17 @@ class BERTSentenceEncoder:
|
|||
)
|
||||
self.num_gpus = num_gpus
|
||||
self.max_len = max_len
|
||||
if isinstance(layer_index, int):
|
||||
self.layer_index = [layer_index]
|
||||
else:
|
||||
self.layer_index = layer_index
|
||||
self.pooling_strategy = pooling_strategy
|
||||
|
||||
def get_hidden_states(self, text, layer_indices=[-2], batch_size=32):
|
||||
def get_hidden_states(self, text, batch_size=32):
|
||||
"""Extract the hidden states from the pretrained model
|
||||
|
||||
Args:
|
||||
text: List of documents to extract features from.
|
||||
layer_indices: List of indices of the layers to extract features from. Defaults to the second-to-last layer.
|
||||
batch_size: Batch size, defaults to 32.
|
||||
|
||||
Returns:
|
||||
|
@ -117,11 +126,12 @@ class BERTSentenceEncoder:
|
|||
input_ids_tensor,
|
||||
token_type_ids=None,
|
||||
attention_mask=input_mask_tensor,
|
||||
)
|
||||
)
|
||||
self.embedding_dim = all_encoder_layers[0].size()[-1]
|
||||
|
||||
for b, example_index in enumerate(example_indices_tensor):
|
||||
for (i, token) in enumerate(tokens[example_index.item()]):
|
||||
for (j, layer_index) in enumerate(layer_indices):
|
||||
for (j, layer_index) in enumerate(self.layer_index):
|
||||
layer_output = (
|
||||
all_encoder_layers[int(layer_index)]
|
||||
.detach()
|
||||
|
@ -148,7 +158,7 @@ class BERTSentenceEncoder:
|
|||
|
||||
return pd.DataFrame.from_dict(hidden_states)
|
||||
|
||||
def pool(self, df, pooling_strategy=PoolingStrategy.MEAN):
|
||||
def pool(self, df):
|
||||
"""Pooling to aggregate token-wise embeddings to sentence embeddings
|
||||
|
||||
Args:
|
||||
|
@ -161,7 +171,7 @@ class BERTSentenceEncoder:
|
|||
def max_pool(x):
|
||||
values = np.array(
|
||||
[
|
||||
np.reshape(np.array(x.values[i]), 768)
|
||||
np.reshape(np.array(x.values[i]), self.embedding_dim)
|
||||
for i in range(x.values.shape[0])
|
||||
]
|
||||
)
|
||||
|
@ -171,7 +181,7 @@ class BERTSentenceEncoder:
|
|||
def mean_pool(x):
|
||||
values = np.array(
|
||||
[
|
||||
np.reshape(np.array(x.values[i]), 768)
|
||||
np.reshape(np.array(x.values[i]), self.embedding_dim)
|
||||
for i in range(x.values.shape[0])
|
||||
]
|
||||
)
|
||||
|
@ -182,22 +192,22 @@ class BERTSentenceEncoder:
|
|||
def cls_pool(x):
|
||||
values = np.array(
|
||||
[
|
||||
np.reshape(np.array(x.values[i]), 768)
|
||||
np.reshape(np.array(x.values[i]), self.embedding_dim)
|
||||
for i in range(x.values.shape[0])
|
||||
]
|
||||
)
|
||||
return values[0]
|
||||
|
||||
try:
|
||||
if pooling_strategy == "max":
|
||||
if self.pooling_strategy == "max":
|
||||
pool_func = max_pool
|
||||
elif pooling_strategy == "mean":
|
||||
elif self.pooling_strategy == "mean":
|
||||
pool_func = mean_pool
|
||||
elif pooling_strategy == "cls":
|
||||
elif self.pooling_strategy == "cls":
|
||||
pool_func = cls_pool
|
||||
else:
|
||||
raise ValuerError("Please enter valid pooling strategy")
|
||||
except ValuerError as ve:
|
||||
raise ValueError("Please enter valid pooling strategy")
|
||||
except ValueError as ve:
|
||||
print(ve)
|
||||
|
||||
return df.groupby(["text_index", "layer_index"])["values"].apply(lambda x: pool_func(x)).reset_index()
|
||||
|
@ -205,21 +215,17 @@ class BERTSentenceEncoder:
|
|||
def encode(
|
||||
self,
|
||||
text,
|
||||
layer_indices=[-2],
|
||||
batch_size=32,
|
||||
pooling_strategy=PoolingStrategy.MEAN,
|
||||
as_numpy=False
|
||||
):
|
||||
"""Computes sentence encodings
|
||||
|
||||
Args:
|
||||
text: List of documents to encode.
|
||||
layer_indices: List of indexes of the layers to extract features from. Defaults to the second-to-last layer.
|
||||
batch_size: Batch size, defaults to 32.
|
||||
pooling_strategy: Pooling strategy to aggregate token embeddings into sentence embedding.
|
||||
"""
|
||||
df = self.get_hidden_states(text, layer_indices, batch_size)
|
||||
pooled = self.pool(df, pooling_strategy=pooling_strategy)
|
||||
df = self.get_hidden_states(text, batch_size)
|
||||
pooled = self.pool(df)
|
||||
|
||||
if as_numpy:
|
||||
return np.array(pooled["values"].tolist())
|
||||
|
|
Загрузка…
Ссылка в новой задаче