Merge pull request #341 from luisquintanilla/make-runnable-on-azure

Make mnist-vscode-docs-sample run on Azure by default
This commit is contained in:
Luis Quintanilla 2020-01-08 12:21:01 -05:00 коммит произвёл GitHub
Родитель 87433279a8 b8ce49928d
Коммит 53f7005e1c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 34 добавлений и 1 удалений

4
.gitignore поставляемый
Просмотреть файл

@ -286,3 +286,7 @@ __pycache__/
*.btm.cs *.btm.cs
*.odx.cs *.odx.cs
*.xsd.cs *.xsd.cs
# AML
**/aml_config/*
**/azureml_outputs/*

Просмотреть файл

@ -0,0 +1,12 @@
import azureml
from azureml.core import Run
# Access the Azure ML run
# Init run param to check if running within AML
def get_AMLRun():
try:
run = Run.get_submitted_run()
return run
except Exception as e:
print("Caught = {}".format(e.message))
return None

Просмотреть файл

@ -0,0 +1,9 @@
name: vs-code-azure-ml-tutorial
channels:
- defaults
dependencies:
- python=3.6.2
- tensorflow=1.15.0
- pip
- pip:
- azureml-defaults

Просмотреть файл

@ -10,7 +10,7 @@ def init():
global X, output, sess global X, output, sess
tf.reset_default_graph() tf.reset_default_graph()
# retreive the local path to the model using the model name # retreive the local path to the model using the model name
model_root = Model.get_model_path('mnist_tf_model') model_root = Model.get_model_path('MNIST-TensorFlow-model')
saver = tf.train.import_meta_graph(os.path.join(model_root, 'mnist-tf.model.meta')) saver = tf.train.import_meta_graph(os.path.join(model_root, 'mnist-tf.model.meta'))
X = tf.get_default_graph().get_tensor_by_name("network/X:0") X = tf.get_default_graph().get_tensor_by_name("network/X:0")
output = tf.get_default_graph().get_tensor_by_name("network/output/MatMul:0") output = tf.get_default_graph().get_tensor_by_name("network/output/MatMul:0")

Просмотреть файл

@ -7,6 +7,7 @@ import sys
import gzip import gzip
import struct import struct
from utils import prepare_data from utils import prepare_data
from amlrun import get_AMLRun
# ## Download MNIST dataset # ## Download MNIST dataset
# In order to train on the MNIST dataset we will first need to download # In order to train on the MNIST dataset we will first need to download
@ -43,6 +44,7 @@ with tf.name_scope('eval'):
init = tf.global_variables_initializer() init = tf.global_variables_initializer()
saver = tf.train.Saver() saver = tf.train.Saver()
run = get_AMLRun()
with tf.Session() as sess: with tf.Session() as sess:
init.run() init.run()
@ -66,11 +68,17 @@ with tf.Session() as sess:
# train # train
sess.run(train_op, feed_dict = {X: X_batch, y: y_batch}) sess.run(train_op, feed_dict = {X: X_batch, y: y_batch})
# evaluate training set # evaluate training set
acc_train = acc_op.eval(feed_dict = {X: X_batch, y: y_batch}) acc_train = acc_op.eval(feed_dict = {X: X_batch, y: y_batch})
# evaluate validation set # evaluate validation set
acc_val = acc_op.eval(feed_dict = {X: X_test, y: y_test}) acc_val = acc_op.eval(feed_dict = {X: X_test, y: y_test})
# Log accuracies to AML logger if using AML
if run != None:
run.log('Validation Accuracy', np.float(acc_val))
run.log('Training Accuracy', np.float(acc_train))
# print out training and validation accuracy # print out training and validation accuracy
print(epoch, '-- Training accuracy:', acc_train, '\b Validation accuracy:', acc_val) print(epoch, '-- Training accuracy:', acc_train, '\b Validation accuracy:', acc_val)
y_hat = np.argmax(output.eval(feed_dict = {X: X_test}), axis = 1) y_hat = np.argmax(output.eval(feed_dict = {X: X_test}), axis = 1)