Use pad_id for padding in tensorflow op

This commit is contained in:
Taku Kudo 2018-08-07 00:03:05 +09:00
Родитель 1697a22423
Коммит 6273c3f0e4
2 изменённых файлов: 16 добавлений и 2 удалений

Просмотреть файл

@ -141,6 +141,15 @@ class SentencePieceBaseOp : public OpKernel {
}
protected:
void GetPad(int32* pad) const { *pad = pad_id_; }
void GetPad(std::string* pad) const {
pad->clear();
if (sentencepiece_processor_ && pad_id_ >= 0 &&
pad_id_ != sentencepiece_processor_->unk_id())
*pad = sentencepiece_processor_->IdToPiece(pad_id_);
}
std::shared_ptr<SentencePieceProcessor> sentencepiece_processor_;
int bos_id_ = -1;
int eos_id_ = -1;
@ -378,7 +387,9 @@ template <typename T>
class SentencePieceEncodeDenseOp : public SentencePieceEncodeOpBase<T> {
public:
explicit SentencePieceEncodeDenseOp(OpKernelConstruction* context)
: SentencePieceEncodeOpBase<T>(context) {}
: SentencePieceEncodeOpBase<T>(context) {
this->GetPad(&pad_);
}
// protected:
void MakeOutputTensor(OpKernelContext* context,
@ -406,11 +417,14 @@ class SentencePieceEncodeDenseOp : public SentencePieceEncodeOpBase<T> {
for (int row = 0; row < batch_size; ++row) {
for (int col = 0; col < max_sequence_length; ++col) {
values_tensor_output(row, col) =
col < pieces[row].size() ? pieces[row][col] : T();
col < pieces[row].size() ? pieces[row][col] : pad_;
}
length_tensor_output(row) = pieces[row].size();
}
}
private:
T pad_;
};
template <typename T>

Двоичный файл не отображается.