зеркало из https://github.com/microsoft/EdgeML.git
Moving example to examples/ directory and renaming driver filese
This commit is contained in:
Родитель
f0e1a12061
Коммит
bd24e4734d
|
@ -328,19 +328,24 @@ class BonsaiTrainer:
|
|||
self.bonsaiObj.sigmaI = oldSigmaI
|
||||
sys.stdout.flush()
|
||||
|
||||
# sigmaI has to be set to infinity to ensure only a single path is used in inference
|
||||
# sigmaI has to be set to infinity to ensure
|
||||
# only a single path is used in inference
|
||||
self.bonsaiObj.sigmaI = 1e9
|
||||
print("Maximum Test accuracy at compressed model size(including early stopping): " +
|
||||
print("Maximum Test accuracy at compressed" +
|
||||
" model size(including early stopping): " +
|
||||
str(maxTestAcc) + " at Epoch: " +
|
||||
str(maxTestAccEpoch + 1) + "\nFinal Test Accuracy: " + str(testAcc))
|
||||
str(maxTestAccEpoch + 1) + "\nFinal Test" +
|
||||
" Accuracy: " + str(testAcc))
|
||||
print("\nNon-Zeros: " + str(self.getModelSize()[1]) + " Model Size: " +
|
||||
str(float(self.getModelSize()[1]) / 1024.0) + " KB hasSparse: " +
|
||||
str(self.getModelSize()[2]) + "\n")
|
||||
|
||||
resultFile.write("MaxTestAcc: " + str(maxTestAcc) +
|
||||
" at Epoch(totalEpochs): " + str(maxTestAccEpoch + 1) +
|
||||
" at Epoch(totalEpochs): " +
|
||||
str(maxTestAccEpoch + 1) +
|
||||
"(" + str(totalEpochs) + ")" + " ModelSize: " +
|
||||
str(float(self.getModelSize()[1]) / 1024.0) +
|
||||
" KB hasSparse: " + str(self.getModelSize()[2]) +
|
||||
" Param Directory: " + str(os.path.abspath(currDir)) + "\n")
|
||||
" Param Directory: " +
|
||||
str(os.path.abspath(currDir)) + "\n")
|
||||
resultFile.close()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import bonsaiPreProcess
|
||||
import bonsaipreprocess
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from bonsaiTrainer import BonsaiTrainer
|
||||
|
@ -10,7 +10,7 @@ tf.set_random_seed(42)
|
|||
np.random.seed(42)
|
||||
|
||||
# Hyper Param pre-processing
|
||||
args = bonsaiPreProcess.getArgs()
|
||||
args = bonsaipreprocess.getArgs()
|
||||
|
||||
sigma = args.sigma
|
||||
depth = args.depth
|
||||
|
@ -28,7 +28,7 @@ learningRate = args.learningRate
|
|||
data_dir = args.data_dir
|
||||
|
||||
(dataDimension, numClasses,
|
||||
Xtrain, Ytrain, Xtest, Ytest) = bonsaiPreProcess.preProcessData(data_dir)
|
||||
Xtrain, Ytrain, Xtest, Ytest) = bonsaipreprocess.preProcessData(data_dir)
|
||||
|
||||
sparZ = args.sZ
|
||||
|
||||
|
@ -61,7 +61,7 @@ if numClasses == 2:
|
|||
X = tf.placeholder("float32", [None, dataDimension])
|
||||
Y = tf.placeholder("float32", [None, numClasses])
|
||||
|
||||
currDir = bonsaiPreProcess.createDir(data_dir)
|
||||
currDir = bonsaipreprocess.createDir(data_dir)
|
||||
|
||||
# numClasses = 1 for binary case
|
||||
bonsaiObj = Bonsai(numClasses, dataDimension,
|
|
@ -1,10 +1,12 @@
|
|||
'''
|
||||
Functions to check sanity of input arguments
|
||||
for the example script.
|
||||
'''
|
||||
import argparse
|
||||
import numpy as np
|
||||
import datetime
|
||||
import os
|
||||
|
||||
# Functions to check sanity of input arguments
|
||||
|
||||
|
||||
def checkIntPos(value):
|
||||
ivalue = int(value)
|
||||
|
@ -45,7 +47,8 @@ def getArgs():
|
|||
parser = argparse.ArgumentParser(
|
||||
description='HyperParams for Bonsai Algorithm')
|
||||
parser.add_argument('-dir', '--data_dir', required=True,
|
||||
help='Data directory containing train.npy and test.npy')
|
||||
help='Data directory containing' +
|
||||
'train.npy and test.npy')
|
||||
|
||||
parser.add_argument('-d', '--depth', type=checkIntNneg, default=2,
|
||||
help='Depth of Bonsai Tree (default: 2 try: [0, 1, 3])')
|
Загрузка…
Ссылка в новой задаче