Expose option to shuffle training tensor iterator.

This commit is contained in:
Miltos Allamanis 2020-09-28 23:33:50 +01:00
Родитель 128f01c040
Коммит 7d19a30ad4
2 изменённых файлов: 6 добавлений и 2 удалений

Просмотреть файл

@ -213,7 +213,7 @@ class AbstractNeuralModel(ABC, Generic[TRawDatapoint, TTensorizedDatapoint, TNeu
d if return_input_data else None,
),
dataset_iterator,
chunksize=200,
chunksize=20,
):
if tensorized_sample[0] is not None:
yield tensorized_sample

Просмотреть файл

@ -177,6 +177,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
parallelize,
scheduler,
show_progress_bar,
shuffle_input: bool = True,
):
sum_epoch_loss, running_avg_loss, num_minibatches, num_samples = 0.0, 0.0, 0, 0
start_time = time.time()
@ -188,7 +189,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
device=device,
max_minibatch_size=self.__minibatch_size,
yield_partial_minibatches=False,
shuffle_input=True,
shuffle_input=shuffle_input,
parallelize=parallelize,
)
):
@ -307,6 +308,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
exponential_running_average_factor: float = 0.97,
device=None,
store_tensorized_data_in_memory: bool = False,
shuffle_training_data: bool = True,
) -> None:
"""
The training-validation loop for `AbstractNeuralModel`s.
@ -324,6 +326,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
displayed in the progress bar.
:param device: the target PyTorch device for training
:param store_tensorized_data_in_memory: store all tensorized data in memory instead of computing them on-line.
:param shuffle_training_data: shuffle the incoming data from `training_data`.
"""
if initialize_metadata:
self.__load_metadata_and_create_network(training_data, parallelize, show_progress_bar)
@ -384,6 +387,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
parallelize,
scheduler,
show_progress_bar,
shuffle_training_data,
)
target_metric, target_metric_improved = self._run_validation(