This repository is the official PyTorch implementation of the paper "Geometric Representation Condition Improves Equivariant Molecule Generation".
To install the required packages, run:
bash ./install.sh
You can download all the checkpoints used in the paper from this link, and place them in the checkpoints
directory.
-
QM9 dataset: This dataset will be automatically downloaded when you run the EDM training script.
-
GEOM-DRUG dataset: Follow the instructions on the EDM GitHub to download this dataset.
-
GEOM-DRUG dataset (Semla): Semlaflow uses a smaller version of GEOM-DRUG with lower-energy conformations. You can download it from the SemlaFlow GitHub.
💡 You can train both the RDM and the Molecule Generator simultaneously, as the RDM is not required for sampling during Molecule Generator training (we use the training dataset for sampling).
-
QM9 Dataset (Unconditional):
To train an RDM that produces Frad representation on the QM9 dataset:
python src_GeoRCG/train_RDM.py experiments_RDM=qm9_uncond
We use Frad encoder's public checkpoint, available here. You can also download it from our provided checkpoints at
checkpoints/encoder_ckpts/qm9_frad.ckpt
. -
QM9 Dataset (Conditional):
To train a conditional RDM for the second half of the QM9 dataset:
python src_GeoRCG/train_RDM.py experiments_RDM=qm9_cond experiments_RDM.rdm_args.conditioning=["lumo"] experiments_RDM.rdm_args.exp_name="qm9_cond_lumo"
Modify the
"lumo"
property for other properties like"alpha"
or"homo"
. -
GEOM-DRUG Dataset:
To train an RDM on the GEOM-DRUG dataset:
python src_GeoRCG/train_RDM.py experiments_RDM=drug
The Unimol encoder was trained by ourselves for two primary reasons:
- The provided checkpoint by Unimol was trained on a setting with the hydrogen (H) atom removed, whereas our setup requires hydrogen to be included.
- The provided checkpoint was trained on datasets that do not contain some of the rarer elements found in the GEOM-DRUG dataset. As a result, we pretrained the encoder on a more comprehensive dataset that includes GEOM-DRUG.
In initial experiments, we used a 15-layer Unimol encoder, which should replicate the results in the GeoRCG paper. In later experiments, we found that a 6-layer Unimol encoder performed more effectively for the generation task, which may be due to a relatively lower Lipschitz constant for the conformations of the 6-layer encoder. As of now, the default configuration uses the 6-layer Unimol encoder. You can find the checkpoint for this model at
checkpoints/encoder_ckpts/drug_unimol_6layers_noise1.5.pt
. Alternatively, you can pretrain the encoder yourself using the provided script located inEDM_based/models_GeoRCG/unimol/unimol
. -
SemlaFlow Setting:
To train an RDM on GEOM-DRUG dataset using the SemlaFlow data:
python src_GeoRCG/train_RDM.py experiments_RDM=drug experiments_RDM.rdm_args.semlaflow_data=true experiments_RDM.rdm_args.encoder_type=unimol_global experiments_RDM.rdm_args.encoder_path=../checkpoints/encoder_ckpts/drug_unimol_global_iter1.8M.pt
In the SemlaFlow setting, we experiment with two configurations:
- Unimol Global Configuration (default setting in this code):
We train a modified version of the 15-layer Unimol model, referred to asunimol_global
, and use the first 4 layers for representations. The modified model includes a global output head and a global pretext pretraining task, which helps achieve a lower Lipschitz constant for the conformations, benefiting GeoRCG. The checkpoint for this configuration can be found atcheckpoints/encoder_ckpts/drug_unimol_global_iter1.8M.pt
. You can also pretrain this version using the provided task file atmodels_GeoRCG/unimol/unimol/tasks/unimol_global.py
. - Standard Unimol Configuration:
In this configuration, we train a standard 15-layer Unimol model and use the first 4 layers for representations. The checkpoint for this model is located at../checkpoints/encoder_ckpts/drug_unimol_15layers_noise1_iter2M.pt
. You can train RDM under this configuration by modifying theencoder_type
option tounimol_truncated
and specifying the corresponding checkpoint path:encoder_path=../checkpoints/encoder_ckpts/drug_unimol_15layers_noise1_iter2M.pt
. Similarly, modify theencoder_type
andencoder_path
options in the Semlaflow molecule generator training to match the desired configuration. This configuration provides improved Strain and Energy metrics, but it slightly sacrifices stability and validity metrics.
- Unimol Global Configuration (default setting in this code):
-
QM9 Dataset:
To train a molecule generator on the QM9 dataset:
python src_GeoRCG/train_gen.py experiments_gen=qm9
-
GEOM-DRUG Dataset:
To train on the GEOM-DRUG dataset using 4 GPUs:
python -m torch.distributed.run --nproc_per_node=4 --master-port=20001 src_GeoRCG/train_gen.py experiments_gen=drug
For single GPU, disable distributed processing:
python src_GeoRCG/train_gen.py experiments_gen=drug experiments_gen.gen_args.dp=false
To train on the GEOM-DRUG dataset using SemlaFlow:
python src_GeoRCG/train_drug.py experiments_gen=drug
-
Unconditionally Generated QM9 Molecules:
python eval_src/eval_analyze.py cfg=1.0 inv_temp=1.0 gen_model_path=../checkpoints/gen_ckpts/edm_qm9_frad_noise0.3 rdm_ckpt=../checkpoints/rdm_ckpts/rdm_qm9_frad_uncond/model/checkpoint-last.pth
Feel free to adjust the
cfg
andinv_temp
parameters as needed. Thecfg
parameter controls the sampling temperature, whileinv_temp
sets the inverse temperature for sampling. -
Unconditionally Generated GEOM-DRUG Molecules:
python eval_src/eval_analyze.py cfg=1.0 inv_temp=1.0 gen_model_path=../checkpoints/gen_ckpts/edm_drug_unimol6layers_noise0.5 rdm_ckpt=../checkpoints/rdm_ckpts/rdm_drug_unimol_6layers/checkpoint-98.pth
You can adjust the sampling step number of above by adding ddim_S=10
or other values.
-
Conditionally Generated QM9 Molecules:
python eval_src/eval_conditional_qm9.py classifiers_path=../checkpoints/classifiers_ckpts/exp_class_alpha property=alpha gen_model_path=../checkpoints/gen_ckpts/edm_qm9_second_half_frad_noise0.3 rdm_ckpt=../checkpoints/rdm_ckpts/rdm_qm9_frad_alpha/model/checkpoint-last.pth cfg=2.0 inv_temp=1.0
Change the
alpha
property to other properties likelumo
orhomo
, and update the corresponding paths accordingly.
To evaluate unconditionally generated GEOM-DRUG molecules using SemlaFlow:
python evaluate.py ckpt_path=../checkpoints/gen_ckpts/semlaflow_unimol_global_truncated4_noise0.3_reploss0.5/checkpoints/last.ckpt rdm_ckpt=../checkpoints/rdm_ckpts/rdm_drug_semla_unimol_global_truncated4/model/checkpoint-last.pth n_molecules=10000 n_replicates=3 cfg_coef=-0.9 batch_cost=2048 integration_steps=100
Change integration_steps
to adjust the sampling step number.
Another version of GeoRCG (Semla) produces better Strain and Energy metrics but slightly worse stability and validity, as mentioned in this section. You can use the following command to evaluate this version:
`ckpt_path=../checkpoints/gen_ckpts/semlaflow_unimol_truncated4_noise0.3_reploss0.1/checkpoints/last.ckpt rdm_ckpt=../checkpoints/rdm_ckpts/rdm_drug_semla_unimol_truncated4`
If you find this repository useful and use it in your research, please cite our paper:
@article{li2024geometric,
title={Geometric Representation Condition Improves Equivariant Molecule Generation},
author={Li, Zian and Zhou, Cai and Wang, Xiyuan and Peng, Xingang and Zhang, Muhan},
journal={arXiv preprint arXiv:2410.03655},
year={2024}
}
This code repository is built upon the following works:
Thanks for all the authors for their contributions!