Recover customized dataset (#68)
This commit is contained in:
Родитель
774be627b6
Коммит
46b38748c5
|
@ -0,0 +1,24 @@
|
|||
from graphormer.data import register_dataset
|
||||
from dgl.data import QM9
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
@register_dataset("customized_qm9_dataset")
|
||||
def create_customized_dataset():
|
||||
dataset = QM9(label_keys=["mu"])
|
||||
num_graphs = len(dataset)
|
||||
|
||||
# customized dataset split
|
||||
train_valid_idx, test_idx = train_test_split(
|
||||
np.arange(num_graphs), test_size=num_graphs // 10, random_state=0
|
||||
)
|
||||
train_idx, valid_idx = train_test_split(
|
||||
train_valid_idx, test_size=num_graphs // 5, random_state=0
|
||||
)
|
||||
return {
|
||||
"dataset": dataset,
|
||||
"train_idx": train_idx,
|
||||
"valid_idx": valid_idx,
|
||||
"test_idx": test_idx,
|
||||
"source": "dgl"
|
||||
}
|
|
@ -11,6 +11,7 @@ import logging
|
|||
import contextlib
|
||||
from dataclasses import dataclass, field
|
||||
from omegaconf import II, open_dict, OmegaConf
|
||||
import importlib
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data import (
|
||||
|
@ -32,6 +33,10 @@ import torch
|
|||
from fairseq.optim.amp_optimizer import AMPOptimizer
|
||||
import math
|
||||
|
||||
from ..data import DATASET_REGISTRY
|
||||
import sys
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -119,6 +124,11 @@ class GraphPredictionConfig(FairseqDataclass):
|
|||
metadata={"help": "whether to shuffle the dataset at each epoch"},
|
||||
)
|
||||
|
||||
user_data_dir: str = field(
|
||||
default="",
|
||||
metadata={"help": "path to the module of user-defined dataset"},
|
||||
)
|
||||
|
||||
|
||||
@register_task("graph_prediction", dataclass=GraphPredictionConfig)
|
||||
class GraphPredictionTask(FairseqTask):
|
||||
|
@ -128,11 +138,40 @@ class GraphPredictionTask(FairseqTask):
|
|||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.dm = GraphormerDataset(
|
||||
dataset_spec=cfg.dataset_name,
|
||||
dataset_source=cfg.dataset_source,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
if cfg.user_data_dir != "":
|
||||
self.__import_user_defined_datasets(cfg.user_data_dir)
|
||||
if cfg.dataset_name in DATASET_REGISTRY:
|
||||
dataset_dict = DATASET_REGISTRY[cfg.dataset_name]
|
||||
self.dm = GraphormerDataset(
|
||||
dataset=dataset_dict["dataset"],
|
||||
dataset_source=dataset_dict["source"],
|
||||
train_idx=dataset_dict["train_idx"],
|
||||
valid_idx=dataset_dict["valid_idx"],
|
||||
test_idx=dataset_dict["test_idx"],
|
||||
seed=cfg.seed)
|
||||
else:
|
||||
raise ValueError(f"dataset {cfg.dataset_name} is not found in customized dataset module {cfg.user_data_dir}")
|
||||
else:
|
||||
self.dm = GraphormerDataset(
|
||||
dataset_spec=cfg.dataset_name,
|
||||
dataset_source=cfg.dataset_source,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
|
||||
def __import_user_defined_datasets(self, dataset_dir):
|
||||
dataset_dir = dataset_dir.strip("/")
|
||||
module_parent, module_name = os.path.split(dataset_dir)
|
||||
sys.path.insert(0, module_parent)
|
||||
importlib.import_module(module_name)
|
||||
for file in os.listdir(dataset_dir):
|
||||
path = os.path.join(dataset_dir, file)
|
||||
if (
|
||||
not file.startswith("_")
|
||||
and not file.startswith(".")
|
||||
and (file.endswith(".py") or os.path.isdir(path))
|
||||
):
|
||||
task_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
importlib.import_module(module_name + "." + task_name)
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, cfg, **kwargs):
|
||||
|
|
Загрузка…
Ссылка в новой задаче