Merge pull request #341 from luisquintanilla/make-runnable-on-azure
Make mnist-vscode-docs-sample run on Azure by default
This commit is contained in:
Коммит
53f7005e1c
|
@ -286,3 +286,7 @@ __pycache__/
|
|||
*.btm.cs
|
||||
*.odx.cs
|
||||
*.xsd.cs
|
||||
|
||||
# AML
|
||||
**/aml_config/*
|
||||
**/azureml_outputs/*
|
|
@ -0,0 +1,12 @@
|
|||
import azureml
|
||||
from azureml.core import Run
|
||||
|
||||
# Access the Azure ML run
|
||||
# Init run param to check if running within AML
|
||||
def get_AMLRun():
|
||||
try:
|
||||
run = Run.get_submitted_run()
|
||||
return run
|
||||
except Exception as e:
|
||||
print("Caught = {}".format(e.message))
|
||||
return None
|
|
@ -0,0 +1,9 @@
|
|||
name: vs-code-azure-ml-tutorial
|
||||
channels:
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.6.2
|
||||
- tensorflow=1.15.0
|
||||
- pip
|
||||
- pip:
|
||||
- azureml-defaults
|
|
@ -10,7 +10,7 @@ def init():
|
|||
global X, output, sess
|
||||
tf.reset_default_graph()
|
||||
# retreive the local path to the model using the model name
|
||||
model_root = Model.get_model_path('mnist_tf_model')
|
||||
model_root = Model.get_model_path('MNIST-TensorFlow-model')
|
||||
saver = tf.train.import_meta_graph(os.path.join(model_root, 'mnist-tf.model.meta'))
|
||||
X = tf.get_default_graph().get_tensor_by_name("network/X:0")
|
||||
output = tf.get_default_graph().get_tensor_by_name("network/output/MatMul:0")
|
||||
|
|
|
@ -7,6 +7,7 @@ import sys
|
|||
import gzip
|
||||
import struct
|
||||
from utils import prepare_data
|
||||
from amlrun import get_AMLRun
|
||||
|
||||
# ## Download MNIST dataset
|
||||
# In order to train on the MNIST dataset we will first need to download
|
||||
|
@ -43,6 +44,7 @@ with tf.name_scope('eval'):
|
|||
|
||||
init = tf.global_variables_initializer()
|
||||
saver = tf.train.Saver()
|
||||
run = get_AMLRun()
|
||||
|
||||
with tf.Session() as sess:
|
||||
init.run()
|
||||
|
@ -66,11 +68,17 @@ with tf.Session() as sess:
|
|||
|
||||
# train
|
||||
sess.run(train_op, feed_dict = {X: X_batch, y: y_batch})
|
||||
|
||||
# evaluate training set
|
||||
acc_train = acc_op.eval(feed_dict = {X: X_batch, y: y_batch})
|
||||
# evaluate validation set
|
||||
acc_val = acc_op.eval(feed_dict = {X: X_test, y: y_test})
|
||||
|
||||
# Log accuracies to AML logger if using AML
|
||||
if run != None:
|
||||
run.log('Validation Accuracy', np.float(acc_val))
|
||||
run.log('Training Accuracy', np.float(acc_train))
|
||||
|
||||
# print out training and validation accuracy
|
||||
print(epoch, '-- Training accuracy:', acc_train, '\b Validation accuracy:', acc_val)
|
||||
y_hat = np.argmax(output.eval(feed_dict = {X: X_test}), axis = 1)
|
||||
|
|
Загрузка…
Ссылка в новой задаче