From 0edfb2845d037631b445a8c438476fd83b8c64dd Mon Sep 17 00:00:00 2001 From: Brett Cannon <54418+brettcannon@users.noreply.github.com> Date: Wed, 11 Mar 2020 13:30:51 -0700 Subject: [PATCH] Simplify some code --- mnist-vscode-docs-sample/train.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/mnist-vscode-docs-sample/train.py b/mnist-vscode-docs-sample/train.py index cfd8cc0..79f7d49 100644 --- a/mnist-vscode-docs-sample/train.py +++ b/mnist-vscode-docs-sample/train.py @@ -56,15 +56,13 @@ with tf.Session() as sess: y_train = y_train[indices] # batch index - b_start = 0 - b_end = b_start + batch_size - for _ in range(training_set_size // batch_size): + b_end = batch_size + for b_start in range(0, training_set_size, batch_size): # get a batch X_batch, y_batch = X_train[b_start: b_end], y_train[b_start: b_end] # update batch index for the next batch - b_start = b_start + batch_size - b_end = min(b_start + batch_size, training_set_size) + b_end = min(b_start + (batch_size * 2), training_set_size) # train sess.run(train_op, feed_dict = {X: X_batch, y: y_batch}) @@ -75,7 +73,7 @@ with tf.Session() as sess: acc_val = acc_op.eval(feed_dict = {X: X_test, y: y_test}) # Log accuracies to AML logger if using AML - if run != None: + if run is not None: run.log('Validation Accuracy', np.float(acc_val)) run.log('Training Accuracy', np.float(acc_train)) @@ -84,4 +82,4 @@ with tf.Session() as sess: y_hat = np.argmax(output.eval(feed_dict = {X: X_test}), axis = 1) os.makedirs('./outputs/model', exist_ok = True) - saver.save(sess, './outputs/model/mnist-tf.model') \ No newline at end of file + saver.save(sess, './outputs/model/mnist-tf.model')