From 774be627b66748ba668ad19cf35dbaaf551da157 Mon Sep 17 00:00:00 2001 From: volltin Date: Tue, 28 Dec 2021 18:14:59 +0800 Subject: [PATCH] Add epoch shuffle option (#67) --- graphormer/data/dataset.py | 21 ++++++++++++++++++++- graphormer/tasks/graph_prediction.py | 17 ++++++++++++++++- graphormer/tasks/is2re.py | 22 +--------------------- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/graphormer/data/dataset.py b/graphormer/data/dataset.py index e412f7e..c1b72e5 100644 --- a/graphormer/data/dataset.py +++ b/graphormer/data/dataset.py @@ -4,7 +4,7 @@ import ogb import numpy as np import torch from torch.nn import functional as F -from fairseq.data import FairseqDataset +from fairseq.data import data_utils, FairseqDataset, BaseWrapperDataset from .wrapper import MyPygGraphPropPredDataset from .collator import collator @@ -92,3 +92,22 @@ class GraphormerDataset: self.dataset_train = self.dataset.train_data self.dataset_val = self.dataset.valid_data self.dataset_test = self.dataset.test_data + + +class EpochShuffleDataset(BaseWrapperDataset): + def __init__(self, dataset, num_samples, seed): + super().__init__(dataset) + self.num_samples = num_samples + self.seed = seed + self.set_epoch(1) + + def set_epoch(self, epoch): + with data_utils.numpy_seed(self.seed + epoch - 1): + self.sort_order = np.random.permutation(self.num_samples) + + def ordered_indices(self): + return self.sort_order + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False diff --git a/graphormer/tasks/graph_prediction.py b/graphormer/tasks/graph_prediction.py index a22bb99..a196ef5 100644 --- a/graphormer/tasks/graph_prediction.py +++ b/graphormer/tasks/graph_prediction.py @@ -21,7 +21,12 @@ from fairseq.tasks import FairseqDataclass, FairseqTask, register_task from graphormer.pretrain import load_pretrained_model -from ..data.dataset import BatchedDataDataset, TargetDataset, GraphormerDataset +from ..data.dataset import ( + BatchedDataDataset, + TargetDataset, + GraphormerDataset, + EpochShuffleDataset, +) import torch from fairseq.optim.amp_optimizer import AMPOptimizer @@ -109,6 +114,11 @@ class GraphPredictionConfig(FairseqDataclass): metadata={"help": "whether to load the output layer of pretrained model"}, ) + train_epoch_shuffle: bool = field( + default=False, + metadata={"help": "whether to shuffle the dataset at each epoch"}, + ) + @register_task("graph_prediction", dataclass=GraphPredictionConfig) class GraphPredictionTask(FairseqTask): @@ -161,6 +171,11 @@ class GraphPredictionTask(FairseqTask): sizes=data_sizes, ) + if split == "train" and self.cfg.train_epoch_shuffle: + dataset = EpochShuffleDataset( + dataset, size=len(dataset), seed=self.cfg.seed + ) + logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset diff --git a/graphormer/tasks/is2re.py b/graphormer/tasks/is2re.py index 77e9aaa..236a974 100644 --- a/graphormer/tasks/is2re.py +++ b/graphormer/tasks/is2re.py @@ -20,27 +20,7 @@ from fairseq.data import ( ) from fairseq.tasks import FairseqTask, register_task - -class EpochShuffleDataset(BaseWrapperDataset): - def __init__(self, dataset, num_samples, seed): - super().__init__(dataset) - self.num_samples = num_samples - self.seed = seed - self.set_epoch(1) - - def set_epoch(self, epoch): - with data_utils.numpy_seed(self.seed + epoch - 1): - self.sort_order = np.random.permutation( - self.num_samples - ) # random ordered_indices will break fairseq bucket by size batch iter, but we just want to reproduce... - - def ordered_indices(self): - return self.sort_order - - @property - def can_reuse_epoch_itr_across_epochs(self): - return False - +from ..data.dataset import EpochShuffleDataset class LMDBDataset: def __init__(self, db_path):