Fix seed and split of customized datasets (#76)

This commit is contained in:
shiyu1994 2021-12-30 11:05:25 +08:00 коммит произвёл GitHub
Родитель 46b38748c5
Коммит 3de5d51cd5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 7 добавлений и 10 удалений

Просмотреть файл

@ -73,15 +73,17 @@ class GraphormerDataset:
super().__init__()
if dataset is not None:
if dataset_source == "dgl":
self.dataset = GraphormerDGLDataset(dataset, train_idx, valid_idx, test_idx)
self.dataset = GraphormerDGLDataset(dataset, seed=seed, train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx)
elif dataset_source == "pyg":
self.dataset = GraphormerPYGDataset(dataset, train_idx, valid_idx, test_idx)
self.dataset = GraphormerPYGDataset(dataset, train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx)
else:
raise ValueError("customized dataset can only have source pyg or dgl")
elif dataset_source == "dgl":
self.dataset = DGLDatasetLookupTable.GetDGLDataset(dataset_spec, seed)
self.dataset = DGLDatasetLookupTable.GetDGLDataset(dataset_spec, seed=seed)
elif dataset_source == "pyg":
self.dataset = PYGDatasetLookupTable.GetPYGDataset(dataset_spec, seed)
self.dataset = PYGDatasetLookupTable.GetPYGDataset(dataset_spec, seed=seed)
elif dataset_source == "ogb":
self.dataset = OGBDatasetLookupTable.GetOGBDataset(dataset_spec, seed)
self.dataset = OGBDatasetLookupTable.GetOGBDataset(dataset_spec, seed=seed)
self.setup()
def setup(self):

Просмотреть файл

@ -276,11 +276,6 @@ class GraphPredictionWithFlagTask(GraphPredictionTask):
def __init__(self, cfg):
super().__init__(cfg)
self.dm = GraphormerDataset(
dataset_spec=cfg.dataset_name,
dataset_source=cfg.dataset_source,
seed=cfg.seed,
)
self.flag_m = cfg.flag_m
self.flag_step_size = cfg.flag_step_size
self.flag_mag = cfg.flag_mag