зеркало из https://github.com/microsoft/EdgeML.git
Merge branch 'ProtoNN-dev' of https://github.com/Microsoft/EdgeML into ProtoNN-dev
This commit is contained in:
Коммит
29e1ea41b8
|
@ -38,7 +38,7 @@ int main(int argc, char **argv)
|
|||
BonsaiPredictor predictor(modelBytes, model); // use the constructor predictor(modelBytes, model, false) for loading a sparse model.
|
||||
predictor.importMeanVar(meanVarBytes, meanVar);
|
||||
|
||||
predictor.batchEvaluate(trainer.data.Xtest, trainer.data.Ytest, dataDir, currResultsPath);
|
||||
predictor.batchEvaluate(trainer.data.Xvalidation, trainer.data.Yvalidation, dataDir, currResultsPath);
|
||||
|
||||
delete[] model, meanVar;
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ namespace EdgeML
|
|||
|
||||
int seed;
|
||||
int iters, epochs;
|
||||
dataCount_t ntrain, ntest, batchSize;
|
||||
dataCount_t ntrain, nvalidation, batchSize;
|
||||
bool isOneIndex;
|
||||
|
||||
FP_TYPE Sigma; ///< Sigmoid parameter for prediction
|
||||
|
|
|
@ -957,7 +957,7 @@ void Bonsai::parseInput(const int& argc, const char** argv,
|
|||
required++;
|
||||
break;
|
||||
case 'E':
|
||||
hyperParam.ntest = int(atoi(argv[i]));
|
||||
hyperParam.nvalidation = int(atoi(argv[i]));
|
||||
required++;
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ BonsaiModel::BonsaiHyperParams::BonsaiHyperParams()
|
|||
seed = 42;
|
||||
|
||||
ntrain = 0;
|
||||
ntest = 0;
|
||||
nvalidation = 0;
|
||||
batchSize = 0;
|
||||
|
||||
iters = 0;
|
||||
|
@ -102,7 +102,6 @@ void BonsaiModel::BonsaiHyperParams::finalizeHyperParams()
|
|||
// Following asserts removed to faciliate support for TLC
|
||||
// which does not know how many datapoints are going to be fed before-hand!
|
||||
// assert(ntrain >= 1);
|
||||
// assert(ntest >= 0);
|
||||
assert(projectionDimension <= dataDimension + 1);
|
||||
assert(numClasses > 0);
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ int main()
|
|||
|
||||
hyperParam.numClasses = 10;
|
||||
|
||||
hyperParam.ntest = 0;
|
||||
hyperParam.nvalidation = 0;
|
||||
hyperParam.ntrain = 5000;
|
||||
|
||||
hyperParam.Sigma = 1.0;
|
||||
|
|
|
@ -29,7 +29,7 @@ BonsaiTrainer::BonsaiTrainer(
|
|||
data(dataIngestType,
|
||||
DataFormatParams{
|
||||
model.hyperParams.ntrain,
|
||||
model.hyperParams.ntest,
|
||||
model.hyperParams.nvalidation,
|
||||
model.hyperParams.numClasses,
|
||||
model.hyperParams.dataDimension })
|
||||
{
|
||||
|
@ -47,7 +47,7 @@ BonsaiTrainer::BonsaiTrainer(
|
|||
mean = MatrixXuf::Zero(model.hyperParams.dataDimension, 1);
|
||||
variance = MatrixXuf::Zero(model.hyperParams.dataDimension, 1);
|
||||
|
||||
data.loadDataFromFile(model.hyperParams.dataformatType, dataDir + "/train.txt", dataDir + "/test.txt");
|
||||
data.loadDataFromFile(model.hyperParams.dataformatType, dataDir + "/train.txt", dataDir + "/test.txt", "");
|
||||
finalizeData();
|
||||
}
|
||||
|
||||
|
@ -62,7 +62,7 @@ BonsaiTrainer::BonsaiTrainer(
|
|||
data(dataIngestType,
|
||||
DataFormatParams{
|
||||
model.hyperParams.ntrain,
|
||||
model.hyperParams.ntest,
|
||||
model.hyperParams.nvalidation,
|
||||
model.hyperParams.numClasses,
|
||||
model.hyperParams.dataDimension })
|
||||
{
|
||||
|
@ -81,7 +81,7 @@ BonsaiTrainer::BonsaiTrainer(
|
|||
mean = MatrixXuf::Zero(model.hyperParams.dataDimension, 1);
|
||||
variance = MatrixXuf::Zero(model.hyperParams.dataDimension, 1);
|
||||
|
||||
data.loadDataFromFile(model.hyperParams.dataformatType, dataDir + "/train.txt", dataDir + "/test.txt");
|
||||
data.loadDataFromFile(model.hyperParams.dataformatType, dataDir + "/train.txt", dataDir + "/test.txt", "");
|
||||
finalizeData();
|
||||
|
||||
initializeModel();
|
||||
|
@ -96,7 +96,7 @@ BonsaiTrainer::BonsaiTrainer(
|
|||
data(dataIngestType,
|
||||
DataFormatParams{
|
||||
model.hyperParams.ntrain,
|
||||
model.hyperParams.ntest,
|
||||
model.hyperParams.nvalidation,
|
||||
model.hyperParams.numClasses,
|
||||
model.hyperParams.dataDimension })
|
||||
{
|
||||
|
@ -155,12 +155,11 @@ void BonsaiTrainer::finalizeData()
|
|||
// This condition means that the ingest type is Interface ingest,
|
||||
// hence the number of training points was not known beforehand.
|
||||
model.hyperParams.ntrain = data.Xtrain.cols();
|
||||
assert(data.Xtest.cols() == 0);
|
||||
model.hyperParams.ntest = 0;
|
||||
assert(data.Xvalidation.cols() == 0);
|
||||
model.hyperParams.nvalidation = 0;
|
||||
}
|
||||
else {
|
||||
assert(model.hyperParams.ntrain == data.Xtrain.cols());
|
||||
// assert(model.hyperParams.ntest == data.Xtest.cols());
|
||||
}
|
||||
|
||||
// Following asserts can only be made in finalieData since TLC
|
||||
|
@ -340,11 +339,15 @@ void BonsaiTrainer::exportMeanVar(
|
|||
void BonsaiTrainer::normalize()
|
||||
{
|
||||
if (model.hyperParams.normalizationType == minMax) {
|
||||
minMaxNormalize(data.Xtrain, data.Xtest);
|
||||
computeMinMax(data.Xtrain, data.min, data.max);
|
||||
minMaxNormalize(data.Xtrain, data.min, data.max);
|
||||
if (data.Xvalidation.cols() > 0)
|
||||
minMaxNormalize(data.Xvalidation, data.min, data.max);
|
||||
}
|
||||
else if (model.hyperParams.normalizationType == l2) {
|
||||
l2Normalize(data.Xtrain);
|
||||
l2Normalize(data.Xtest);
|
||||
if (data.Xvalidation.cols() > 0)
|
||||
l2Normalize(data.Xvalidation);
|
||||
}
|
||||
else;
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче