This commit is contained in:
2017-11-14 12:06:22 +05:30
Родитель 1e2346b7f6 7fd713988f
Коммит 29e1ea41b8
6 изменённых файлов: 19 добавлений и 17 удалений

Просмотреть файл

@ -38,7 +38,7 @@ int main(int argc, char **argv)
BonsaiPredictor predictor(modelBytes, model); // use the constructor predictor(modelBytes, model, false) for loading a sparse model.
predictor.importMeanVar(meanVarBytes, meanVar);
predictor.batchEvaluate(trainer.data.Xtest, trainer.data.Ytest, dataDir, currResultsPath);
predictor.batchEvaluate(trainer.data.Xvalidation, trainer.data.Yvalidation, dataDir, currResultsPath);
delete[] model, meanVar;

Просмотреть файл

@ -80,7 +80,7 @@ namespace EdgeML
int seed;
int iters, epochs;
dataCount_t ntrain, ntest, batchSize;
dataCount_t ntrain, nvalidation, batchSize;
bool isOneIndex;
FP_TYPE Sigma; ///< Sigmoid parameter for prediction

Просмотреть файл

@ -957,7 +957,7 @@ void Bonsai::parseInput(const int& argc, const char** argv,
required++;
break;
case 'E':
hyperParam.ntest = int(atoi(argv[i]));
hyperParam.nvalidation = int(atoi(argv[i]));
required++;
break;
}

Просмотреть файл

@ -20,7 +20,7 @@ BonsaiModel::BonsaiHyperParams::BonsaiHyperParams()
seed = 42;
ntrain = 0;
ntest = 0;
nvalidation = 0;
batchSize = 0;
iters = 0;
@ -102,7 +102,6 @@ void BonsaiModel::BonsaiHyperParams::finalizeHyperParams()
// Following asserts removed to faciliate support for TLC
// which does not know how many datapoints are going to be fed before-hand!
// assert(ntrain >= 1);
// assert(ntest >= 0);
assert(projectionDimension <= dataDimension + 1);
assert(numClasses > 0);

Просмотреть файл

@ -24,7 +24,7 @@ int main()
hyperParam.numClasses = 10;
hyperParam.ntest = 0;
hyperParam.nvalidation = 0;
hyperParam.ntrain = 5000;
hyperParam.Sigma = 1.0;

Просмотреть файл

@ -29,7 +29,7 @@ BonsaiTrainer::BonsaiTrainer(
data(dataIngestType,
DataFormatParams{
model.hyperParams.ntrain,
model.hyperParams.ntest,
model.hyperParams.nvalidation,
model.hyperParams.numClasses,
model.hyperParams.dataDimension })
{
@ -47,7 +47,7 @@ BonsaiTrainer::BonsaiTrainer(
mean = MatrixXuf::Zero(model.hyperParams.dataDimension, 1);
variance = MatrixXuf::Zero(model.hyperParams.dataDimension, 1);
data.loadDataFromFile(model.hyperParams.dataformatType, dataDir + "/train.txt", dataDir + "/test.txt");
data.loadDataFromFile(model.hyperParams.dataformatType, dataDir + "/train.txt", dataDir + "/test.txt", "");
finalizeData();
}
@ -62,7 +62,7 @@ BonsaiTrainer::BonsaiTrainer(
data(dataIngestType,
DataFormatParams{
model.hyperParams.ntrain,
model.hyperParams.ntest,
model.hyperParams.nvalidation,
model.hyperParams.numClasses,
model.hyperParams.dataDimension })
{
@ -81,7 +81,7 @@ BonsaiTrainer::BonsaiTrainer(
mean = MatrixXuf::Zero(model.hyperParams.dataDimension, 1);
variance = MatrixXuf::Zero(model.hyperParams.dataDimension, 1);
data.loadDataFromFile(model.hyperParams.dataformatType, dataDir + "/train.txt", dataDir + "/test.txt");
data.loadDataFromFile(model.hyperParams.dataformatType, dataDir + "/train.txt", dataDir + "/test.txt", "");
finalizeData();
initializeModel();
@ -96,7 +96,7 @@ BonsaiTrainer::BonsaiTrainer(
data(dataIngestType,
DataFormatParams{
model.hyperParams.ntrain,
model.hyperParams.ntest,
model.hyperParams.nvalidation,
model.hyperParams.numClasses,
model.hyperParams.dataDimension })
{
@ -155,12 +155,11 @@ void BonsaiTrainer::finalizeData()
// This condition means that the ingest type is Interface ingest,
// hence the number of training points was not known beforehand.
model.hyperParams.ntrain = data.Xtrain.cols();
assert(data.Xtest.cols() == 0);
model.hyperParams.ntest = 0;
assert(data.Xvalidation.cols() == 0);
model.hyperParams.nvalidation = 0;
}
else {
assert(model.hyperParams.ntrain == data.Xtrain.cols());
// assert(model.hyperParams.ntest == data.Xtest.cols());
}
// Following asserts can only be made in finalieData since TLC
@ -340,11 +339,15 @@ void BonsaiTrainer::exportMeanVar(
void BonsaiTrainer::normalize()
{
if (model.hyperParams.normalizationType == minMax) {
minMaxNormalize(data.Xtrain, data.Xtest);
computeMinMax(data.Xtrain, data.min, data.max);
minMaxNormalize(data.Xtrain, data.min, data.max);
if (data.Xvalidation.cols() > 0)
minMaxNormalize(data.Xvalidation, data.min, data.max);
}
else if (model.hyperParams.normalizationType == l2) {
l2Normalize(data.Xtrain);
l2Normalize(data.Xtest);
if (data.Xvalidation.cols() > 0)
l2Normalize(data.Xvalidation);
}
else;
}