This commit is contained in:
volltin 2021-12-28 18:14:59 +08:00 коммит произвёл GitHub
Родитель 49286ac809
Коммит 774be627b6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 37 добавлений и 23 удалений

Просмотреть файл

@ -4,7 +4,7 @@ import ogb
import numpy as np
import torch
from torch.nn import functional as F
from fairseq.data import FairseqDataset
from fairseq.data import data_utils, FairseqDataset, BaseWrapperDataset
from .wrapper import MyPygGraphPropPredDataset
from .collator import collator
@ -92,3 +92,22 @@ class GraphormerDataset:
self.dataset_train = self.dataset.train_data
self.dataset_val = self.dataset.valid_data
self.dataset_test = self.dataset.test_data
class EpochShuffleDataset(BaseWrapperDataset):
def __init__(self, dataset, num_samples, seed):
super().__init__(dataset)
self.num_samples = num_samples
self.seed = seed
self.set_epoch(1)
def set_epoch(self, epoch):
with data_utils.numpy_seed(self.seed + epoch - 1):
self.sort_order = np.random.permutation(self.num_samples)
def ordered_indices(self):
return self.sort_order
@property
def can_reuse_epoch_itr_across_epochs(self):
return False

Просмотреть файл

@ -21,7 +21,12 @@ from fairseq.tasks import FairseqDataclass, FairseqTask, register_task
from graphormer.pretrain import load_pretrained_model
from ..data.dataset import BatchedDataDataset, TargetDataset, GraphormerDataset
from ..data.dataset import (
BatchedDataDataset,
TargetDataset,
GraphormerDataset,
EpochShuffleDataset,
)
import torch
from fairseq.optim.amp_optimizer import AMPOptimizer
@ -109,6 +114,11 @@ class GraphPredictionConfig(FairseqDataclass):
metadata={"help": "whether to load the output layer of pretrained model"},
)
train_epoch_shuffle: bool = field(
default=False,
metadata={"help": "whether to shuffle the dataset at each epoch"},
)
@register_task("graph_prediction", dataclass=GraphPredictionConfig)
class GraphPredictionTask(FairseqTask):
@ -161,6 +171,11 @@ class GraphPredictionTask(FairseqTask):
sizes=data_sizes,
)
if split == "train" and self.cfg.train_epoch_shuffle:
dataset = EpochShuffleDataset(
dataset, size=len(dataset), seed=self.cfg.seed
)
logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))
self.datasets[split] = dataset

Просмотреть файл

@ -20,27 +20,7 @@ from fairseq.data import (
)
from fairseq.tasks import FairseqTask, register_task
class EpochShuffleDataset(BaseWrapperDataset):
def __init__(self, dataset, num_samples, seed):
super().__init__(dataset)
self.num_samples = num_samples
self.seed = seed
self.set_epoch(1)
def set_epoch(self, epoch):
with data_utils.numpy_seed(self.seed + epoch - 1):
self.sort_order = np.random.permutation(
self.num_samples
) # random ordered_indices will break fairseq bucket by size batch iter, but we just want to reproduce...
def ordered_indices(self):
return self.sort_order
@property
def can_reuse_epoch_itr_across_epochs(self):
return False
from ..data.dataset import EpochShuffleDataset
class LMDBDataset:
def __init__(self, db_path):