add config generation for parallel sgd

This commit is contained in:
jeanfad 2016-05-10 19:21:30 +02:00
Родитель 8582774421
Коммит 744db6c37f
1 изменённых файлов: 35 добавлений и 21 удалений

Просмотреть файл

@ -265,10 +265,10 @@ class SGDParams:
self.parallel_training_ = None
def _set_global_parallel_params(self,
parallalization_method = 'none',
parallelization_start_epoch = 0,
distributed_mb_reading = False,
sync_perf_stats = 0):
parallalization_method = None,
parallelization_start_epoch = None,
distributed_mb_reading = None,
sync_perf_stats = None):
self.parallel_training = {
'parallelizationMethod':parallalization_method,
'parallelizationStartEpoch':parallelization_start_epoch,
@ -276,12 +276,12 @@ class SGDParams:
'syncPerfStats':sync_perf_stats}
def set_parallel_to_data_parallel(self,
parallelization_start_epoch = 0,
distributed_mb_reading = False,
sync_perf_stats = 0,
gradient_bits = 8,
use_zero_threshold_for_1bit = True,
use_buffered_async_gradient_aggregation = False):
parallelization_start_epoch = None,
distributed_mb_reading = None,
sync_perf_stats = None,
gradient_bits = None,
use_zero_threshold_for_1bit = None,
use_buffered_async_gradient_aggregation = None):
self._set_global_parallel_params('DataParallelSGD',
parallelization_start_epoch,
@ -294,10 +294,10 @@ class SGDParams:
'useBufferedAsyncGradientAggregation':use_buffered_async_gradient_aggregation}
def set_parallel_to_model_average(self,
parallelization_start_epoch = 0,
distributed_mb_reading = False,
sync_perf_stats = 0,
sync_period = 40000,
parallelization_start_epoch = None,
distributed_mb_reading = None,
sync_perf_stats = None,
sync_period = None,
sync_frequency_in_frames = None):
self._set_global_parallel_params('ModelAveragingSGD',
@ -310,13 +310,13 @@ class SGDParams:
'syncFrequencyInFrames':sync_frequency_in_frames}
def set_parallel_to_block_momentum(self,
parallelization_start_epoch = 0,
distributed_mb_reading = False,
sync_perf_stats = 0,
sync_period = 120000,
reset_sgd_momentum = True,
use_nesterov_momentum = True,
block_learning_rate = 1.0,
parallelization_start_epoch = None,
distributed_mb_reading = None,
sync_perf_stats = None,
sync_period = None,
reset_sgd_momentum = None,
use_nesterov_momentum = None,
block_learning_rate = None,
block_momentum_per_sync = None,
block_momentum_as_time_constant = None):
@ -333,6 +333,20 @@ class SGDParams:
'blockMomentumPerSync':block_momentum_per_sync,
'blockMomentumAsTimeConstant':block_momentum_as_time_constant}
def _generate_parallel_training_config(self):
config = ['ParallelTrain=[']
for k,v in self.parallel_training.items():
if v:
config.append('\t{0} = {1}'.format(k, v))
config.append('\t{0} = ['.format(self.parallel_training['parallelizationMethod']))
for k,v in self.parallel_training_subblock.items():
if v:
config.append('\t\t{0} = {1}'.format(k, v))
config.append['\t]']
config.append[']']
def _to_config_description(self):
"""Generate the SGDParams configuration block
"""