Add epoch shuffle option (#67)
This commit is contained in:
Родитель
49286ac809
Коммит
774be627b6
|
@ -4,7 +4,7 @@ import ogb
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from fairseq.data import FairseqDataset
|
||||
from fairseq.data import data_utils, FairseqDataset, BaseWrapperDataset
|
||||
|
||||
from .wrapper import MyPygGraphPropPredDataset
|
||||
from .collator import collator
|
||||
|
@ -92,3 +92,22 @@ class GraphormerDataset:
|
|||
self.dataset_train = self.dataset.train_data
|
||||
self.dataset_val = self.dataset.valid_data
|
||||
self.dataset_test = self.dataset.test_data
|
||||
|
||||
|
||||
class EpochShuffleDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, num_samples, seed):
|
||||
super().__init__(dataset)
|
||||
self.num_samples = num_samples
|
||||
self.seed = seed
|
||||
self.set_epoch(1)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
with data_utils.numpy_seed(self.seed + epoch - 1):
|
||||
self.sort_order = np.random.permutation(self.num_samples)
|
||||
|
||||
def ordered_indices(self):
|
||||
return self.sort_order
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return False
|
||||
|
|
|
@ -21,7 +21,12 @@ from fairseq.tasks import FairseqDataclass, FairseqTask, register_task
|
|||
|
||||
from graphormer.pretrain import load_pretrained_model
|
||||
|
||||
from ..data.dataset import BatchedDataDataset, TargetDataset, GraphormerDataset
|
||||
from ..data.dataset import (
|
||||
BatchedDataDataset,
|
||||
TargetDataset,
|
||||
GraphormerDataset,
|
||||
EpochShuffleDataset,
|
||||
)
|
||||
|
||||
import torch
|
||||
from fairseq.optim.amp_optimizer import AMPOptimizer
|
||||
|
@ -109,6 +114,11 @@ class GraphPredictionConfig(FairseqDataclass):
|
|||
metadata={"help": "whether to load the output layer of pretrained model"},
|
||||
)
|
||||
|
||||
train_epoch_shuffle: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to shuffle the dataset at each epoch"},
|
||||
)
|
||||
|
||||
|
||||
@register_task("graph_prediction", dataclass=GraphPredictionConfig)
|
||||
class GraphPredictionTask(FairseqTask):
|
||||
|
@ -161,6 +171,11 @@ class GraphPredictionTask(FairseqTask):
|
|||
sizes=data_sizes,
|
||||
)
|
||||
|
||||
if split == "train" and self.cfg.train_epoch_shuffle:
|
||||
dataset = EpochShuffleDataset(
|
||||
dataset, size=len(dataset), seed=self.cfg.seed
|
||||
)
|
||||
|
||||
logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))
|
||||
|
||||
self.datasets[split] = dataset
|
||||
|
|
|
@ -20,27 +20,7 @@ from fairseq.data import (
|
|||
)
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
|
||||
|
||||
class EpochShuffleDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, num_samples, seed):
|
||||
super().__init__(dataset)
|
||||
self.num_samples = num_samples
|
||||
self.seed = seed
|
||||
self.set_epoch(1)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
with data_utils.numpy_seed(self.seed + epoch - 1):
|
||||
self.sort_order = np.random.permutation(
|
||||
self.num_samples
|
||||
) # random ordered_indices will break fairseq bucket by size batch iter, but we just want to reproduce...
|
||||
|
||||
def ordered_indices(self):
|
||||
return self.sort_order
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return False
|
||||
|
||||
from ..data.dataset import EpochShuffleDataset
|
||||
|
||||
class LMDBDataset:
|
||||
def __init__(self, db_path):
|
||||
|
|
Загрузка…
Ссылка в новой задаче