Skip to content

JiaYuanChng/PA-Graph-Transformer

 
 

Repository files navigation

Path-Augmented Graph Transformer Network

This is the github repo for the paper "Path-Augmented Graph Transformer Network"

All data (and splits used for experiments are under data.zip)

These are the require packages and set up for a conda environment (can be slightly different depending on system).

conda create -c rdkit -n prop_predictor rdkit
source activate prop_predictor
conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
conda install scikit-learn tqdm

Add the repo to PYTHONPATH:

export PYTHONPATH=path_to_repo:

To compute the shortest paths, run the following:

python preprocess/shortest_paths.py -data_dir path_to_data

To run the (transformer) model:

dataset=path_to_dataset
python train/train_prop.py -cuda \
    -data $dataset -loss_type mse \
    -max_grad_norm 10 -batch_size 50 -num_epochs 100 \
	-output_dir output_test/sol_transformer -n_rounds 10 \
	-model_type transformer -hidden_size 160 \
	-p_embed -ring_embed -max_path_length 3 -lr 5e-4 \
	-no_share -n_heads 2 -d_k 80 -dropout 0.2

To run the local model, add the option:

-mask_neigh

To run the conv net model:

dataset=path_to_dataset
python train/train_prop.py -cuda \
    -data $dataset -loss_type mse \
    -max_grad_norm 10 -batch_size 50 -num_epochs 100 \
    -output_dir output_test/sol_conv_net -n_rounds 10  \
    -model_type conv_net -hidden_size 160 -lr 5e-4 -dropout 0.2

All the scripts used to generate the results in the paper are under the scripts directory.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 91.6%
  • Shell 8.4%