add costType to alignment training

This commit is contained in:
Marcin Junczys-Dowmunt 2018-08-27 17:55:43 +02:00
Родитель c1f0297082
Коммит b69ad3ce39
1 изменённых файлов: 17 добавлений и 5 удалений

Просмотреть файл

@ -15,24 +15,36 @@ static inline Expr guidedAlignmentCost(Ptr<ExpressionGraph> graph,
int dimTrg = att->shape()[-1];
//debug(att, "Attention");
auto aln = graph->constant(att->shape(),
inits::from_vector(batch->getGuidedAlignment()));
//debug(aln, "Alignment");
std::string guidedCostType
= options->get<std::string>("guided-alignment-cost");
std::string costType = options->get<std::string>("cost-type");
int div = 1;
if(costType == "ce-mean-words") {
div = dimBatch * dimSrc * dimTrg;
} else if(costType == "perplexity") {
div = dimBatch * dimSrc * dimTrg;
} else if(costType == "ce-sum") {
div = 1;
} else {
div = dimBatch;
}
Expr alnCost;
float eps = 1e-6;
if(guidedCostType == "mse") {
alnCost = sum(flatten(square(att - aln))) / (2 * dimBatch);
alnCost = sum(flatten(square(att - aln))) / (2 * div);
} else if(guidedCostType == "mult") {
alnCost = -log(sum(flatten(att * aln)) + eps) / dimBatch;
alnCost = -log(sum(flatten(att * aln)) + eps) / div;
} else if(guidedCostType == "ce") {
alnCost = -sum(flatten(aln * log(att + eps))) / dimBatch;
alnCost = -sum(flatten(aln * log(att + eps))) / div;
} else {
ABORT("Unknown alignment cost type");
}