Added code to run on Azure by default

This commit is contained in:
luisquintanilla 2020-01-06 11:40:27 -05:00
Родитель 87433279a8
Коммит a2d6e448af
2 изменённых файлов: 20 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,12 @@
import azureml
from azureml.core import Run
# Access the Azure ML run
# Init run param to check if running within AML
def get_AMLRun():
try:
run = Run.get_submitted_run()
return run
except Exception as e:
print("Caught = {}".format(e.message))
return None

Просмотреть файл

@ -7,6 +7,7 @@ import sys
import gzip
import struct
from utils import prepare_data
from amlrun import get_AMLRun
# ## Download MNIST dataset
# In order to train on the MNIST dataset we will first need to download
@ -43,6 +44,7 @@ with tf.name_scope('eval'):
init = tf.global_variables_initializer()
saver = tf.train.Saver()
run = get_AMLRun()
with tf.Session() as sess:
init.run()
@ -66,11 +68,17 @@ with tf.Session() as sess:
# train
sess.run(train_op, feed_dict = {X: X_batch, y: y_batch})
# evaluate training set
acc_train = acc_op.eval(feed_dict = {X: X_batch, y: y_batch})
# evaluate validation set
acc_val = acc_op.eval(feed_dict = {X: X_test, y: y_test})
# Log accuracies to AML logger if using AML
if run != None:
run.log('Validation Accuracy', np.float(acc_val))
run.log('Training Accuracy', np.float(acc_train))
# print out training and validation accuracy
print(epoch, '-- Training accuracy:', acc_train, '\b Validation accuracy:', acc_val)
y_hat = np.argmax(output.eval(feed_dict = {X: X_test}), axis = 1)