This commit is contained in:
shiyu1994 2021-12-28 23:35:15 +08:00 коммит произвёл GitHub
Родитель 774be627b6
Коммит 46b38748c5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 68 добавлений и 5 удалений

Просмотреть файл

@ -0,0 +1,24 @@
from graphormer.data import register_dataset
from dgl.data import QM9
import numpy as np
from sklearn.model_selection import train_test_split
@register_dataset("customized_qm9_dataset")
def create_customized_dataset():
dataset = QM9(label_keys=["mu"])
num_graphs = len(dataset)
# customized dataset split
train_valid_idx, test_idx = train_test_split(
np.arange(num_graphs), test_size=num_graphs // 10, random_state=0
)
train_idx, valid_idx = train_test_split(
train_valid_idx, test_size=num_graphs // 5, random_state=0
)
return {
"dataset": dataset,
"train_idx": train_idx,
"valid_idx": valid_idx,
"test_idx": test_idx,
"source": "dgl"
}

Просмотреть файл

@ -11,6 +11,7 @@ import logging
import contextlib
from dataclasses import dataclass, field
from omegaconf import II, open_dict, OmegaConf
import importlib
import numpy as np
from fairseq.data import (
@ -32,6 +33,10 @@ import torch
from fairseq.optim.amp_optimizer import AMPOptimizer
import math
from ..data import DATASET_REGISTRY
import sys
import os
logger = logging.getLogger(__name__)
@ -119,6 +124,11 @@ class GraphPredictionConfig(FairseqDataclass):
metadata={"help": "whether to shuffle the dataset at each epoch"},
)
user_data_dir: str = field(
default="",
metadata={"help": "path to the module of user-defined dataset"},
)
@register_task("graph_prediction", dataclass=GraphPredictionConfig)
class GraphPredictionTask(FairseqTask):
@ -128,11 +138,40 @@ class GraphPredictionTask(FairseqTask):
def __init__(self, cfg):
super().__init__(cfg)
self.dm = GraphormerDataset(
dataset_spec=cfg.dataset_name,
dataset_source=cfg.dataset_source,
seed=cfg.seed,
)
if cfg.user_data_dir != "":
self.__import_user_defined_datasets(cfg.user_data_dir)
if cfg.dataset_name in DATASET_REGISTRY:
dataset_dict = DATASET_REGISTRY[cfg.dataset_name]
self.dm = GraphormerDataset(
dataset=dataset_dict["dataset"],
dataset_source=dataset_dict["source"],
train_idx=dataset_dict["train_idx"],
valid_idx=dataset_dict["valid_idx"],
test_idx=dataset_dict["test_idx"],
seed=cfg.seed)
else:
raise ValueError(f"dataset {cfg.dataset_name} is not found in customized dataset module {cfg.user_data_dir}")
else:
self.dm = GraphormerDataset(
dataset_spec=cfg.dataset_name,
dataset_source=cfg.dataset_source,
seed=cfg.seed,
)
def __import_user_defined_datasets(self, dataset_dir):
dataset_dir = dataset_dir.strip("/")
module_parent, module_name = os.path.split(dataset_dir)
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
for file in os.listdir(dataset_dir):
path = os.path.join(dataset_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
task_name = file[: file.find(".py")] if file.endswith(".py") else file
importlib.import_module(module_name + "." + task_name)
@classmethod
def setup_task(cls, cfg, **kwargs):