Switch from deprecated tfv1.app to absl-py

This commit is contained in:
Reuben Morais 2019-08-28 10:55:33 +02:00
Родитель 7e96961e35
Коммит 24bcdeb3d6
4 изменённых файлов: 10 добавлений и 7 удалений

Просмотреть файл

@ -8,6 +8,7 @@ import sys
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0 LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
import absl.app
import numpy as np import numpy as np
import progressbar import progressbar
import shutil import shutil
@ -891,4 +892,4 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
create_flags() create_flags()
tfv1.app.run(main) absl.app.run(main)

Просмотреть файл

@ -7,6 +7,7 @@ import json
from multiprocessing import cpu_count from multiprocessing import cpu_count
import absl.app
import numpy as np import numpy as np
import progressbar import progressbar
import tensorflow as tf import tensorflow as tf
@ -167,5 +168,4 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
create_flags() create_flags()
tf.app.flags.DEFINE_string('test_output_file', '', 'path to a file to save all src/decoded/distance/loss tuples') absl.app.run(main)
tfv1.app.run(main)

Просмотреть файл

@ -6,6 +6,7 @@ pandas
six six
pyxdg pyxdg
attrdict attrdict
absl-py
# Requirements for building native_client files # Requirements for building native_client files
setuptools setuptools

Просмотреть файл

@ -1,16 +1,15 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import tensorflow as tf
import os import os
import absl.flags
FLAGS = tf.app.flags.FLAGS FLAGS = absl.flags.FLAGS
def create_flags(): def create_flags():
# Importer # Importer
# ======== # ========
f = tf.app.flags f = absl.flags
f.DEFINE_string('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.') f.DEFINE_string('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.')
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.') f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
@ -89,6 +88,8 @@ def create_flags():
f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification') f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
f.DEFINE_string('test_output_file', '', 'path to a file to save all src/decoded/distance/loss tuples generated during a test epoch')
# Geometry # Geometry
f.DEFINE_integer('n_hidden', 2048, 'layer width to use when initialising layers') f.DEFINE_integer('n_hidden', 2048, 'layer width to use when initialising layers')