init v2 (#51)
This commit is contained in:
Родитель
1fdec9ab3d
Коммит
9fa61b9417
|
@ -127,3 +127,13 @@ exps
|
|||
|
||||
# Weights and Biases logs
|
||||
wandb/
|
||||
|
||||
|
||||
*.pyc
|
||||
*.log
|
||||
ckpts
|
||||
examples/dataset
|
||||
examples/property_prediction/ckpts
|
||||
#examples/property_prediction/dataset
|
||||
!examples/property_prediction/dataset/pcqm4m-v2/RELEASE_v1.txt
|
||||
!examples/property_prediction/dataset/pcqm4m_kddcup2021/RELEASE_v1.txt
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
[submodule "fairseq"]
|
||||
path = fairseq
|
||||
url = https://github.com/pytorch/fairseq
|
|
@ -0,0 +1,19 @@
|
|||
# Required
|
||||
version: 1
|
||||
|
||||
# Set the version of Python and other tools you might need
|
||||
build:
|
||||
os: ubuntu-20.04
|
||||
tools:
|
||||
python: "3.9"
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
configuration: docs/conf.py
|
||||
|
||||
# Optionally declare the Python requirements required to build your docs
|
||||
python:
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
- method: pip
|
||||
path: fairseq
|
87
README.md
87
README.md
|
@ -1,9 +1,18 @@
|
|||
# Graphormer <img src="docs/graphformer_logo.png" width="100" align="left">
|
||||
<img src="docs/logo-10.png" width=100%>
|
||||
|
||||
Graphormer is a deep learning package that allows researchers and developers to train custom models for molecule modeling tasks. It aims to accelerate the research and application in AI for molecule science, such as material discovery, drug discovery, etc. [Project website](https://www.microsoft.com/en-us/research/project/graphormer/).
|
||||
|
||||
This repo is the official implementation of ["Do Transformers Really Perform Badly for Graph Representation?"](https://openreview.net/forum?id=OeWooOxFwDa).
|
||||
## Highlights in Graphormer v2.0
|
||||
* The model, code, and script used in the [Open Catalyst Challenge](https://opencatalystproject.org/challenge.html) are available.
|
||||
* Pre-trained models on PCQM4M and PCQM4Mv2 are available, more pre-trained models are comming soon.
|
||||
* Supports interface and datasets of PyG, DGL, OGB, and OCP.
|
||||
* Supports fairseq backbone.
|
||||
* Document is online!
|
||||
|
||||
## News
|
||||
## What's New:
|
||||
|
||||
***12/19/2021***
|
||||
1. Graphormer v2.0 is released. Enjoy!
|
||||
|
||||
***12/10/2021***
|
||||
1. **Graphormer** has won the [Open Catalyst Challenge](https://opencatalystproject.org/challenge.html). The technical talk could be found through this [link](https://www.youtube.com/watch?v=uKJX3Mpu3OA&ab_channel=OpenCatalystProject).
|
||||
|
@ -14,86 +23,34 @@ This repo is the official implementation of ["Do Transformers Really Perform Bad
|
|||
1. **Graphormer** has been accepted by **NeurIPS 2021**.
|
||||
2. We're hiring! Please contact ``shuz[at]microsoft.com`` for more information.
|
||||
|
||||
|
||||
***08/03/2021***
|
||||
1. Codes and scripts are released.
|
||||
|
||||
***06/16/2021***
|
||||
1. Graphormer has won the **1st place** of quantum prediction track of Open Graph Benchmark Large-Scale Challenge (KDD CUP 2021) [[Competition Description]](https://ogb.stanford.edu/kddcup2021/pcqm4m/) [[Competition Result]](https://ogb.stanford.edu/kddcup2021/results/) [[Technical Report]](https://arxiv.org/pdf/2106.08279.pdf) [[Blog (English)]](https://www.microsoft.com/en-us/research/lab/microsoft-research-asia/articles/transformer-stands-out-as-the-best-graph-learner-researchers-from-microsoft-research-asia-wins-the-kdd-cups-2021-graph-prediction-track/) [[Blog (Chinese)]](https://www.msra.cn/zh-cn/news/features/ogb-lsc)
|
||||
|
||||
## Hiring
|
||||
We are hiring at all levels (including FTE researchers and interns)! If you are interested in working with us on AI for Molecule Science, please send your resume to <a href="mailto:shuz@microsoft.com" class="x-hidden-focus">shuz@microsoft.com</a>.
|
||||
|
||||
## Get Started
|
||||
|
||||
## Introduction
|
||||
**Graphormer** is initially described in [arxiv](https://arxiv.org/abs/2106.05234), which is a standard Transformer architecture with several structural encodings, which could effectively encoding the structural information of a graph into the model.
|
||||
Our primary documentation is at https://graphormer.readthedocs.io/ and is generated from this repository, which contains instructions for getting started, training new models and extending Graphormer with new model types and tasks.
|
||||
|
||||
Graphormer achieves strong performance on PCQM4M-LSC (`0.1234 MAE` on val), MolPCBA (`31.39 AP(%)` on test), MolHIV (`80.51 AUC(%)` on test) and ZINC (`0.122 MAE on test`), surpassing previous models by a large margin.
|
||||
Next you may want to read:
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="docs/graphformer.png" width="600">
|
||||
</p>
|
||||
|
||||
## Main Results
|
||||
|
||||
#### PCQM4M-LSC
|
||||
Method | #params | train MAE | valid MAE |
|
||||
--------------|---------|-----------|-----------|
|
||||
GCN | 2.0M | 0.1318 | 0.1691 |
|
||||
GIN | 3.8M | 0.1203 | 0.1537 |
|
||||
GCN-VN | 4.9M | 0.1225 | 0.1485 |
|
||||
GIN-VN | 6.7M | 0.1150 | 0.1395 |
|
||||
Graphormer-Small| 12.5M | 0.0778 | 0.1264 |
|
||||
Graphormer | 47.1M | 0.0582 | **0.1234** |
|
||||
|
||||
#### OGBG-MolPCBA
|
||||
Method | #params | test AP (%)|
|
||||
--------------|---------|------------|
|
||||
DeeperGCN-VN+FLAG | 5.6M | 28.42 |
|
||||
DGN | 6.7M | 28.85 |
|
||||
GINE-VN | 6.1M | 29.17 |
|
||||
PHC-GNN | 1.7M | 29.47 |
|
||||
GINE-APPNP | 6.1M | 29.79 |
|
||||
Graphormer | 119.5M | **31.39** |
|
||||
|
||||
#### OGBG-MolHIV
|
||||
Method | #params | test AP (%)|
|
||||
--------------|---------|------------|
|
||||
GCN-GraphNorm | 526K | 78.83 |
|
||||
PNA | 326K | 79.05 |
|
||||
PHC-GNN | 111K | 79.34 |
|
||||
DeeperGCN-FLAG | 532K | 79.42 |
|
||||
DGN | 114K | 79.70 |
|
||||
Graphormer | 47.0M | **80.51** |
|
||||
|
||||
#### ZINC-500K
|
||||
Method | #params | test MAE |
|
||||
--------------|---------|------------|
|
||||
GIN | 509.5K | 0.526 |
|
||||
GraphSage | 505.3K | 0.398 |
|
||||
GAT | 531.3K | 0.384 |
|
||||
GCN | 505.1K | 0.367 |
|
||||
GT | 588.9K | 0.226 |
|
||||
GatedGCN-PE | 505.0K | 0.214 |
|
||||
MPNN (sum) | 480.8K | 0.145 |
|
||||
PNA | 387.2K | 0.142 |
|
||||
SAN | 508.6K | 0.139 |
|
||||
Graphormer-Slim | 489.3K | **0.122** |
|
||||
* [Examples](https://github.com/microsoft/Graphormer/tree/main/examples) showing command line usage of common tasks.
|
||||
|
||||
|
||||
## Requirements and Installation
|
||||
|
||||
#### Setup with Conda
|
||||
|
||||
```
|
||||
# create a new environment
|
||||
conda create --name graphormer python=3.7
|
||||
conda activate graphormer
|
||||
# install requirements
|
||||
pip install rdkit-pypi cython
|
||||
pip install ogb==1.3.1 pytorch-lightning==1.3.0
|
||||
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install torch-geometric==1.6.3 ogb==1.3.1 pytorch-lightning==1.3.1 tqdm torch-sparse==0.6.9 torch-scatter==2.0.6 -f https://pytorch-geometric.com/whl/torch-1.7.0+cu110.html
|
||||
bash install.sh
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Citation
|
||||
Please kindly cite this paper if you use the code:
|
||||
```
|
||||
|
|
50
SUPPORT.md
50
SUPPORT.md
|
@ -1,25 +1,25 @@
|
|||
# TODO: The maintainer of this repo has not yet edited this file
|
||||
|
||||
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
|
||||
|
||||
- **No CSS support:** Fill out this template with information about how to file issues and get help.
|
||||
- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
|
||||
- **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
|
||||
|
||||
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
|
||||
|
||||
# Support
|
||||
|
||||
## How to file issues and get help
|
||||
|
||||
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
|
||||
issues before filing new issues to avoid duplicates. For new issues, file your bug or
|
||||
feature request as a new Issue.
|
||||
|
||||
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
|
||||
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
|
||||
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
|
||||
|
||||
## Microsoft Support Policy
|
||||
|
||||
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
|
||||
# TODO: The maintainer of this repo has not yet edited this file
|
||||
|
||||
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
|
||||
|
||||
- **No CSS support:** Fill out this template with information about how to file issues and get help.
|
||||
- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
|
||||
- **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
|
||||
|
||||
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
|
||||
|
||||
# Support
|
||||
|
||||
## How to file issues and get help
|
||||
|
||||
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
|
||||
issues before filing new issues to avoid duplicates. For new issues, file your bug or
|
||||
feature request as a new Issue.
|
||||
|
||||
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
|
||||
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
|
||||
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
|
||||
|
||||
## Microsoft Support Policy
|
||||
|
||||
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
Datasets
|
||||
==================
|
||||
|
||||
Graphormer supports training with both existing datasets in graph libraries and customized datasets.
|
||||
|
||||
Existing Datasets
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
Graphormer supports training with datasets in existing libraries.
|
||||
Users can easily exploit datasets in these libraries by specifying the ``--dataset-source`` and ``--dataset-name`` parameters.
|
||||
|
||||
``--dataset-source`` specifies the source for the dataset, can be:
|
||||
|
||||
1. ``dgl`` for `DGL <https://docs.dgl.ai/>`__
|
||||
|
||||
2. ``pyg`` for `Pytorch Geometric <https://pytorch-geometric.readthedocs.io/en/latest/>`__
|
||||
|
||||
3. ``ogb`` for `OGB <https://ogb.stanford.edu/>`__
|
||||
|
||||
``--dataset-name`` specifies the dataset in the source.
|
||||
For example, by specifying ``--dataset-source pyg`` and ``--dataset-name zinc``, Graphormer will load the `ZINC <https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.ZINC>`__ dataset from Pytorch Geometric.
|
||||
When a dataset requires additional parameters to construct, the parameters are specified as ``<dataset_name>:<param_1>=<value_1>,<param_2>=<value_2>,...,<param_n>=<value_n>``.
|
||||
When the type of a parameter value is a list, the value is represented as a string with the list elements concatenated by `+`.
|
||||
For example, if we want to specify multiple ``label_keys`` with ``mu``, ``alpha``, and ``homo`` for `QM9 <https://docs.dgl.ai/en/0.6.x/api/python/dgl.data.html#qm9-dataset>`__ dataset,
|
||||
``--dataset-name`` should be ``qm9:label_keys=mu+alpha+homo``.
|
||||
|
||||
When dataset split (``train``, ``valid`` and ``test`` subsets) is not configured in the original dataset source, we randomly partition
|
||||
the full set into ``train``, ``valid`` and ``test`` with ratios ``0.7``, ``0.2`` and ``0.1``, respectively.
|
||||
If you want customized split of a dataset, you may implement a `customized dataset `.
|
||||
Currently, only integer features of nodes and edges in the datasets are used.
|
||||
|
||||
A full list of supported datasets of each data source:
|
||||
|
||||
+------------------+----------------+-----------------------------------------+-----------------------------+
|
||||
| Dataset Source | Dataset Name | Link | #Label/#Class |
|
||||
+==================+================+=========================================+=============================+
|
||||
| ``dgl`` | ``qm7b`` | QM7B_ dataset | 14 |
|
||||
| +----------------+-----------------------------------------+-----------------------------+
|
||||
| | ``qm9`` | QM9_ dataset | Depending on ``label_keys`` |
|
||||
| +----------------+-----------------------------------------+-----------------------------+
|
||||
| | ``qm9edge`` | QM9Edge_ dataset | Depending on ``label_keys`` |
|
||||
| +----------------+-----------------------------------------+-----------------------------+
|
||||
| | ``minigc`` | MiniGC_ dataset | 8 |
|
||||
| +----------------+-----------------------------------------+-----------------------------+
|
||||
| | ``gin`` | `Graph Isomorphism Network`_ dataset | 1 |
|
||||
| +----------------+-----------------------------------------+-----------------------------+
|
||||
| | ``fakenews`` | `FakeNewsDataset`_ dataset | 1 |
|
||||
+------------------+----------------+-----------------------------------------+-----------------------------+
|
||||
| ``pgy`` |``moleculenet`` | MoleculeNet_ dataset | 1 |
|
||||
| +----------------+-----------------------------------------+-----------------------------+
|
||||
| | ``zinc`` | ZINC_ dataset | 1 |
|
||||
+------------------+----------------+-----------------------------------------+-----------------------------+
|
||||
| ``ogb`` |``ogbg-molhiv`` | ogbg-molhiv_ dataset | 1 |
|
||||
| +----------------+-----------------------------------------+-----------------------------+
|
||||
| |``ogbg-molpcba``| ogbg-molpcba_ dataset | 128 |
|
||||
| +----------------+-----------------------------------------+-----------------------------+
|
||||
| |``pcqm4m`` | PCQM4M_ dataset | 1 |
|
||||
| +----------------+-----------------------------------------+-----------------------------+
|
||||
| |``pcqm4mv2`` | PCQM4Mv2_ dataset | 1 |
|
||||
|------------------+----------------+-----------------------------------------+-----------------------------+
|
||||
|
||||
|
||||
.. _QM7B: https://docs.dgl.ai/en/0.6.x/api/python/dgl.data.html#qm7b-dataset
|
||||
.. _QM9: https://docs.dgl.ai/en/0.6.x/api/python/dgl.data.html#qm9-dataset
|
||||
.. _QM9Edge: https://docs.dgl.ai/en/0.6.x/api/python/dgl.data.html#qm9edge-dataset
|
||||
.. _MiniGC: https://docs.dgl.ai/en/0.6.x/api/python/dgl.data.html#mini-graph-classification-dataset
|
||||
.. _TU: https://docs.dgl.ai/en/0.6.x/api/python/dgl.data.html#tu-dataset
|
||||
.. _Graph Isomorphism Network: https://docs.dgl.ai/en/0.6.x/api/python/dgl.data.html#qm9-dataset
|
||||
.. _FakeNewsDataset: https://docs.dgl.ai/en/0.7.x/_modules/dgl/data/fakenews.html
|
||||
|
||||
.. _KarateClub: https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.KarateClub
|
||||
.. _MoleculeNet: https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.MoleculeNet
|
||||
.. _ZINC: https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.ZINC
|
||||
.. _MD17: https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.MD17
|
||||
|
||||
.. _ogbg-molhiv: https://ogb.stanford.edu/docs/graphprop/#ogbg-mol
|
||||
.. _ogbg-molpcba: https://ogb.stanford.edu/docs/graphprop/#ogbg-mol
|
||||
.. _PCQM4M: https://ogb.stanford.edu/kddcup2021/pcqm4m/
|
||||
.. _PCQM4Mv2: https://ogb.stanford.edu/docs/lsc/pcqm4mv2/
|
||||
.. _ogbg-ppa: https://ogb.stanford.edu/docs/graphprop/#ogbg-ppa
|
||||
|
||||
.. _Customized Datasets:
|
||||
Customized Datasets
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Users may create their own datasets. To use customized dataset:
|
||||
|
||||
1. Create a folder (for example, with name `customized_dataset`), and a python script with arbitrary name in the folder.
|
||||
|
||||
2. In the created python script, define a function which returns the created dataset. And register the function with ``register_dataset``. Here is a sample python script.
|
||||
We define a `QM9 <https://docs.dgl.ai/en/0.6.x/api/python/dgl.data.html#qm9-dataset>`__ dataset from ``dgl`` with customized split.
|
||||
|
||||
.. code-block:: python
|
||||
:linenos:
|
||||
|
||||
from graphormer.data import register_dataset
|
||||
from dgl.data import QM9
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
@register_dataset("customized_qm9_dataset")
|
||||
def create_customized_dataset():
|
||||
dataset = QM9(label_keys=["mu"])
|
||||
num_graphs = len(dataset)
|
||||
|
||||
# customized dataset split
|
||||
train_valid_idx, test_idx = train_test_split(
|
||||
np.arange(num_graphs), test_size=num_graphs // 10, random_state=0
|
||||
)
|
||||
train_idx, valid_idx = train_test_split(
|
||||
train_valid_idx, test_size=num_graphs // 5, random_state=0
|
||||
)
|
||||
return {
|
||||
"dataset": dataset,
|
||||
"train_idx": train_idx,
|
||||
"valid_idx": valid_idx,
|
||||
"test_idx": test_idx,
|
||||
"source": "dgl"
|
||||
}
|
||||
|
||||
The function returns a dictionary. In the dictionary, ``dataset`` is the dataset object. ``train_idx`` is the graph indices used for training. Similarly we have
|
||||
``valid_idx`` and ``test_idx``. Finally ``source`` records the underlying graph library used by the dataset.
|
||||
|
||||
3. Specify the ``--user-data-dir`` as ``customized_dataset`` when training. And set ``--dataset-name`` as ``customized_qm9_dataset``.
|
||||
Note that ``--user-data-dir`` should not be used together with ``--dataset-source``. All datasets defined in all python scripts under the ``customized_dataset``
|
||||
will be registered automatically.
|
|
@ -0,0 +1,19 @@
|
|||
Installation Guide
|
||||
==================
|
||||
|
||||
This is a guide to install Graphormer. Currently Graphormer supports intallation on Linux only.
|
||||
|
||||
Linux
|
||||
~~~~~
|
||||
|
||||
On Linux, Graphormer can be easily installed with the install.sh script with prepared python environments.
|
||||
|
||||
1. Please use Python3.9 for Graphormer. It is recommended to create a virtual environment with `conda <https://docs.conda.io/en/latest/>`__ or `virtualenv <https://virtualenv.pypa.io/en/latest/>`__.
|
||||
|
||||
2. Run the following commands:
|
||||
|
||||
.. code::
|
||||
|
||||
git clone --recursive https://github.com/microsoft/Graphormer.git
|
||||
cd Graphormer
|
||||
bash install.sh
|
|
@ -0,0 +1,20 @@
|
|||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
|
@ -0,0 +1,11 @@
|
|||
Overview
|
||||
========
|
||||
|
||||
Basically, Graphormer inherits the extending usage of fairseq, which means it could easily support user-defined `plug-ins <https://fairseq.readthedocs.io/en/latest/overview.html>`_.
|
||||
|
||||
For example, the Graphormer-base model could be defined through :class:`~graphormer.models.GraphormerModel`, which inherits the :class:`~fairseq.models.FairseqModel` class.
|
||||
|
||||
|
||||
It's also easy to extend the Graphormer-base model, which means you could define your own `model <https://fairseq.readthedocs.io/en/latest/models.html>`_ and `criterion <https://fairseq.readthedocs.io/en/latest/criterions.html>`_, and then use them in Graphormer.
|
||||
|
||||
Also, development of new model is easy. We provide a tutorial of how to implement a simple MLP model on graph in :ref:`Tutorials`.
|
|
@ -0,0 +1,179 @@
|
|||
.. _Command-line Tools:
|
||||
|
||||
Command-line Tools
|
||||
==================
|
||||
|
||||
Graphormer reuses the ``fairseq-train`` command-line tools of `fairseq <https://fairseq.readthedocs.io/en/latest/command_line_tools.html>`__ for training, and here we mainly document the additional parameters in Graphormer
|
||||
and parameters of ``fairseq-train`` used by Graphormer.
|
||||
|
||||
Model
|
||||
-----
|
||||
- ``--arch``, type=enum, options: ``graphormer_base``, ``graphormer_slim``, ``graphormer_large``
|
||||
|
||||
- Predefined graphormer architectures
|
||||
|
||||
- ``--encoder-ffn-embed-dim``, type = float
|
||||
|
||||
- encoder embedding dimension for FFN
|
||||
|
||||
- ``--encoder-layers``, type = int
|
||||
|
||||
- number of graphormer encoder layers
|
||||
|
||||
- ``--encoder-embed-dim``, type = int
|
||||
|
||||
- encoder embedding dimension
|
||||
|
||||
- ``--share-encoder-input-output-embed``, type = bool
|
||||
|
||||
- if set, share encoder input and output embeddings
|
||||
|
||||
- ``--share-encoder-input-output-embed``, type = bool
|
||||
|
||||
- if set, share encoder input and output embeddings
|
||||
|
||||
- ``--encoder-learned-pos``, type = bool
|
||||
|
||||
- if set, use learned positional embeddings in the encoder
|
||||
|
||||
- ``--no-token-positional-embeddings``, type = bool
|
||||
|
||||
- if set, disables positional embeddings" " (outside self attention)
|
||||
|
||||
- ``--max-positions``, type = int
|
||||
|
||||
- number of positional embeddings to learn
|
||||
|
||||
- ``--activation-fn``, type = enum, options: ``gelu``, ``relu``
|
||||
|
||||
- activation function to use
|
||||
|
||||
- ``--encoder-normalize-before``
|
||||
|
||||
- if set, apply layernorm before each encoder block
|
||||
|
||||
|
||||
Training
|
||||
--------
|
||||
- ``--apply-graphormer-init``, type = bool
|
||||
|
||||
- if set, use custom param initialization for Graphormer
|
||||
|
||||
- ``--dropout``, type = float
|
||||
|
||||
- dropout probability
|
||||
|
||||
- ``--attention-dropout``, type = float
|
||||
|
||||
- dropout probability for attention weights
|
||||
|
||||
- ``--act-dropout``, type = float
|
||||
|
||||
- dropout probability after activation in FFN
|
||||
|
||||
- ``--seed``, type = int
|
||||
|
||||
- random seed
|
||||
|
||||
- ``--pretrained-model-name``, type = enum, default= ``none``, options: ``pcqm4mv1_graphormer_base``, ``pcqm4mv2_graphormer_base``
|
||||
|
||||
- name of used pretrained model
|
||||
|
||||
- ``pcqm4mv1_graphormer_base``: Pretrained Graphormer base model with `PCQM4M v1 <https://ogb.stanford.edu/kddcup2021/pcqm4m/>`__ dataset.
|
||||
|
||||
- ``pcqm4mv2_graphormer_base``: Pretrained Graphormer base model with `PCQM4M v2 <https://ogb.stanford.edu/docs/lsc/pcqm4mv2/>`__ dataset.
|
||||
|
||||
- ``--load-pretrained-model-output-layer``, type = bool
|
||||
|
||||
- if set, the weights of the final fully connected layer in the pre-trained model is loaded
|
||||
|
||||
- ``--optimizer``, type = enum
|
||||
|
||||
- optimizers from `fairseq <https://fairseq.readthedocs.io/en/latest/optim.html>`__
|
||||
|
||||
- ``--lr``, type = float
|
||||
|
||||
- learning rate
|
||||
|
||||
- ``--lr-scheduler``, type=enum
|
||||
|
||||
- learning rate scheduler from `fairseq <https://fairseq.readthedocs.io/en/latest/lr_scheduler.html>`__
|
||||
|
||||
- ``--fp16``, type=bool
|
||||
|
||||
- if set, use mixed precision training
|
||||
|
||||
- ``--data-buffer-size``, type=int, default=10
|
||||
|
||||
- number of batches to preload
|
||||
|
||||
- ``--batch-size``, type=int
|
||||
|
||||
- number of examples in a batch
|
||||
|
||||
- ``--max-epoch``, type=int, default=0
|
||||
|
||||
- force stop training at specified epoch
|
||||
|
||||
- ``--save-dir``, type=str, default=``checkpoints``
|
||||
|
||||
- path to save checkpoints
|
||||
|
||||
|
||||
Dataset
|
||||
-------
|
||||
- ``--dataset-name``, type = str, default= ``pcqm4m``
|
||||
|
||||
- name of the dataset
|
||||
|
||||
- ``--dataset-source``, type = str, default= ``ogb``
|
||||
|
||||
- source of graph dataset, can be: ``pyg``, ``dgl``, ``ogb``
|
||||
|
||||
- ``--num-classes``, type = int, default=-1
|
||||
|
||||
- number of classes or regression targets
|
||||
|
||||
- ``--num-atoms``, type = int, default=512 * 9
|
||||
|
||||
- number of atom types in the graph
|
||||
|
||||
- ``--num-edges``, type = int, default=512 * 3
|
||||
|
||||
- number of edge types in the graph
|
||||
|
||||
- ``--num-in-degree``, type = int, default=512
|
||||
|
||||
- number of in degree types in the graph
|
||||
|
||||
- ``--num-out-degree``, type = int, default=512
|
||||
|
||||
- number of out degree types in the graph
|
||||
|
||||
- ``--num-spatial``, type = int, default=512
|
||||
|
||||
- number of spatial types in the graph
|
||||
|
||||
- ``--num-edge-dis``, type = int, default=128
|
||||
|
||||
- number of edge dis types in the graph
|
||||
|
||||
- ``--multi-hop-max-dist``, type = int, default=5
|
||||
|
||||
- max number of edges considered in the edge encoding
|
||||
|
||||
- ``--spatial-pos-max``, type = int, default=1024
|
||||
|
||||
- max distance of attention in graph
|
||||
|
||||
- ``--edge-type``, type = str, default="multi_hop"
|
||||
|
||||
- edge type in the graph
|
||||
|
||||
- ``--edge-type``, type = str, default="multi_hop"
|
||||
|
||||
- edge type in the graph
|
||||
|
||||
- ``--user-data-dir``, type = str, default=""
|
||||
|
||||
- path to the module of user-defined dataset
|
|
@ -0,0 +1,10 @@
|
|||
Pretrained Models
|
||||
==================
|
||||
|
||||
Graphormer provides a series of pre-trained model to help users leverage the power of the model quickly and smoothly. Contributing your pre-trained model by creating a pull request.
|
||||
|
||||
- :ref:`Quick-Start`: as a exmple using pre-trained models.
|
||||
|
||||
Pre-trained models from specific papers:
|
||||
|
||||
- `Do Transformers Really Perform Badly for Graph Representation? <https://proceedings.neurips.cc/paper/2021/hash/f1c1592588411002af340cbaedd6fc33-Abstract.html>`__
|
|
@ -0,0 +1,175 @@
|
|||
Start with Example
|
||||
==================
|
||||
|
||||
Graphormer provides example scripts to train your own models on several datasets.
|
||||
For example, to train a Graphormer-slim on ZINC-500K on a single GPU card:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
> cd examples/property_prediction/
|
||||
> bash zinc.sh
|
||||
|
||||
The content of ``zinc.sh`` is simply a ``fairseq-train`` command:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 fairseq-train \
|
||||
--user-dir ../../graphormer \
|
||||
--num-workers 16 \
|
||||
--ddp-backend=legacy_ddp \
|
||||
--dataset-name zinc \
|
||||
--dataset-source pyg \
|
||||
--task graph_prediction \
|
||||
--criterion l1_loss \
|
||||
--arch graphormer_slim \
|
||||
--num-classes 1 \
|
||||
--attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.01 \
|
||||
--lr-scheduler polynomial_decay --power 1 --warmup-updates 60000 --total-num-update 400000 \
|
||||
--lr 2e-4 --end-learning-rate 1e-9 \
|
||||
--batch-size 64 \
|
||||
--fp16 \
|
||||
--data-buffer-size 20 \
|
||||
--encoder-layers 12 \
|
||||
--encoder-embed-dim 80 \
|
||||
--encoder-ffn-embed-dim 80 \
|
||||
--encoder-attention-heads 8 \
|
||||
--max-epoch 10000 \
|
||||
--save-dir ./ckpts
|
||||
|
||||
``CUDA_VISIBLE_DEVICES`` specifies the GPUs to use. With multiple GPUs, the GPU IDs should be separated by commas.
|
||||
A ``fairseq-train`` with Graphormer model is used to launch training.
|
||||
:ref:`Command-line Tools` gives detailed explanations to the parameters.
|
||||
|
||||
Similarily, to train a Graphormer-base on PCQM4M dataset on multiple GPU cards:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
> cd examples/property_prediction/
|
||||
> bash pcqv1.sh
|
||||
|
||||
|
||||
By runing the instructions in the scripts, Graphormer will automatically download the needed datasets and pre-process them.
|
||||
|
||||
|
||||
Evaluate Pre-trained Models
|
||||
===========================
|
||||
|
||||
Graphormer provides pretrained models so that users can easily evaluate, and finetune.
|
||||
To evaluate a pre-trained model, use the script ``graphormer/evaluate/evaluate.py``.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
python evaluate.py \
|
||||
--user-dir ../../graphormer \
|
||||
--num-workers 16 \
|
||||
--ddp-backend=legacy_ddp \
|
||||
--dataset-name pcqm4m \
|
||||
--dataset-source ogb \
|
||||
--task graph_prediction \
|
||||
--criterion l1_loss \
|
||||
--arch graphormer_base \
|
||||
--num-classes 1 \
|
||||
--batch-size 64 \
|
||||
--pretrained-model-name pcqm4mv1_graphormer_base \
|
||||
--load-pretrained-model-output-layer \
|
||||
--split valid \
|
||||
--seed 1
|
||||
|
||||
``--pretrained-model-name`` specifies the pre-trained model to be valuated. The pre-trained model will be automatically downloaded. And ``--load-pretrained-model-output-layer`` is set so that weights of the
|
||||
final fully connected layer in the pre-trained model is loaded. And ``--split`` specifies the split of the dataset to be evaluated, can be ``train`` or ``valid``.
|
||||
|
||||
Fine-tuning Pre-trained Models
|
||||
==============================
|
||||
To fine-tune pre-trained models, use ``--pretrained-model-name`` to set the model name. For example, the script ``examples/property_prediction/hiv_pre.sh``
|
||||
fine-tunes our model ``pcqm4mv1_graphormer_base`` on the ``ogbg-molhiv`` dataset. The command for fine-tune is
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
fairseq-train \
|
||||
--user-dir ../../graphormer \
|
||||
--num-workers 16 \
|
||||
--ddp-backend=legacy_ddp \
|
||||
--dataset-name ogbg-molhiv \
|
||||
--dataset-source ogb \
|
||||
--task graph_prediction_with_flag \
|
||||
--criterion binary_logloss_with_flag \
|
||||
--arch graphormer_base \
|
||||
--num-classes 1 \
|
||||
--attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --weight-decay 0.0 \
|
||||
--lr-scheduler polynomial_decay --power 1 --warmup-updates $warmup_updates --total-num-update $tot_updates \
|
||||
--lr 2e-4 --end-learning-rate 1e-9 \
|
||||
--batch-size $batch_size \
|
||||
--fp16 \
|
||||
--data-buffer-size 20 \
|
||||
--encoder-layers 12 \
|
||||
--encoder-embed-dim 768 \
|
||||
--encoder-ffn-embed-dim 768 \
|
||||
--encoder-attention-heads 32 \
|
||||
--max-epoch $max_epoch \
|
||||
--save-dir ./ckpts \
|
||||
--pretrained-model-name pcqm4mv1_graphormer_base \
|
||||
--flag-m 2 \
|
||||
--flag-step-size 0.2 \
|
||||
--flag-mag 0 \
|
||||
--seed 1
|
||||
|
||||
After fine-tuning, use ``graphormer/evaluate/evaluate.py`` to evaluate the performance of all checkpoints:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
python evaluate.py \
|
||||
--user-dir ../../graphormer \
|
||||
--num-workers 16 \
|
||||
--ddp-backend=legacy_ddp \
|
||||
--dataset-name ogbg-molhiv \
|
||||
--dataset-source ogb \
|
||||
--task graph_prediction \
|
||||
--arch graphormer_base \
|
||||
--num-classes 1 \
|
||||
--batch-size 64 \
|
||||
--save-dir ../../examples/property_prediction/ckpts/ \
|
||||
--split test \
|
||||
--metric auc \
|
||||
--seed 1
|
||||
|
||||
|
||||
Training a New Model
|
||||
====================
|
||||
|
||||
We take OC20 as an example to show how to train a new model on your own datasets.
|
||||
|
||||
First, download IS2RE train, validation, and test data in LMDB format by:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
> cd examples/oc20/ && mkdir data && cd data/
|
||||
> wget -c https://dl.fbaipublicfiles.com/opencatalystproject/data/is2res_train_val_test_lmdbs.tar.gz && tar -xzvf is2res_train_val_test_lmdbs.tar.gz
|
||||
|
||||
Create ``ckpt`` folder to save checkpoints during the training:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
> cd ../ && mkdir ckpt/
|
||||
|
||||
Now we train a 48-layer ``graphormer-3D`` architecture, which has 4 blocks and each block contains 12 Graphormer layers. The parameters are sharing across blocks. The total training steps are 1 million, and we warmup the learning rate by 10 thousand steps.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
> fairseq-train --user-dir ../../graphormer \
|
||||
./data/is2res_train_val_test_lmdbs/data/is2re/all --valid-subset val_id,val_ood_ads,val_ood_cat,val_ood_both --best-checkpoint-metric loss \
|
||||
--num-workers 0 --ddp-backend=c10d \
|
||||
--task is2re --criterion mae_deltapos --arch graphormer3d_base \
|
||||
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 --clip-norm $clip_norm \
|
||||
--lr-scheduler polynomial_decay --lr 3e-4 --warmup-updates --total-num-update 1000000 --batch-size 4 \
|
||||
--dropout 0.0 --attention-dropout 0.1 --weight-decay 0.001 --update-freq 1 --seed 1 \
|
||||
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 --tensorboard-logdir ./tsbs \
|
||||
--embed-dim 768 --ffn-embed-dim 768 --attention-heads 48 \
|
||||
--max-update 1000000 --log-interval 100 --log-format simple \
|
||||
--save-interval-updates 5000 --validate-interval-updates 2500 --keep-interval-updates 30 --no-epoch-checkpoints \
|
||||
--save-dir ./ckpt --layers 12 --blocks 4 --required-batch-size-multiple 1 --node-loss-weight 15
|
||||
|
||||
Please note that ``--batch-size 4`` requires at least 32GB of GPU memory. If out of GPU momery occuars, one may try to reduce the batchsize then train with more GPU cards, or increase the ``--update-freq`` to accumulate the gradients.
|
||||
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
Simple MLP Tutorial
|
||||
===================
|
||||
|
||||
In this tutorial, we will extend Graphormer by adding a new :class:`~graphormer.models.GraphMLP` that transforms the node features, and uses a sum pooling layer to combine the output of the MLP as graph representation.
|
||||
|
||||
This tutorial covers:
|
||||
|
||||
1. **Writing a new Model** so that the node token embeddings can be transformed by the MLP.
|
||||
2. **Training the Model** using the existing command-line tools.
|
||||
|
||||
1. Writing a new GraphMLP Model
|
||||
--------------------------------
|
||||
|
||||
First, we create a new file with filename :file:`graphormer/models/graphmlp.py`::
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq.models import FairseqEncoderModel, register_model
|
||||
|
||||
@register_model("graphmlp")
|
||||
class GraphMLP(FairseqEncoderModel):
|
||||
def __init__(self, args, encoder):
|
||||
super().__init__(encoder)
|
||||
self.args = args
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add model-specific arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-nodes", type=int, metavar="N", help="num max nodes"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder-embed-dim",type=int, metavar="N", help="encoder embedding dimension",
|
||||
)
|
||||
|
||||
def max_nodes(self):
|
||||
return self.encoder.max_nodes
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
# make sure all arguments are present in older models
|
||||
graphmlp_architecture(args)
|
||||
encoder = GraphMLPEncoder(args)
|
||||
return cls(args, encoder)
|
||||
|
||||
def forward(self, batched_data, **kwargs):
|
||||
return self.encoder(batched_data, **kwargs)
|
||||
|
||||
The main component in :class:`~graphormer.models.GraphMLP` is the :class:`~graphormer.models.GraphMLPEncoder`. Here we implement it by adding following codes in :file:`graphormer/models/graphmlp.py`::
|
||||
|
||||
from fairseq.models import FairseqEncoder
|
||||
from ..modules import GraphNodeFeature
|
||||
|
||||
class GraphMLPEncoder(FairseqEncoder):
|
||||
def __init__(self, args):
|
||||
super().__init__(dictionary=None)
|
||||
self.max_nodes = args.max_nodes
|
||||
self.emb_dim = args.encoder_embed_dim
|
||||
self.num_layer = args.encoder_layers
|
||||
self.num_classes = args.num_classes
|
||||
|
||||
self.atom_encoder = GraphNodeFeature(
|
||||
num_heads=1,
|
||||
num_atoms=512*9,
|
||||
num_in_degree=512,
|
||||
num_out_degree=512,
|
||||
hidden_dim=self.emb_dim,
|
||||
n_layers=self.num_layer,
|
||||
)
|
||||
|
||||
self.linear = torch.nn.ModuleList()
|
||||
self.batch_norms = torch.nn.ModuleList()
|
||||
|
||||
for layer in range(self.num_layer):
|
||||
self.linear.append(torch.nn.Linear(self.emb_dim, self.emb_dim))
|
||||
self.batch_norms.append(torch.nn.BatchNorm1d(self.emb_dim))
|
||||
|
||||
|
||||
self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_classes)
|
||||
|
||||
|
||||
def forward(self, batched_data, **unused):
|
||||
h=self.atom_encoder(batched_data)
|
||||
for layer in range(self.num_layer):
|
||||
h = self.linear[layer](h)
|
||||
h = h.transpose(1,2)
|
||||
h = self.batch_norms[layer](h)
|
||||
h = h.transpose(1,2)
|
||||
|
||||
if layer != self.num_layer - 1:
|
||||
h = F.relu(h)
|
||||
|
||||
h = h.sum(dim=1)
|
||||
out = self.graph_pred_linear(h)
|
||||
|
||||
return out.unsqueeze(1)
|
||||
|
||||
|
||||
def max_nodes(self):
|
||||
return self.max_nodes
|
||||
|
||||
Since we will validate our GraphMLP model on a graph representation task, we choose dataset in MoleculeNet. Therefore, we employ the :class:`~graphormer.modules.GraphNodeFeature` to encode the node features.
|
||||
|
||||
And finally, we register the model architecture by adding following codes in :file:`graphormer/models/graphmlp.py`::
|
||||
|
||||
from fairseq.models import register_model_architecture
|
||||
@register_model_architecture("graphmlp", "graphmlp")
|
||||
def graphmlp_architecture(args):
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
||||
|
||||
args.max_nodes = getattr(args, "max_nodes", 512)
|
||||
|
||||
2. Training the Model
|
||||
---------------------
|
||||
|
||||
Next, we prepare the training script for the model. We create a bash file :file:`examples/property_prediction/graphmlp.sh`::
|
||||
|
||||
#!/bin/bash
|
||||
CUDA_VISIBLE_DEVICES=0 fairseq-train \
|
||||
--user-dir ../../graphormer \
|
||||
--num-workers 16 \
|
||||
--ddp-backend=legacy_ddp \
|
||||
--dataset-name moleculenet:name=bbbp \
|
||||
--dataset-source pyg \
|
||||
--task graph_prediction \
|
||||
--criterion binary_logloss \
|
||||
--arch graphmlp \
|
||||
--num-classes 1 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.0 \
|
||||
--lr-scheduler polynomial_decay --power 1 --total-num-update 1000000 \
|
||||
--lr 0.001 --end-learning-rate 1e-9 \
|
||||
--batch-size 32 \
|
||||
--fp16 \
|
||||
--data-buffer-size 20 \
|
||||
--encoder-layers 5 \
|
||||
--encoder-embed-dim 256 \
|
||||
--max-epoch 100 \
|
||||
--save-dir ./ckpts \
|
||||
--save-interval-updates 50000 \
|
||||
--no-epoch-checkpoints
|
||||
|
||||
By executing the script, after the dataset is downloaded and processed, the training of the GraphMLP model starts.
|
|
@ -0,0 +1,66 @@
|
|||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file only contains a selection of the most common options. For a full
|
||||
# list see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
# import os
|
||||
# import sys
|
||||
# sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
from pathlib import Path
|
||||
CURR_PATH = Path(__file__).absolute().parent
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'Graphormer'
|
||||
copyright = '2021, Microsoft Corporation'
|
||||
author = 'MSRA Graphormer Team'
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = '1.0'
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
# further. For a list of options available for each theme, see the
|
||||
# documentation.
|
||||
html_theme_options = {
|
||||
'includehidden': False,
|
||||
'logo_only': True,
|
||||
}
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
|
||||
html_logo = str(CURR_PATH / 'logo-09.png')
|
Двоичные данные
docs/graphformer.png
Двоичные данные
docs/graphformer.png
Двоичный файл не отображается.
До Ширина: | Высота: | Размер: 75 KiB |
Двоичные данные
docs/graphformer_logo.png
Двоичные данные
docs/graphformer_logo.png
Двоичный файл не отображается.
До Ширина: | Высота: | Размер: 4.5 KiB |
|
@ -0,0 +1,43 @@
|
|||
.. Graphormer documentation master file, created by
|
||||
sphinx-quickstart on Mon Dec 20 08:26:49 2021.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
:github_url: https://github.com/Microsoft/Graphormer
|
||||
|
||||
Welcome to Graphormer's documentation!
|
||||
======================================
|
||||
|
||||
Graphormer is a deep learning package extended from `fairseq <https://fairseq.readthedocs.io/en/latest/>`__ that allows researchers and developers to train custom models for molecule modeling tasks. It aims to accelerate the research and application in AI for molecule science, such as material discovery, drug discovery, etc.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Getting Started
|
||||
|
||||
Installation Guide <Installation-Guide>
|
||||
Quick Start <Quick-Start>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Features
|
||||
|
||||
Command-line Tools <Parameters>
|
||||
Datasets <Datasets>
|
||||
Pretrained Models <Pretrained-Models>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Extending Usage
|
||||
|
||||
Overview <Overview>
|
||||
Simple MLP Tutorial <Tutorisals>
|
||||
|
||||
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
* :ref:`search`
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 116 KiB |
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 116 KiB |
|
@ -0,0 +1,35 @@
|
|||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.https://www.sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
|
@ -1,85 +0,0 @@
|
|||
# Benchmarking Graph Neural Networks
|
||||
|
||||
[https://arxiv.org/abs/2003.00982](https://arxiv.org/abs/2003.00982)
|
||||
|
||||
[https://github.com/graphdeeplearning/benchmarking-gnns](https://github.com/graphdeeplearning/benchmarking-gnns)
|
||||
|
||||
## Results
|
||||
|
||||
#### ZINC-500K
|
||||
Method | #params | test MAE |
|
||||
--------------|---------|------------|
|
||||
GIN | 509.5K | 0.526 |
|
||||
GraphSage | 505.3K | 0.398 |
|
||||
GAT | 531.3K | 0.384 |
|
||||
GCN | 505.1K | 0.367 |
|
||||
GT | 588.9K | 0.226 |
|
||||
GatedGCN-PE | 505.0K | 0.214 |
|
||||
MPNN (sum) | 480.8K | 0.145 |
|
||||
PNA | 387.2K | 0.142 |
|
||||
SAN | 508.6K | 0.139 |
|
||||
Graphormer-Slim | 489.3K | **0.122** |
|
||||
|
||||
## Example Usage
|
||||
|
||||
```
|
||||
[ -z "${exp_name}" ] && exp_name="zinc"
|
||||
[ -z "${seed}" ] && seed="1"
|
||||
[ -z "${arch}" ] && arch="--ffn_dim 80 --hidden_dim 80 --num_heads 8 --dropout_rate 0.1 --n_layers 12 --peak_lr 2e-4 --edge_type multi_hop --multi_hop_max_dist 20"
|
||||
[ -z "${warmup_updates}" ] && warmup_updates="40000"
|
||||
[ -z "${tot_updates}" ] && tot_updates="400000"
|
||||
|
||||
echo -e "\n\n"
|
||||
echo "=====================================ARGS======================================"
|
||||
echo "arg0: $0"
|
||||
echo "arch: ${arch}"
|
||||
echo "seed: ${seed}"
|
||||
echo "exp_name: ${exp_name}"
|
||||
echo "warmup_updates: ${warmup_updates}"
|
||||
echo "tot_updates: ${tot_updates}"
|
||||
echo "==============================================================================="
|
||||
|
||||
save_path="../../exps/zinc/$exp_name-$warmup_updates-$tot_updates/$seed"
|
||||
mkdir -p $save_path
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python ../../graphormer/entry.py --num_workers 8 --seed $seed --batch_size 256 \
|
||||
--dataset_name ZINC \
|
||||
--gpus 1 --accelerator ddp --precision 16 \
|
||||
$arch \
|
||||
--check_val_every_n_epoch 10 --warmup_updates $warmup_updates --tot_updates $tot_updates \
|
||||
--default_root_dir $save_path
|
||||
```
|
||||
|
||||
## Citation
|
||||
Please kindly cite this paper if you use the code:
|
||||
```
|
||||
@article{ying2021transformers,
|
||||
title={Do Transformers Really Perform Bad for Graph Representation?},
|
||||
author={Ying, Chengxuan and Cai, Tianle and Luo, Shengjie and Zheng, Shuxin and Ke, Guolin and He, Di and Shen, Yanming and Liu, Tie-Yan},
|
||||
journal={arXiv preprint arXiv:2106.05234},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
## Trademarks
|
||||
|
||||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
||||
trademarks or logos is subject to and must follow
|
||||
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
||||
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
||||
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
|
@ -0,0 +1,74 @@
|
|||
# Open Catalyst Challenge
|
||||
|
||||
|
||||
<img src="ocp.gif" width=70%>
|
||||
|
||||
The Open Catalyst Project is a collaborative research effort between Facebook AI Research (FAIR) and Carnegie Mellon University’s (CMU) Department of Chemical Engineering. The aim is to use AI to model and discover new catalysts for use in renewable energy storage to help in addressing climate change.
|
||||
|
||||
The detailed description of this dataset could be found in [here](https://opencatalystproject.org/).
|
||||
|
||||
|
||||
|
||||
### Example Usage
|
||||
|
||||
Data Preparation: Follow the instructions [here](https://github.com/Open-Catalyst-Project/ocp/blob/master/DATASET.md) to prepare your OC20 dataset.
|
||||
|
||||
To train a Graphormer-3D for IS2RE direct task:
|
||||
|
||||
```bash oc20.sh```
|
||||
|
||||
#### IS2RE Direct Energy MAE (eV) on test split
|
||||
|
||||
Method | ID | OOD Ads | OOD Cat | OOD Both | Avg |
|
||||
--------------|---------|-----------|-----------|-----------|---------|
|
||||
CGCNN | 0.6149 | 0.9155 | 0.6219 | 0.8511 | 0.7509 |
|
||||
SchNet | 0.6387 | 0.7342 | 0.6616 | 0.7037 | 0.6846 |
|
||||
DimeNet++ | 0.5620 | 0.7252 | 0.5756 | 0.6613 | 0.6311 |
|
||||
SphereNet | 0.5625 | 0.7033 | 0.5708 | 0.6378 | 0.6186 |
|
||||
SpinConv | 0.5583 | 0.7230 | 0.5687 | 0.6738 | 0.6310 |
|
||||
Noisy Node | 0.4776 | 0.5646 | 0.4932 | 0.5042 | 0.5099 |
|
||||
Graphormer-3D (ensemble) | 0.3976 | 0.5719 | 0.4166 | 0.5029 | 0.4722 |
|
||||
|
||||
*note: Evaluation of model performance on test split requires submission through [EvalAI](https://eval.ai/web/challenges/challenge-page/712/overview).
|
||||
|
||||
|
||||
## Citation
|
||||
Please kindly cite this paper if you use the code:
|
||||
```
|
||||
@inproceedings{
|
||||
ying2021do,
|
||||
title={Do Transformers Really Perform Badly for Graph Representation?},
|
||||
author={Chengxuan Ying and Tianle Cai and Shengjie Luo and Shuxin Zheng and Guolin Ke and Di He and Yanming Shen and Tie-Yan Liu},
|
||||
booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
|
||||
year={2021},
|
||||
url={https://openreview.net/forum?id=OeWooOxFwDa}
|
||||
}
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
## Trademarks
|
||||
|
||||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
||||
trademarks or logos is subject to and must follow
|
||||
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
||||
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
||||
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#!/usr/bin/env bash
|
||||
|
||||
lr=${lr:-3e-4}
|
||||
warmup_steps=${warmup_steps:-10000}
|
||||
total_steps=${total_steps:-1000000}
|
||||
layers=${layers:-12}
|
||||
hidden_size=${hidden_size:-768}
|
||||
num_head=${num_head:-48}
|
||||
batch_size=${batch_size:-2}
|
||||
seed=${seed:-1}
|
||||
clip_norm=${clip_norm:-5}
|
||||
blocks=${blocks:-4}
|
||||
node_loss_weight=${node_loss_weight:-15}
|
||||
update_freq=${update_freq:-1}
|
||||
|
||||
save_dir=./ckpts
|
||||
tsb_dir=./tsbs
|
||||
mkdir -p $save_dir
|
||||
|
||||
echo -e "\n\n"
|
||||
echo "=====================================ARGS======================================"
|
||||
echo "arg0: $0"
|
||||
echo "seed: ${seed}"
|
||||
echo "batch_size: ${batch_size}"
|
||||
echo "layers: ${layers}"
|
||||
echo "update_freq: ${update_freq}"
|
||||
echo "lr: ${lr}"
|
||||
echo "warmup_steps: ${warmup_steps}"
|
||||
echo "total_steps: ${total_steps}"
|
||||
echo "clip_norm: ${clip_norm}"
|
||||
echo "blocks: ${blocks}"
|
||||
echo "node_loss_weight: ${node_loss_weight}"
|
||||
echo "save_dir: ${save_dir}"
|
||||
echo "tsb_dir: ${tsb_dir}"
|
||||
echo "==============================================================================="
|
||||
|
||||
fairseq-train --user-dir ../../graphormer \
|
||||
/home/$USER/ocp/data/is2re/all --valid-subset val_id,val_ood_ads,val_ood_cat,val_ood_both --best-checkpoint-metric loss \
|
||||
--num-workers 0 --ddp-backend=c10d \
|
||||
--task is2re --criterion mae_deltapos --arch graphormer3d_base \
|
||||
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 --clip-norm $clip_norm \
|
||||
--lr-scheduler polynomial_decay --lr $lr --warmup-updates $warmup_steps --total-num-update $total_steps --batch-size $batch_size \
|
||||
--dropout 0.0 --attention-dropout 0.1 --weight-decay 0.001 --update-freq $update_freq --seed $seed \
|
||||
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 --tensorboard-logdir $tsb_dir \
|
||||
--embed-dim $hidden_size --ffn-embed-dim $hidden_size --attention-heads $num_head \
|
||||
--max-update $total_steps --log-interval 100 --log-format simple \
|
||||
--save-interval-updates 5000 --validate-interval-updates 2500 --keep-interval-updates 30 --no-epoch-checkpoints \
|
||||
--save-dir $save_dir --layers $layers --blocks $blocks --required-batch-size-multiple 1 --node-loss-weight $node_loss_weight
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 6.8 MiB |
|
@ -1,79 +0,0 @@
|
|||
# Open Graph Benchmark - Large-Scale Challenge (KDD Cup 2021)
|
||||
|
||||
[https://ogb.stanford.edu/kddcup2021/](https://ogb.stanford.edu/kddcup2021/)
|
||||
|
||||
[https://arxiv.org/abs/2103.09430](https://arxiv.org/abs/2103.09430)
|
||||
|
||||
## Results
|
||||
|
||||
#### PCQM4M-LSC
|
||||
Method | #params | train MAE | valid MAE |
|
||||
--------------|---------|-----------|-----------|
|
||||
GCN | 2.0M | 0.1318 | 0.1691 |
|
||||
GIN | 3.8M | 0.1203 | 0.1537 |
|
||||
GCN-VN | 4.9M | 0.1225 | 0.1485 |
|
||||
GIN-VN | 6.7M | 0.1150 | 0.1395 |
|
||||
Graphormer-Small| 12.5M | 0.0778 | 0.1264 |
|
||||
Graphormer | 47.1M | 0.0582 | **0.1234** |
|
||||
|
||||
## Example Usage
|
||||
|
||||
```
|
||||
[ -z "${exp_name}" ] && exp_name="pcq"
|
||||
[ -z "${seed}" ] && seed="1"
|
||||
[ -z "${arch}" ] && arch="--ffn_dim 768 --hidden_dim 768 --dropout_rate 0.1 --n_layers 12 --peak_lr 2e-4 --edge_type multi_hop --multi_hop_max_dist 5"
|
||||
[ -z "${batch_size}" ] && batch_size="256"
|
||||
|
||||
echo -e "\n\n"
|
||||
echo "=====================================ARGS======================================"
|
||||
echo "arg0: $0"
|
||||
echo "exp_name: ${exp_name}"
|
||||
echo "arch: ${arch}"
|
||||
echo "seed: ${seed}"
|
||||
echo "batch_size: ${batch_size}"
|
||||
echo "==============================================================================="
|
||||
|
||||
default_root_dir="../../exps/pcq/$exp_name/$seed"
|
||||
mkdir -p $default_root_dir
|
||||
n_gpu=$(nvidia-smi -L | wc -l)
|
||||
|
||||
python ../../graphormer/entry.py --num_workers 8 --seed $seed --batch_size $batch_size \
|
||||
--dataset_name PCQM4M-LSC \
|
||||
--gpus $n_gpu --accelerator ddp --precision 16 --gradient_clip_val 5.0 \
|
||||
$arch \
|
||||
--default_root_dir $default_root_dir
|
||||
```
|
||||
|
||||
|
||||
## Citation
|
||||
Please kindly cite this paper if you use the code:
|
||||
```
|
||||
@article{ying2021transformers,
|
||||
title={Do Transformers Really Perform Bad for Graph Representation?},
|
||||
author={Ying, Chengxuan and Cai, Tianle and Luo, Shengjie and Zheng, Shuxin and Ke, Guolin and He, Di and Shen, Yanming and Liu, Tie-Yan},
|
||||
journal={arXiv preprint arXiv:2106.05234},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
## Trademarks
|
||||
|
||||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
||||
trademarks or logos is subject to and must follow
|
||||
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
||||
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
||||
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
|
@ -1,77 +0,0 @@
|
|||
# Open Graph Benchmark
|
||||
|
||||
[https://arxiv.org/abs/2005.00687](https://arxiv.org/abs/2005.00687)
|
||||
|
||||
|
||||
[https://ogb.stanford.edu/](https://ogb.stanford.edu/)
|
||||
|
||||
## Results
|
||||
|
||||
#### OGBG-MolPCBA
|
||||
Method | #params | test AP (%)|
|
||||
--------------|---------|------------|
|
||||
DeeperGCN-VN+FLAG | 5.6M | 28.42 |
|
||||
DGN | 6.7M | 28.85 |
|
||||
GINE-VN | 6.1M | 29.17 |
|
||||
PHC-GNN | 1.7M | 29.47 |
|
||||
GINE-APPNP | 6.1M | 29.79 |
|
||||
Graphormer | 119.5M | **31.39** |
|
||||
|
||||
#### OGBG-MolHIV
|
||||
Method | #params | test AP (%)|
|
||||
--------------|---------|------------|
|
||||
GCN-GraphNorm | 526K | 78.83 |
|
||||
PNA | 326K | 79.05 |
|
||||
PHC-GNN | 111K | 79.34 |
|
||||
DeeperGCN-FLAG | 532K | 79.42 |
|
||||
DGN | 114K | 79.70 |
|
||||
Graphormer | 47.0M | **80.51** |
|
||||
|
||||
## Example Usage
|
||||
|
||||
Prepare your pre-trained models following our paper ["Do Transformers Really Perform Bad for Graph Representation?"](https://arxiv.org/abs/2106.05234).
|
||||
|
||||
Fine-tuning your pre-trained model on OGBG-MolPCBA:
|
||||
|
||||
```
|
||||
bash pcba.sh
|
||||
```
|
||||
|
||||
Fine-tuning your pre-trained model on OGBG-MolHIV:
|
||||
|
||||
```
|
||||
bash hiv.sh
|
||||
```
|
||||
|
||||
## Citation
|
||||
Please kindly cite this paper if you use the code:
|
||||
```
|
||||
@article{ying2021transformers,
|
||||
title={Do Transformers Really Perform Bad for Graph Representation?},
|
||||
author={Ying, Chengxuan and Cai, Tianle and Luo, Shengjie and Zheng, Shuxin and Ke, Guolin and He, Di and Shen, Yanming and Liu, Tie-Yan},
|
||||
journal={arXiv preprint arXiv:2106.05234},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
## Trademarks
|
||||
|
||||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
||||
trademarks or logos is subject to and must follow
|
||||
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
||||
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
||||
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#!/usr/bin/env bash
|
||||
|
||||
[ -z "${exp_name}" ] && exp_name="hiv_flag"
|
||||
[ -z "${seed}" ] && seed="1"
|
||||
[ -z "${arch}" ] && arch="--ffn_dim 768 --hidden_dim 768 --intput_dropout_rate 0.0 --attention_dropout_rate 0.1 --dropout_rate 0.1 --weight_decay 0.0 --n_layers 12 --edge_type multi_hop --multi_hop_max_dist 5"
|
||||
[ -z "${batch_size}" ] && batch_size="128" # Alternatively, you can decrease the bsz to 64 and use 2 GPUs, if you do not have 32G GPU memory.
|
||||
[ -z "${epoch}" ] && epoch="8"
|
||||
[ -z "${peak_lr}" ] && peak_lr="2e-4"
|
||||
[ -z "${end_lr}" ] && end_lr="1e-9"
|
||||
|
||||
[ -z "${flag_m}" ] && flag_m="2"
|
||||
[ -z "${flag_step_size}" ] && flag_step_size="0.2"
|
||||
[ -z "${flag_mag}" ] && flag_mag="0"
|
||||
|
||||
[ -z "${ckpt_path}" ] && ckpt_path="../../checkpoints/hiv/<your_pretrained_model_for_hiv>"
|
||||
|
||||
echo -e "\n\n"
|
||||
echo "=====================================ARGS======================================"
|
||||
echo "arg0: $0"
|
||||
echo "exp_name: ${exp_name}"
|
||||
echo "ckpt_path ${ckpt_path}"
|
||||
echo "arch: ${arch}"
|
||||
echo "batch_size: ${batch_size}"
|
||||
echo "peak_lr ${peak_lr}"
|
||||
echo "end_lr ${end_lr}"
|
||||
echo "flag_m ${flag_m}"
|
||||
echo "flag_step_size :${flag_step_size}"
|
||||
echo "flag_mag: ${flag_mag}"
|
||||
echo "seed: ${seed}"
|
||||
echo "epoch: ${epoch}"
|
||||
echo "==============================================================================="
|
||||
|
||||
n_gpu=1 # Please use 1 GPU (We use 1 32GB V100 card) to reproduce our results.
|
||||
tot_updates=$((33000*epoch/batch_size/n_gpu))
|
||||
warmup_updates=$((tot_updates/10))
|
||||
max_epochs=$((epoch+1))
|
||||
echo "=====================================ARGS======================================"
|
||||
echo "tot_updates ${tot_updates}"
|
||||
echo "warmup_updates: ${warmup_updates}"
|
||||
echo "max_epochs: ${max_epochs}"
|
||||
echo "==============================================================================="
|
||||
|
||||
default_root_dir=../../exps/hiv/$exp_name/$seed
|
||||
mkdir -p $default_root_dir
|
||||
|
||||
python ../../graphormer/entry.py --num_workers 8 --seed $seed --batch_size $batch_size \
|
||||
--dataset_name ogbg-molhiv \
|
||||
--gpus $n_gpu --accelerator ddp --precision 16 $arch \
|
||||
--default_root_dir $default_root_dir \
|
||||
--tot_updates $tot_updates --warmup_updates $warmup_updates --max_epochs $max_epochs \
|
||||
--checkpoint_path $ckpt_path \
|
||||
--peak_lr $peak_lr --end_lr $end_lr --progress_bar_refresh_rate 10 \
|
||||
--flag --flag_m $flag_m --flag_step_size $flag_step_size --flag_mag $flag_mag
|
||||
|
||||
|
||||
# validate and test on every checkpoint
|
||||
checkpoint_dir=$default_root_dir/lightning_logs/checkpoints/
|
||||
echo "=====================================EVAL======================================"
|
||||
for file in `ls $checkpoint_dir/*.ckpt`
|
||||
do
|
||||
echo -e "\n\n\n ckpt:"
|
||||
echo "$file"
|
||||
echo -e "\n\n\n"
|
||||
python ../../graphormer/entry.py --num_workers 8 --seed 1 --batch_size $batch_size \
|
||||
--dataset_name ogbg-molhiv \
|
||||
--gpus 1 --accelerator ddp --precision 16 $arch \
|
||||
--default_root_dir tmp/ \
|
||||
--checkpoint_path $file --validate --progress_bar_refresh_rate 100
|
||||
|
||||
python ../../graphormer/entry.py --num_workers 8 --seed 1 --batch_size $batch_size \
|
||||
--dataset_name ogbg-molhiv \
|
||||
--gpus 1 --accelerator ddp --precision 16 $arch \
|
||||
--default_root_dir tmp/ \
|
||||
--checkpoint_path $file --test --progress_bar_refresh_rate 100
|
||||
done
|
||||
echo "==============================================================================="
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#!/usr/bin/env bash
|
||||
|
||||
[ -z "${exp_name}" ] && exp_name="pcba_flag"
|
||||
[ -z "${seed}" ] && seed="1"
|
||||
[ -z "${arch}" ] && arch="--ffn_dim 1024 --hidden_dim 1024 --intput_dropout_rate 0.0 --attention_dropout_rate 0.3 --dropout_rate 0.1 --weight_decay 0.0 --n_layers 18 --edge_type multi_hop --multi_hop_max_dist 5"
|
||||
[ -z "${batch_size}" ] && batch_size="64" # Alternatively, you can decrease the bsz to 32, if you do not have 32G GPU memory.
|
||||
[ -z "${epoch}" ] && epoch="10"
|
||||
[ -z "${peak_lr}" ] && peak_lr="3e-4"
|
||||
[ -z "${end_lr}" ] && end_lr="1e-9"
|
||||
|
||||
[ -z "${flag_m}" ] && flag_m="4"
|
||||
[ -z "${flag_step_size}" ] && flag_step_size="0.001"
|
||||
[ -z "${flag_mag}" ] && flag_mag="0.001"
|
||||
|
||||
[ -z "${ckpt_path}" ] && ckpt_path="../../checkpoints/<your_pretrained_model_for_pcba>"
|
||||
|
||||
echo -e "\n\n"
|
||||
echo "=====================================ARGS======================================"
|
||||
echo "arg0: $0"
|
||||
echo "exp_name: ${exp_name}"
|
||||
echo "ckpt_path ${ckpt_path}"
|
||||
echo "arch: ${arch}"
|
||||
echo "batch_size: ${batch_size}"
|
||||
echo "peak_lr ${peak_lr}"
|
||||
echo "end_lr ${end_lr}"
|
||||
echo "flag_m ${flag_m}"
|
||||
echo "flag_step_size :${flag_step_size}"
|
||||
echo "flag_mag: ${flag_mag}"
|
||||
echo "seed: ${seed}"
|
||||
echo "epoch: ${epoch}"
|
||||
echo "==============================================================================="
|
||||
|
||||
n_gpu=$(nvidia-smi -L | wc -l) # Please use 4 GPUs (We use 4 V100 cards) to reproduce our results.
|
||||
tot_updates=$((350000*epoch/batch_size/n_gpu))
|
||||
warmup_updates=$((tot_updates/16))
|
||||
max_epochs=$((epoch+1))
|
||||
echo "=====================================ARGS======================================"
|
||||
echo "tot_updates ${tot_updates}"
|
||||
echo "warmup_updates: ${warmup_updates}"
|
||||
echo "max_epochs: ${max_epochs}"
|
||||
echo "==============================================================================="
|
||||
|
||||
default_root_dir=../../exps/pcba/$exp_name/$seed
|
||||
mkdir -p $default_root_dir
|
||||
|
||||
python ../../graphormer/entry.py --num_workers 8 --seed $seed --batch_size $batch_size \
|
||||
--dataset_name ogbg-molpcba \
|
||||
--gpus $n_gpu --accelerator ddp --precision 32 $arch \
|
||||
--default_root_dir $default_root_dir \
|
||||
--tot_updates $tot_updates --warmup_updates $warmup_updates --max_epochs $max_epochs \
|
||||
--checkpoint_path $ckpt_path \
|
||||
--peak_lr $peak_lr --end_lr $end_lr --progress_bar_refresh_rate 10 \
|
||||
--flag --flag_m $flag_m --flag_step_size $flag_step_size --flag_mag $flag_mag
|
||||
|
||||
|
||||
# validate and test on every checkpoint
|
||||
checkpoint_dir=$default_root_dir/lightning_logs/checkpoints/
|
||||
echo "=====================================EVAL======================================"
|
||||
for file in `ls $checkpoint_dir/*.ckpt`
|
||||
do
|
||||
echo -e "\n\n\n ckpt:"
|
||||
echo "$file"
|
||||
echo -e "\n\n\n"
|
||||
python ../../graphormer/entry.py --num_workers 8 --seed 1 --batch_size $batch_size \
|
||||
--dataset_name ogbg-molpcba \
|
||||
--gpus 1 --accelerator ddp --precision 16 $arch \
|
||||
--default_root_dir tmp/ \
|
||||
--checkpoint_path $file --validate --progress_bar_refresh_rate 100
|
||||
|
||||
python ../../graphormer/entry.py --num_workers 8 --seed 1 --batch_size $batch_size \
|
||||
--dataset_name ogbg-molpcba \
|
||||
--gpus 1 --accelerator ddp --precision 16 $arch \
|
||||
--default_root_dir tmp/ \
|
||||
--checkpoint_path $file --test --progress_bar_refresh_rate 100
|
||||
done
|
||||
echo "==============================================================================="
|
|
@ -0,0 +1,121 @@
|
|||
## Open Graph Benchmark - Large-Scale Challenge (KDD Cup 2021)
|
||||
|
||||
The detailed description of the dataset could be found [here](https://ogb.stanford.edu/kddcup2021/).
|
||||
|
||||
### Example Usage
|
||||
To train Graphormer-Base on PCQM4Mv1 dataset:
|
||||
|
||||
```bash pcqv1.sh```
|
||||
|
||||
To train Graphormer-Base on PCQM4Mv2 dataset:
|
||||
|
||||
```bash pcqv2.sh```
|
||||
|
||||
|
||||
Note that the ```--batch-size``` should be modified accordingly to set the total batchsize as 1024.
|
||||
|
||||
#### PCQM4Mv2
|
||||
Method | #params | train MAE | valid MAE |
|
||||
--------------|---------|-----------|-----------|
|
||||
GCN | 2.0M | -- | 0.1379 |
|
||||
GIN | 3.8M | -- | 0.1195 |
|
||||
GCN-VN | 4.9M | -- | 0.1153 |
|
||||
GIN-VN | 6.7M | -- | 0.1083 |
|
||||
Graphormer-v2 | 47.1M | 0.0253 | **0.0865** |
|
||||
|
||||
#### PCQM4Mv1
|
||||
Method | #params | train MAE | valid MAE |
|
||||
--------------|---------|-----------|-----------|
|
||||
GCN | 2.0M | 0.1318 | 0.1691 |
|
||||
GIN | 3.8M | 0.1203 | 0.1537 |
|
||||
GCN-VN | 4.9M | 0.1225 | 0.1485 |
|
||||
GIN-VN | 6.7M | 0.1150 | 0.1395 |
|
||||
Graphormer-Small| 12.5M | 0.0778 | 0.1264 |
|
||||
Graphormer | 47.1M | 0.0582 | 0.1234 |
|
||||
Graphormer-v2 | 47.1M | 0.0309 | **0.1201** |
|
||||
|
||||
## Open Graph Benchmark
|
||||
|
||||
The detailed description of the dataset could be found [here](https://ogb.stanford.edu/).
|
||||
|
||||
### Example Usage
|
||||
|
||||
Fine-tuning the pre-trained model on OGBG-MolHIV:
|
||||
|
||||
```
|
||||
bash hiv_pre.sh
|
||||
```
|
||||
|
||||
#### OGBG-MolHIV
|
||||
Method | #params | test AUC (%)|
|
||||
--------------|---------|------------|
|
||||
GCN-GraphNorm | 526K | 78.83 |
|
||||
PNA | 326K | 79.05 |
|
||||
PHC-GNN | 111K | 79.34 |
|
||||
DeeperGCN-FLAG | 532K | 79.42 |
|
||||
DGN | 114K | 79.70 |
|
||||
Graphormer | 47.0M | 80.51 |
|
||||
Graphormer-v2 | 47.1M | **81.28** |
|
||||
|
||||
## Benchmarking Graph Neural Networks - ZINC-500K
|
||||
|
||||
|
||||
The detailed description of the dataset could be found [here](https://github.com/graphdeeplearning/benchmarking-gnns).
|
||||
|
||||
### Example Usage
|
||||
|
||||
To train Graphormer-Slim on ZINC-500K dataset:
|
||||
|
||||
```bash zinc.sh```
|
||||
|
||||
#### ZINC-500K
|
||||
Method | #params | test MAE |
|
||||
--------------|---------|------------|
|
||||
GIN | 509.5K | 0.526 |
|
||||
GraphSage | 505.3K | 0.398 |
|
||||
GAT | 531.3K | 0.384 |
|
||||
GCN | 505.1K | 0.367 |
|
||||
GT | 588.9K | 0.226 |
|
||||
GatedGCN-PE | 505.0K | 0.214 |
|
||||
MPNN (sum) | 480.8K | 0.145 |
|
||||
PNA | 387.2K | 0.142 |
|
||||
SAN | 508.6K | 0.139 |
|
||||
Graphormer-Slim | 489.3K | **0.122** |
|
||||
|
||||
|
||||
|
||||
## Citation
|
||||
Please kindly cite this paper if you use the code:
|
||||
```
|
||||
@inproceedings{
|
||||
ying2021do,
|
||||
title={Do Transformers Really Perform Badly for Graph Representation?},
|
||||
author={Chengxuan Ying and Tianle Cai and Shengjie Luo and Shuxin Zheng and Guolin Ke and Di He and Yanming Shen and Tie-Yan Liu},
|
||||
booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
|
||||
year={2021},
|
||||
url={https://openreview.net/forum?id=OeWooOxFwDa}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
## Trademarks
|
||||
|
||||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
||||
trademarks or logos is subject to and must follow
|
||||
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
||||
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
||||
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#!/usr/bin/env bash
|
||||
|
||||
n_gpu=2
|
||||
epoch=8
|
||||
max_epoch=$((epoch + 1))
|
||||
batch_size=64
|
||||
tot_updates=$((33000*epoch/batch_size/n_gpu))
|
||||
warmup_updates=$((tot_updates/10))
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1
|
||||
|
||||
fairseq-train \
|
||||
--user-dir ../../graphormer \
|
||||
--num-workers 16 \
|
||||
--ddp-backend=legacy_ddp \
|
||||
--dataset-name ogbg-molhiv \
|
||||
--dataset-source ogb \
|
||||
--task graph_prediction_with_flag \
|
||||
--criterion binary_logloss_with_flag \
|
||||
--arch graphormer_base \
|
||||
--num-classes 1 \
|
||||
--attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.0 \
|
||||
--lr-scheduler polynomial_decay --power 1 --warmup-updates $warmup_updates --total-num-update $tot_updates \
|
||||
--lr 2e-4 --end-learning-rate 1e-9 \
|
||||
--batch-size $batch_size \
|
||||
--fp16 \
|
||||
--data-buffer-size 20 \
|
||||
--encoder-layers 12 \
|
||||
--encoder-embed-dim 768 \
|
||||
--encoder-ffn-embed-dim 768 \
|
||||
--encoder-attention-heads 32 \
|
||||
--max-epoch $max_epoch \
|
||||
--save-dir ./ckpts \
|
||||
--pretrained-model-name pcqm4mv1_graphormer_base \
|
||||
--seed 1 \
|
||||
--flag-m 3 \
|
||||
--flag-step-size 0.001 \
|
||||
--flag-mag 0.001 \
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#!/usr/bin/env bash
|
||||
|
||||
fairseq-train \
|
||||
--user-dir ../../graphormer \
|
||||
--num-workers 16 \
|
||||
--ddp-backend=legacy_ddp \
|
||||
--dataset-name pcqm4m \
|
||||
--dataset-source ogb \
|
||||
--task graph_prediction \
|
||||
--criterion l1_loss \
|
||||
--arch graphormer_base \
|
||||
--num-classes 1 \
|
||||
--attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.0 \
|
||||
--lr-scheduler polynomial_decay --power 1 --warmup-updates 60000 --total-num-update 1000000 \
|
||||
--lr 2e-4 --end-learning-rate 1e-9 \
|
||||
--batch-size 64 \
|
||||
--fp16 \
|
||||
--data-buffer-size 20 \
|
||||
--encoder-layers 12 \
|
||||
--encoder-embed-dim 768 \
|
||||
--encoder-ffn-embed-dim 768 \
|
||||
--encoder-attention-heads 32 \
|
||||
--max-epoch 300 \
|
||||
--save-dir ./ckpts
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#!/usr/bin/env bash
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
fairseq-train \
|
||||
--user-dir ../../graphormer \
|
||||
--num-workers 16 \
|
||||
--ddp-backend=legacy_ddp \
|
||||
--dataset-name pcqm4mv2 \
|
||||
--dataset-source ogb \
|
||||
--task graph_prediction \
|
||||
--criterion l1_loss \
|
||||
--arch graphormer_base \
|
||||
--num-classes 1 \
|
||||
--attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.0 \
|
||||
--lr-scheduler polynomial_decay --power 1 --warmup-updates 60000 --total-num-update 1000000 \
|
||||
--lr 2e-4 --end-learning-rate 1e-9 \
|
||||
--batch-size 256 \
|
||||
--fp16 \
|
||||
--data-buffer-size 20 \
|
||||
--save-dir ./ckpts
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#!/usr/bin/env bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 fairseq-train \
|
||||
--user-dir ../../graphormer \
|
||||
--num-workers 16 \
|
||||
--ddp-backend=legacy_ddp \
|
||||
--dataset-name zinc \
|
||||
--dataset-source pyg \
|
||||
--task graph_prediction \
|
||||
--criterion l1_loss \
|
||||
--arch graphormer_slim \
|
||||
--num-classes 1 \
|
||||
--attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 \
|
||||
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.01 \
|
||||
--lr-scheduler polynomial_decay --power 1 --warmup-updates 60000 --total-num-update 400000 \
|
||||
--lr 2e-4 --end-learning-rate 1e-9 \
|
||||
--batch-size 64 \
|
||||
--fp16 \
|
||||
--data-buffer-size 20 \
|
||||
--encoder-layers 12 \
|
||||
--encoder-embed-dim 80 \
|
||||
--encoder-ffn-embed-dim 80 \
|
||||
--encoder-attention-heads 8 \
|
||||
--max-epoch 10000 \
|
||||
--save-dir ./ckpts
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 98ebe4f1ada75d006717d84f9d603519d8ff5579
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch.multiprocessing.set_start_method("fork", force=True)
|
||||
except:
|
||||
import sys
|
||||
|
||||
print(
|
||||
"Your OS does not support multiprocessing based on fork, please use num_workers=0",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
import graphormer.criterions
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
|
||||
# automatically import any Python files in the criterions/ directory
|
||||
for file in sorted(Path(__file__).parent.glob("*.py")):
|
||||
if not file.name.startswith("_"):
|
||||
importlib.import_module("graphormer.criterions." + file.name[:-3])
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from fairseq.dataclass.configs import FairseqDataclass
|
||||
|
||||
import torch
|
||||
from torch.nn import functional
|
||||
from fairseq import metrics
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
|
||||
|
||||
@register_criterion("binary_logloss", dataclass=FairseqDataclass)
|
||||
class GraphPredictionBinaryLogLoss(FairseqCriterion):
|
||||
"""
|
||||
Implementation for the binary log loss used in graphormer model training.
|
||||
"""
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
sample_size = sample["nsamples"]
|
||||
|
||||
with torch.no_grad():
|
||||
natoms = sample["net_input"]["batched_data"]["x"].shape[1]
|
||||
|
||||
logits = model(**sample["net_input"])
|
||||
logits = logits[:, 0, :]
|
||||
targets = model.get_targets(sample, [logits])
|
||||
preds = torch.where(torch.sigmoid(logits) < 0.5, 0, 1)
|
||||
|
||||
logits_flatten = logits.reshape(-1)
|
||||
targets_flatten = targets[: logits.size(0)].reshape(-1)
|
||||
mask = ~torch.isnan(targets_flatten)
|
||||
loss = functional.binary_cross_entropy_with_logits(
|
||||
logits_flatten[mask].float(), targets_flatten[mask].float(), reduction="sum"
|
||||
)
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"sample_size": torch.sum(mask.type(torch.int64)),
|
||||
"nsentences": sample_size,
|
||||
"ntokens": natoms,
|
||||
"ncorrect": (preds == targets[:preds.size(0)]).sum(),
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
|
||||
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
|
||||
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
|
||||
metrics.log_scalar(
|
||||
"accuracy", 100.0 * ncorrect / sample_size, sample_size, round=1
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@register_criterion("binary_logloss_with_flag", dataclass=FairseqDataclass)
|
||||
class GraphPredictionBinaryLogLossWithFlag(GraphPredictionBinaryLogLoss):
|
||||
"""
|
||||
Implementation for the binary log loss used in graphormer model training.
|
||||
"""
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
sample_size = sample["nsamples"]
|
||||
perturb = sample.get("perturb", None)
|
||||
|
||||
batch_data = sample["net_input"]["batched_data"]["x"]
|
||||
with torch.no_grad():
|
||||
natoms = batch_data.shape[1]
|
||||
logits = model(**sample["net_input"], perturb=perturb)[:, 0, :]
|
||||
targets = model.get_targets(sample, [logits])
|
||||
preds = torch.where(torch.sigmoid(logits) < 0.5, 0, 1)
|
||||
|
||||
logits_flatten = logits.reshape(-1)
|
||||
targets_flatten = targets[: logits.size(0)].reshape(-1)
|
||||
mask = ~torch.isnan(targets_flatten)
|
||||
loss = functional.binary_cross_entropy_with_logits(
|
||||
logits_flatten[mask].float(), targets_flatten[mask].float(), reduction="sum"
|
||||
)
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"sample_size": logits.size(0),
|
||||
"nsentences": sample_size,
|
||||
"ntokens": natoms,
|
||||
"ncorrect": (preds == targets[:preds.size(0)]).sum(),
|
||||
}
|
||||
return loss, sample_size, logging_output
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from fairseq.dataclass.configs import FairseqDataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq import metrics
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
|
||||
|
||||
@register_criterion("l1_loss", dataclass=FairseqDataclass)
|
||||
class GraphPredictionL1Loss(FairseqCriterion):
|
||||
"""
|
||||
Implementation for the L1 loss (MAE loss) used in graphormer model training.
|
||||
"""
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
sample_size = sample["nsamples"]
|
||||
|
||||
with torch.no_grad():
|
||||
natoms = sample["net_input"]["batched_data"]["x"].shape[1]
|
||||
|
||||
logits = model(**sample["net_input"])
|
||||
logits = logits[:, 0, :]
|
||||
targets = model.get_targets(sample, [logits])
|
||||
|
||||
loss = nn.L1Loss(reduction="sum")(logits, targets[: logits.size(0)])
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"sample_size": logits.size(0),
|
||||
"nsentences": sample_size,
|
||||
"ntokens": natoms,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=6)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
@register_criterion("l1_loss_with_flag", dataclass=FairseqDataclass)
|
||||
class GraphPredictionL1LossWithFlag(GraphPredictionL1Loss):
|
||||
"""
|
||||
Implementation for the binary log loss used in graphormer model training.
|
||||
"""
|
||||
|
||||
def perturb_forward(self, model, sample, perturb, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
sample_size = sample["nsamples"]
|
||||
|
||||
batch_data = sample["net_input"]["batched_data"]["x"]
|
||||
with torch.no_grad():
|
||||
natoms = batch_data.shape[1]
|
||||
logits = model(**sample["net_input"], perturb=perturb)[:, 0, :]
|
||||
targets = model.get_targets(sample, [logits])
|
||||
loss = nn.L1Loss(reduction="sum")(logits, targets[: logits.size(0)])
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"sample_size": logits.size(0),
|
||||
"nsentences": sample_size,
|
||||
"ntokens": natoms,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Callable, Mapping, Sequence, Tuple
|
||||
from numpy import mod
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from fairseq import metrics
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
|
||||
|
||||
@register_criterion("mae_deltapos")
|
||||
class IS2RECriterion(FairseqCriterion):
|
||||
e_thresh = 0.02
|
||||
e_mean = -1.4729953244844094
|
||||
e_std = 2.2707848125378405
|
||||
d_mean = [0.1353900283575058, 0.06877671927213669, 0.08111362904310226]
|
||||
d_std = [1.7862379550933838, 1.78688645362854, 0.8023099899291992]
|
||||
|
||||
def __init__(self, task, cfg):
|
||||
super().__init__(task)
|
||||
self.node_loss_weight = cfg.node_loss_weight
|
||||
self.min_node_loss_weight = cfg.min_node_loss_weight
|
||||
self.max_update = cfg.max_update
|
||||
self.node_loss_weight_range = max(
|
||||
0, self.node_loss_weight - self.min_node_loss_weight
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
model: Callable[..., Tuple[Tensor, Tensor, Tensor]],
|
||||
sample: Mapping[str, Mapping[str, Tensor]],
|
||||
reduce=True,
|
||||
):
|
||||
update_num = model.num_updates
|
||||
assert update_num >= 0
|
||||
node_loss_weight = (
|
||||
self.node_loss_weight
|
||||
- self.node_loss_weight_range * update_num / self.max_update
|
||||
)
|
||||
|
||||
valid_nodes = sample["net_input"]["atoms"].ne(0).sum()
|
||||
output, node_output, node_target_mask = model(
|
||||
**sample["net_input"],
|
||||
)
|
||||
|
||||
relaxed_energy = sample["targets"]["relaxed_energy"]
|
||||
relaxed_energy = relaxed_energy.float()
|
||||
relaxed_energy = (relaxed_energy - self.e_mean) / self.e_std
|
||||
sample_size = relaxed_energy.numel()
|
||||
loss = F.l1_loss(output.float().view(-1), relaxed_energy, reduction="none")
|
||||
with torch.no_grad():
|
||||
energy_within_threshold = (loss.detach() * self.e_std < self.e_thresh).sum()
|
||||
loss = loss.sum()
|
||||
|
||||
deltapos = sample["targets"]["deltapos"].float()
|
||||
deltapos = (deltapos - deltapos.new_tensor(self.d_mean)) / deltapos.new_tensor(
|
||||
self.d_std
|
||||
)
|
||||
deltapos *= node_target_mask
|
||||
node_output *= node_target_mask
|
||||
target_cnt = node_target_mask.sum(dim=[1, 2])
|
||||
node_loss = (
|
||||
F.l1_loss(node_output.float(), deltapos, reduction="none")
|
||||
.mean(dim=-1)
|
||||
.sum(dim=-1)
|
||||
/ target_cnt
|
||||
).sum()
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.detach(),
|
||||
"energy_within_threshold": energy_within_threshold,
|
||||
"node_loss": node_loss.detach(),
|
||||
"sample_size": sample_size,
|
||||
"nsentences": sample_size,
|
||||
"num_nodes": valid_nodes.detach(),
|
||||
"node_loss_weight": node_loss_weight * sample_size,
|
||||
}
|
||||
return loss + node_loss_weight * node_loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs: Sequence[Mapping]) -> None:
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
energy_within_threshold_sum = sum(
|
||||
log.get("energy_within_threshold", 0) for log in logging_outputs
|
||||
)
|
||||
node_loss_sum = sum(log.get("node_loss", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
mean_loss = (loss_sum / sample_size) * IS2RECriterion.e_std
|
||||
energy_within_threshold = energy_within_threshold_sum / sample_size
|
||||
mean_node_loss = (node_loss_sum / sample_size) * sum(IS2RECriterion.d_std) / 3.0
|
||||
mean_n_nodes = (
|
||||
sum([log.get("num_nodes", 0) for log in logging_outputs]) / sample_size
|
||||
)
|
||||
node_loss_weight = (
|
||||
sum([log.get("node_loss_weight", 0) for log in logging_outputs])
|
||||
/ sample_size
|
||||
)
|
||||
|
||||
metrics.log_scalar("loss", mean_loss, sample_size, round=6)
|
||||
metrics.log_scalar("ewth", energy_within_threshold, sample_size, round=6)
|
||||
metrics.log_scalar("node_loss", mean_node_loss, sample_size, round=6)
|
||||
metrics.log_scalar("nodes_per_graph", mean_n_nodes, sample_size, round=6)
|
||||
metrics.log_scalar("node_loss_weight", node_loss_weight, sample_size, round=6)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
|
@ -1,150 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from collator import collator
|
||||
from wrapper import MyGraphPropPredDataset, MyPygPCQM4MDataset, MyZINCDataset
|
||||
|
||||
from pytorch_lightning import LightningDataModule
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
import ogb
|
||||
import ogb.lsc
|
||||
import ogb.graphproppred
|
||||
from functools import partial
|
||||
|
||||
|
||||
dataset = None
|
||||
|
||||
|
||||
def get_dataset(dataset_name='abaaba'):
|
||||
global dataset
|
||||
if dataset is not None:
|
||||
return dataset
|
||||
|
||||
# max_node is set to max(max(num_val_graph_nodes), max(num_test_graph_nodes))
|
||||
if dataset_name == 'ogbg-molpcba':
|
||||
dataset = {
|
||||
'num_class': 128,
|
||||
'loss_fn': F.binary_cross_entropy_with_logits,
|
||||
'metric': 'ap',
|
||||
'metric_mode': 'max',
|
||||
'evaluator': ogb.graphproppred.Evaluator('ogbg-molpcba'),
|
||||
'dataset': MyGraphPropPredDataset('ogbg-molpcba', root='../../dataset'),
|
||||
'max_node': 128,
|
||||
}
|
||||
elif dataset_name == 'ogbg-molhiv':
|
||||
dataset = {
|
||||
'num_class': 1,
|
||||
'loss_fn': F.binary_cross_entropy_with_logits,
|
||||
'metric': 'rocauc',
|
||||
'metric_mode': 'max',
|
||||
'evaluator': ogb.graphproppred.Evaluator('ogbg-molhiv'),
|
||||
'dataset': MyGraphPropPredDataset('ogbg-molhiv', root='../../dataset'),
|
||||
'max_node': 128,
|
||||
}
|
||||
elif dataset_name == 'PCQM4M-LSC':
|
||||
dataset = {
|
||||
'num_class': 1,
|
||||
'loss_fn': F.l1_loss,
|
||||
'metric': 'mae',
|
||||
'metric_mode': 'min',
|
||||
'evaluator': ogb.lsc.PCQM4MEvaluator(),
|
||||
'dataset': MyPygPCQM4MDataset(root='../../dataset'),
|
||||
'max_node': 128,
|
||||
}
|
||||
elif dataset_name == 'ZINC':
|
||||
dataset = {
|
||||
'num_class': 1,
|
||||
'loss_fn': F.l1_loss,
|
||||
'metric': 'mae',
|
||||
'metric_mode': 'min',
|
||||
'evaluator': ogb.lsc.PCQM4MEvaluator(), # same objective function, so reuse it
|
||||
'train_dataset': MyZINCDataset(subset=True, root='../../dataset/pyg_zinc', split='train'),
|
||||
'valid_dataset': MyZINCDataset(subset=True, root='../../dataset/pyg_zinc', split='val'),
|
||||
'test_dataset': MyZINCDataset(subset=True, root='../../dataset/pyg_zinc', split='test'),
|
||||
'max_node': 128,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
print(f' > {dataset_name} loaded!')
|
||||
print(dataset)
|
||||
print(f' > dataset info ends')
|
||||
return dataset
|
||||
|
||||
|
||||
class GraphDataModule(LightningDataModule):
|
||||
name = "OGB-GRAPH"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name: str = 'ogbg-molpcba',
|
||||
num_workers: int = 0,
|
||||
batch_size: int = 256,
|
||||
seed: int = 42,
|
||||
multi_hop_max_dist: int = 5,
|
||||
spatial_pos_max: int = 1024,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_name = dataset_name
|
||||
self.dataset = get_dataset(self.dataset_name)
|
||||
|
||||
self.num_workers = num_workers
|
||||
self.batch_size = batch_size
|
||||
self.dataset_train = ...
|
||||
self.dataset_val = ...
|
||||
self.multi_hop_max_dist = multi_hop_max_dist
|
||||
self.spatial_pos_max = spatial_pos_max
|
||||
|
||||
def setup(self, stage: str = None):
|
||||
if self.dataset_name == 'ZINC':
|
||||
self.dataset_train = self.dataset['train_dataset']
|
||||
self.dataset_val = self.dataset['valid_dataset']
|
||||
self.dataset_test = self.dataset['test_dataset']
|
||||
else:
|
||||
split_idx = self.dataset['dataset'].get_idx_split()
|
||||
self.dataset_train = self.dataset['dataset'][split_idx["train"]]
|
||||
self.dataset_val = self.dataset['dataset'][split_idx["valid"]]
|
||||
self.dataset_test = self.dataset['dataset'][split_idx["test"]]
|
||||
|
||||
def train_dataloader(self):
|
||||
loader = DataLoader(
|
||||
self.dataset_train,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=partial(collator, max_node=get_dataset(self.dataset_name)[
|
||||
'max_node'], multi_hop_max_dist=self.multi_hop_max_dist, spatial_pos_max=self.spatial_pos_max),
|
||||
)
|
||||
print('len(train_dataloader)', len(loader))
|
||||
return loader
|
||||
|
||||
def val_dataloader(self):
|
||||
loader = DataLoader(
|
||||
self.dataset_val,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=False,
|
||||
collate_fn=partial(collator, max_node=get_dataset(self.dataset_name)[
|
||||
'max_node'], multi_hop_max_dist=self.multi_hop_max_dist, spatial_pos_max=self.spatial_pos_max),
|
||||
)
|
||||
print('len(val_dataloader)', len(loader))
|
||||
return loader
|
||||
|
||||
def test_dataloader(self):
|
||||
loader = DataLoader(
|
||||
self.dataset_test,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=False,
|
||||
collate_fn=partial(collator, max_node=get_dataset(self.dataset_name)[
|
||||
'max_node'], multi_hop_max_dist=self.multi_hop_max_dist, spatial_pos_max=self.spatial_pos_max),
|
||||
)
|
||||
print('len(test_dataloader)', len(loader))
|
||||
return loader
|
|
@ -0,0 +1,6 @@
|
|||
DATASET_REGISTRY = {}
|
||||
|
||||
def register_dataset(name: str):
|
||||
def register_dataset_func(func):
|
||||
DATASET_REGISTRY[name] = func()
|
||||
return register_dataset_func
|
|
@ -27,8 +27,7 @@ def pad_2d_unsqueeze(x, padlen):
|
|||
def pad_attn_bias_unsqueeze(x, padlen):
|
||||
xlen = x.size(0)
|
||||
if xlen < padlen:
|
||||
new_x = x.new_zeros(
|
||||
[padlen, padlen], dtype=x.dtype).fill_(float('-inf'))
|
||||
new_x = x.new_zeros([padlen, padlen], dtype=x.dtype).fill_(float("-inf"))
|
||||
new_x[:xlen, :xlen] = x
|
||||
new_x[xlen:, :xlen] = 0
|
||||
x = new_x
|
||||
|
@ -45,7 +44,6 @@ def pad_edge_type_unsqueeze(x, padlen):
|
|||
|
||||
|
||||
def pad_spatial_pos_unsqueeze(x, padlen):
|
||||
|
||||
x = x + 1
|
||||
xlen = x.size(0)
|
||||
if xlen < padlen:
|
||||
|
@ -65,62 +63,61 @@ def pad_3d_unsqueeze(x, padlen1, padlen2, padlen3):
|
|||
return x.unsqueeze(0)
|
||||
|
||||
|
||||
class Batch():
|
||||
def __init__(self, idx, attn_bias, attn_edge_type, spatial_pos, in_degree, out_degree, x, edge_input, y):
|
||||
super(Batch, self).__init__()
|
||||
self.idx = idx
|
||||
self.in_degree, self.out_degree = in_degree, out_degree
|
||||
self.x, self.y = x, y
|
||||
self.attn_bias, self.attn_edge_type, self.spatial_pos = attn_bias, attn_edge_type, spatial_pos
|
||||
self.edge_input = edge_input
|
||||
|
||||
def to(self, device):
|
||||
self.idx = self.idx.to(device)
|
||||
self.in_degree, self.out_degree = self.in_degree.to(
|
||||
device), self.out_degree.to(device)
|
||||
self.x, self.y = self.x.to(device), self.y.to(device)
|
||||
self.attn_bias, self.attn_edge_type, self.spatial_pos = self.attn_bias.to(
|
||||
device), self.attn_edge_type.to(device), self.spatial_pos.to(device)
|
||||
self.edge_input = self.edge_input.to(device)
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return self.in_degree.size(0)
|
||||
|
||||
|
||||
def collator(items, max_node=512, multi_hop_max_dist=20, spatial_pos_max=20):
|
||||
items = [item for item in items if item is not None and item.x.size(0) <= max_node]
|
||||
items = [
|
||||
item for item in items if item is not None and item.x.size(0) <= max_node]
|
||||
items = [(item.idx, item.attn_bias, item.attn_edge_type, item.spatial_pos, item.in_degree,
|
||||
item.out_degree, item.x, item.edge_input[:, :, :multi_hop_max_dist, :], item.y) for item in items]
|
||||
idxs, attn_biases, attn_edge_types, spatial_poses, in_degrees, out_degrees, xs, edge_inputs, ys = zip(
|
||||
*items)
|
||||
(
|
||||
item.idx,
|
||||
item.attn_bias,
|
||||
item.attn_edge_type,
|
||||
item.spatial_pos,
|
||||
item.in_degree,
|
||||
item.out_degree,
|
||||
item.x,
|
||||
item.edge_input[:, :, :multi_hop_max_dist, :],
|
||||
item.y,
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
(
|
||||
idxs,
|
||||
attn_biases,
|
||||
attn_edge_types,
|
||||
spatial_poses,
|
||||
in_degrees,
|
||||
out_degrees,
|
||||
xs,
|
||||
edge_inputs,
|
||||
ys,
|
||||
) = zip(*items)
|
||||
|
||||
for idx, _ in enumerate(attn_biases):
|
||||
attn_biases[idx][1:, 1:][spatial_poses[idx] >= spatial_pos_max] = float('-inf')
|
||||
attn_biases[idx][1:, 1:][spatial_poses[idx] >= spatial_pos_max] = float("-inf")
|
||||
max_node_num = max(i.size(0) for i in xs)
|
||||
max_dist = max(i.size(-2) for i in edge_inputs)
|
||||
y = torch.cat(ys)
|
||||
x = torch.cat([pad_2d_unsqueeze(i, max_node_num) for i in xs])
|
||||
edge_input = torch.cat([pad_3d_unsqueeze(
|
||||
i, max_node_num, max_node_num, max_dist) for i in edge_inputs])
|
||||
attn_bias = torch.cat([pad_attn_bias_unsqueeze(
|
||||
i, max_node_num + 1) for i in attn_biases])
|
||||
edge_input = torch.cat(
|
||||
[pad_3d_unsqueeze(i, max_node_num, max_node_num, max_dist) for i in edge_inputs]
|
||||
)
|
||||
attn_bias = torch.cat(
|
||||
[pad_attn_bias_unsqueeze(i, max_node_num + 1) for i in attn_biases]
|
||||
)
|
||||
attn_edge_type = torch.cat(
|
||||
[pad_edge_type_unsqueeze(i, max_node_num) for i in attn_edge_types])
|
||||
spatial_pos = torch.cat([pad_spatial_pos_unsqueeze(i, max_node_num)
|
||||
for i in spatial_poses])
|
||||
in_degree = torch.cat([pad_1d_unsqueeze(i, max_node_num)
|
||||
for i in in_degrees])
|
||||
out_degree = torch.cat([pad_1d_unsqueeze(i, max_node_num)
|
||||
for i in out_degrees])
|
||||
return Batch(
|
||||
[pad_edge_type_unsqueeze(i, max_node_num) for i in attn_edge_types]
|
||||
)
|
||||
spatial_pos = torch.cat(
|
||||
[pad_spatial_pos_unsqueeze(i, max_node_num) for i in spatial_poses]
|
||||
)
|
||||
in_degree = torch.cat([pad_1d_unsqueeze(i, max_node_num) for i in in_degrees])
|
||||
|
||||
return dict(
|
||||
idx=torch.LongTensor(idxs),
|
||||
attn_bias=attn_bias,
|
||||
attn_edge_type=attn_edge_type,
|
||||
spatial_pos=spatial_pos,
|
||||
in_degree=in_degree,
|
||||
out_degree=out_degree,
|
||||
out_degree=in_degree, # for undirected graph
|
||||
x=x,
|
||||
edge_input=edge_input,
|
||||
y=y,
|
|
@ -0,0 +1,94 @@
|
|||
from functools import lru_cache
|
||||
|
||||
import ogb
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from fairseq.data import FairseqDataset
|
||||
|
||||
from .wrapper import MyPygGraphPropPredDataset
|
||||
from .collator import collator
|
||||
|
||||
from typing import Optional, Union
|
||||
from torch_geometric.data import Data as PYGDataset
|
||||
from dgl.data import DGLDataset
|
||||
from .dgl_datasets import DGLDatasetLookupTable, GraphormerDGLDataset
|
||||
from .pyg_datasets import PYGDatasetLookupTable, GraphormerPYGDataset
|
||||
from .ogb_datasets import OGBDatasetLookupTable
|
||||
|
||||
|
||||
class BatchedDataDataset(FairseqDataset):
|
||||
def __init__(
|
||||
self, dataset, max_node=128, multi_hop_max_dist=5, spatial_pos_max=1024
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
self.max_node = max_node
|
||||
self.multi_hop_max_dist = multi_hop_max_dist
|
||||
self.spatial_pos_max = spatial_pos_max
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.dataset[int(index)]
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples):
|
||||
return collator(
|
||||
samples,
|
||||
max_node=self.max_node,
|
||||
multi_hop_max_dist=self.multi_hop_max_dist,
|
||||
spatial_pos_max=self.spatial_pos_max,
|
||||
)
|
||||
|
||||
|
||||
class TargetDataset(FairseqDataset):
|
||||
def __init__(self, dataset):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def __getitem__(self, index):
|
||||
return self.dataset[index].y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples):
|
||||
return torch.stack(samples, dim=0)
|
||||
|
||||
|
||||
class GraphormerDataset:
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Optional[Union[PYGDataset, DGLDataset]] = None,
|
||||
dataset_spec: Optional[str] = None,
|
||||
dataset_source: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
train_idx = None,
|
||||
valid_idx = None,
|
||||
test_idx = None,
|
||||
):
|
||||
super().__init__()
|
||||
if dataset is not None:
|
||||
if dataset_source == "dgl":
|
||||
self.dataset = GraphormerDGLDataset(dataset, train_idx, valid_idx, test_idx)
|
||||
elif dataset_source == "pyg":
|
||||
self.dataset = GraphormerPYGDataset(dataset, train_idx, valid_idx, test_idx)
|
||||
elif dataset_source == "dgl":
|
||||
self.dataset = DGLDatasetLookupTable.GetDGLDataset(dataset_spec, seed)
|
||||
elif dataset_source == "pyg":
|
||||
self.dataset = PYGDatasetLookupTable.GetPYGDataset(dataset_spec, seed)
|
||||
elif dataset_source == "ogb":
|
||||
self.dataset = OGBDatasetLookupTable.GetOGBDataset(dataset_spec, seed)
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self.train_idx = self.dataset.train_idx
|
||||
self.valid_idx = self.dataset.valid_idx
|
||||
self.test_idx = self.dataset.test_idx
|
||||
|
||||
self.dataset_train = self.dataset.train_data
|
||||
self.dataset_val = self.dataset.valid_data
|
||||
self.dataset_test = self.dataset.test_data
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .dgl_dataset_lookup_table import DGLDatasetLookupTable
|
||||
from .dgl_dataset import GraphormerDGLDataset
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from scipy.sparse.construct import random
|
||||
from torch_geometric.data import Dataset
|
||||
from dgl.data import DGLDataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from typing import List
|
||||
from dgl import DGLGraph
|
||||
from torch_geometric.data import Data as PYGGraph
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from ..wrapper import convert_to_single_emb
|
||||
from .. import algos
|
||||
from copy import copy
|
||||
|
||||
|
||||
class GraphormerDGLDataset(Dataset):
|
||||
def __init__(self,
|
||||
dataset: DGLDataset,
|
||||
seed: int = 0,
|
||||
train_idx=None,
|
||||
valid_idx=None,
|
||||
test_idx=None,
|
||||
):
|
||||
self.dataset = dataset
|
||||
num_data = len(self.dataset)
|
||||
self.seed = seed
|
||||
if train_idx is None:
|
||||
train_valid_idx, test_idx = train_test_split(
|
||||
np.arange(num_data), test_size=num_data // 10, random_state=seed
|
||||
)
|
||||
train_idx, valid_idx = train_test_split(
|
||||
train_valid_idx, test_size=num_data // 5, random_state=seed
|
||||
)
|
||||
self.train_idx = train_idx
|
||||
self.valid_idx = valid_idx
|
||||
self.test_idx = test_idx
|
||||
self.__indices__ = None
|
||||
self.train_data = self.index_select(train_idx)
|
||||
self.valid_data = self.index_select(valid_idx)
|
||||
self.test_data = self.index_select(test_idx)
|
||||
|
||||
def index_select(self, indices: List[int]):
|
||||
subset = copy(self)
|
||||
subset.__indices__ = indices
|
||||
subset.train_idx = None
|
||||
subset.valid_idx = None
|
||||
subset.test_idx = None
|
||||
subset.train_data = None
|
||||
subset.valid_data = None
|
||||
subset.test_data = None
|
||||
return subset
|
||||
|
||||
def __extract_edge_and_node_features(
|
||||
self, graph_data: DGLGraph
|
||||
) -> Tuple[
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
]:
|
||||
def extract_tensor_from_node_or_edge_data(
|
||||
feature_dict: dict, num_nodes_or_edges
|
||||
):
|
||||
float_feature_list = []
|
||||
int_feature_list = []
|
||||
|
||||
def extract_tensor_from_dict(feature: torch.Tensor):
|
||||
if feature.dtype == torch.int32 or feature.dtype == torch.long:
|
||||
int_feature_list.append(feature.unsqueeze(1))
|
||||
elif feature.dtype == torch.float32 or feature.dtype == torch.float64:
|
||||
float_feature_list.append(feature.unsqueeze(1))
|
||||
|
||||
for feature_or_dict in feature_dict:
|
||||
if isinstance(feature_or_dict, torch.Tensor):
|
||||
extract_tensor_from_dict(feature_or_dict)
|
||||
elif isinstance(feature_or_dict, dict):
|
||||
for feature in feature_or_dict:
|
||||
extract_tensor_from_dict(feature)
|
||||
int_feature_tensor = (
|
||||
torch.from_numpy(np.zeros(shape=[num_nodes_or_edges, 1])).long()
|
||||
if len(int_feature_list) == 0
|
||||
else torch.cat(int_feature_list)
|
||||
)
|
||||
float_feature_tensor = (
|
||||
None if len(float_feature_list) == 0 else torch.cat(float_feature_list)
|
||||
)
|
||||
return int_feature_tensor, float_feature_tensor
|
||||
|
||||
node_int_feature, node_float_feature = extract_tensor_from_node_or_edge_data(
|
||||
graph_data.ndata, graph_data.num_nodes()
|
||||
)
|
||||
edge_int_feature, edge_float_feature = extract_tensor_from_node_or_edge_data(
|
||||
graph_data.edata, graph_data.num_edges()
|
||||
)
|
||||
return (
|
||||
node_int_feature,
|
||||
node_float_feature,
|
||||
edge_int_feature,
|
||||
edge_float_feature,
|
||||
)
|
||||
|
||||
def __preprocess_dgl_graph(
|
||||
self, graph_data: DGLGraph, y: torch.Tensor, idx: int
|
||||
) -> PYGGraph:
|
||||
if not graph_data.is_homogeneous:
|
||||
raise ValueError(
|
||||
"Heterogeneous DGLGraph is found. Only homogeneous graph is supported."
|
||||
)
|
||||
N = graph_data.num_nodes()
|
||||
|
||||
(
|
||||
node_int_feature,
|
||||
node_float_feature,
|
||||
edge_int_feature,
|
||||
edge_float_feature,
|
||||
) = self.__extract_edge_and_node_features(graph_data)
|
||||
edge_index = graph_data.edges()
|
||||
attn_edge_type = torch.zeros(
|
||||
[N, N, edge_int_feature.shape[1]], dtype=torch.long
|
||||
)
|
||||
attn_edge_type[
|
||||
edge_index[0].long(), edge_index[1].long()
|
||||
] = convert_to_single_emb(edge_int_feature)
|
||||
dense_adj = graph_data.adj().to_dense().type(torch.int)
|
||||
shortest_path_result, path = algos.floyd_warshall(dense_adj.numpy())
|
||||
max_dist = np.amax(shortest_path_result)
|
||||
edge_input = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy())
|
||||
spatial_pos = torch.from_numpy((shortest_path_result)).long()
|
||||
attn_bias = torch.zeros([N + 1, N + 1], dtype=torch.float) # with graph token
|
||||
|
||||
pyg_graph = PYGGraph()
|
||||
pyg_graph.x = convert_to_single_emb(node_int_feature)
|
||||
pyg_graph.adj = dense_adj
|
||||
pyg_graph.attn_bias = attn_bias
|
||||
pyg_graph.attn_edge_type = attn_edge_type
|
||||
pyg_graph.spatial_pos = spatial_pos
|
||||
pyg_graph.in_degree = dense_adj.long().sum(dim=1).view(-1)
|
||||
pyg_graph.out_degree = pyg_graph.in_degree
|
||||
pyg_graph.edge_input = torch.from_numpy(edge_input).long()
|
||||
if y.dim() == 0:
|
||||
y = y.unsqueeze(-1)
|
||||
pyg_graph.y = y
|
||||
pyg_graph.idx = idx
|
||||
|
||||
return pyg_graph
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
if self.__indices__ is not None:
|
||||
idx = self.__indices__[idx]
|
||||
graph, y = self.dataset[idx]
|
||||
return self.__preprocess_dgl_graph(graph, y, idx)
|
||||
else:
|
||||
raise TypeError("index to a GraphormerDGLDataset can only be an integer.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset) if self.__indices__ is None else len(self.__indices__)
|
|
@ -0,0 +1,159 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Optional
|
||||
from dgl.data import (
|
||||
QM7bDataset,
|
||||
QM9Dataset,
|
||||
QM9EdgeDataset,
|
||||
MiniGCDataset,
|
||||
TUDataset,
|
||||
GINDataset,
|
||||
FakeNewsDataset,
|
||||
)
|
||||
from dgl.data import DGLDataset
|
||||
from .dgl_dataset import GraphormerDGLDataset
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
class MyQM7bDataset(QM7bDataset):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyQM7bDataset, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
class MyQM9Dataset(QM9Dataset):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyQM9Dataset, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
class MyQM9EdgeDataset(QM9EdgeDataset):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyQM9EdgeDataset, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class MyMiniGCDataset(MiniGCDataset):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyMiniGCDataset, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class MyTUDataset(TUDataset):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyTUDataset, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
class MyGINDataset(GINDataset):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyGINDataset, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
class MyFakeNewsDataset(FakeNewsDataset):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyFakeNewsDataset, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class DGLDatasetLookupTable:
|
||||
@staticmethod
|
||||
def GetDGLDataset(dataset_name: str, seed: int) -> Optional[DGLDataset]:
|
||||
params = dataset_name.split(":")[-1].split(",")
|
||||
inner_dataset = None
|
||||
|
||||
if dataset_name == "qm7b":
|
||||
inner_dataset = MyQM7bDataset()
|
||||
elif dataset_name.startswith("qm9"):
|
||||
label_keys = None
|
||||
cutoff = 5.0
|
||||
for param in params:
|
||||
name, value = param.split("=")
|
||||
if name == "label_keys":
|
||||
label_keys = value.split("+")
|
||||
elif name == "cutoff":
|
||||
cutoff = float(value)
|
||||
inner_dataset = MyQM9Dataset(label_keys=label_keys, cutoff=cutoff)
|
||||
elif dataset_name.startswith("qm9edge"):
|
||||
label_keys = None
|
||||
for param in params:
|
||||
name, value = param.split("=")
|
||||
if name == "label_keys":
|
||||
label_keys = value.split("+")
|
||||
inner_dataset = MyQM9EdgeDataset(label_keys=label_keys)
|
||||
elif dataset_name.startswith("minigc"):
|
||||
num_graphs = None
|
||||
min_num_v = None
|
||||
max_num_v = None
|
||||
data_seed = seed
|
||||
for param in params:
|
||||
name, value = param.split("=")
|
||||
if name == "num_graphs":
|
||||
num_graphs = int(value)
|
||||
elif name == "min_num_v":
|
||||
min_num_v = int(value)
|
||||
elif name == "max_num_v":
|
||||
max_num_v = int(value)
|
||||
elif name == "seed":
|
||||
data_seed = int(value)
|
||||
inner_dataset = MyMiniGCDataset(
|
||||
num_graphs, min_num_v, max_num_v, seed=data_seed
|
||||
)
|
||||
elif dataset_name.startswith("tu"):
|
||||
nm = None
|
||||
for param in params:
|
||||
name, value = param.split("=")
|
||||
if name == "name":
|
||||
nm = value
|
||||
inner_dataset = MyTUDataset(name=nm)
|
||||
elif dataset_name.startswith("gin"):
|
||||
nm = None
|
||||
self_loop = None
|
||||
degree_as_nlabel = False
|
||||
for param in params:
|
||||
name, value = param.split("=")
|
||||
if name == "name":
|
||||
nm = value
|
||||
elif name == "self_loop":
|
||||
if value.lower() == "false":
|
||||
self_loop = False
|
||||
elif value.lower() == "true":
|
||||
self_loop = True
|
||||
elif name == "degree_as_nlabel":
|
||||
if value.lower() == "false":
|
||||
degree_as_nlabel = False
|
||||
elif value.lower() == "true":
|
||||
degree_as_nlabel = True
|
||||
inner_dataset = MyGINDataset(
|
||||
name=nm, self_loop=self_loop, degree_as_nlabel=degree_as_nlabel
|
||||
)
|
||||
elif dataset_name.startswith("fakenews"):
|
||||
nm = None
|
||||
feature_name = None
|
||||
for param in params:
|
||||
name, value = param.split("=")
|
||||
if name == "name":
|
||||
nm = value
|
||||
elif name == "feature_name":
|
||||
feature_name = value
|
||||
inner_dataset = MyFakeNewsDataset(name=nm, feature_name=feature_name)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset specificaion {dataset_name}")
|
||||
|
||||
return (
|
||||
None
|
||||
if inner_dataset is None
|
||||
else GraphormerDGLDataset(inner_dataset, seed)
|
||||
)
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .ogb_dataset_lookup_table import OGBDatasetLookupTable
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Optional
|
||||
from ogb.lsc.pcqm4mv2_pyg import PygPCQM4Mv2Dataset
|
||||
from ogb.lsc.pcqm4m_pyg import PygPCQM4MDataset
|
||||
from ogb.graphproppred import PygGraphPropPredDataset
|
||||
from torch_geometric.data import Dataset
|
||||
from ..pyg_datasets import GraphormerPYGDataset
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
|
||||
class MyPygPCQM4Mv2Dataset(PygPCQM4Mv2Dataset):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyPygPCQM4Mv2Dataset, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
def process(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyPygPCQM4Mv2Dataset, self).process()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class MyPygPCQM4MDataset(PygPCQM4MDataset):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyPygPCQM4MDataset, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
def process(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyPygPCQM4MDataset, self).process()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class MyPygGraphPropPredDataset(PygGraphPropPredDataset):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyPygGraphPropPredDataset, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
def process(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyPygGraphPropPredDataset, self).process()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class OGBDatasetLookupTable:
|
||||
@staticmethod
|
||||
def GetOGBDataset(dataset_name: str, seed: int) -> Optional[Dataset]:
|
||||
inner_dataset = None
|
||||
train_idx = None
|
||||
valid_idx = None
|
||||
test_idx = None
|
||||
if dataset_name == "ogbg-molhiv":
|
||||
folder_name = dataset_name.replace("-", "_")
|
||||
os.system(f"mkdir -p dataset/{folder_name}/")
|
||||
os.system(f"touch dataset/{folder_name}/RELEASE_v1.txt")
|
||||
inner_dataset = MyPygGraphPropPredDataset(dataset_name)
|
||||
idx_split = inner_dataset.get_idx_split()
|
||||
train_idx = idx_split["train"]
|
||||
valid_idx = idx_split["valid"]
|
||||
test_idx = idx_split["test"]
|
||||
elif dataset_name == "ogbg-molpcba":
|
||||
folder_name = dataset_name.replace("-", "_")
|
||||
os.system(f"mkdir -p dataset/{folder_name}/")
|
||||
os.system(f"touch dataset/{folder_name}/RELEASE_v1.txt")
|
||||
inner_dataset = MyPygGraphPropPredDataset(dataset_name)
|
||||
idx_split = inner_dataset.get_idx_split()
|
||||
train_idx = idx_split["train"]
|
||||
valid_idx = idx_split["valid"]
|
||||
test_idx = idx_split["test"]
|
||||
elif dataset_name == "pcqm4mv2":
|
||||
os.system("mkdir -p dataset/pcqm4m-v2/")
|
||||
os.system("touch dataset/pcqm4m-v2/RELEASE_v1.txt")
|
||||
inner_dataset = MyPygPCQM4Mv2Dataset()
|
||||
idx_split = inner_dataset.get_idx_split()
|
||||
train_idx = idx_split["train"]
|
||||
valid_idx = idx_split["valid"]
|
||||
test_idx = idx_split["test-dev"]
|
||||
elif dataset_name == "pcqm4m":
|
||||
os.system("mkdir -p dataset/pcqm4m_kddcup2021/")
|
||||
os.system("touch dataset/pcqm4m_kddcup2021/RELEASE_v1.txt")
|
||||
inner_dataset = MyPygPCQM4MDataset()
|
||||
idx_split = inner_dataset.get_idx_split()
|
||||
train_idx = idx_split["train"]
|
||||
valid_idx = idx_split["valid"]
|
||||
test_idx = idx_split["test"]
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name {dataset_name} for ogb source.")
|
||||
return (
|
||||
None
|
||||
if inner_dataset is None
|
||||
else GraphormerPYGDataset(
|
||||
inner_dataset, seed, train_idx, valid_idx, test_idx
|
||||
)
|
||||
)
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .pyg_dataset_lookup_table import PYGDatasetLookupTable
|
||||
from .pyg_dataset import GraphormerPYGDataset
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from torch_geometric.data import Dataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from typing import List
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from ..wrapper import preprocess_item
|
||||
from .. import algos
|
||||
|
||||
import copy
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class GraphormerPYGDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
seed: int = 0,
|
||||
train_idx=None,
|
||||
valid_idx=None,
|
||||
test_idx=None,
|
||||
train_set=None,
|
||||
valid_set=None,
|
||||
test_set=None,
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.num_data = len(self.dataset)
|
||||
self.seed = seed
|
||||
if train_idx is None and train_set is None:
|
||||
train_valid_idx, test_idx = train_test_split(
|
||||
np.arange(self.num_data),
|
||||
test_size=self.num_data // 10,
|
||||
random_state=seed,
|
||||
)
|
||||
train_idx, valid_idx = train_test_split(
|
||||
train_valid_idx, test_size=self.num_data // 5, random_state=seed
|
||||
)
|
||||
self.train_idx = torch.from_numpy(train_idx)
|
||||
self.valid_idx = torch.from_numpy(valid_idx)
|
||||
self.test_idx = torch.from_numpy(test_idx)
|
||||
self.train_data = self.index_select(self.train_idx)
|
||||
self.valid_data = self.index_select(self.valid_idx)
|
||||
self.test_data = self.index_select(self.test_idx)
|
||||
elif train_set is not None:
|
||||
self.train_data = self.create_subset(train_set)
|
||||
self.valid_data = self.create_subset(valid_set)
|
||||
self.test_data = self.create_subset(test_set)
|
||||
self.train_idx = None
|
||||
self.valid_idx = None
|
||||
self.test_idx = None
|
||||
else:
|
||||
self.train_idx = train_idx
|
||||
self.valid_idx = valid_idx
|
||||
self.test_idx = test_idx
|
||||
self.train_data = self.index_select(self.train_idx)
|
||||
self.valid_data = self.index_select(self.valid_idx)
|
||||
self.test_data = self.index_select(self.test_idx)
|
||||
self.__indices__ = None
|
||||
|
||||
def index_select(self, idx):
|
||||
dataset = copy.copy(self)
|
||||
dataset.dataset = self.dataset.index_select(idx)
|
||||
if isinstance(idx, torch.Tensor):
|
||||
dataset.num_data = idx.size(0)
|
||||
else:
|
||||
dataset.num_data = idx.shape[0]
|
||||
dataset.__indices__ = idx
|
||||
dataset.train_data = None
|
||||
dataset.valid_data = None
|
||||
dataset.test_data = None
|
||||
dataset.train_idx = None
|
||||
dataset.valid_idx = None
|
||||
dataset.test_idx = None
|
||||
return dataset
|
||||
|
||||
def create_subset(self, subset):
|
||||
dataset = copy.copy(self)
|
||||
dataset.dataset = subset
|
||||
dataset.num_data = len(subset)
|
||||
dataset.__indices__ = None
|
||||
dataset.train_data = None
|
||||
dataset.valid_data = None
|
||||
dataset.test_data = None
|
||||
dataset.train_idx = None
|
||||
dataset.valid_idx = None
|
||||
dataset.test_idx = None
|
||||
return dataset
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
item = self.dataset[idx]
|
||||
item.idx = idx
|
||||
item.y = item.y.reshape(-1)
|
||||
return preprocess_item(item)
|
||||
else:
|
||||
raise TypeError("index to a GraphormerPYGDataset can only be an integer.")
|
||||
|
||||
def __len__(self):
|
||||
return self.num_data
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Optional
|
||||
from torch_geometric.datasets import *
|
||||
from torch_geometric.data import Dataset
|
||||
from torch.nn import functional
|
||||
from .pyg_dataset import GraphormerPYGDataset
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class MyQM7b(QM7b):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyQM7b, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
def process(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyQM7b, self).process()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class MyQM9(QM9):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyQM9, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
def process(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyQM9, self).process()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
class MyZINC(ZINC):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyZINC, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
def process(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyZINC, self).process()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class MyMoleculeNet(MoleculeNet):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyMoleculeNet, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
def process(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyMoleculeNet, self).process()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
class MyMD17(MD17):
|
||||
def download(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyMD17, self).download()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
def process(self):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
super(MyMD17, self).process()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class PYGDatasetLookupTable:
|
||||
@staticmethod
|
||||
def GetPYGDataset(dataset_spec: str, seed: int) -> Optional[Dataset]:
|
||||
split_result = dataset_spec.split(":")
|
||||
if len(split_result) == 2:
|
||||
name, params = split_result[0], split_result[1]
|
||||
params = params.split(",")
|
||||
elif len(split_result) == 1:
|
||||
name = dataset_spec
|
||||
params = []
|
||||
inner_dataset = None
|
||||
num_class = 1
|
||||
|
||||
train_set = None
|
||||
valid_set = None
|
||||
test_set = None
|
||||
|
||||
root = "dataset"
|
||||
if name == "qm7b":
|
||||
inner_dataset = MyQM7b(root=root)
|
||||
elif name == "qm9":
|
||||
inner_dataset = MyQM9(root=root)
|
||||
elif name == "zinc":
|
||||
inner_dataset = MyZINC(root=root)
|
||||
train_set = MyZINC(root=root, split="train")
|
||||
valid_set = MyZINC(root=root, split="val")
|
||||
test_set = MyZINC(root=root, split="test")
|
||||
elif name == "moleculenet":
|
||||
nm = None
|
||||
for param in params:
|
||||
name, value = param.split("=")
|
||||
if name == "name":
|
||||
nm = value
|
||||
inner_dataset = MyMoleculeNet(root=root, name=nm)
|
||||
elif name == "md17":
|
||||
nm = None
|
||||
for param in params:
|
||||
name, value = param.split("=")
|
||||
if name == "name":
|
||||
nm = value
|
||||
inner_dataset = MyMD17(root=root, name=nm)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name {name} for pyg source.")
|
||||
if train_set is not None:
|
||||
return GraphormerPYGDataset(
|
||||
None,
|
||||
seed,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
train_set,
|
||||
valid_set,
|
||||
test_set,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
None
|
||||
if inner_dataset is None
|
||||
else GraphormerPYGDataset(inner_dataset, seed)
|
||||
)
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from ..wrapper import preprocess_item
|
||||
from .. import algos
|
||||
from ..pyg_datasets import GraphormerPYGDataset
|
||||
|
||||
from ogb.utils.mol import smiles2graph
|
||||
|
||||
|
||||
class GraphormerSMILESDataset(GraphormerPYGDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: str,
|
||||
num_class: int,
|
||||
max_node: int,
|
||||
multi_hop_max_dist: int,
|
||||
spatial_pos_max: int,
|
||||
):
|
||||
self.dataset = np.genfromtxt(dataset, delimiter=",", dtype=str)
|
||||
num_data = len(self.dataset)
|
||||
self.num_class = num_class
|
||||
self.__get_graph_metainfo(max_node, multi_hop_max_dist, spatial_pos_max)
|
||||
train_valid_idx, test_idx = train_test_split(num_data // 10)
|
||||
train_idx, valid_idx = train_test_split(train_valid_idx, num_data // 5)
|
||||
self.train_idx = train_idx
|
||||
self.valid_idx = valid_idx
|
||||
self.test_idx = test_idx
|
||||
self.__indices__ = None
|
||||
self.train_data = self.index_select(train_idx)
|
||||
self.valid_data = self.index_select(valid_idx)
|
||||
self.test_data = self.index_select(test_idx)
|
||||
|
||||
def __get_graph_metainfo(
|
||||
self, max_node: int, multi_hop_max_dist: int, spatial_pos_max: int
|
||||
):
|
||||
self.max_node = min(
|
||||
max_node,
|
||||
torch.max(self.dataset[i][0].num_nodes() for i in range(len(self.dataset))),
|
||||
)
|
||||
max_dist = 0
|
||||
for i in range(len(self.dataset)):
|
||||
pyg_graph = smiles2graph(self.dataset[i])
|
||||
dense_adj = pyg_graph.adj().to_dense().type(torch.int)
|
||||
shortest_path_result, _ = algos.floyd_warshall(dense_adj.numpy())
|
||||
max_dist = max(max_dist, np.amax(shortest_path_result))
|
||||
self.multi_hop_max_dist = min(multi_hop_max_dist, max_dist)
|
||||
self.spatial_pos_max = min(spatial_pos_max, max_dist)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
item = smiles2graph(self.dataset[idx])
|
||||
item.idx = idx
|
||||
return preprocess_item(item)
|
||||
else:
|
||||
raise TypeError("index to a GraphormerPYGDataset can only be an integer.")
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from ogb.graphproppred import PygGraphPropPredDataset
|
||||
from ogb.lsc.pcqm4mv2_pyg import PygPCQM4Mv2Dataset
|
||||
from functools import lru_cache
|
||||
import pyximport
|
||||
import torch.distributed as dist
|
||||
|
||||
pyximport.install(setup_args={"include_dirs": np.get_include()})
|
||||
from . import algos
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def convert_to_single_emb(x, offset: int = 512):
|
||||
feature_num = x.size(1) if len(x.size()) > 1 else 1
|
||||
feature_offset = 1 + torch.arange(0, feature_num * offset, offset, dtype=torch.long)
|
||||
x = x + feature_offset
|
||||
return x
|
||||
|
||||
|
||||
def preprocess_item(item):
|
||||
edge_attr, edge_index, x = item.edge_attr, item.edge_index, item.x
|
||||
N = x.size(0)
|
||||
x = convert_to_single_emb(x)
|
||||
|
||||
# node adj matrix [N, N] bool
|
||||
adj = torch.zeros([N, N], dtype=torch.bool)
|
||||
adj[edge_index[0, :], edge_index[1, :]] = True
|
||||
|
||||
# edge feature here
|
||||
if len(edge_attr.size()) == 1:
|
||||
edge_attr = edge_attr[:, None]
|
||||
attn_edge_type = torch.zeros([N, N, edge_attr.size(-1)], dtype=torch.long)
|
||||
attn_edge_type[edge_index[0, :], edge_index[1, :]] = (
|
||||
convert_to_single_emb(edge_attr) + 1
|
||||
)
|
||||
|
||||
shortest_path_result, path = algos.floyd_warshall(adj.numpy())
|
||||
max_dist = np.amax(shortest_path_result)
|
||||
edge_input = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy())
|
||||
spatial_pos = torch.from_numpy((shortest_path_result)).long()
|
||||
attn_bias = torch.zeros([N + 1, N + 1], dtype=torch.float) # with graph token
|
||||
|
||||
# combine
|
||||
item.x = x
|
||||
item.attn_bias = attn_bias
|
||||
item.attn_edge_type = attn_edge_type
|
||||
item.spatial_pos = spatial_pos
|
||||
item.in_degree = adj.long().sum(dim=1).view(-1)
|
||||
item.out_degree = item.in_degree # for undirected graph
|
||||
item.edge_input = torch.from_numpy(edge_input).long()
|
||||
|
||||
return item
|
||||
|
||||
|
||||
class MyPygPCQM4MDataset(PygPCQM4Mv2Dataset):
|
||||
def download(self):
|
||||
super(MyPygPCQM4MDataset, self).download()
|
||||
|
||||
def process(self):
|
||||
super(MyPygPCQM4MDataset, self).process()
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def __getitem__(self, idx):
|
||||
item = self.get(self.indices()[idx])
|
||||
item.idx = idx
|
||||
return preprocess_item(item)
|
||||
|
||||
|
||||
class MyPygGraphPropPredDataset(PygGraphPropPredDataset):
|
||||
def download(self):
|
||||
if dist.get_rank() == 0:
|
||||
super(MyPygGraphPropPredDataset, self).download()
|
||||
dist.barrier()
|
||||
|
||||
def process(self):
|
||||
if dist.get_rank() == 0:
|
||||
super(MyPygGraphPropPredDataset, self).process()
|
||||
dist.barrier()
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def __getitem__(self, idx):
|
||||
item = self.get(self.indices()[idx])
|
||||
item.idx = idx
|
||||
item.y = item.y.reshape(-1)
|
||||
return preprocess_item(item)
|
|
@ -1,115 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from model import Graphormer
|
||||
from data import GraphDataModule, get_dataset
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from pprint import pprint
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
||||
import os
|
||||
|
||||
|
||||
def cli_main():
|
||||
# ------------
|
||||
# args
|
||||
# ------------
|
||||
parser = ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = Graphormer.add_model_specific_args(parser)
|
||||
parser = GraphDataModule.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
args.max_steps = args.tot_updates + 1
|
||||
if not args.test and not args.validate:
|
||||
print(args)
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
# ------------
|
||||
# data
|
||||
# ------------
|
||||
dm = GraphDataModule.from_argparse_args(args)
|
||||
|
||||
# ------------
|
||||
# model
|
||||
# ------------
|
||||
if args.checkpoint_path != '':
|
||||
model = Graphormer.load_from_checkpoint(
|
||||
args.checkpoint_path,
|
||||
strict=False,
|
||||
n_layers=args.n_layers,
|
||||
num_heads=args.num_heads,
|
||||
hidden_dim=args.hidden_dim,
|
||||
attention_dropout_rate=args.attention_dropout_rate,
|
||||
dropout_rate=args.dropout_rate,
|
||||
intput_dropout_rate=args.intput_dropout_rate,
|
||||
weight_decay=args.weight_decay,
|
||||
ffn_dim=args.ffn_dim,
|
||||
dataset_name=dm.dataset_name,
|
||||
warmup_updates=args.warmup_updates,
|
||||
tot_updates=args.tot_updates,
|
||||
peak_lr=args.peak_lr,
|
||||
end_lr=args.end_lr,
|
||||
edge_type=args.edge_type,
|
||||
multi_hop_max_dist=args.multi_hop_max_dist,
|
||||
flag=args.flag,
|
||||
flag_m=args.flag_m,
|
||||
flag_step_size=args.flag_step_size,
|
||||
)
|
||||
else:
|
||||
model = Graphormer(
|
||||
n_layers=args.n_layers,
|
||||
num_heads=args.num_heads,
|
||||
hidden_dim=args.hidden_dim,
|
||||
attention_dropout_rate=args.attention_dropout_rate,
|
||||
dropout_rate=args.dropout_rate,
|
||||
intput_dropout_rate=args.intput_dropout_rate,
|
||||
weight_decay=args.weight_decay,
|
||||
ffn_dim=args.ffn_dim,
|
||||
dataset_name=dm.dataset_name,
|
||||
warmup_updates=args.warmup_updates,
|
||||
tot_updates=args.tot_updates,
|
||||
peak_lr=args.peak_lr,
|
||||
end_lr=args.end_lr,
|
||||
edge_type=args.edge_type,
|
||||
multi_hop_max_dist=args.multi_hop_max_dist,
|
||||
flag=args.flag,
|
||||
flag_m=args.flag_m,
|
||||
flag_step_size=args.flag_step_size,
|
||||
)
|
||||
if not args.test and not args.validate:
|
||||
print(model)
|
||||
print('total params:', sum(p.numel() for p in model.parameters()))
|
||||
|
||||
# ------------
|
||||
# training
|
||||
# ------------
|
||||
metric = 'valid_' + get_dataset(dm.dataset_name)['metric']
|
||||
dirpath = args.default_root_dir + f'/lightning_logs/checkpoints'
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
monitor=metric,
|
||||
dirpath=dirpath,
|
||||
filename=dm.dataset_name + '-{epoch:03d}-{' + metric + ':.4f}',
|
||||
save_top_k=100,
|
||||
mode=get_dataset(dm.dataset_name)['metric_mode'],
|
||||
save_last=True,
|
||||
)
|
||||
if not args.test and not args.validate and os.path.exists(dirpath + '/last.ckpt'):
|
||||
args.resume_from_checkpoint = dirpath + '/last.ckpt'
|
||||
print('args.resume_from_checkpoint', args.resume_from_checkpoint)
|
||||
trainer = pl.Trainer.from_argparse_args(args)
|
||||
trainer.callbacks.append(checkpoint_callback)
|
||||
trainer.callbacks.append(LearningRateMonitor(logging_interval='step'))
|
||||
|
||||
if args.test:
|
||||
result = trainer.test(model, datamodule=dm)
|
||||
pprint(result)
|
||||
elif args.validate:
|
||||
result = trainer.validate(model, datamodule=dm)
|
||||
pprint(result)
|
||||
else:
|
||||
trainer.fit(model, datamodule=dm)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli_main()
|
|
@ -0,0 +1,127 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from fairseq import checkpoint_utils, utils, options, tasks
|
||||
from fairseq.logging import progress_bar
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
import ogb
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
import sys
|
||||
from os import path
|
||||
|
||||
sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) )
|
||||
from pretrain import load_pretrained_model
|
||||
|
||||
import logging
|
||||
|
||||
def eval(args, use_pretrained, checkpoint_path=None, logger=None):
|
||||
cfg = convert_namespace_to_omegaconf(args)
|
||||
np.random.seed(cfg.common.seed)
|
||||
utils.set_torch_seed(cfg.common.seed)
|
||||
|
||||
# initialize task
|
||||
task = tasks.setup_task(cfg.task)
|
||||
model = task.build_model(cfg.model)
|
||||
|
||||
# load checkpoint
|
||||
if use_pretrained:
|
||||
model_state = load_pretrained_model(cfg.task.pretrained_model_name)
|
||||
else:
|
||||
model_state = torch.load(checkpoint_path)["model"]
|
||||
model.load_state_dict(
|
||||
model_state, strict=True, model_cfg=cfg.model
|
||||
)
|
||||
del model_state
|
||||
|
||||
model.to(torch.cuda.current_device())
|
||||
# load dataset
|
||||
split = args.split
|
||||
task.load_dataset(split)
|
||||
batch_iterator = task.get_batch_iterator(
|
||||
dataset=task.dataset(split),
|
||||
max_tokens=cfg.dataset.max_tokens_valid,
|
||||
max_sentences=cfg.dataset.batch_size_valid,
|
||||
max_positions=utils.resolve_max_positions(
|
||||
task.max_positions(),
|
||||
model.max_positions(),
|
||||
),
|
||||
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
||||
seed=cfg.common.seed,
|
||||
num_workers=cfg.dataset.num_workers,
|
||||
epoch=0,
|
||||
data_buffer_size=cfg.dataset.data_buffer_size,
|
||||
disable_iterator_cache=False,
|
||||
)
|
||||
itr = batch_iterator.next_epoch_itr(
|
||||
shuffle=False, set_dataset_epoch=False
|
||||
)
|
||||
progress = progress_bar.progress_bar(
|
||||
itr,
|
||||
log_format=cfg.common.log_format,
|
||||
log_interval=cfg.common.log_interval,
|
||||
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple")
|
||||
)
|
||||
|
||||
# infer
|
||||
y_pred = []
|
||||
y_true = []
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
for i, sample in enumerate(progress):
|
||||
sample = utils.move_to_cuda(sample)
|
||||
y = model(**sample["net_input"])[:, 0, :].reshape(-1)
|
||||
y_pred.extend(y.detach().cpu())
|
||||
y_true.extend(sample["target"].detach().cpu().reshape(-1)[:y.shape[0]])
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# save predictions
|
||||
y_pred = torch.Tensor(y_pred)
|
||||
y_true = torch.Tensor(y_true)
|
||||
|
||||
# evaluate
|
||||
if use_pretrained:
|
||||
if cfg.task.pretrained_model_name == "pcqm4mv1_graphormer_base":
|
||||
evaluator = ogb.lsc.PCQM4Mv2Evaluator()
|
||||
input_dict = {'y_pred': y_pred, 'y_true': y_true}
|
||||
result_dict = evaluator.eval(input_dict)
|
||||
logger.info(f'PCQM4Mv2Evaluator: {result_dict}')
|
||||
elif cfg.task.pretrained_model_name == "pcqm4mv2_graphormer_base":
|
||||
evaluator = ogb.lsc.PCQM4MEvaluator()
|
||||
input_dict = {'y_pred': y_pred, 'y_true': y_true}
|
||||
result_dict = evaluator.eval(input_dict)
|
||||
logger.info(f'PCQM4Mv1Evaluator: {result_dict}')
|
||||
else:
|
||||
if args.metric == "auc":
|
||||
auc = roc_auc_score(y_true, y_pred)
|
||||
logger.info(f"auc: {auc}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported metric {args.metric}")
|
||||
|
||||
def main():
|
||||
parser = options.get_training_parser()
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metric",
|
||||
type=str,
|
||||
)
|
||||
args = options.parse_args_and_arch(parser, modify_parser=None)
|
||||
logger = logging.getLogger(__name__)
|
||||
if hasattr(args, "save_dir"):
|
||||
for checkpoint_fname in os.listdir(args.save_dir):
|
||||
checkpoint_path = Path(args.save_dir) / checkpoint_fname
|
||||
logger.info(f"evaluating checkpoint file {checkpoint_path}")
|
||||
eval(args, False, checkpoint_path, logger)
|
||||
else:
|
||||
eval(args, True, logger=logger)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,92 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from fairseq import checkpoint_utils, utils, options, tasks
|
||||
from fairseq.logging import progress_bar
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
import ogb
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
|
||||
def main(iter):
|
||||
parser = options.get_training_parser()
|
||||
args = options.parse_args_and_arch(parser, modify_parser=None)
|
||||
cfg = convert_namespace_to_omegaconf(args)
|
||||
np.random.seed(cfg.common.seed)
|
||||
utils.set_torch_seed(cfg.common.seed)
|
||||
|
||||
# initialize task
|
||||
task = tasks.setup_task(cfg.task)
|
||||
model = task.build_model(cfg.model)
|
||||
criterion = task.build_criterion(cfg.criterion)
|
||||
|
||||
# load checkpoint
|
||||
checkpoint_path = cfg.checkpoint.save_dir + f"checkpoint{iter}.pt"
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint_path) # or best?
|
||||
model.load_state_dict(
|
||||
state["model"], strict=True, model_cfg=cfg.model
|
||||
)
|
||||
del state["model"]
|
||||
|
||||
model.to(torch.cuda.current_device())
|
||||
# load dataset
|
||||
split = "test"
|
||||
task.load_dataset(split)
|
||||
batch_iterator = task.get_batch_iterator(
|
||||
dataset=task.dataset(split),
|
||||
max_tokens=cfg.dataset.max_tokens_valid,
|
||||
max_sentences=cfg.dataset.batch_size_valid,
|
||||
max_positions=utils.resolve_max_positions(
|
||||
task.max_positions(),
|
||||
model.max_positions(),
|
||||
),
|
||||
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
||||
seed=cfg.common.seed,
|
||||
num_workers=cfg.dataset.num_workers,
|
||||
epoch=0,
|
||||
data_buffer_size=cfg.dataset.data_buffer_size,
|
||||
disable_iterator_cache=False,
|
||||
)
|
||||
itr = batch_iterator.next_epoch_itr(
|
||||
shuffle=False, set_dataset_epoch=False
|
||||
)
|
||||
progress = progress_bar.progress_bar(
|
||||
itr,
|
||||
log_format=cfg.common.log_format,
|
||||
log_interval=cfg.common.log_interval,
|
||||
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple")
|
||||
)
|
||||
|
||||
# infer
|
||||
y_pred = []
|
||||
y_true = []
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
for i, sample in enumerate(progress):
|
||||
sample = utils.move_to_cuda(sample)
|
||||
output = model(**sample["net_input"])[:, 0, :]
|
||||
y = output.reshape(-1)
|
||||
y_pred.extend(y.detach().cpu())
|
||||
y_true.extend(sample["target"].detach().cpu().reshape(-1)[:y.shape[0]])
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
auc = roc_auc_score(y_true, y_pred)
|
||||
print("auc = %f" % auc)
|
||||
|
||||
# save predictions
|
||||
y_pred = torch.Tensor(y_pred)
|
||||
y_true = torch.Tensor(y_true)
|
||||
|
||||
torch.save(y_pred, "y_pred.pt")
|
||||
torch.save(y_true, "y_true.pt")
|
||||
|
||||
# evaluate
|
||||
evaluator = ogb.lsc.PCQM4MEvaluator()
|
||||
input_dict = {'y_pred': y_pred, 'y_true': y_true}
|
||||
result_dict = evaluator.eval(input_dict)
|
||||
print('PCQM4MEvaluator:', result_dict)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for iter in range(1, 9):
|
||||
main(iter)
|
|
@ -0,0 +1,2 @@
|
|||
# CUDA_VISIBLE_DEVICES=0 python evaluate_hiv.py --user-dir ../ --num-workers 16 --ddp-backend=legacy_ddp --dataset-name ogbg-molhiv --dataset-source ogb --task graph_prediction_with_flag --criterion binary_logloss_with_flag --arch graphormer_base --num-classes 1 --batch-size 64 --fp16 --data-buffer-size 20 --encoder-layers 12 --encoder-embed-dim 768 --encoder-ffn-embed-dim 768 --encoder-attention-heads 32 --save-dir ../../examples/property_prediction/ckpts1/ --seed 1
|
||||
CUDA_VISIBLE_DEVICES=0 python evaluate_hiv.py --user-dir ../ --num-workers 16 --ddp-backend=legacy_ddp --dataset-name ogbg-molhiv --dataset-source ogb --task graph_prediction --criterion binary_logloss --arch graphormer_base --num-classes 1 --batch-size 64 --fp16 --data-buffer-size 20 --encoder-layers 12 --encoder-embed-dim 768 --encoder-ffn-embed-dim 768 --encoder-attention-heads 32 --save-dir ../../examples/property_prediction/ckpts1/ --seed 1
|
|
@ -1,34 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class PolynomialDecayLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_updates, tot_updates, lr, end_lr, power, last_epoch=-1, verbose=False):
|
||||
self.warmup_updates = warmup_updates
|
||||
self.tot_updates = tot_updates
|
||||
self.lr = lr
|
||||
self.end_lr = end_lr
|
||||
self.power = power
|
||||
super(PolynomialDecayLR, self).__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if self._step_count <= self.warmup_updates:
|
||||
self.warmup_factor = self._step_count / float(self.warmup_updates)
|
||||
lr = self.warmup_factor * self.lr
|
||||
elif self._step_count >= self.tot_updates:
|
||||
lr = self.end_lr
|
||||
else:
|
||||
warmup = self.warmup_updates
|
||||
lr_range = self.lr - self.end_lr
|
||||
pct_remaining = 1 - (self._step_count - warmup) / (
|
||||
self.tot_updates - warmup
|
||||
)
|
||||
lr = lr_range * pct_remaining ** (self.power) + self.end_lr
|
||||
|
||||
return [lr for group in self.optimizer.param_groups]
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
assert False
|
|
@ -1,423 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from data import get_dataset
|
||||
from lr import PolynomialDecayLR
|
||||
import torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from utils.flag import flag_bounded
|
||||
|
||||
|
||||
def init_params(module, n_layers):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(n_layers))
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
|
||||
|
||||
class Graphormer(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
n_layers,
|
||||
num_heads,
|
||||
hidden_dim,
|
||||
dropout_rate,
|
||||
intput_dropout_rate,
|
||||
weight_decay,
|
||||
ffn_dim,
|
||||
dataset_name,
|
||||
warmup_updates,
|
||||
tot_updates,
|
||||
peak_lr,
|
||||
end_lr,
|
||||
edge_type,
|
||||
multi_hop_max_dist,
|
||||
attention_dropout_rate,
|
||||
flag=False,
|
||||
flag_m=3,
|
||||
flag_step_size=1e-3,
|
||||
flag_mag=1e-3,
|
||||
):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.num_heads = num_heads
|
||||
if dataset_name == 'ZINC':
|
||||
self.atom_encoder = nn.Embedding(64, hidden_dim, padding_idx=0)
|
||||
self.edge_encoder = nn.Embedding(64, num_heads, padding_idx=0)
|
||||
self.edge_type = edge_type
|
||||
if self.edge_type == 'multi_hop':
|
||||
self.edge_dis_encoder = nn.Embedding(
|
||||
40 * num_heads * num_heads, 1)
|
||||
self.spatial_pos_encoder = nn.Embedding(40, num_heads, padding_idx=0)
|
||||
self.in_degree_encoder = nn.Embedding(
|
||||
64, hidden_dim, padding_idx=0)
|
||||
self.out_degree_encoder = nn.Embedding(
|
||||
64, hidden_dim, padding_idx=0)
|
||||
else:
|
||||
self.atom_encoder = nn.Embedding(
|
||||
512 * 9 + 1, hidden_dim, padding_idx=0)
|
||||
self.edge_encoder = nn.Embedding(
|
||||
512 * 3 + 1, num_heads, padding_idx=0)
|
||||
self.edge_type = edge_type
|
||||
if self.edge_type == 'multi_hop':
|
||||
self.edge_dis_encoder = nn.Embedding(
|
||||
128 * num_heads * num_heads, 1)
|
||||
self.spatial_pos_encoder = nn.Embedding(512, num_heads, padding_idx=0)
|
||||
self.in_degree_encoder = nn.Embedding(
|
||||
512, hidden_dim, padding_idx=0)
|
||||
self.out_degree_encoder = nn.Embedding(
|
||||
512, hidden_dim, padding_idx=0)
|
||||
|
||||
self.input_dropout = nn.Dropout(intput_dropout_rate)
|
||||
encoders = [EncoderLayer(hidden_dim, ffn_dim, dropout_rate, attention_dropout_rate, num_heads)
|
||||
for _ in range(n_layers)]
|
||||
self.layers = nn.ModuleList(encoders)
|
||||
self.final_ln = nn.LayerNorm(hidden_dim)
|
||||
|
||||
if dataset_name == 'PCQM4M-LSC':
|
||||
self.out_proj = nn.Linear(hidden_dim, 1)
|
||||
else:
|
||||
self.downstream_out_proj = nn.Linear(
|
||||
hidden_dim, get_dataset(dataset_name)['num_class'])
|
||||
|
||||
self.graph_token = nn.Embedding(1, hidden_dim)
|
||||
self.graph_token_virtual_distance = nn.Embedding(1, num_heads)
|
||||
|
||||
self.evaluator = get_dataset(dataset_name)['evaluator']
|
||||
self.metric = get_dataset(dataset_name)['metric']
|
||||
self.loss_fn = get_dataset(dataset_name)['loss_fn']
|
||||
self.dataset_name = dataset_name
|
||||
|
||||
self.warmup_updates = warmup_updates
|
||||
self.tot_updates = tot_updates
|
||||
self.peak_lr = peak_lr
|
||||
self.end_lr = end_lr
|
||||
self.weight_decay = weight_decay
|
||||
self.multi_hop_max_dist = multi_hop_max_dist
|
||||
|
||||
self.flag = flag
|
||||
self.flag_m = flag_m
|
||||
self.flag_step_size = flag_step_size
|
||||
self.flag_mag = flag_mag
|
||||
self.hidden_dim = hidden_dim
|
||||
self.automatic_optimization = not self.flag
|
||||
self.apply(lambda module: init_params(module, n_layers=n_layers))
|
||||
|
||||
def forward(self, batched_data, perturb=None):
|
||||
attn_bias, spatial_pos, x = batched_data.attn_bias, batched_data.spatial_pos, batched_data.x
|
||||
in_degree, out_degree = batched_data.in_degree, batched_data.in_degree
|
||||
edge_input, attn_edge_type = batched_data.edge_input, batched_data.attn_edge_type
|
||||
# graph_attn_bias
|
||||
n_graph, n_node = x.size()[:2]
|
||||
graph_attn_bias = attn_bias.clone()
|
||||
graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(
|
||||
1, self.num_heads, 1, 1) # [n_graph, n_head, n_node+1, n_node+1]
|
||||
|
||||
# spatial pos
|
||||
# [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
|
||||
spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
|
||||
graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:,
|
||||
:, 1:, 1:] + spatial_pos_bias
|
||||
# reset spatial pos here
|
||||
t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1)
|
||||
graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
|
||||
graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t
|
||||
|
||||
# edge feature
|
||||
if self.edge_type == 'multi_hop':
|
||||
spatial_pos_ = spatial_pos.clone()
|
||||
spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1
|
||||
# set 1 to 1, x > 1 to x - 1
|
||||
spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
|
||||
if self.multi_hop_max_dist > 0:
|
||||
spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist)
|
||||
edge_input = edge_input[:, :, :, :self.multi_hop_max_dist, :]
|
||||
# [n_graph, n_node, n_node, max_dist, n_head]
|
||||
edge_input = self.edge_encoder(edge_input).mean(-2)
|
||||
max_dist = edge_input.size(-2)
|
||||
edge_input_flat = edge_input.permute(
|
||||
3, 0, 1, 2, 4).reshape(max_dist, -1, self.num_heads)
|
||||
edge_input_flat = torch.bmm(edge_input_flat, self.edge_dis_encoder.weight.reshape(
|
||||
-1, self.num_heads, self.num_heads)[:max_dist, :, :])
|
||||
edge_input = edge_input_flat.reshape(
|
||||
max_dist, n_graph, n_node, n_node, self.num_heads).permute(1, 2, 3, 0, 4)
|
||||
edge_input = (edge_input.sum(-2) /
|
||||
(spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2)
|
||||
else:
|
||||
# [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
|
||||
edge_input = self.edge_encoder(
|
||||
attn_edge_type).mean(-2).permute(0, 3, 1, 2)
|
||||
|
||||
graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:,
|
||||
:, 1:, 1:] + edge_input
|
||||
graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset
|
||||
|
||||
# node feauture + graph token
|
||||
node_feature = self.atom_encoder(x).sum(
|
||||
dim=-2) # [n_graph, n_node, n_hidden]
|
||||
if self.flag and perturb is not None:
|
||||
node_feature += perturb
|
||||
|
||||
node_feature = node_feature + \
|
||||
self.in_degree_encoder(in_degree) + \
|
||||
self.out_degree_encoder(out_degree)
|
||||
graph_token_feature = self.graph_token.weight.unsqueeze(
|
||||
0).repeat(n_graph, 1, 1)
|
||||
graph_node_feature = torch.cat(
|
||||
[graph_token_feature, node_feature], dim=1)
|
||||
|
||||
# transfomrer encoder
|
||||
output = self.input_dropout(graph_node_feature)
|
||||
for enc_layer in self.layers:
|
||||
output = enc_layer(output, graph_attn_bias)
|
||||
output = self.final_ln(output)
|
||||
|
||||
# output part
|
||||
if self.dataset_name == 'PCQM4M-LSC':
|
||||
# get whole graph rep
|
||||
output = self.out_proj(output[:, 0, :])
|
||||
else:
|
||||
output = self.downstream_out_proj(output[:, 0, :])
|
||||
return output
|
||||
|
||||
def training_step(self, batched_data, batch_idx):
|
||||
if self.dataset_name == 'ogbg-molpcba':
|
||||
if not self.flag:
|
||||
y_hat = self(batched_data).view(-1)
|
||||
y_gt = batched_data.y.view(-1).float()
|
||||
mask = ~torch.isnan(y_gt)
|
||||
loss = self.loss_fn(y_hat[mask], y_gt[mask])
|
||||
else:
|
||||
y_gt = batched_data.y.view(-1).float()
|
||||
mask = ~torch.isnan(y_gt)
|
||||
|
||||
def forward(perturb): return self(batched_data, perturb)
|
||||
model_forward = (self, forward)
|
||||
n_graph, n_node = batched_data.x.size()[:2]
|
||||
perturb_shape = (n_graph, n_node, self.hidden_dim)
|
||||
|
||||
optimizer = self.optimizers()
|
||||
optimizer.zero_grad()
|
||||
loss, _ = flag_bounded(model_forward, perturb_shape, y_gt[mask], optimizer, batched_data.x.device, self.loss_fn,
|
||||
m=self.flag_m, step_size=self.flag_step_size, mag=self.flag_mag, mask=mask)
|
||||
self.lr_schedulers().step()
|
||||
|
||||
elif self.dataset_name == 'ogbg-molhiv':
|
||||
if not self.flag:
|
||||
y_hat = self(batched_data).view(-1)
|
||||
y_gt = batched_data.y.view(-1).float()
|
||||
loss = self.loss_fn(y_hat, y_gt)
|
||||
else:
|
||||
y_gt = batched_data.y.view(-1).float()
|
||||
def forward(perturb): return self(batched_data, perturb)
|
||||
model_forward = (self, forward)
|
||||
n_graph, n_node = batched_data.x.size()[:2]
|
||||
perturb_shape = (n_graph, n_node, self.hidden_dim)
|
||||
|
||||
optimizer = self.optimizers()
|
||||
optimizer.zero_grad()
|
||||
loss, _ = flag_bounded(model_forward, perturb_shape, y_gt, optimizer, batched_data.x.device, self.loss_fn,
|
||||
m=self.flag_m, step_size=self.flag_step_size, mag=self.flag_mag)
|
||||
self.lr_schedulers().step()
|
||||
else:
|
||||
y_hat = self(batched_data).view(-1)
|
||||
y_gt = batched_data.y.view(-1)
|
||||
loss = self.loss_fn(y_hat, y_gt)
|
||||
self.log('train_loss', loss, sync_dist=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batched_data, batch_idx):
|
||||
if self.dataset_name in ['PCQM4M-LSC', 'ZINC']:
|
||||
y_pred = self(batched_data).view(-1)
|
||||
y_true = batched_data.y.view(-1)
|
||||
else:
|
||||
y_pred = self(batched_data)
|
||||
y_true = batched_data.y
|
||||
return {
|
||||
'y_pred': y_pred,
|
||||
'y_true': y_true,
|
||||
}
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
y_pred = torch.cat([i['y_pred'] for i in outputs])
|
||||
y_true = torch.cat([i['y_true'] for i in outputs])
|
||||
if self.dataset_name == 'ogbg-molpcba':
|
||||
mask = ~torch.isnan(y_true)
|
||||
loss = self.loss_fn(y_pred[mask], y_true[mask])
|
||||
self.log('valid_ap', loss, sync_dist=True)
|
||||
else:
|
||||
input_dict = {"y_true": y_true, "y_pred": y_pred}
|
||||
try:
|
||||
self.log('valid_' + self.metric, self.evaluator.eval(input_dict)
|
||||
[self.metric], sync_dist=True)
|
||||
except:
|
||||
pass
|
||||
|
||||
def test_step(self, batched_data, batch_idx):
|
||||
if self.dataset_name in ['PCQM4M-LSC', 'ZINC']:
|
||||
y_pred = self(batched_data).view(-1)
|
||||
y_true = batched_data.y.view(-1)
|
||||
else:
|
||||
y_pred = self(batched_data)
|
||||
y_true = batched_data.y
|
||||
return {
|
||||
'y_pred': y_pred,
|
||||
'y_true': y_true,
|
||||
'idx': batched_data.idx,
|
||||
}
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
y_pred = torch.cat([i['y_pred'] for i in outputs])
|
||||
y_true = torch.cat([i['y_true'] for i in outputs])
|
||||
if self.dataset_name == 'PCQM4M-LSC':
|
||||
result = y_pred.cpu().float().numpy()
|
||||
idx = torch.cat([i['idx'] for i in outputs])
|
||||
torch.save(result, 'y_pred.pt')
|
||||
torch.save(idx, 'idx.pt')
|
||||
exit(0)
|
||||
input_dict = {"y_true": y_true, "y_pred": y_pred}
|
||||
self.log('test_' + self.metric, self.evaluator.eval(input_dict)
|
||||
[self.metric], sync_dist=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.parameters(), lr=self.peak_lr, weight_decay=self.weight_decay)
|
||||
lr_scheduler = {
|
||||
'scheduler': PolynomialDecayLR(
|
||||
optimizer,
|
||||
warmup_updates=self.warmup_updates,
|
||||
tot_updates=self.tot_updates,
|
||||
lr=self.peak_lr,
|
||||
end_lr=self.end_lr,
|
||||
power=1.0,
|
||||
),
|
||||
'name': 'learning_rate',
|
||||
'interval': 'step',
|
||||
'frequency': 1,
|
||||
}
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parent_parser):
|
||||
parser = parent_parser.add_argument_group("Graphormer")
|
||||
parser.add_argument('--n_layers', type=int, default=12)
|
||||
parser.add_argument('--num_heads', type=int, default=32)
|
||||
parser.add_argument('--hidden_dim', type=int, default=512)
|
||||
parser.add_argument('--ffn_dim', type=int, default=512)
|
||||
parser.add_argument('--intput_dropout_rate', type=float, default=0.1)
|
||||
parser.add_argument('--dropout_rate', type=float, default=0.1)
|
||||
parser.add_argument('--weight_decay', type=float, default=0.01)
|
||||
parser.add_argument('--attention_dropout_rate',
|
||||
type=float, default=0.1)
|
||||
parser.add_argument('--checkpoint_path', type=str, default='')
|
||||
parser.add_argument('--warmup_updates', type=int, default=60000)
|
||||
parser.add_argument('--tot_updates', type=int, default=1000000)
|
||||
parser.add_argument('--peak_lr', type=float, default=2e-4)
|
||||
parser.add_argument('--end_lr', type=float, default=1e-9)
|
||||
parser.add_argument('--edge_type', type=str, default='multi_hop')
|
||||
parser.add_argument('--validate', action='store_true', default=False)
|
||||
parser.add_argument('--test', action='store_true', default=False)
|
||||
parser.add_argument('--flag', action='store_true')
|
||||
parser.add_argument('--flag_m', type=int, default=3)
|
||||
parser.add_argument('--flag_step_size', type=float, default=1e-3)
|
||||
parser.add_argument('--flag_mag', type=float, default=1e-3)
|
||||
return parent_parser
|
||||
|
||||
|
||||
class FeedForwardNetwork(nn.Module):
|
||||
def __init__(self, hidden_size, ffn_size, dropout_rate):
|
||||
super(FeedForwardNetwork, self).__init__()
|
||||
|
||||
self.layer1 = nn.Linear(hidden_size, ffn_size)
|
||||
self.gelu = nn.GELU()
|
||||
self.layer2 = nn.Linear(ffn_size, hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layer1(x)
|
||||
x = self.gelu(x)
|
||||
x = self.layer2(x)
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, hidden_size, attention_dropout_rate, num_heads):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.att_size = att_size = hidden_size // num_heads
|
||||
self.scale = att_size ** -0.5
|
||||
|
||||
self.linear_q = nn.Linear(hidden_size, num_heads * att_size)
|
||||
self.linear_k = nn.Linear(hidden_size, num_heads * att_size)
|
||||
self.linear_v = nn.Linear(hidden_size, num_heads * att_size)
|
||||
self.att_dropout = nn.Dropout(attention_dropout_rate)
|
||||
|
||||
self.output_layer = nn.Linear(num_heads * att_size, hidden_size)
|
||||
|
||||
def forward(self, q, k, v, attn_bias=None):
|
||||
orig_q_size = q.size()
|
||||
|
||||
d_k = self.att_size
|
||||
d_v = self.att_size
|
||||
batch_size = q.size(0)
|
||||
|
||||
# head_i = Attention(Q(W^Q)_i, K(W^K)_i, V(W^V)_i)
|
||||
q = self.linear_q(q).view(batch_size, -1, self.num_heads, d_k)
|
||||
k = self.linear_k(k).view(batch_size, -1, self.num_heads, d_k)
|
||||
v = self.linear_v(v).view(batch_size, -1, self.num_heads, d_v)
|
||||
|
||||
q = q.transpose(1, 2) # [b, h, q_len, d_k]
|
||||
v = v.transpose(1, 2) # [b, h, v_len, d_v]
|
||||
k = k.transpose(1, 2).transpose(2, 3) # [b, h, d_k, k_len]
|
||||
|
||||
# Scaled Dot-Product Attention.
|
||||
# Attention(Q, K, V) = softmax((QK^T)/sqrt(d_k))V
|
||||
q = q * self.scale
|
||||
x = torch.matmul(q, k) # [b, h, q_len, k_len]
|
||||
if attn_bias is not None:
|
||||
x = x + attn_bias
|
||||
|
||||
x = torch.softmax(x, dim=3)
|
||||
x = self.att_dropout(x)
|
||||
x = x.matmul(v) # [b, h, q_len, attn]
|
||||
|
||||
x = x.transpose(1, 2).contiguous() # [b, q_len, h, attn]
|
||||
x = x.view(batch_size, -1, self.num_heads * d_v)
|
||||
|
||||
x = self.output_layer(x)
|
||||
|
||||
assert x.size() == orig_q_size
|
||||
return x
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, hidden_size, ffn_size, dropout_rate, attention_dropout_rate, num_heads):
|
||||
super(EncoderLayer, self).__init__()
|
||||
|
||||
self.self_attention_norm = nn.LayerNorm(hidden_size)
|
||||
self.self_attention = MultiHeadAttention(
|
||||
hidden_size, attention_dropout_rate, num_heads)
|
||||
self.self_attention_dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
self.ffn_norm = nn.LayerNorm(hidden_size)
|
||||
self.ffn = FeedForwardNetwork(hidden_size, ffn_size, dropout_rate)
|
||||
self.ffn_dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x, attn_bias=None):
|
||||
y = self.self_attention_norm(x)
|
||||
y = self.self_attention(y, y, y, attn_bias)
|
||||
y = self.self_attention_dropout(y)
|
||||
x = x + y
|
||||
|
||||
y = self.ffn_norm(x)
|
||||
y = self.ffn(y)
|
||||
y = self.ffn_dropout(y)
|
||||
x = x + y
|
||||
return x
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .graphormer import GraphormerModel
|
|
@ -0,0 +1,346 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq import utils
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqEncoderModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.modules import (
|
||||
LayerNorm,
|
||||
)
|
||||
from fairseq.utils import safe_hasattr
|
||||
|
||||
from ..modules import init_graphormer_params, GraphormerGraphEncoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ..pretrain import load_pretrained_model
|
||||
|
||||
|
||||
@register_model("graphormer")
|
||||
class GraphormerModel(FairseqEncoderModel):
|
||||
def __init__(self, args, encoder):
|
||||
super().__init__(encoder)
|
||||
self.args = args
|
||||
|
||||
if getattr(args, "apply_graphormer_init", False):
|
||||
self.apply(init_graphormer_params)
|
||||
self.encoder_embed_dim = args.encoder_embed_dim
|
||||
if args.pretrained_model_name != "none":
|
||||
self.load_state_dict(load_pretrained_model(args.pretrained_model_name))
|
||||
if not args.load_pretrained_model_output_layer:
|
||||
self.encoder.reset_output_layer_parameters()
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add model-specific arguments to the parser."""
|
||||
# Arguments related to dropout
|
||||
parser.add_argument(
|
||||
"--dropout", type=float, metavar="D", help="dropout probability"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-dropout",
|
||||
type=float,
|
||||
metavar="D",
|
||||
help="dropout probability for" " attention weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--act-dropout",
|
||||
type=float,
|
||||
metavar="D",
|
||||
help="dropout probability after" " activation in FFN",
|
||||
)
|
||||
|
||||
# Arguments related to hidden states and self-attention
|
||||
parser.add_argument(
|
||||
"--encoder-ffn-embed-dim",
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="encoder embedding dimension for FFN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder-attention-heads",
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="num encoder attention heads",
|
||||
)
|
||||
|
||||
# Arguments related to input and output embeddings
|
||||
parser.add_argument(
|
||||
"--encoder-embed-dim",
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="encoder embedding dimension",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--share-encoder-input-output-embed",
|
||||
action="store_true",
|
||||
help="share encoder input" " and output embeddings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder-learned-pos",
|
||||
action="store_true",
|
||||
help="use learned positional embeddings in the encoder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-token-positional-embeddings",
|
||||
action="store_true",
|
||||
help="if set, disables positional embeddings" " (outside self attention)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-positions", type=int, help="number of positional embeddings to learn"
|
||||
)
|
||||
|
||||
# Arguments related to parameter initialization
|
||||
parser.add_argument(
|
||||
"--apply-graphormer-init",
|
||||
action="store_true",
|
||||
help="use custom param initialization for Graphormer",
|
||||
)
|
||||
|
||||
# misc params
|
||||
parser.add_argument(
|
||||
"--activation-fn",
|
||||
choices=utils.get_available_activation_fns(),
|
||||
help="activation function to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder-normalize-before",
|
||||
action="store_true",
|
||||
help="apply layernorm before each encoder block",
|
||||
)
|
||||
|
||||
def max_nodes(self):
|
||||
return self.encoder.max_nodes
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
# make sure all arguments are present in older models
|
||||
base_architecture(args)
|
||||
|
||||
if not safe_hasattr(args, "max_nodes"):
|
||||
args.max_nodes = args.tokens_per_sample
|
||||
|
||||
logger.info(args)
|
||||
|
||||
encoder = GraphormerEncoder(args)
|
||||
return cls(args, encoder)
|
||||
|
||||
def forward(self, batched_data, **kwargs):
|
||||
return self.encoder(batched_data, **kwargs)
|
||||
|
||||
|
||||
class GraphormerEncoder(FairseqEncoder):
|
||||
def __init__(self, args):
|
||||
super().__init__(dictionary=None)
|
||||
self.max_nodes = args.max_nodes
|
||||
|
||||
self.graph_encoder = GraphormerGraphEncoder(
|
||||
# < for graphormer
|
||||
num_atoms=args.num_atoms,
|
||||
num_in_degree=args.num_in_degree,
|
||||
num_out_degree=args.num_out_degree,
|
||||
num_edges=args.num_edges,
|
||||
num_spatial=args.num_spatial,
|
||||
num_edge_dis=args.num_edge_dis,
|
||||
edge_type=args.edge_type,
|
||||
multi_hop_max_dist=args.multi_hop_max_dist,
|
||||
# >
|
||||
num_encoder_layers=args.encoder_layers,
|
||||
embedding_dim=args.encoder_embed_dim,
|
||||
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||
num_attention_heads=args.encoder_attention_heads,
|
||||
dropout=args.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.act_dropout,
|
||||
encoder_normalize_before=args.encoder_normalize_before,
|
||||
apply_graphormer_init=args.apply_graphormer_init,
|
||||
activation_fn=args.activation_fn,
|
||||
)
|
||||
|
||||
self.share_input_output_embed = args.share_encoder_input_output_embed
|
||||
self.embed_out = None
|
||||
self.lm_output_learned_bias = None
|
||||
|
||||
# Remove head is set to true during fine-tuning
|
||||
self.load_softmax = not getattr(args, "remove_head", False)
|
||||
|
||||
self.masked_lm_pooler = nn.Linear(
|
||||
args.encoder_embed_dim, args.encoder_embed_dim
|
||||
)
|
||||
|
||||
self.lm_head_transform_weight = nn.Linear(
|
||||
args.encoder_embed_dim, args.encoder_embed_dim
|
||||
)
|
||||
self.activation_fn = utils.get_activation_fn(args.activation_fn)
|
||||
self.layer_norm = LayerNorm(args.encoder_embed_dim)
|
||||
|
||||
self.lm_output_learned_bias = None
|
||||
if self.load_softmax:
|
||||
self.lm_output_learned_bias = nn.Parameter(torch.zeros(1))
|
||||
|
||||
if not self.share_input_output_embed:
|
||||
self.embed_out = nn.Linear(
|
||||
args.encoder_embed_dim, args.num_classes, bias=False
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_output_layer_parameters(self):
|
||||
self.lm_output_learned_bias = nn.Parameter(torch.zeros(1))
|
||||
if self.embed_out is not None:
|
||||
self.embed_out.reset_parameters()
|
||||
|
||||
def forward(self, batched_data, perturb=None, masked_tokens=None, **unused):
|
||||
inner_states, graph_rep = self.graph_encoder(
|
||||
batched_data,
|
||||
perturb=perturb,
|
||||
)
|
||||
|
||||
x = inner_states[-1].transpose(0, 1)
|
||||
|
||||
# project masked tokens only
|
||||
if masked_tokens is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x)))
|
||||
|
||||
# project back to size of vocabulary
|
||||
if self.share_input_output_embed and hasattr(
|
||||
self.graph_encoder.embed_tokens, "weight"
|
||||
):
|
||||
x = F.linear(x, self.graph_encoder.embed_tokens.weight)
|
||||
elif self.embed_out is not None:
|
||||
x = self.embed_out(x)
|
||||
if self.lm_output_learned_bias is not None:
|
||||
x = x + self.lm_output_learned_bias
|
||||
|
||||
return x
|
||||
|
||||
def max_nodes(self):
|
||||
"""Maximum output length supported by the encoder."""
|
||||
return self.max_nodes
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
if not self.load_softmax:
|
||||
for k in list(state_dict.keys()):
|
||||
if "embed_out.weight" in k or "lm_output_learned_bias" in k:
|
||||
del state_dict[k]
|
||||
return state_dict
|
||||
|
||||
|
||||
@register_model_architecture("graphormer", "graphormer")
|
||||
def base_architecture(args):
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.act_dropout = getattr(args, "act_dropout", 0.0)
|
||||
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
||||
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
||||
args.share_encoder_input_output_embed = getattr(
|
||||
args, "share_encoder_input_output_embed", False
|
||||
)
|
||||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
|
||||
args.apply_graphormer_init = getattr(args, "apply_graphormer_init", False)
|
||||
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
||||
|
||||
|
||||
@register_model_architecture("graphormer", "graphormer_base")
|
||||
def graphormer_base_architecture(args):
|
||||
if args.pretrained_model_name == "pcqm4mv1_graphormer_base":
|
||||
args.encoder_layers = 12
|
||||
args.encoder_attention_heads = 32
|
||||
args.encoder_ffn_embed_dim = 768
|
||||
args.encoder_embed_dim = 768
|
||||
args.dropout = getattr(args, "dropout", 0.0)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.act_dropout = getattr(args, "act_dropout", 0.1)
|
||||
elif args.pretrained_model_name == "pcqm4mv2_graphormer_base":
|
||||
args.dropout = getattr(args, "dropout", 0.0)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.act_dropout = getattr(args, "act_dropout", 0.1)
|
||||
else:
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
||||
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
||||
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 32)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768)
|
||||
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
||||
args.apply_graphormer_init = getattr(args, "apply_graphormer_init", True)
|
||||
args.share_encoder_input_output_embed = getattr(
|
||||
args, "share_encoder_input_output_embed", False
|
||||
)
|
||||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
base_architecture(args)
|
||||
|
||||
|
||||
@register_model_architecture("graphormer", "graphormer_slim")
|
||||
def graphormer_slim_architecture(args):
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 80)
|
||||
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
||||
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 80)
|
||||
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
||||
args.apply_graphormer_init = getattr(args, "apply_graphormer_init", True)
|
||||
args.share_encoder_input_output_embed = getattr(
|
||||
args, "share_encoder_input_output_embed", False
|
||||
)
|
||||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
base_architecture(args)
|
||||
|
||||
|
||||
@register_model_architecture("graphormer", "graphormer_large")
|
||||
def graphormer_large_architecture(args):
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
||||
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
||||
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 42)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
|
||||
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
||||
args.apply_graphormer_init = getattr(args, "apply_graphormer_init", True)
|
||||
args.share_encoder_input_output_embed = getattr(
|
||||
args, "share_encoder_input_output_embed", False
|
||||
)
|
||||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
base_architecture(args)
|
|
@ -0,0 +1,403 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from fairseq.models import (
|
||||
BaseFairseqModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def softmax_dropout(input, dropout_prob: float, is_training: bool):
|
||||
return F.dropout(F.softmax(input, -1), dropout_prob, is_training)
|
||||
|
||||
|
||||
class SelfMultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
scaling_factor=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
self.scaling = (self.head_dim * scaling_factor) ** -0.5
|
||||
|
||||
self.in_proj: Callable[[Tensor], Tensor] = nn.Linear(
|
||||
embed_dim, embed_dim * 3, bias=bias
|
||||
)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
attn_bias: Tensor = None,
|
||||
) -> Tensor:
|
||||
n_node, n_graph, embed_dim = query.size()
|
||||
q, k, v = self.in_proj(query).chunk(3, dim=-1)
|
||||
|
||||
_shape = (-1, n_graph * self.num_heads, self.head_dim)
|
||||
q = q.contiguous().view(_shape).transpose(0, 1) * self.scaling
|
||||
k = k.contiguous().view(_shape).transpose(0, 1)
|
||||
v = v.contiguous().view(_shape).transpose(0, 1)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_bias
|
||||
attn_probs = softmax_dropout(attn_weights, self.dropout, self.training)
|
||||
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
attn = attn.transpose(0, 1).contiguous().view(n_node, n_graph, embed_dim)
|
||||
attn = self.out_proj(attn)
|
||||
return attn
|
||||
|
||||
|
||||
class Graphormer3DEncoderLayer(nn.Module):
|
||||
"""
|
||||
Implements a Graphormer-3D Encoder Layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 768,
|
||||
ffn_embedding_dim: int = 3072,
|
||||
num_attention_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Initialize parameters
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
self.dropout = dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
|
||||
self.self_attn = SelfMultiheadAttention(
|
||||
self.embedding_dim,
|
||||
num_attention_heads,
|
||||
dropout=attention_dropout,
|
||||
)
|
||||
# layer norm associated with the self attention layer
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
|
||||
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_bias: Tensor = None,
|
||||
):
|
||||
residual = x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x = self.self_attn(
|
||||
query=x,
|
||||
attn_bias=attn_bias,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
x = F.gelu(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
class RBF(nn.Module):
|
||||
def __init__(self, K, edge_types):
|
||||
super().__init__()
|
||||
self.K = K
|
||||
self.means = nn.parameter.Parameter(torch.empty(K))
|
||||
self.temps = nn.parameter.Parameter(torch.empty(K))
|
||||
self.mul: Callable[..., Tensor] = nn.Embedding(edge_types, 1)
|
||||
self.bias: Callable[..., Tensor] = nn.Embedding(edge_types, 1)
|
||||
nn.init.uniform_(self.means, 0, 3)
|
||||
nn.init.uniform_(self.temps, 0.1, 10)
|
||||
nn.init.constant_(self.bias.weight, 0)
|
||||
nn.init.constant_(self.mul.weight, 1)
|
||||
|
||||
def forward(self, x: Tensor, edge_types):
|
||||
mul = self.mul(edge_types)
|
||||
bias = self.bias(edge_types)
|
||||
x = mul * x.unsqueeze(-1) + bias
|
||||
mean = self.means.float()
|
||||
temp = self.temps.float().abs()
|
||||
return ((x - mean).square() * (-temp)).exp().type_as(self.means)
|
||||
|
||||
|
||||
class NonLinear(nn.Module):
|
||||
def __init__(self, input, output_size, hidden=None):
|
||||
super(NonLinear, self).__init__()
|
||||
if hidden is None:
|
||||
hidden = input
|
||||
self.layer1 = nn.Linear(input, hidden)
|
||||
self.layer2 = nn.Linear(hidden, output_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.gelu(self.layer1(x))
|
||||
x = self.layer2(x)
|
||||
return x
|
||||
|
||||
|
||||
class NodeTaskHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.q_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
|
||||
self.k_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
|
||||
self.num_heads = num_heads
|
||||
self.scaling = (embed_dim // num_heads) ** -0.5
|
||||
self.force_proj1: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
|
||||
self.force_proj2: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
|
||||
self.force_proj3: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
attn_bias: Tensor,
|
||||
delta_pos: Tensor,
|
||||
) -> Tensor:
|
||||
bsz, n_node, _ = query.size()
|
||||
q = (
|
||||
self.q_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2)
|
||||
* self.scaling
|
||||
)
|
||||
k = self.k_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2)
|
||||
v = self.v_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2)
|
||||
attn = q @ k.transpose(-1, -2) # [bsz, head, n, n]
|
||||
attn_probs = softmax_dropout(
|
||||
attn.view(-1, n_node, n_node) + attn_bias, 0.1, self.training
|
||||
).view(bsz, self.num_heads, n_node, n_node)
|
||||
rot_attn_probs = attn_probs.unsqueeze(-1) * delta_pos.unsqueeze(1).type_as(
|
||||
attn_probs
|
||||
) # [bsz, head, n, n, 3]
|
||||
rot_attn_probs = rot_attn_probs.permute(0, 1, 4, 2, 3)
|
||||
x = rot_attn_probs @ v.unsqueeze(2) # [bsz, head , 3, n, d]
|
||||
x = x.permute(0, 3, 2, 1, 4).contiguous().view(bsz, n_node, 3, -1)
|
||||
f1 = self.force_proj1(x[:, :, 0, :]).view(bsz, n_node, 1)
|
||||
f2 = self.force_proj2(x[:, :, 1, :]).view(bsz, n_node, 1)
|
||||
f3 = self.force_proj3(x[:, :, 2, :]).view(bsz, n_node, 1)
|
||||
cur_force = torch.cat([f1, f2, f3], dim=-1).float()
|
||||
return cur_force
|
||||
|
||||
|
||||
@register_model("graphormer3d")
|
||||
class Graphormer3D(BaseFairseqModel):
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
"""Add model-specific arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"--layers", type=int, metavar="L", help="num encoder layers"
|
||||
)
|
||||
parser.add_argument("--blocks", type=int, metavar="L", help="num blocks")
|
||||
parser.add_argument(
|
||||
"--embed-dim",
|
||||
type=int,
|
||||
metavar="H",
|
||||
help="encoder embedding dimension",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ffn-embed-dim",
|
||||
type=int,
|
||||
metavar="F",
|
||||
help="encoder embedding dimension for FFN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-heads",
|
||||
type=int,
|
||||
metavar="A",
|
||||
help="num encoder attention heads",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dropout", type=float, metavar="D", help="dropout probability"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-dropout",
|
||||
type=float,
|
||||
metavar="D",
|
||||
help="dropout probability for attention weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--activation-dropout",
|
||||
type=float,
|
||||
metavar="D",
|
||||
help="dropout probability after activation in FFN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--node-loss-weight",
|
||||
type=float,
|
||||
metavar="D",
|
||||
help="loss weight for node fitting",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-node-loss-weight",
|
||||
type=float,
|
||||
metavar="D",
|
||||
help="loss weight for node fitting",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-kernel",
|
||||
type=int,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
base_architecture(args)
|
||||
return cls(args)
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.atom_types = 64
|
||||
self.edge_types = 64 * 64
|
||||
self.atom_encoder = nn.Embedding(
|
||||
self.atom_types, self.args.embed_dim, padding_idx=0
|
||||
)
|
||||
self.tag_encoder = nn.Embedding(3, self.args.embed_dim)
|
||||
self.input_dropout = self.args.input_dropout
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Graphormer3DEncoderLayer(
|
||||
self.args.embed_dim,
|
||||
self.args.ffn_embed_dim,
|
||||
num_attention_heads=self.args.attention_heads,
|
||||
dropout=self.args.dropout,
|
||||
attention_dropout=self.args.attention_dropout,
|
||||
activation_dropout=self.args.activation_dropout,
|
||||
)
|
||||
for _ in range(self.args.layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_ln: Callable[[Tensor], Tensor] = nn.LayerNorm(self.args.embed_dim)
|
||||
|
||||
self.engergy_proj: Callable[[Tensor], Tensor] = NonLinear(
|
||||
self.args.embed_dim, 1
|
||||
)
|
||||
self.energe_agg_factor: Callable[[Tensor], Tensor] = nn.Embedding(3, 1)
|
||||
nn.init.normal_(self.energe_agg_factor.weight, 0, 0.01)
|
||||
|
||||
K = self.args.num_kernel
|
||||
|
||||
self.rbf: Callable[[Tensor, Tensor], Tensor] = RBF(K, self.edge_types)
|
||||
self.bias_proj: Callable[[Tensor], Tensor] = NonLinear(
|
||||
K, self.args.attention_heads
|
||||
)
|
||||
self.edge_proj: Callable[[Tensor], Tensor] = nn.Linear(K, self.args.embed_dim)
|
||||
self.node_proc: Callable[[Tensor, Tensor, Tensor], Tensor] = NodeTaskHead(
|
||||
self.args.embed_dim, self.args.attention_heads
|
||||
)
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
self.num_updates = num_updates
|
||||
return super().set_num_updates(num_updates)
|
||||
|
||||
def forward(self, atoms: Tensor, tags: Tensor, pos: Tensor, real_mask: Tensor):
|
||||
padding_mask = atoms.eq(0)
|
||||
|
||||
n_graph, n_node = atoms.size()
|
||||
delta_pos = pos.unsqueeze(1) - pos.unsqueeze(2)
|
||||
dist: Tensor = delta_pos.norm(dim=-1)
|
||||
delta_pos /= dist.unsqueeze(-1) + 1e-5
|
||||
|
||||
edge_type = atoms.view(n_graph, n_node, 1) * self.atom_types + atoms.view(
|
||||
n_graph, 1, n_node
|
||||
)
|
||||
|
||||
rbf_feature = self.rbf(dist, edge_type)
|
||||
edge_features = rbf_feature.masked_fill(
|
||||
padding_mask.unsqueeze(1).unsqueeze(-1), 0.0
|
||||
)
|
||||
|
||||
graph_node_feature = (
|
||||
self.tag_encoder(tags)
|
||||
+ self.atom_encoder(atoms)
|
||||
+ self.edge_proj(edge_features.sum(dim=-2))
|
||||
)
|
||||
|
||||
# ===== MAIN MODEL =====
|
||||
output = F.dropout(
|
||||
graph_node_feature, p=self.input_dropout, training=self.training
|
||||
)
|
||||
output = output.transpose(0, 1).contiguous()
|
||||
|
||||
graph_attn_bias = self.bias_proj(rbf_feature).permute(0, 3, 1, 2).contiguous()
|
||||
graph_attn_bias.masked_fill_(
|
||||
padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
|
||||
)
|
||||
|
||||
graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
|
||||
for _ in range(self.args.blocks):
|
||||
for enc_layer in self.layers:
|
||||
output = enc_layer(output, attn_bias=graph_attn_bias)
|
||||
|
||||
output = self.final_ln(output)
|
||||
output = output.transpose(0, 1)
|
||||
|
||||
eng_output = F.dropout(output, p=0.1, training=self.training)
|
||||
eng_output = (
|
||||
self.engergy_proj(eng_output) * self.energe_agg_factor(tags)
|
||||
).flatten(-2)
|
||||
output_mask = (
|
||||
tags > 0
|
||||
) & real_mask # no need to consider padding, since padding has tag 0, real_mask False
|
||||
|
||||
eng_output *= output_mask
|
||||
eng_output = eng_output.sum(dim=-1)
|
||||
|
||||
node_output = self.node_proc(output, graph_attn_bias, delta_pos)
|
||||
|
||||
node_target_mask = output_mask.unsqueeze(-1)
|
||||
return eng_output, node_output, node_target_mask
|
||||
|
||||
|
||||
@register_model_architecture("graphormer3d", "graphormer3d_base")
|
||||
def base_architecture(args):
|
||||
args.blocks = getattr(args, "blocks", 4)
|
||||
args.layers = getattr(args, "layers", 12)
|
||||
args.embed_dim = getattr(args, "embed_dim", 768)
|
||||
args.ffn_embed_dim = getattr(args, "ffn_embed_dim", 768)
|
||||
args.attention_heads = getattr(args, "attention_heads", 48)
|
||||
args.input_dropout = getattr(args, "input_dropout", 0.0)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
||||
args.node_loss_weight = getattr(args, "node_loss_weight", 15)
|
||||
args.min_node_loss_weight = getattr(args, "min_node_loss_weight", 1)
|
||||
args.eng_loss_weight = getattr(args, "eng_loss_weight", 1)
|
||||
args.num_kernel = getattr(args, "num_kernel", 128)
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .multihead_attention import MultiheadAttention
|
||||
from .graphormer_layers import GraphNodeFeature, GraphAttnBias
|
||||
from .graphormer_graph_encoder_layer import GraphormerGraphEncoderLayer
|
||||
from .graphormer_graph_encoder import GraphormerGraphEncoder, init_graphormer_params
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq.modules import FairseqDropout, LayerDropModuleList, LayerNorm
|
||||
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
|
||||
|
||||
from .multihead_attention import MultiheadAttention
|
||||
from .graphormer_layers import GraphNodeFeature, GraphAttnBias
|
||||
from .graphormer_graph_encoder_layer import GraphormerGraphEncoderLayer
|
||||
|
||||
|
||||
def init_graphormer_params(module):
|
||||
"""
|
||||
Initialize the weights specific to the Graphormer Model.
|
||||
"""
|
||||
|
||||
def normal_(data):
|
||||
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||
# so that the RNG is consistent with and without FSDP
|
||||
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
normal_(module.weight.data)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
normal_(module.weight.data)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if isinstance(module, MultiheadAttention):
|
||||
normal_(module.q_proj.weight.data)
|
||||
normal_(module.k_proj.weight.data)
|
||||
normal_(module.v_proj.weight.data)
|
||||
|
||||
|
||||
class GraphormerGraphEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_atoms: int,
|
||||
num_in_degree: int,
|
||||
num_out_degree: int,
|
||||
num_edges: int,
|
||||
num_spatial: int,
|
||||
num_edge_dis: int,
|
||||
edge_type: str,
|
||||
multi_hop_max_dist: int,
|
||||
num_encoder_layers: int = 12,
|
||||
embedding_dim: int = 768,
|
||||
ffn_embedding_dim: int = 768,
|
||||
num_attention_heads: int = 32,
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.1,
|
||||
layerdrop: float = 0.0,
|
||||
encoder_normalize_before: bool = False,
|
||||
apply_graphormer_init: bool = False,
|
||||
activation_fn: str = "gelu",
|
||||
embed_scale: float = None,
|
||||
freeze_embeddings: bool = False,
|
||||
n_trans_layers_to_freeze: int = 0,
|
||||
export: bool = False,
|
||||
traceable: bool = False,
|
||||
q_noise: float = 0.0,
|
||||
qn_block_size: int = 8,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.dropout_module = FairseqDropout(
|
||||
dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
self.layerdrop = layerdrop
|
||||
self.embedding_dim = embedding_dim
|
||||
self.apply_graphormer_init = apply_graphormer_init
|
||||
self.traceable = traceable
|
||||
|
||||
self.graph_node_feature = GraphNodeFeature(
|
||||
num_heads=num_attention_heads,
|
||||
num_atoms=num_atoms,
|
||||
num_in_degree=num_in_degree,
|
||||
num_out_degree=num_out_degree,
|
||||
hidden_dim=embedding_dim,
|
||||
n_layers=num_encoder_layers,
|
||||
)
|
||||
|
||||
self.graph_attn_bias = GraphAttnBias(
|
||||
num_heads=num_attention_heads,
|
||||
num_atoms=num_atoms,
|
||||
num_edges=num_edges,
|
||||
num_spatial=num_spatial,
|
||||
num_edge_dis=num_edge_dis,
|
||||
edge_type=edge_type,
|
||||
multi_hop_max_dist=multi_hop_max_dist,
|
||||
hidden_dim=embedding_dim,
|
||||
n_layers=num_encoder_layers,
|
||||
)
|
||||
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
if q_noise > 0:
|
||||
self.quant_noise = apply_quant_noise_(
|
||||
nn.Linear(self.embedding_dim, self.embedding_dim, bias=False),
|
||||
q_noise,
|
||||
qn_block_size,
|
||||
)
|
||||
else:
|
||||
self.quant_noise = None
|
||||
|
||||
if encoder_normalize_before:
|
||||
self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export)
|
||||
else:
|
||||
self.emb_layer_norm = None
|
||||
|
||||
if self.layerdrop > 0.0:
|
||||
self.layers = LayerDropModuleList(p=self.layerdrop)
|
||||
else:
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers.extend(
|
||||
[
|
||||
self.build_graphormer_graph_encoder_layer(
|
||||
embedding_dim=self.embedding_dim,
|
||||
ffn_embedding_dim=ffn_embedding_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
dropout=self.dropout_module.p,
|
||||
attention_dropout=attention_dropout,
|
||||
activation_dropout=activation_dropout,
|
||||
activation_fn=activation_fn,
|
||||
export=export,
|
||||
q_noise=q_noise,
|
||||
qn_block_size=qn_block_size,
|
||||
)
|
||||
for _ in range(num_encoder_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Apply initialization of model params after building the model
|
||||
if self.apply_graphormer_init:
|
||||
self.apply(init_graphormer_params)
|
||||
|
||||
def freeze_module_params(m):
|
||||
if m is not None:
|
||||
for p in m.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
if freeze_embeddings:
|
||||
raise NotImplementedError("Freezing embeddings is not implemented yet.")
|
||||
|
||||
for layer in range(n_trans_layers_to_freeze):
|
||||
freeze_module_params(self.layers[layer])
|
||||
|
||||
def build_graphormer_graph_encoder_layer(
|
||||
self,
|
||||
embedding_dim,
|
||||
ffn_embedding_dim,
|
||||
num_attention_heads,
|
||||
dropout,
|
||||
attention_dropout,
|
||||
activation_dropout,
|
||||
activation_fn,
|
||||
export,
|
||||
q_noise,
|
||||
qn_block_size,
|
||||
):
|
||||
return GraphormerGraphEncoderLayer(
|
||||
embedding_dim=embedding_dim,
|
||||
ffn_embedding_dim=ffn_embedding_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
dropout=dropout,
|
||||
attention_dropout=attention_dropout,
|
||||
activation_dropout=activation_dropout,
|
||||
activation_fn=activation_fn,
|
||||
export=export,
|
||||
q_noise=q_noise,
|
||||
qn_block_size=qn_block_size,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batched_data,
|
||||
perturb=None,
|
||||
last_state_only: bool = False,
|
||||
token_embeddings: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
is_tpu = False
|
||||
# compute padding mask. This is needed for multi-head attention
|
||||
data_x = batched_data["x"]
|
||||
n_graph, n_node = data_x.size()[:2]
|
||||
padding_mask = (data_x[:, :, 0]).eq(0) # B x T x 1
|
||||
padding_mask_cls = torch.zeros(
|
||||
n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype
|
||||
)
|
||||
padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1)
|
||||
# B x (T+1) x 1
|
||||
|
||||
if token_embeddings is not None:
|
||||
x = token_embeddings
|
||||
else:
|
||||
x = self.graph_node_feature(batched_data)
|
||||
|
||||
if perturb is not None:
|
||||
#ic(torch.mean(torch.abs(x[:, 1, :])))
|
||||
#ic(torch.mean(torch.abs(perturb)))
|
||||
x[:, 1:, :] += perturb
|
||||
|
||||
# x: B x T x C
|
||||
|
||||
attn_bias = self.graph_attn_bias(batched_data)
|
||||
|
||||
if self.embed_scale is not None:
|
||||
x = x * self.embed_scale
|
||||
|
||||
if self.quant_noise is not None:
|
||||
x = self.quant_noise(x)
|
||||
|
||||
if self.emb_layer_norm is not None:
|
||||
x = self.emb_layer_norm(x)
|
||||
|
||||
x = self.dropout_module(x)
|
||||
|
||||
# account for padding while computing the representation
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
inner_states = []
|
||||
if not last_state_only:
|
||||
inner_states.append(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x, _ = layer(
|
||||
x,
|
||||
self_attn_padding_mask=padding_mask,
|
||||
self_attn_mask=attn_mask,
|
||||
self_attn_bias=attn_bias,
|
||||
)
|
||||
if not last_state_only:
|
||||
inner_states.append(x)
|
||||
|
||||
graph_rep = x[0, :, :]
|
||||
|
||||
if last_state_only:
|
||||
inner_states = [x]
|
||||
|
||||
if self.traceable:
|
||||
return torch.stack(inner_states), graph_rep
|
||||
else:
|
||||
return inner_states, graph_rep
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq import utils
|
||||
from fairseq.modules import LayerNorm
|
||||
from fairseq.modules.fairseq_dropout import FairseqDropout
|
||||
from fairseq.modules.quant_noise import quant_noise
|
||||
|
||||
from .multihead_attention import MultiheadAttention
|
||||
|
||||
|
||||
class GraphormerGraphEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 768,
|
||||
ffn_embedding_dim: int = 3072,
|
||||
num_attention_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.1,
|
||||
activation_fn: str = "relu",
|
||||
export: bool = False,
|
||||
q_noise: float = 0.0,
|
||||
qn_block_size: int = 8,
|
||||
init_fn: Callable = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if init_fn is not None:
|
||||
init_fn()
|
||||
|
||||
# Initialize parameters
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_dropout = attention_dropout
|
||||
self.q_noise = q_noise
|
||||
self.qn_block_size = qn_block_size
|
||||
|
||||
self.dropout_module = FairseqDropout(
|
||||
dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
self.activation_dropout_module = FairseqDropout(
|
||||
activation_dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
# Initialize blocks
|
||||
self.activation_fn = utils.get_activation_fn(activation_fn)
|
||||
self.self_attn = self.build_self_attention(
|
||||
self.embedding_dim,
|
||||
num_attention_heads,
|
||||
dropout=attention_dropout,
|
||||
self_attention=True,
|
||||
q_noise=q_noise,
|
||||
qn_block_size=qn_block_size,
|
||||
)
|
||||
|
||||
# layer norm associated with the self attention layer
|
||||
self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export)
|
||||
|
||||
self.fc1 = self.build_fc1(
|
||||
self.embedding_dim,
|
||||
ffn_embedding_dim,
|
||||
q_noise=q_noise,
|
||||
qn_block_size=qn_block_size,
|
||||
)
|
||||
self.fc2 = self.build_fc2(
|
||||
ffn_embedding_dim,
|
||||
self.embedding_dim,
|
||||
q_noise=q_noise,
|
||||
qn_block_size=qn_block_size,
|
||||
)
|
||||
|
||||
# layer norm associated with the position wise feed-forward NN
|
||||
self.final_layer_norm = LayerNorm(self.embedding_dim, export=export)
|
||||
|
||||
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
|
||||
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
|
||||
|
||||
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
|
||||
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
|
||||
|
||||
def build_self_attention(
|
||||
self,
|
||||
embed_dim,
|
||||
num_attention_heads,
|
||||
dropout,
|
||||
self_attention,
|
||||
q_noise,
|
||||
qn_block_size,
|
||||
):
|
||||
return MultiheadAttention(
|
||||
embed_dim,
|
||||
num_attention_heads,
|
||||
dropout=dropout,
|
||||
self_attention=True,
|
||||
q_noise=q_noise,
|
||||
qn_block_size=qn_block_size,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
self_attn_bias: Optional[torch.Tensor] = None,
|
||||
self_attn_mask: Optional[torch.Tensor] = None,
|
||||
self_attn_padding_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
LayerNorm is applied either before or after the self-attention/ffn
|
||||
modules similar to the original Transformer implementation.
|
||||
"""
|
||||
# x: T x B x C
|
||||
residual = x
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
attn_bias=self_attn_bias,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=False,
|
||||
attn_mask=self_attn_mask,
|
||||
)
|
||||
x = self.dropout_module(x)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.activation_dropout_module(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout_module(x)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
return x, attn
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def init_params(module, n_layers):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(n_layers))
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
|
||||
|
||||
class GraphNodeFeature(nn.Module):
|
||||
"""
|
||||
Compute node features for each node in the graph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_heads, num_atoms, num_in_degree, num_out_degree, hidden_dim, n_layers
|
||||
):
|
||||
super(GraphNodeFeature, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
self.num_atoms = num_atoms
|
||||
|
||||
# 1 for graph token
|
||||
self.atom_encoder = nn.Embedding(num_atoms + 1, hidden_dim, padding_idx=0)
|
||||
self.in_degree_encoder = nn.Embedding(num_in_degree, hidden_dim, padding_idx=0)
|
||||
self.out_degree_encoder = nn.Embedding(
|
||||
num_out_degree, hidden_dim, padding_idx=0
|
||||
)
|
||||
|
||||
self.graph_token = nn.Embedding(1, hidden_dim)
|
||||
|
||||
self.apply(lambda module: init_params(module, n_layers=n_layers))
|
||||
|
||||
def forward(self, batched_data):
|
||||
x, in_degree, out_degree = (
|
||||
batched_data["x"],
|
||||
batched_data["in_degree"],
|
||||
batched_data["out_degree"],
|
||||
)
|
||||
n_graph, n_node = x.size()[:2]
|
||||
|
||||
# node feauture + graph token
|
||||
node_feature = self.atom_encoder(x).sum(dim=-2) # [n_graph, n_node, n_hidden]
|
||||
|
||||
# if self.flag and perturb is not None:
|
||||
# node_feature += perturb
|
||||
|
||||
node_feature = (
|
||||
node_feature
|
||||
+ self.in_degree_encoder(in_degree)
|
||||
+ self.out_degree_encoder(out_degree)
|
||||
)
|
||||
|
||||
graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1)
|
||||
|
||||
graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1)
|
||||
|
||||
return graph_node_feature
|
||||
|
||||
|
||||
class GraphAttnBias(nn.Module):
|
||||
"""
|
||||
Compute attention bias for each head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_atoms,
|
||||
num_edges,
|
||||
num_spatial,
|
||||
num_edge_dis,
|
||||
hidden_dim,
|
||||
edge_type,
|
||||
multi_hop_max_dist,
|
||||
n_layers,
|
||||
):
|
||||
super(GraphAttnBias, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
self.multi_hop_max_dist = multi_hop_max_dist
|
||||
|
||||
self.edge_encoder = nn.Embedding(num_edges + 1, num_heads, padding_idx=0)
|
||||
self.edge_type = edge_type
|
||||
if self.edge_type == "multi_hop":
|
||||
self.edge_dis_encoder = nn.Embedding(
|
||||
num_edge_dis * num_heads * num_heads, 1
|
||||
)
|
||||
self.spatial_pos_encoder = nn.Embedding(num_spatial, num_heads, padding_idx=0)
|
||||
|
||||
self.graph_token_virtual_distance = nn.Embedding(1, num_heads)
|
||||
|
||||
self.apply(lambda module: init_params(module, n_layers=n_layers))
|
||||
|
||||
def forward(self, batched_data):
|
||||
attn_bias, spatial_pos, x = (
|
||||
batched_data["attn_bias"],
|
||||
batched_data["spatial_pos"],
|
||||
batched_data["x"],
|
||||
)
|
||||
# in_degree, out_degree = batched_data.in_degree, batched_data.in_degree
|
||||
edge_input, attn_edge_type = (
|
||||
batched_data["edge_input"],
|
||||
batched_data["attn_edge_type"],
|
||||
)
|
||||
|
||||
n_graph, n_node = x.size()[:2]
|
||||
graph_attn_bias = attn_bias.clone()
|
||||
graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(
|
||||
1, self.num_heads, 1, 1
|
||||
) # [n_graph, n_head, n_node+1, n_node+1]
|
||||
|
||||
# spatial pos
|
||||
# [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
|
||||
spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
|
||||
graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias
|
||||
|
||||
# reset spatial pos here
|
||||
t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1)
|
||||
graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
|
||||
graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t
|
||||
|
||||
# edge feature
|
||||
if self.edge_type == "multi_hop":
|
||||
spatial_pos_ = spatial_pos.clone()
|
||||
spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1
|
||||
# set 1 to 1, x > 1 to x - 1
|
||||
spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
|
||||
if self.multi_hop_max_dist > 0:
|
||||
spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist)
|
||||
edge_input = edge_input[:, :, :, : self.multi_hop_max_dist, :]
|
||||
# [n_graph, n_node, n_node, max_dist, n_head]
|
||||
edge_input = self.edge_encoder(edge_input).mean(-2)
|
||||
max_dist = edge_input.size(-2)
|
||||
edge_input_flat = edge_input.permute(3, 0, 1, 2, 4).reshape(
|
||||
max_dist, -1, self.num_heads
|
||||
)
|
||||
edge_input_flat = torch.bmm(
|
||||
edge_input_flat,
|
||||
self.edge_dis_encoder.weight.reshape(
|
||||
-1, self.num_heads, self.num_heads
|
||||
)[:max_dist, :, :],
|
||||
)
|
||||
edge_input = edge_input_flat.reshape(
|
||||
max_dist, n_graph, n_node, n_node, self.num_heads
|
||||
).permute(1, 2, 3, 0, 4)
|
||||
edge_input = (
|
||||
edge_input.sum(-2) / (spatial_pos_.float().unsqueeze(-1))
|
||||
).permute(0, 3, 1, 2)
|
||||
else:
|
||||
# [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
|
||||
edge_input = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2)
|
||||
|
||||
graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + edge_input
|
||||
graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset
|
||||
|
||||
return graph_attn_bias
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from fairseq import utils
|
||||
from fairseq.modules.fairseq_dropout import FairseqDropout
|
||||
from fairseq.modules.quant_noise import quant_noise
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""Multi-headed attention.
|
||||
|
||||
See "Attention Is All You Need" for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
self_attention=False,
|
||||
q_noise=0.0,
|
||||
qn_block_size=8,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout_module = FairseqDropout(
|
||||
dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.self_attention = self_attention
|
||||
|
||||
assert self.self_attention, "Only support self attention"
|
||||
|
||||
assert not self.self_attention or self.qkv_same_dim, (
|
||||
"Self-attention requires query, key and " "value to be of the same size"
|
||||
)
|
||||
|
||||
self.k_proj = quant_noise(
|
||||
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
self.v_proj = quant_noise(
|
||||
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
self.q_proj = quant_noise(
|
||||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
self.out_proj = quant_noise(
|
||||
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
self.onnx_trace = False
|
||||
|
||||
def prepare_for_onnx_export_(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.qkv_same_dim:
|
||||
# Empirically observed the convergence to be much better with
|
||||
# the scaled initialization
|
||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
if self.out_proj.bias is not None:
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key: Optional[Tensor],
|
||||
value: Optional[Tensor],
|
||||
attn_bias: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
before_softmax: bool = False,
|
||||
need_head_weights: bool = False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""Input shape: Time x Batch x Channel
|
||||
|
||||
Args:
|
||||
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||
keys that are pads, of shape `(batch, src_len)`, where
|
||||
padding elements are indicated by 1s.
|
||||
need_weights (bool, optional): return the attention weights,
|
||||
averaged over heads (default: False).
|
||||
attn_mask (ByteTensor, optional): typically used to
|
||||
implement causal attention, where the mask prevents the
|
||||
attention from looking forward in time (default: None).
|
||||
before_softmax (bool, optional): return the raw attention
|
||||
weights and values before the attention softmax.
|
||||
need_head_weights (bool, optional): return the attention
|
||||
weights for each head. Implies *need_weights*. Default:
|
||||
return the average attention weights over all heads.
|
||||
"""
|
||||
if need_head_weights:
|
||||
need_weights = True
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
src_len = tgt_len
|
||||
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
if key is not None:
|
||||
src_len, key_bsz, _ = key.size()
|
||||
if not torch.jit.is_scripting():
|
||||
assert key_bsz == bsz
|
||||
assert value is not None
|
||||
assert src_len, bsz == value.shape[:2]
|
||||
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(query)
|
||||
v = self.v_proj(query)
|
||||
q *= self.scaling
|
||||
|
||||
q = (
|
||||
q.contiguous()
|
||||
.view(tgt_len, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if k is not None:
|
||||
k = (
|
||||
k.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if v is not None:
|
||||
v = (
|
||||
v.contiguous()
|
||||
.view(-1, bsz * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
assert k is not None
|
||||
assert k.size(1) == src_len
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism
|
||||
# not supporting Optional types.
|
||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||
key_padding_mask = None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||
|
||||
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_bias is not None:
|
||||
attn_weights += attn_bias.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||
float("-inf"),
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if before_softmax:
|
||||
return attn_weights, v
|
||||
|
||||
attn_weights_float = utils.softmax(
|
||||
attn_weights, dim=-1, onnx_trace=self.onnx_trace
|
||||
)
|
||||
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
assert v is not None
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
|
||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
attn_weights: Optional[Tensor] = None
|
||||
if need_weights:
|
||||
attn_weights = attn_weights_float.view(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
).transpose(1, 0)
|
||||
if not need_head_weights:
|
||||
# average attention weights over heads
|
||||
attn_weights = attn_weights.mean(dim=0)
|
||||
|
||||
return attn, attn_weights
|
||||
|
||||
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||
return attn_weights
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
prefix = name + "." if name != "" else ""
|
||||
items_to_add = {}
|
||||
keys_to_remove = []
|
||||
for k in state_dict.keys():
|
||||
if k.endswith(prefix + "in_proj_weight"):
|
||||
# in_proj_weight used to be q + k + v with same dimensions
|
||||
dim = int(state_dict[k].shape[0] / 3)
|
||||
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
||||
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
||||
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
||||
|
||||
keys_to_remove.append(k)
|
||||
|
||||
k_bias = prefix + "in_proj_bias"
|
||||
if k_bias in state_dict.keys():
|
||||
dim = int(state_dict[k].shape[0] / 3)
|
||||
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
||||
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
||||
dim : 2 * dim
|
||||
]
|
||||
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
||||
|
||||
keys_to_remove.append(prefix + "in_proj_bias")
|
||||
|
||||
for k in keys_to_remove:
|
||||
del state_dict[k]
|
||||
|
||||
for key, value in items_to_add.items():
|
||||
state_dict[key] = value
|
|
@ -0,0 +1,17 @@
|
|||
from torch.hub import load_state_dict_from_url
|
||||
import torch.distributed as dist
|
||||
|
||||
PRETRAINED_MODEL_URLS = {
|
||||
"pcqm4mv1_graphormer_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv1/checkpoint_best_pcqm4mv1_full.pt",
|
||||
"pcqm4mv2_graphormer_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv2/checkpoint_best_pcqm4mv2_full.pt",
|
||||
}
|
||||
|
||||
def load_pretrained_model(pretrained_model_name):
|
||||
if pretrained_model_name not in PRETRAINED_MODEL_URLS:
|
||||
raise ValueError("Unknown pretrained model name %s", pretrained_model_name)
|
||||
if not dist.is_initialized():
|
||||
return load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True)["model"]
|
||||
else:
|
||||
pretrained_model = load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True, file_name=f"{pretrained_model_name}_{dist.get_rank()}")["model"]
|
||||
dist.barrier()
|
||||
return pretrained_model
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
|
@ -0,0 +1,313 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import contextlib
|
||||
from dataclasses import dataclass, field
|
||||
from omegaconf import II, open_dict, OmegaConf
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data import (
|
||||
NestedDictionaryDataset,
|
||||
NumSamplesDataset,
|
||||
)
|
||||
from fairseq.tasks import FairseqDataclass, FairseqTask, register_task
|
||||
|
||||
from graphormer.pretrain import load_pretrained_model
|
||||
|
||||
from ..data.dataset import BatchedDataDataset, TargetDataset, GraphormerDataset
|
||||
|
||||
import torch
|
||||
from fairseq.optim.amp_optimizer import AMPOptimizer
|
||||
import math
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphPredictionConfig(FairseqDataclass):
|
||||
dataset_name: str = field(
|
||||
default="pcqm4m",
|
||||
metadata={"help": "name of the dataset"},
|
||||
)
|
||||
|
||||
num_classes: int = field(
|
||||
default=-1,
|
||||
metadata={"help": "number of classes or regression targets"},
|
||||
)
|
||||
|
||||
max_nodes: int = field(
|
||||
default=128,
|
||||
metadata={"help": "max nodes per graph"},
|
||||
)
|
||||
|
||||
dataset_source: str = field(
|
||||
default="pyg",
|
||||
metadata={"help": "source of graph dataset, can be: pyg, dgl, ogb, smiles"},
|
||||
)
|
||||
|
||||
num_atoms: int = field(
|
||||
default=512 * 9,
|
||||
metadata={"help": "number of atom types in the graph"},
|
||||
)
|
||||
|
||||
num_edges: int = field(
|
||||
default=512 * 3,
|
||||
metadata={"help": "number of edge types in the graph"},
|
||||
)
|
||||
|
||||
num_in_degree: int = field(
|
||||
default=512,
|
||||
metadata={"help": "number of in degree types in the graph"},
|
||||
)
|
||||
|
||||
num_out_degree: int = field(
|
||||
default=512,
|
||||
metadata={"help": "number of out degree types in the graph"},
|
||||
)
|
||||
|
||||
num_spatial: int = field(
|
||||
default=512,
|
||||
metadata={"help": "number of spatial types in the graph"},
|
||||
)
|
||||
|
||||
num_edge_dis: int = field(
|
||||
default=128,
|
||||
metadata={"help": "number of edge dis types in the graph"},
|
||||
)
|
||||
|
||||
multi_hop_max_dist: int = field(
|
||||
default=5,
|
||||
metadata={"help": "max distance of multi-hop edges"},
|
||||
)
|
||||
|
||||
spatial_pos_max: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "max distance of multi-hop edges"},
|
||||
)
|
||||
|
||||
edge_type: str = field(
|
||||
default="multi_hop",
|
||||
metadata={"help": "edge type in the graph"},
|
||||
)
|
||||
|
||||
seed: int = II("common.seed")
|
||||
|
||||
pretrained_model_name: str = field(
|
||||
default="none",
|
||||
metadata={"help": "name of used pretrained model"},
|
||||
)
|
||||
|
||||
load_pretrained_model_output_layer: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to load the output layer of pretrained model"},
|
||||
)
|
||||
|
||||
|
||||
@register_task("graph_prediction", dataclass=GraphPredictionConfig)
|
||||
class GraphPredictionTask(FairseqTask):
|
||||
"""
|
||||
Graph prediction (classification or regression) task.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.dm = GraphormerDataset(
|
||||
dataset_spec=cfg.dataset_name,
|
||||
dataset_source=cfg.dataset_source,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, cfg, **kwargs):
|
||||
assert cfg.num_classes > 0, "Must set task.num_classes"
|
||||
return cls(cfg)
|
||||
|
||||
def load_dataset(self, split, combine=False, **kwargs):
|
||||
"""Load a given dataset split (e.g., train, valid, test)."""
|
||||
|
||||
assert split in ["train", "valid", "test"]
|
||||
|
||||
if split == "train":
|
||||
batched_data = self.dm.dataset_train
|
||||
elif split == "valid":
|
||||
batched_data = self.dm.dataset_val
|
||||
elif split == "test":
|
||||
batched_data = self.dm.dataset_test
|
||||
|
||||
batched_data = BatchedDataDataset(
|
||||
batched_data,
|
||||
max_node=self.max_nodes(),
|
||||
multi_hop_max_dist=self.cfg.multi_hop_max_dist,
|
||||
spatial_pos_max=self.cfg.spatial_pos_max,
|
||||
)
|
||||
|
||||
data_sizes = np.array([self.max_nodes()] * len(batched_data))
|
||||
|
||||
target = TargetDataset(batched_data)
|
||||
|
||||
dataset = NestedDictionaryDataset(
|
||||
{
|
||||
"nsamples": NumSamplesDataset(),
|
||||
"net_input": {"batched_data": batched_data},
|
||||
"target": target,
|
||||
},
|
||||
sizes=data_sizes,
|
||||
)
|
||||
|
||||
logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))
|
||||
|
||||
self.datasets[split] = dataset
|
||||
return self.datasets[split]
|
||||
|
||||
def build_model(self, cfg):
|
||||
from fairseq import models
|
||||
|
||||
with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack():
|
||||
cfg.max_nodes = self.cfg.max_nodes
|
||||
|
||||
model = models.build_model(cfg, self)
|
||||
|
||||
return model
|
||||
|
||||
def max_nodes(self):
|
||||
return self.cfg.max_nodes
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def label_dictionary(self):
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphPredictionWithFlagConfig(GraphPredictionConfig):
|
||||
flag_m: int = field(
|
||||
default=3,
|
||||
metadata={
|
||||
"help": "number of iterations to optimize the perturbations with flag objectives"
|
||||
},
|
||||
)
|
||||
|
||||
flag_step_size: float = field(
|
||||
default=1e-3,
|
||||
metadata={
|
||||
"help": "learing rate of iterations to optimize the perturbations with flag objective"
|
||||
},
|
||||
)
|
||||
|
||||
flag_mag: float = field(
|
||||
default=1e-3,
|
||||
metadata={"help": "magnitude bound for perturbations in flag objectives"},
|
||||
)
|
||||
|
||||
|
||||
@register_task("graph_prediction_with_flag", dataclass=GraphPredictionWithFlagConfig)
|
||||
class GraphPredictionWithFlagTask(GraphPredictionTask):
|
||||
"""
|
||||
Graph prediction (classification or regression) task.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.dm = GraphormerDataset(
|
||||
dataset_spec=cfg.dataset_name,
|
||||
dataset_source=cfg.dataset_source,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
self.flag_m = cfg.flag_m
|
||||
self.flag_step_size = cfg.flag_step_size
|
||||
self.flag_mag = cfg.flag_mag
|
||||
|
||||
def train_step(
|
||||
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||
):
|
||||
"""
|
||||
Do forward and backward, and return the loss as computed by *criterion*
|
||||
for the given *model* and *sample*.
|
||||
|
||||
Args:
|
||||
sample (dict): the mini-batch. The format is defined by the
|
||||
:class:`~fairseq.data.FairseqDataset`.
|
||||
model (~fairseq.models.BaseFairseqModel): the model
|
||||
criterion (~fairseq.criterions.FairseqCriterion): the criterion
|
||||
optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
|
||||
update_num (int): the current update
|
||||
ignore_grad (bool): multiply loss by 0 if this is set to True
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the loss
|
||||
- the sample size, which is used as the denominator for the
|
||||
gradient
|
||||
- logging outputs to display while training
|
||||
"""
|
||||
model.train()
|
||||
model.set_num_updates(update_num)
|
||||
|
||||
batched_data = sample["net_input"]["batched_data"]["x"]
|
||||
n_graph, n_node = batched_data.shape[:2]
|
||||
perturb_shape = n_graph, n_node, model.encoder_embed_dim
|
||||
if self.flag_mag > 0:
|
||||
perturb = (
|
||||
torch.FloatTensor(*perturb_shape)
|
||||
.uniform_(-1, 1)
|
||||
.to(batched_data.device)
|
||||
)
|
||||
perturb = perturb * self.flag_mag / math.sqrt(perturb_shape[-1])
|
||||
else:
|
||||
perturb = (
|
||||
torch.FloatTensor(*perturb_shape)
|
||||
.uniform_(-self.flag_step_size, self.flag_step_size)
|
||||
.to(batched_data.device)
|
||||
)
|
||||
perturb.requires_grad_()
|
||||
sample["perturb"] = perturb
|
||||
with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
|
||||
loss, sample_size, logging_output = criterion(
|
||||
model, sample
|
||||
)
|
||||
if ignore_grad:
|
||||
loss *= 0
|
||||
loss /= self.flag_m
|
||||
total_loss = 0
|
||||
for _ in range(self.flag_m - 1):
|
||||
optimizer.backward(loss)
|
||||
total_loss += loss.detach()
|
||||
perturb_data = perturb.detach() + self.flag_step_size * torch.sign(
|
||||
perturb.grad.detach()
|
||||
)
|
||||
if self.flag_mag > 0:
|
||||
perturb_data_norm = torch.norm(perturb_data, dim=-1).detach()
|
||||
exceed_mask = (perturb_data_norm > self.flag_mag).to(perturb_data)
|
||||
reweights = (
|
||||
self.flag_mag / perturb_data_norm * exceed_mask
|
||||
+ (1 - exceed_mask)
|
||||
).unsqueeze(-1)
|
||||
perturb_data = (perturb_data * reweights).detach()
|
||||
perturb.data = perturb_data.data
|
||||
perturb.grad[:] = 0
|
||||
sample["perturb"] = perturb
|
||||
with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
|
||||
loss, sample_size, logging_output = criterion(
|
||||
model, sample
|
||||
)
|
||||
if ignore_grad:
|
||||
loss *= 0
|
||||
loss /= self.flag_m
|
||||
optimizer.backward(loss)
|
||||
total_loss += loss.detach()
|
||||
logging_output["loss"] = total_loss
|
||||
return total_loss, sample_size, logging_output
|
|
@ -0,0 +1,321 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
|
||||
import pickle
|
||||
from functools import lru_cache
|
||||
|
||||
import lmdb
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from fairseq.data import (
|
||||
FairseqDataset,
|
||||
BaseWrapperDataset,
|
||||
NestedDictionaryDataset,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
|
||||
|
||||
class EpochShuffleDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, num_samples, seed):
|
||||
super().__init__(dataset)
|
||||
self.num_samples = num_samples
|
||||
self.seed = seed
|
||||
self.set_epoch(1)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
with data_utils.numpy_seed(self.seed + epoch - 1):
|
||||
self.sort_order = np.random.permutation(
|
||||
self.num_samples
|
||||
) # random ordered_indices will break fairseq bucket by size batch iter, but we just want to reproduce...
|
||||
|
||||
def ordered_indices(self):
|
||||
return self.sort_order
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return False
|
||||
|
||||
|
||||
class LMDBDataset:
|
||||
def __init__(self, db_path):
|
||||
super().__init__()
|
||||
assert Path(db_path).exists(), f"{db_path}: No such file or directory"
|
||||
self.env = lmdb.Environment(
|
||||
db_path,
|
||||
map_size=(1024 ** 3) * 256,
|
||||
subdir=False,
|
||||
readonly=True,
|
||||
readahead=True,
|
||||
meminit=False,
|
||||
)
|
||||
self.len: int = self.env.stat()["entries"]
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def __getitem__(self, idx: int) -> dict[str, Union[Tensor, float]]:
|
||||
if idx < 0 or idx >= self.len:
|
||||
raise IndexError
|
||||
data = pickle.loads(self.env.begin().get(f"{idx}".encode()))
|
||||
return dict(
|
||||
pos=torch.as_tensor(data["pos"]).float(),
|
||||
pos_relaxed=torch.as_tensor(data["pos_relaxed"]).float(),
|
||||
cell=torch.as_tensor(data["cell"]).float().view(3, 3),
|
||||
atoms=torch.as_tensor(data["atomic_numbers"]).long(),
|
||||
tags=torch.as_tensor(data["tags"]).long(),
|
||||
relaxed_energy=data["y_relaxed"], # python float
|
||||
)
|
||||
|
||||
|
||||
class PBCDataset:
|
||||
def __init__(self, dataset: LMDBDataset):
|
||||
self.dataset = dataset
|
||||
self.cell_offsets = torch.tensor(
|
||||
[
|
||||
[-1, -1, 0],
|
||||
[-1, 0, 0],
|
||||
[-1, 1, 0],
|
||||
[0, -1, 0],
|
||||
[0, 1, 0],
|
||||
[1, -1, 0],
|
||||
[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
],
|
||||
).float()
|
||||
self.n_cells = self.cell_offsets.size(0)
|
||||
self.cutoff = 8
|
||||
self.filter_by_tag = True
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def __getitem__(self, idx):
|
||||
data = self.dataset[idx]
|
||||
|
||||
pos = data["pos"]
|
||||
pos_relaxed = data["pos_relaxed"]
|
||||
cell = data["cell"]
|
||||
atoms = data["atoms"]
|
||||
tags = data["tags"]
|
||||
|
||||
offsets = torch.matmul(self.cell_offsets, cell).view(self.n_cells, 1, 3)
|
||||
expand_pos = (pos.unsqueeze(0).expand(self.n_cells, -1, -1) + offsets).view(
|
||||
-1, 3
|
||||
)
|
||||
expand_pos_relaxed = (
|
||||
pos.unsqueeze(0).expand(self.n_cells, -1, -1) + offsets
|
||||
).view(-1, 3)
|
||||
src_pos = pos[tags > 1] if self.filter_by_tag else pos
|
||||
|
||||
dist: Tensor = (src_pos.unsqueeze(1) - expand_pos.unsqueeze(0)).norm(dim=-1)
|
||||
used_mask = (dist < self.cutoff).any(dim=0) & tags.ne(2).repeat(
|
||||
self.n_cells
|
||||
) # not copy ads
|
||||
used_expand_pos = expand_pos[used_mask]
|
||||
used_expand_pos_relaxed = expand_pos_relaxed[used_mask]
|
||||
|
||||
used_expand_tags = tags.repeat(self.n_cells)[
|
||||
used_mask
|
||||
] # original implementation use zeros, need to test
|
||||
return dict(
|
||||
pos=torch.cat([pos, used_expand_pos], dim=0),
|
||||
atoms=torch.cat([atoms, atoms.repeat(self.n_cells)[used_mask]]),
|
||||
tags=torch.cat([tags, used_expand_tags]),
|
||||
real_mask=torch.cat(
|
||||
[
|
||||
torch.ones_like(tags, dtype=torch.bool),
|
||||
torch.zeros_like(used_expand_tags, dtype=torch.bool),
|
||||
]
|
||||
),
|
||||
deltapos=torch.cat(
|
||||
[pos_relaxed - pos, used_expand_pos_relaxed - used_expand_pos], dim=0
|
||||
),
|
||||
relaxed_energy=data["relaxed_energy"],
|
||||
)
|
||||
|
||||
|
||||
def pad_1d(samples: Sequence[Tensor], fill=0, multiplier=8):
|
||||
max_len = max(x.size(0) for x in samples)
|
||||
max_len = (max_len + multiplier - 1) // multiplier * multiplier
|
||||
n_samples = len(samples)
|
||||
out = torch.full(
|
||||
(n_samples, max_len, *samples[0].shape[1:]), fill, dtype=samples[0].dtype
|
||||
)
|
||||
for i in range(n_samples):
|
||||
x_len = samples[i].size(0)
|
||||
out[i][:x_len] = samples[i]
|
||||
return out
|
||||
|
||||
|
||||
class AtomDataset(FairseqDataset):
|
||||
def __init__(self, dataset, keyword):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
self.keyword = keyword
|
||||
self.atom_list = [
|
||||
1,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
11,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
55,
|
||||
72,
|
||||
73,
|
||||
74,
|
||||
75,
|
||||
76,
|
||||
77,
|
||||
78,
|
||||
79,
|
||||
80,
|
||||
81,
|
||||
82,
|
||||
83,
|
||||
]
|
||||
# fill others as unk
|
||||
unk_idx = len(self.atom_list) + 1
|
||||
self.atom_mapper = torch.full((128,), unk_idx)
|
||||
for idx, atom in enumerate(self.atom_list):
|
||||
self.atom_mapper[atom] = idx + 1 # reserve 0 for paddin
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def __getitem__(self, index):
|
||||
atoms: Tensor = self.dataset[index][self.keyword]
|
||||
return self.atom_mapper[atoms]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples):
|
||||
return pad_1d(samples)
|
||||
|
||||
|
||||
class KeywordDataset(FairseqDataset):
|
||||
def __init__(self, dataset, keyword, is_scalar=False, pad_fill=0):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
self.keyword = keyword
|
||||
self.is_scalar = is_scalar
|
||||
self.pad_fill = pad_fill
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def __getitem__(self, index):
|
||||
return self.dataset[index][self.keyword]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples):
|
||||
if self.is_scalar:
|
||||
return torch.tensor(samples)
|
||||
return pad_1d(samples, fill=self.pad_fill)
|
||||
|
||||
|
||||
@register_task("is2re")
|
||||
class IS2RETask(FairseqTask):
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
parser.add_argument("data", metavar="FILE", help="directory for data")
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return None
|
||||
|
||||
def load_dataset(self, split, combine=False, **kwargs):
|
||||
assert split in [
|
||||
"train",
|
||||
"val_id",
|
||||
"val_ood_ads",
|
||||
"val_ood_cat",
|
||||
"val_ood_both",
|
||||
"test_id",
|
||||
"test_ood_ads",
|
||||
"test_ood_cat",
|
||||
"test_ood_both",
|
||||
], "invalid split: {}!".format(split)
|
||||
print(" > Loading {} ...".format(split))
|
||||
|
||||
db_path = str(Path(self.cfg.data) / split / "data.lmdb")
|
||||
lmdb_dataset = LMDBDataset(db_path)
|
||||
pbc_dataset = PBCDataset(lmdb_dataset)
|
||||
|
||||
atoms = AtomDataset(pbc_dataset, "atoms")
|
||||
tags = KeywordDataset(pbc_dataset, "tags")
|
||||
real_mask = KeywordDataset(pbc_dataset, "real_mask")
|
||||
|
||||
pos = KeywordDataset(pbc_dataset, "pos")
|
||||
|
||||
relaxed_energy = KeywordDataset(pbc_dataset, "relaxed_energy", is_scalar=True)
|
||||
deltapos = KeywordDataset(pbc_dataset, "deltapos")
|
||||
|
||||
dataset = NestedDictionaryDataset(
|
||||
{
|
||||
"net_input": {
|
||||
"pos": pos,
|
||||
"atoms": atoms,
|
||||
"tags": tags,
|
||||
"real_mask": real_mask,
|
||||
},
|
||||
"targets": {
|
||||
"relaxed_energy": relaxed_energy,
|
||||
"deltapos": deltapos,
|
||||
},
|
||||
},
|
||||
sizes=[np.zeros(len(atoms))],
|
||||
)
|
||||
|
||||
if split == "train":
|
||||
dataset = EpochShuffleDataset(
|
||||
dataset,
|
||||
num_samples=len(atoms),
|
||||
seed=self.cfg.seed,
|
||||
)
|
||||
|
||||
print("| Loaded {} with {} samples".format(split, len(dataset)))
|
||||
self.datasets[split] = dataset
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
|
@ -1,51 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
def flag_bounded(model_forward, perturb_shape, y, optimizer, device, criterion, m=3, step_size=1e-3, mag=1e-3, mask=None):
|
||||
model, forward = model_forward
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if mag > 0:
|
||||
perturb = torch.FloatTensor(*perturb_shape).uniform_(-1, 1).to(device)
|
||||
perturb = perturb * mag / math.sqrt(perturb_shape[-1])
|
||||
else:
|
||||
perturb = torch.FloatTensor(
|
||||
*perturb_shape).uniform_(-step_size, step_size).to(device)
|
||||
perturb.requires_grad_()
|
||||
out = forward(perturb).view(-1)
|
||||
if mask is not None:
|
||||
out = out[mask]
|
||||
loss = criterion(out, y)
|
||||
loss /= m
|
||||
|
||||
for _ in range(m-1):
|
||||
# loss.backward()
|
||||
model.manual_backward(loss)
|
||||
perturb_data = perturb.detach() + step_size * torch.sign(perturb.grad.detach())
|
||||
if mag > 0:
|
||||
perturb_data_norm = torch.norm(perturb_data, dim=-1).detach()
|
||||
exceed_mask = (perturb_data_norm > mag).to(perturb_data)
|
||||
reweights = (mag / perturb_data_norm * exceed_mask +
|
||||
(1-exceed_mask)).unsqueeze(-1)
|
||||
perturb_data = (perturb_data * reweights).detach()
|
||||
|
||||
perturb.data = perturb_data.data
|
||||
perturb.grad[:] = 0
|
||||
|
||||
out = forward(perturb).view(-1)
|
||||
if mask is not None:
|
||||
out = out[mask]
|
||||
loss = criterion(out, y)
|
||||
loss /= m
|
||||
|
||||
# loss.backward()
|
||||
model.manual_backward(loss)
|
||||
optimizer.step()
|
||||
|
||||
return loss, out
|
|
@ -1,104 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch_geometric.datasets
|
||||
from ogb.graphproppred import PygGraphPropPredDataset
|
||||
from ogb.lsc.pcqm4m_pyg import PygPCQM4MDataset
|
||||
import pyximport
|
||||
|
||||
pyximport.install(setup_args={'include_dirs': np.get_include()})
|
||||
import algos
|
||||
|
||||
|
||||
def convert_to_single_emb(x, offset=512):
|
||||
feature_num = x.size(1) if len(x.size()) > 1 else 1
|
||||
feature_offset = 1 + \
|
||||
torch.arange(0, feature_num * offset, offset, dtype=torch.long)
|
||||
x = x + feature_offset
|
||||
return x
|
||||
|
||||
|
||||
def preprocess_item(item):
|
||||
edge_attr, edge_index, x = item.edge_attr, item.edge_index, item.x
|
||||
N = x.size(0)
|
||||
x = convert_to_single_emb(x)
|
||||
|
||||
# node adj matrix [N, N] bool
|
||||
adj = torch.zeros([N, N], dtype=torch.bool)
|
||||
adj[edge_index[0, :], edge_index[1, :]] = True
|
||||
|
||||
# edge feature here
|
||||
if len(edge_attr.size()) == 1:
|
||||
edge_attr = edge_attr[:, None]
|
||||
attn_edge_type = torch.zeros([N, N, edge_attr.size(-1)], dtype=torch.long)
|
||||
attn_edge_type[edge_index[0, :], edge_index[1, :]
|
||||
] = convert_to_single_emb(edge_attr) + 1
|
||||
|
||||
shortest_path_result, path = algos.floyd_warshall(adj.numpy())
|
||||
max_dist = np.amax(shortest_path_result)
|
||||
edge_input = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy())
|
||||
spatial_pos = torch.from_numpy((shortest_path_result)).long()
|
||||
attn_bias = torch.zeros(
|
||||
[N + 1, N + 1], dtype=torch.float) # with graph token
|
||||
|
||||
# combine
|
||||
item.x = x
|
||||
item.adj = adj
|
||||
item.attn_bias = attn_bias
|
||||
item.attn_edge_type = attn_edge_type
|
||||
item.spatial_pos = spatial_pos
|
||||
item.in_degree = adj.long().sum(dim=1).view(-1)
|
||||
item.out_degree = adj.long().sum(dim=0).view(-1)
|
||||
item.edge_input = torch.from_numpy(edge_input).long()
|
||||
|
||||
return item
|
||||
|
||||
|
||||
class MyGraphPropPredDataset(PygGraphPropPredDataset):
|
||||
def download(self):
|
||||
super(MyGraphPropPredDataset, self).download()
|
||||
|
||||
def process(self):
|
||||
super(MyGraphPropPredDataset, self).process()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
item = self.get(self.indices()[idx])
|
||||
item.idx = idx
|
||||
return preprocess_item(item)
|
||||
else:
|
||||
return self.index_select(idx)
|
||||
|
||||
|
||||
class MyPygPCQM4MDataset(PygPCQM4MDataset):
|
||||
def download(self):
|
||||
super(MyPygPCQM4MDataset, self).download()
|
||||
|
||||
def process(self):
|
||||
super(MyPygPCQM4MDataset, self).process()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
item = self.get(self.indices()[idx])
|
||||
item.idx = idx
|
||||
return preprocess_item(item)
|
||||
else:
|
||||
return self.index_select(idx)
|
||||
|
||||
|
||||
class MyZINCDataset(torch_geometric.datasets.ZINC):
|
||||
def download(self):
|
||||
super(MyZINCDataset, self).download()
|
||||
|
||||
def process(self):
|
||||
super(MyZINCDataset, self).process()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
item = self.get(self.indices()[idx])
|
||||
item.idx = idx
|
||||
return preprocess_item(item)
|
||||
else:
|
||||
return self.index_select(idx)
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# create new environment
|
||||
conda create --name graphormerv2 python=3.9
|
||||
conda activate graphormerv2
|
||||
|
||||
# install requirements
|
||||
pip install torch==1.9.1+cu111 torchaudio -f https://download.pytorch.org/whl/cu111/torch_stable.html
|
||||
# install torchaudio, thus fairseq installation will not install newest torchaudio and torch(would replace torch-1.9.1)
|
||||
pip install lmdb
|
||||
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.1+cu111.html
|
||||
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.1+cu111.html
|
||||
pip install torch-geometric==1.7.2
|
||||
pip install tensorboardX==2.4.1
|
||||
pip install ogb==1.3.2
|
||||
pip install rdkit-pypi==2021.9.3
|
||||
pip install dgl==0.7.2 -f https://data.dgl.ai/wheels/repo.html
|
||||
|
||||
cd fairseq
|
||||
# if fairseq submodule has not been checkouted, run:
|
||||
# git submodule update --init --recursive
|
||||
pip install . --use-feature=in-tree-build
|
||||
python setup.py build_ext --inplace
|
Загрузка…
Ссылка в новой задаче