add config generation for parallel sgd
This commit is contained in:
Родитель
8582774421
Коммит
744db6c37f
|
@ -265,10 +265,10 @@ class SGDParams:
|
|||
self.parallel_training_ = None
|
||||
|
||||
def _set_global_parallel_params(self,
|
||||
parallalization_method = 'none',
|
||||
parallelization_start_epoch = 0,
|
||||
distributed_mb_reading = False,
|
||||
sync_perf_stats = 0):
|
||||
parallalization_method = None,
|
||||
parallelization_start_epoch = None,
|
||||
distributed_mb_reading = None,
|
||||
sync_perf_stats = None):
|
||||
self.parallel_training = {
|
||||
'parallelizationMethod':parallalization_method,
|
||||
'parallelizationStartEpoch':parallelization_start_epoch,
|
||||
|
@ -276,12 +276,12 @@ class SGDParams:
|
|||
'syncPerfStats':sync_perf_stats}
|
||||
|
||||
def set_parallel_to_data_parallel(self,
|
||||
parallelization_start_epoch = 0,
|
||||
distributed_mb_reading = False,
|
||||
sync_perf_stats = 0,
|
||||
gradient_bits = 8,
|
||||
use_zero_threshold_for_1bit = True,
|
||||
use_buffered_async_gradient_aggregation = False):
|
||||
parallelization_start_epoch = None,
|
||||
distributed_mb_reading = None,
|
||||
sync_perf_stats = None,
|
||||
gradient_bits = None,
|
||||
use_zero_threshold_for_1bit = None,
|
||||
use_buffered_async_gradient_aggregation = None):
|
||||
|
||||
self._set_global_parallel_params('DataParallelSGD',
|
||||
parallelization_start_epoch,
|
||||
|
@ -294,10 +294,10 @@ class SGDParams:
|
|||
'useBufferedAsyncGradientAggregation':use_buffered_async_gradient_aggregation}
|
||||
|
||||
def set_parallel_to_model_average(self,
|
||||
parallelization_start_epoch = 0,
|
||||
distributed_mb_reading = False,
|
||||
sync_perf_stats = 0,
|
||||
sync_period = 40000,
|
||||
parallelization_start_epoch = None,
|
||||
distributed_mb_reading = None,
|
||||
sync_perf_stats = None,
|
||||
sync_period = None,
|
||||
sync_frequency_in_frames = None):
|
||||
|
||||
self._set_global_parallel_params('ModelAveragingSGD',
|
||||
|
@ -310,13 +310,13 @@ class SGDParams:
|
|||
'syncFrequencyInFrames':sync_frequency_in_frames}
|
||||
|
||||
def set_parallel_to_block_momentum(self,
|
||||
parallelization_start_epoch = 0,
|
||||
distributed_mb_reading = False,
|
||||
sync_perf_stats = 0,
|
||||
sync_period = 120000,
|
||||
reset_sgd_momentum = True,
|
||||
use_nesterov_momentum = True,
|
||||
block_learning_rate = 1.0,
|
||||
parallelization_start_epoch = None,
|
||||
distributed_mb_reading = None,
|
||||
sync_perf_stats = None,
|
||||
sync_period = None,
|
||||
reset_sgd_momentum = None,
|
||||
use_nesterov_momentum = None,
|
||||
block_learning_rate = None,
|
||||
block_momentum_per_sync = None,
|
||||
block_momentum_as_time_constant = None):
|
||||
|
||||
|
@ -333,6 +333,20 @@ class SGDParams:
|
|||
'blockMomentumPerSync':block_momentum_per_sync,
|
||||
'blockMomentumAsTimeConstant':block_momentum_as_time_constant}
|
||||
|
||||
|
||||
def _generate_parallel_training_config(self):
|
||||
config = ['ParallelTrain=[']
|
||||
for k,v in self.parallel_training.items():
|
||||
if v:
|
||||
config.append('\t{0} = {1}'.format(k, v))
|
||||
|
||||
config.append('\t{0} = ['.format(self.parallel_training['parallelizationMethod']))
|
||||
for k,v in self.parallel_training_subblock.items():
|
||||
if v:
|
||||
config.append('\t\t{0} = {1}'.format(k, v))
|
||||
config.append['\t]']
|
||||
config.append[']']
|
||||
|
||||
def _to_config_description(self):
|
||||
"""Generate the SGDParams configuration block
|
||||
"""
|
||||
|
|
Загрузка…
Ссылка в новой задаче