зеркало из https://github.com/microsoft/ptgnn.git
Expose option to shuffle training tensor iterator.
This commit is contained in:
Родитель
128f01c040
Коммит
7d19a30ad4
|
@ -213,7 +213,7 @@ class AbstractNeuralModel(ABC, Generic[TRawDatapoint, TTensorizedDatapoint, TNeu
|
|||
d if return_input_data else None,
|
||||
),
|
||||
dataset_iterator,
|
||||
chunksize=200,
|
||||
chunksize=20,
|
||||
):
|
||||
if tensorized_sample[0] is not None:
|
||||
yield tensorized_sample
|
||||
|
|
|
@ -177,6 +177,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
parallelize,
|
||||
scheduler,
|
||||
show_progress_bar,
|
||||
shuffle_input: bool = True,
|
||||
):
|
||||
sum_epoch_loss, running_avg_loss, num_minibatches, num_samples = 0.0, 0.0, 0, 0
|
||||
start_time = time.time()
|
||||
|
@ -188,7 +189,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
device=device,
|
||||
max_minibatch_size=self.__minibatch_size,
|
||||
yield_partial_minibatches=False,
|
||||
shuffle_input=True,
|
||||
shuffle_input=shuffle_input,
|
||||
parallelize=parallelize,
|
||||
)
|
||||
):
|
||||
|
@ -307,6 +308,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
exponential_running_average_factor: float = 0.97,
|
||||
device=None,
|
||||
store_tensorized_data_in_memory: bool = False,
|
||||
shuffle_training_data: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
The training-validation loop for `AbstractNeuralModel`s.
|
||||
|
@ -324,6 +326,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
displayed in the progress bar.
|
||||
:param device: the target PyTorch device for training
|
||||
:param store_tensorized_data_in_memory: store all tensorized data in memory instead of computing them on-line.
|
||||
:param shuffle_training_data: shuffle the incoming data from `training_data`.
|
||||
"""
|
||||
if initialize_metadata:
|
||||
self.__load_metadata_and_create_network(training_data, parallelize, show_progress_bar)
|
||||
|
@ -384,6 +387,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
parallelize,
|
||||
scheduler,
|
||||
show_progress_bar,
|
||||
shuffle_training_data,
|
||||
)
|
||||
|
||||
target_metric, target_metric_improved = self._run_validation(
|
||||
|
|
Загрузка…
Ссылка в новой задаче