This commit is contained in:
Brett Cannon 2020-03-11 13:30:51 -07:00 коммит произвёл GitHub
Родитель 53f7005e1c
Коммит 0edfb2845d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 5 добавлений и 7 удалений

Просмотреть файл

@ -56,15 +56,13 @@ with tf.Session() as sess:
y_train = y_train[indices]
# batch index
b_start = 0
b_end = b_start + batch_size
for _ in range(training_set_size // batch_size):
b_end = batch_size
for b_start in range(0, training_set_size, batch_size):
# get a batch
X_batch, y_batch = X_train[b_start: b_end], y_train[b_start: b_end]
# update batch index for the next batch
b_start = b_start + batch_size
b_end = min(b_start + batch_size, training_set_size)
b_end = min(b_start + (batch_size * 2), training_set_size)
# train
sess.run(train_op, feed_dict = {X: X_batch, y: y_batch})
@ -75,7 +73,7 @@ with tf.Session() as sess:
acc_val = acc_op.eval(feed_dict = {X: X_test, y: y_test})
# Log accuracies to AML logger if using AML
if run != None:
if run is not None:
run.log('Validation Accuracy', np.float(acc_val))
run.log('Training Accuracy', np.float(acc_train))
@ -84,4 +82,4 @@ with tf.Session() as sess:
y_hat = np.argmax(output.eval(feed_dict = {X: X_test}), axis = 1)
os.makedirs('./outputs/model', exist_ok = True)
saver.save(sess, './outputs/model/mnist-tf.model')
saver.save(sess, './outputs/model/mnist-tf.model')