add preln and fix hiv script (#96)
This commit is contained in:
Родитель
3b06c7e2fe
Коммит
5139627f7d
|
@ -111,9 +111,10 @@ fine-tunes our model ``pcqm4mv1_graphormer_base`` on the ``ogbg-molhiv`` dataset
|
|||
--save-dir ./ckpts \
|
||||
--pretrained-model-name pcqm4mv1_graphormer_base \
|
||||
--flag-m 3 \
|
||||
--flag-step-size 0.001 \
|
||||
--flag-mag 0.001 \
|
||||
--seed 1
|
||||
--flag-step-size 0.01 \
|
||||
--flag-mag 0 \
|
||||
--seed 1 \
|
||||
--pre-layernorm
|
||||
|
||||
After fine-tuning, use ``graphormer/evaluate/evaluate.py`` to evaluate the performance of all checkpoints:
|
||||
|
||||
|
@ -132,7 +133,8 @@ After fine-tuning, use ``graphormer/evaluate/evaluate.py`` to evaluate the perfo
|
|||
--save-dir ../../examples/property_prediction/ckpts/ \
|
||||
--split test \
|
||||
--metric auc \
|
||||
--seed 1
|
||||
--seed 1 \
|
||||
--pre-layernorm
|
||||
|
||||
|
||||
Training a New Model
|
||||
|
|
|
@ -21,7 +21,7 @@ GCN | 2.0M | -- | 0.1379 |
|
|||
GIN | 3.8M | -- | 0.1195 |
|
||||
GCN-VN | 4.9M | -- | 0.1153 |
|
||||
GIN-VN | 6.7M | -- | 0.1083 |
|
||||
Graphormer-v2 | 47.1M | 0.0253 | **0.0865** |
|
||||
Graphormer-v2 | 48.3M | 0.0253 | **0.0865** |
|
||||
|
||||
#### PCQM4Mv1
|
||||
Method | #params | train MAE | valid MAE |
|
||||
|
@ -32,7 +32,7 @@ GCN-VN | 4.9M | 0.1225 | 0.1485 |
|
|||
GIN-VN | 6.7M | 0.1150 | 0.1395 |
|
||||
Graphormer-Small| 12.5M | 0.0778 | 0.1264 |
|
||||
Graphormer | 47.1M | 0.0582 | 0.1234 |
|
||||
Graphormer-v2 | 47.1M | 0.0309 | **0.1201** |
|
||||
Graphormer-v2 | 48.3M | 0.0309 | **0.1201** |
|
||||
|
||||
## Open Graph Benchmark
|
||||
|
||||
|
@ -55,7 +55,7 @@ PHC-GNN | 111K | 79.34 |
|
|||
DeeperGCN-FLAG | 532K | 79.42 |
|
||||
DGN | 114K | 79.70 |
|
||||
Graphormer | 47.0M | 80.51 |
|
||||
Graphormer-v2 | 47.1M | **81.28** |
|
||||
Graphormer-v2 | 48.3M | **80.56** |
|
||||
|
||||
## Benchmarking Graph Neural Networks - ZINC-500K
|
||||
|
||||
|
|
|
@ -2,16 +2,14 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
n_gpu=2
|
||||
epoch=8
|
||||
n_gpu=1
|
||||
epoch=4
|
||||
max_epoch=$((epoch + 1))
|
||||
batch_size=64
|
||||
batch_size=128
|
||||
tot_updates=$((33000*epoch/batch_size/n_gpu))
|
||||
warmup_updates=$((tot_updates/10))
|
||||
warmup_updates=$((tot_updates*16/100))
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1
|
||||
|
||||
fairseq-train \
|
||||
CUDA_VISIBLE_DEVICES=3 fairseq-train \
|
||||
--user-dir ../../graphormer \
|
||||
--num-workers 16 \
|
||||
--ddp-backend=legacy_ddp \
|
||||
|
@ -24,7 +22,7 @@ fairseq-train \
|
|||
--attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.0 \
|
||||
--lr-scheduler polynomial_decay --power 1 --warmup-updates $warmup_updates --total-num-update $tot_updates \
|
||||
--lr 2e-4 --end-learning-rate 1e-9 \
|
||||
--lr 2e-4 --end-learning-rate 1e-5 \
|
||||
--batch-size $batch_size \
|
||||
--fp16 \
|
||||
--data-buffer-size 20 \
|
||||
|
@ -34,8 +32,9 @@ fairseq-train \
|
|||
--encoder-attention-heads 32 \
|
||||
--max-epoch $max_epoch \
|
||||
--save-dir ./ckpts \
|
||||
--pretrained-model-name pcqm4mv1_graphormer_base \
|
||||
--pretrained-model-name pcqm4mv1_graphormer_base_for_molhiv \
|
||||
--seed 1 \
|
||||
--flag-m 3 \
|
||||
--flag-step-size 0.001 \
|
||||
--flag-mag 0.001 \
|
||||
--flag-step-size 0.01 \
|
||||
--flag-mag 0 \
|
||||
--pre-layernorm
|
||||
|
|
|
@ -125,6 +125,11 @@ class GraphormerModel(FairseqEncoderModel):
|
|||
action="store_true",
|
||||
help="apply layernorm before each encoder block",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pre-layernorm",
|
||||
action="store_true",
|
||||
help="apply layernorm before self-attention and ffn. Without this, post layernorm will used",
|
||||
)
|
||||
|
||||
def max_nodes(self):
|
||||
return self.encoder.max_nodes
|
||||
|
@ -171,6 +176,7 @@ class GraphormerEncoder(FairseqEncoder):
|
|||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.act_dropout,
|
||||
encoder_normalize_before=args.encoder_normalize_before,
|
||||
pre_layernorm=args.pre_layernorm,
|
||||
apply_graphormer_init=args.apply_graphormer_init,
|
||||
activation_fn=args.activation_fn,
|
||||
)
|
||||
|
@ -272,15 +278,9 @@ def base_architecture(args):
|
|||
|
||||
@register_model_architecture("graphormer", "graphormer_base")
|
||||
def graphormer_base_architecture(args):
|
||||
if args.pretrained_model_name == "pcqm4mv1_graphormer_base":
|
||||
args.encoder_layers = 12
|
||||
args.encoder_attention_heads = 32
|
||||
args.encoder_ffn_embed_dim = 768
|
||||
args.encoder_embed_dim = 768
|
||||
args.dropout = getattr(args, "dropout", 0.0)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.act_dropout = getattr(args, "act_dropout", 0.1)
|
||||
elif args.pretrained_model_name == "pcqm4mv2_graphormer_base":
|
||||
if args.pretrained_model_name == "pcqm4mv1_graphormer_base" or \
|
||||
args.pretrained_model_name == "pcqm4mv2_graphormer_base" or \
|
||||
args.pretrained_model_name == "pcqm4mv1_graphormer_base_for_molhiv":
|
||||
args.encoder_layers = 12
|
||||
args.encoder_attention_heads = 32
|
||||
args.encoder_ffn_embed_dim = 768
|
||||
|
@ -290,11 +290,12 @@ def graphormer_base_architecture(args):
|
|||
args.act_dropout = getattr(args, "act_dropout", 0.1)
|
||||
else:
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
||||
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
||||
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 32)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768)
|
||||
args.dropout = getattr(args, "dropout", 0.0)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.act_dropout = getattr(args, "act_dropout", 0.1)
|
||||
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
||||
|
@ -305,6 +306,7 @@ def graphormer_base_architecture(args):
|
|||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
args.pre_layernorm = getattr(args, "pre_layernorm", False)
|
||||
base_architecture(args)
|
||||
|
||||
|
||||
|
@ -326,6 +328,7 @@ def graphormer_slim_architecture(args):
|
|||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
args.pre_layernorm = getattr(args, "pre_layernorm", False)
|
||||
base_architecture(args)
|
||||
|
||||
|
||||
|
@ -347,4 +350,5 @@ def graphormer_large_architecture(args):
|
|||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
args.pre_layernorm = getattr(args, "pre_layernorm", False)
|
||||
base_architecture(args)
|
||||
|
|
|
@ -62,6 +62,7 @@ class GraphormerGraphEncoder(nn.Module):
|
|||
activation_dropout: float = 0.1,
|
||||
layerdrop: float = 0.0,
|
||||
encoder_normalize_before: bool = False,
|
||||
pre_layernorm: bool = False,
|
||||
apply_graphormer_init: bool = False,
|
||||
activation_fn: str = "gelu",
|
||||
embed_scale: float = None,
|
||||
|
@ -119,6 +120,9 @@ class GraphormerGraphEncoder(nn.Module):
|
|||
else:
|
||||
self.emb_layer_norm = None
|
||||
|
||||
if pre_layernorm:
|
||||
self.final_layer_norm = LayerNorm(self.embedding_dim, export=export)
|
||||
|
||||
if self.layerdrop > 0.0:
|
||||
self.layers = LayerDropModuleList(p=self.layerdrop)
|
||||
else:
|
||||
|
@ -136,6 +140,7 @@ class GraphormerGraphEncoder(nn.Module):
|
|||
export=export,
|
||||
q_noise=q_noise,
|
||||
qn_block_size=qn_block_size,
|
||||
pre_layernorm=pre_layernorm,
|
||||
)
|
||||
for _ in range(num_encoder_layers)
|
||||
]
|
||||
|
@ -168,6 +173,7 @@ class GraphormerGraphEncoder(nn.Module):
|
|||
export,
|
||||
q_noise,
|
||||
qn_block_size,
|
||||
pre_layernorm,
|
||||
):
|
||||
return GraphormerGraphEncoderLayer(
|
||||
embedding_dim=embedding_dim,
|
||||
|
@ -180,6 +186,7 @@ class GraphormerGraphEncoder(nn.Module):
|
|||
export=export,
|
||||
q_noise=q_noise,
|
||||
qn_block_size=qn_block_size,
|
||||
pre_layernorm=pre_layernorm,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
|
|
@ -32,6 +32,7 @@ class GraphormerGraphEncoderLayer(nn.Module):
|
|||
q_noise: float = 0.0,
|
||||
qn_block_size: int = 8,
|
||||
init_fn: Callable = None,
|
||||
pre_layernorm: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -44,6 +45,7 @@ class GraphormerGraphEncoderLayer(nn.Module):
|
|||
self.attention_dropout = attention_dropout
|
||||
self.q_noise = q_noise
|
||||
self.qn_block_size = qn_block_size
|
||||
self.pre_layernorm = pre_layernorm
|
||||
|
||||
self.dropout_module = FairseqDropout(
|
||||
dropout, module_name=self.__class__.__name__
|
||||
|
@ -119,6 +121,8 @@ class GraphormerGraphEncoderLayer(nn.Module):
|
|||
"""
|
||||
# x: T x B x C
|
||||
residual = x
|
||||
if self.pre_layernorm:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
|
@ -130,13 +134,17 @@ class GraphormerGraphEncoderLayer(nn.Module):
|
|||
)
|
||||
x = self.dropout_module(x)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
if not self.pre_layernorm:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
if self.pre_layernorm:
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.activation_dropout_module(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout_module(x)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
if not self.pre_layernorm:
|
||||
x = self.final_layer_norm(x)
|
||||
return x, attn
|
||||
|
|
|
@ -5,6 +5,7 @@ PRETRAINED_MODEL_URLS = {
|
|||
"pcqm4mv1_graphormer_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv1/checkpoint_best_pcqm4mv1.pt",
|
||||
"pcqm4mv2_graphormer_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv2/checkpoint_best_pcqm4mv2.pt",
|
||||
"oc20is2re_graphormer3d_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/oc20is2re/checkpoint_last_oc20_is2re.pt",
|
||||
"pcqm4mv1_graphormer_base_for_molhiv":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv1/checkpoint_base_preln_pcqm4mv1_for_hiv.pt",
|
||||
}
|
||||
|
||||
def load_pretrained_model(pretrained_model_name):
|
||||
|
|
Загрузка…
Ссылка в новой задаче