diff --git a/src/ProtoNN/ProtoNNPredictor.cpp b/src/ProtoNN/ProtoNNPredictor.cpp index 05ab4c22..1d147658 100644 --- a/src/ProtoNN/ProtoNNPredictor.cpp +++ b/src/ProtoNN/ProtoNNPredictor.cpp @@ -608,14 +608,14 @@ void ProtoNNPredictor::getTopKScoresBatch( } -void ProtoNNPredictor::saveTopKScores(std::string filename, int k) +void ProtoNNPredictor::saveTopKScores(std::string filename, int topk) { dataCount_t n; n = testData.Xtest.cols(); assert(n > 0); - if (k < 1) - k = 5; + if (topk < 1) + topk = 5; if (filename.empty()) filename = outDir + "/predicted_scores.txt"; @@ -629,12 +629,13 @@ void ProtoNNPredictor::saveTopKScores(std::string filename, int k) dataCount_t curBatchSize = (batchSize < n - startIdx)? batchSize : n - startIdx; MatrixXuf Yscores = MatrixXuf::Zero(model.hyperParams.l, curBatchSize); scoreBatch(Yscores, startIdx, curBatchSize); - - getTopKScoresBatch(Yscores, topKindices, topKscores, k); + getTopKScoresBatch(Yscores, topKindices, topKscores, topk); - for (Eigen::Index i = 0; i < topKindices.cols(); i++) { - for (Eigen::Index j = 0; j < topKindices.rows(); j++) { - outfile << topKindices(j, i) << ":" << topKscores(j, i) << " "; + for (Eigen::Index j = 0; j < topKindices.cols(); j++) { + for (LabelMatType::InnerIterator it(testData.Ytest, i*batchSize+j); it; ++it) + outfile << it.row() << ", "; + for (Eigen::Index k = 0; k < topKindices.rows(); k++) { + outfile << topKindices(k, j) << ":" << topKscores(k, j) << " "; } outfile << std::endl; }