Simplify some code
This commit is contained in:
Родитель
53f7005e1c
Коммит
0edfb2845d
|
@ -56,15 +56,13 @@ with tf.Session() as sess:
|
|||
y_train = y_train[indices]
|
||||
|
||||
# batch index
|
||||
b_start = 0
|
||||
b_end = b_start + batch_size
|
||||
for _ in range(training_set_size // batch_size):
|
||||
b_end = batch_size
|
||||
for b_start in range(0, training_set_size, batch_size):
|
||||
# get a batch
|
||||
X_batch, y_batch = X_train[b_start: b_end], y_train[b_start: b_end]
|
||||
|
||||
# update batch index for the next batch
|
||||
b_start = b_start + batch_size
|
||||
b_end = min(b_start + batch_size, training_set_size)
|
||||
b_end = min(b_start + (batch_size * 2), training_set_size)
|
||||
|
||||
# train
|
||||
sess.run(train_op, feed_dict = {X: X_batch, y: y_batch})
|
||||
|
@ -75,7 +73,7 @@ with tf.Session() as sess:
|
|||
acc_val = acc_op.eval(feed_dict = {X: X_test, y: y_test})
|
||||
|
||||
# Log accuracies to AML logger if using AML
|
||||
if run != None:
|
||||
if run is not None:
|
||||
run.log('Validation Accuracy', np.float(acc_val))
|
||||
run.log('Training Accuracy', np.float(acc_train))
|
||||
|
||||
|
@ -84,4 +82,4 @@ with tf.Session() as sess:
|
|||
y_hat = np.argmax(output.eval(feed_dict = {X: X_test}), axis = 1)
|
||||
|
||||
os.makedirs('./outputs/model', exist_ok = True)
|
||||
saver.save(sess, './outputs/model/mnist-tf.model')
|
||||
saver.save(sess, './outputs/model/mnist-tf.model')
|
||||
|
|
Загрузка…
Ссылка в новой задаче