Added code to run on Azure by default
This commit is contained in:
Родитель
87433279a8
Коммит
a2d6e448af
|
@ -0,0 +1,12 @@
|
|||
import azureml
|
||||
from azureml.core import Run
|
||||
|
||||
# Access the Azure ML run
|
||||
# Init run param to check if running within AML
|
||||
def get_AMLRun():
|
||||
try:
|
||||
run = Run.get_submitted_run()
|
||||
return run
|
||||
except Exception as e:
|
||||
print("Caught = {}".format(e.message))
|
||||
return None
|
|
@ -7,6 +7,7 @@ import sys
|
|||
import gzip
|
||||
import struct
|
||||
from utils import prepare_data
|
||||
from amlrun import get_AMLRun
|
||||
|
||||
# ## Download MNIST dataset
|
||||
# In order to train on the MNIST dataset we will first need to download
|
||||
|
@ -43,6 +44,7 @@ with tf.name_scope('eval'):
|
|||
|
||||
init = tf.global_variables_initializer()
|
||||
saver = tf.train.Saver()
|
||||
run = get_AMLRun()
|
||||
|
||||
with tf.Session() as sess:
|
||||
init.run()
|
||||
|
@ -66,11 +68,17 @@ with tf.Session() as sess:
|
|||
|
||||
# train
|
||||
sess.run(train_op, feed_dict = {X: X_batch, y: y_batch})
|
||||
|
||||
# evaluate training set
|
||||
acc_train = acc_op.eval(feed_dict = {X: X_batch, y: y_batch})
|
||||
# evaluate validation set
|
||||
acc_val = acc_op.eval(feed_dict = {X: X_test, y: y_test})
|
||||
|
||||
# Log accuracies to AML logger if using AML
|
||||
if run != None:
|
||||
run.log('Validation Accuracy', np.float(acc_val))
|
||||
run.log('Training Accuracy', np.float(acc_train))
|
||||
|
||||
# print out training and validation accuracy
|
||||
print(epoch, '-- Training accuracy:', acc_train, '\b Validation accuracy:', acc_val)
|
||||
y_hat = np.argmax(output.eval(feed_dict = {X: X_test}), axis = 1)
|
||||
|
|
Загрузка…
Ссылка в новой задаче