Fix seed and split of customized datasets (#76)
This commit is contained in:
Родитель
46b38748c5
Коммит
3de5d51cd5
|
@ -73,15 +73,17 @@ class GraphormerDataset:
|
|||
super().__init__()
|
||||
if dataset is not None:
|
||||
if dataset_source == "dgl":
|
||||
self.dataset = GraphormerDGLDataset(dataset, train_idx, valid_idx, test_idx)
|
||||
self.dataset = GraphormerDGLDataset(dataset, seed=seed, train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx)
|
||||
elif dataset_source == "pyg":
|
||||
self.dataset = GraphormerPYGDataset(dataset, train_idx, valid_idx, test_idx)
|
||||
self.dataset = GraphormerPYGDataset(dataset, train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx)
|
||||
else:
|
||||
raise ValueError("customized dataset can only have source pyg or dgl")
|
||||
elif dataset_source == "dgl":
|
||||
self.dataset = DGLDatasetLookupTable.GetDGLDataset(dataset_spec, seed)
|
||||
self.dataset = DGLDatasetLookupTable.GetDGLDataset(dataset_spec, seed=seed)
|
||||
elif dataset_source == "pyg":
|
||||
self.dataset = PYGDatasetLookupTable.GetPYGDataset(dataset_spec, seed)
|
||||
self.dataset = PYGDatasetLookupTable.GetPYGDataset(dataset_spec, seed=seed)
|
||||
elif dataset_source == "ogb":
|
||||
self.dataset = OGBDatasetLookupTable.GetOGBDataset(dataset_spec, seed)
|
||||
self.dataset = OGBDatasetLookupTable.GetOGBDataset(dataset_spec, seed=seed)
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
|
|
|
@ -276,11 +276,6 @@ class GraphPredictionWithFlagTask(GraphPredictionTask):
|
|||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.dm = GraphormerDataset(
|
||||
dataset_spec=cfg.dataset_name,
|
||||
dataset_source=cfg.dataset_source,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
self.flag_m = cfg.flag_m
|
||||
self.flag_step_size = cfg.flag_step_size
|
||||
self.flag_mag = cfg.flag_mag
|
||||
|
|
Загрузка…
Ссылка в новой задаче