add word level length for char emb

This commit is contained in:
Quanjia Yan 2019-06-06 16:01:56 +08:00
Родитель d22ed5b0fa
Коммит 31bca9e35a
6 изменённых файлов: 23 добавлений и 12 удалений

Просмотреть файл

@ -438,6 +438,8 @@ class LearningMachine(object):
param_list, inputs_desc, length_desc = transform_params2tensors(data_batches[i], length_batches[i])
logits = self.model(inputs_desc, length_desc, *param_list)
for single_key in temp_key_list:
length_batches[i][single_key] = length_batches[i][single_key]['sentence_length']
logits_softmax = {}
if isinstance(self.model, nn.DataParallel):

Просмотреть файл

@ -368,11 +368,11 @@ class Model(nn.Module):
repre_lengths[EMBED_LAYER_ID] = dict()
for input in inputs:
representation[input] = self.layers[EMBED_LAYER_ID](inputs[input], use_gpu=self.is_cuda())
representation[input] = self.layers[EMBED_LAYER_ID](inputs[input], lengths[input], use_gpu=self.is_cuda())
if self.use_gpu:
repre_lengths[input] = transfer_to_gpu(lengths[input])
repre_lengths[input] = transfer_to_gpu(lengths[input]['sentence_length'])
else:
repre_lengths[input] = lengths[input]
repre_lengths[input] = lengths[input]['sentence_length']
for layer_id in self.layer_topological_sequence:
#logging.debug("To proces layer %s" % layer_id)

Просмотреть файл

@ -21,6 +21,7 @@
4. [Compression for MRC Model](#task-6.4)
* [Task 7: Chinese Sentiment Analysis](#task-7)
* [Task 8: Chinese Text Matching](#task-8)
* [Task 9: Sequence Tagging](#task-9)
* [Advanced Usage](#advanced-usage)
* [Extra Feature Support](#extra-feature)
* [Learning Rate Decay](#lr-decay)
@ -562,6 +563,7 @@ Here is an example using Chinese data, for text matching task.
```
*Tips: you can try different models by running different JSON config files. The model file and train log file can be found in JOSN config file's outputs/save_base_dir after you finish training.*
### <span id="task-9">Task 9: Sequence Tagging</span>
## <span id="advanced-usage">Advanced Usage</span>

Просмотреть файл

@ -135,7 +135,7 @@ class Embedding(BaseLayer):
logging.info("The Embedding[%s][fix_weight] is true, fix the embeddings[%s]'s weight" % (input_cluster, input_cluster))
def forward(self, inputs, use_gpu=False):
def forward(self, inputs, lengths, use_gpu=False):
""" process inputs
Args:
@ -144,6 +144,9 @@ class Embedding(BaseLayer):
'word': word ids (Variable), shape:[batch_size, seq_len],\n
'postag': postag ids (Variable), shape: [batch_size, seq_len],\n
...
lengths (dict): a dictionary to describe each type input length. e.g.:\n
'sentence_length':[batch_size]
'word_length': [batch_size, sentence_length]
use_gpu (bool): put embedding matrix on GPU (True) or not (False)
Returns:
@ -156,14 +159,17 @@ class Embedding(BaseLayer):
if 'extra' in input_cluster:
continue
input = inputs[input_cluster]
# if 'type' in self.layer_conf.conf[input_cluster]:
# emb = self.embeddings[input_cluster](input, lengths[input]).float()
# else:
# emb = self.embeddings[input_cluster](input).float()
if list(self.embeddings[input_cluster].parameters())[0].device.type == 'cpu':
emb = self.embeddings[input_cluster](input.cpu()).float()
# for char embedding type
if 'type' in self.layer_conf.conf[input_cluster]:
emb = self.embeddings[input_cluster](input.cpu(), lengths['word_length'].cpu()).float()
else:
emb = self.embeddings[input_cluster](input.cpu()).float()
else:
emb = self.embeddings[input_cluster](input).float()
if 'type' in self.layer_conf.conf[input_cluster]:
emb = self.embeddings[input_cluster](input, lengths['word_length']).float()
else:
emb = self.embeddings[input_cluster](input).float()
if use_gpu is True:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
emb = emb.to(device)

Просмотреть файл

@ -80,7 +80,7 @@ class CNNCharEmbedding(BaseLayer):
if self.activation and hasattr(self.activation, 'weight'):
self.activation.weight = torch.nn.Parameter(self.activation.weight.cuda())
def forward(self, string):
def forward(self, string, length):
"""
Step1: [batch_size, seq_len, char num in words] -> [batch_size, seq_len * char num in words]
Step2: lookup embedding matrix -> [batch_size, seq_len * char num in words, embedding_dim]
@ -91,6 +91,7 @@ class CNNCharEmbedding(BaseLayer):
Args:
string (Variable): [[char ids of word1], [char ids of word2], [...], ...], shape: [batch_size, seq_len, char num in words]
length :record length of each word. [batch_size, sequence_length]
Returns:
Variable: [batch_size, seq_len, output_dim]

Просмотреть файл

@ -357,7 +357,7 @@ def get_batches(problem, data, length, target, batch_size, input_types, pad_ids=
else:
data_batch[input_cluster][input_type] = data[input_cluster][input_type][stidx: stidx + batch_size]
# word_length is used for padding char sequence, now only save sentence_length
length_batch[input_cluster] = length_batch[input_cluster]['sentence_length']
# length_batch[input_cluster] = length_batch[input_cluster]['sentence_length']
data_batches.append(data_batch)
length_batches.append(length_batch)