This commit is contained in:
shiyu1994 2022-03-11 15:42:58 +08:00 коммит произвёл GitHub
Родитель 3b06c7e2fe
Коммит 5139627f7d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 52 добавлений и 31 удалений

Просмотреть файл

@ -111,9 +111,10 @@ fine-tunes our model ``pcqm4mv1_graphormer_base`` on the ``ogbg-molhiv`` dataset
--save-dir ./ckpts \
--pretrained-model-name pcqm4mv1_graphormer_base \
--flag-m 3 \
--flag-step-size 0.001 \
--flag-mag 0.001 \
--seed 1
--flag-step-size 0.01 \
--flag-mag 0 \
--seed 1 \
--pre-layernorm
After fine-tuning, use ``graphormer/evaluate/evaluate.py`` to evaluate the performance of all checkpoints:
@ -132,7 +133,8 @@ After fine-tuning, use ``graphormer/evaluate/evaluate.py`` to evaluate the perfo
--save-dir ../../examples/property_prediction/ckpts/ \
--split test \
--metric auc \
--seed 1
--seed 1 \
--pre-layernorm
Training a New Model

Просмотреть файл

@ -21,7 +21,7 @@ GCN | 2.0M | -- | 0.1379 |
GIN | 3.8M | -- | 0.1195 |
GCN-VN | 4.9M | -- | 0.1153 |
GIN-VN | 6.7M | -- | 0.1083 |
Graphormer-v2 | 47.1M | 0.0253 | **0.0865** |
Graphormer-v2 | 48.3M | 0.0253 | **0.0865** |
#### PCQM4Mv1
Method | #params | train MAE | valid MAE |
@ -32,7 +32,7 @@ GCN-VN | 4.9M | 0.1225 | 0.1485 |
GIN-VN | 6.7M | 0.1150 | 0.1395 |
Graphormer-Small| 12.5M | 0.0778 | 0.1264 |
Graphormer | 47.1M | 0.0582 | 0.1234 |
Graphormer-v2 | 47.1M | 0.0309 | **0.1201** |
Graphormer-v2 | 48.3M | 0.0309 | **0.1201** |
## Open Graph Benchmark
@ -55,7 +55,7 @@ PHC-GNN | 111K | 79.34 |
DeeperGCN-FLAG | 532K | 79.42 |
DGN | 114K | 79.70 |
Graphormer | 47.0M | 80.51 |
Graphormer-v2 | 47.1M | **81.28** |
Graphormer-v2 | 48.3M | **80.56** |
## Benchmarking Graph Neural Networks - ZINC-500K

Просмотреть файл

@ -2,16 +2,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
n_gpu=2
epoch=8
n_gpu=1
epoch=4
max_epoch=$((epoch + 1))
batch_size=64
batch_size=128
tot_updates=$((33000*epoch/batch_size/n_gpu))
warmup_updates=$((tot_updates/10))
warmup_updates=$((tot_updates*16/100))
CUDA_VISIBLE_DEVICES=0,1
fairseq-train \
CUDA_VISIBLE_DEVICES=3 fairseq-train \
--user-dir ../../graphormer \
--num-workers 16 \
--ddp-backend=legacy_ddp \
@ -24,7 +22,7 @@ fairseq-train \
--attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.0 \
--lr-scheduler polynomial_decay --power 1 --warmup-updates $warmup_updates --total-num-update $tot_updates \
--lr 2e-4 --end-learning-rate 1e-9 \
--lr 2e-4 --end-learning-rate 1e-5 \
--batch-size $batch_size \
--fp16 \
--data-buffer-size 20 \
@ -34,8 +32,9 @@ fairseq-train \
--encoder-attention-heads 32 \
--max-epoch $max_epoch \
--save-dir ./ckpts \
--pretrained-model-name pcqm4mv1_graphormer_base \
--pretrained-model-name pcqm4mv1_graphormer_base_for_molhiv \
--seed 1 \
--flag-m 3 \
--flag-step-size 0.001 \
--flag-mag 0.001 \
--flag-step-size 0.01 \
--flag-mag 0 \
--pre-layernorm

Просмотреть файл

@ -125,6 +125,11 @@ class GraphormerModel(FairseqEncoderModel):
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--pre-layernorm",
action="store_true",
help="apply layernorm before self-attention and ffn. Without this, post layernorm will used",
)
def max_nodes(self):
return self.encoder.max_nodes
@ -171,6 +176,7 @@ class GraphormerEncoder(FairseqEncoder):
attention_dropout=args.attention_dropout,
activation_dropout=args.act_dropout,
encoder_normalize_before=args.encoder_normalize_before,
pre_layernorm=args.pre_layernorm,
apply_graphormer_init=args.apply_graphormer_init,
activation_fn=args.activation_fn,
)
@ -272,15 +278,9 @@ def base_architecture(args):
@register_model_architecture("graphormer", "graphormer_base")
def graphormer_base_architecture(args):
if args.pretrained_model_name == "pcqm4mv1_graphormer_base":
args.encoder_layers = 12
args.encoder_attention_heads = 32
args.encoder_ffn_embed_dim = 768
args.encoder_embed_dim = 768
args.dropout = getattr(args, "dropout", 0.0)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.act_dropout = getattr(args, "act_dropout", 0.1)
elif args.pretrained_model_name == "pcqm4mv2_graphormer_base":
if args.pretrained_model_name == "pcqm4mv1_graphormer_base" or \
args.pretrained_model_name == "pcqm4mv2_graphormer_base" or \
args.pretrained_model_name == "pcqm4mv1_graphormer_base_for_molhiv":
args.encoder_layers = 12
args.encoder_attention_heads = 32
args.encoder_ffn_embed_dim = 768
@ -290,11 +290,12 @@ def graphormer_base_architecture(args):
args.act_dropout = getattr(args, "act_dropout", 0.1)
else:
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 32)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768)
args.dropout = getattr(args, "dropout", 0.0)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.act_dropout = getattr(args, "act_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
@ -305,6 +306,7 @@ def graphormer_base_architecture(args):
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.pre_layernorm = getattr(args, "pre_layernorm", False)
base_architecture(args)
@ -326,6 +328,7 @@ def graphormer_slim_architecture(args):
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.pre_layernorm = getattr(args, "pre_layernorm", False)
base_architecture(args)
@ -347,4 +350,5 @@ def graphormer_large_architecture(args):
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.pre_layernorm = getattr(args, "pre_layernorm", False)
base_architecture(args)

Просмотреть файл

@ -62,6 +62,7 @@ class GraphormerGraphEncoder(nn.Module):
activation_dropout: float = 0.1,
layerdrop: float = 0.0,
encoder_normalize_before: bool = False,
pre_layernorm: bool = False,
apply_graphormer_init: bool = False,
activation_fn: str = "gelu",
embed_scale: float = None,
@ -119,6 +120,9 @@ class GraphormerGraphEncoder(nn.Module):
else:
self.emb_layer_norm = None
if pre_layernorm:
self.final_layer_norm = LayerNorm(self.embedding_dim, export=export)
if self.layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.layerdrop)
else:
@ -136,6 +140,7 @@ class GraphormerGraphEncoder(nn.Module):
export=export,
q_noise=q_noise,
qn_block_size=qn_block_size,
pre_layernorm=pre_layernorm,
)
for _ in range(num_encoder_layers)
]
@ -168,6 +173,7 @@ class GraphormerGraphEncoder(nn.Module):
export,
q_noise,
qn_block_size,
pre_layernorm,
):
return GraphormerGraphEncoderLayer(
embedding_dim=embedding_dim,
@ -180,6 +186,7 @@ class GraphormerGraphEncoder(nn.Module):
export=export,
q_noise=q_noise,
qn_block_size=qn_block_size,
pre_layernorm=pre_layernorm,
)
def forward(

Просмотреть файл

@ -32,6 +32,7 @@ class GraphormerGraphEncoderLayer(nn.Module):
q_noise: float = 0.0,
qn_block_size: int = 8,
init_fn: Callable = None,
pre_layernorm: bool = False,
) -> None:
super().__init__()
@ -44,6 +45,7 @@ class GraphormerGraphEncoderLayer(nn.Module):
self.attention_dropout = attention_dropout
self.q_noise = q_noise
self.qn_block_size = qn_block_size
self.pre_layernorm = pre_layernorm
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
@ -119,6 +121,8 @@ class GraphormerGraphEncoderLayer(nn.Module):
"""
# x: T x B x C
residual = x
if self.pre_layernorm:
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
query=x,
key=x,
@ -130,13 +134,17 @@ class GraphormerGraphEncoderLayer(nn.Module):
)
x = self.dropout_module(x)
x = residual + x
x = self.self_attn_layer_norm(x)
if not self.pre_layernorm:
x = self.self_attn_layer_norm(x)
residual = x
if self.pre_layernorm:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = residual + x
x = self.final_layer_norm(x)
if not self.pre_layernorm:
x = self.final_layer_norm(x)
return x, attn

Просмотреть файл

@ -5,6 +5,7 @@ PRETRAINED_MODEL_URLS = {
"pcqm4mv1_graphormer_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv1/checkpoint_best_pcqm4mv1.pt",
"pcqm4mv2_graphormer_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv2/checkpoint_best_pcqm4mv2.pt",
"oc20is2re_graphormer3d_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/oc20is2re/checkpoint_last_oc20_is2re.pt",
"pcqm4mv1_graphormer_base_for_molhiv":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv1/checkpoint_base_preln_pcqm4mv1_for_hiv.pt",
}
def load_pretrained_model(pretrained_model_name):