Skip to content

Commit da83708

Browse files
Can-ZhaoKumoLiu
andauthored
update generative func to monai core func (#1768)
Fixes # . ### Description update generative func to monai core func ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: Can-Zhao <[email protected]> Signed-off-by: Can Zhao <[email protected]> Signed-off-by: YunLiu <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent f444983 commit da83708

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

.github/workflows/test-modified.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
pip uninstall -y monai-weekly
3333
pip uninstall -y monai-weekly # make sure there's no existing installation
3434
BUILD_MONAI=0 python -m pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI
35-
python -m pip install -r https://raw.githubusercontent.com/Project-MONAI/MONAI/main/requirements-dev.txt
35+
python -m pip install -r https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/requirements-dev.txt
3636
python -m pip install -U torch torchvision torchaudio
3737
- uses: actions/checkout@v3
3838
- name: Notebook quick check

generative/maisi/configs/config_maisi.json

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@
9292
"conditioning_embedding_num_channels": [8, 32, 64]
9393
},
9494
"mask_generation_autoencoder_def": {
95-
"_target_": "generative.networks.nets.AutoencoderKL",
96-
"spatial_dims": 3,
95+
"_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
96+
"spatial_dims": "@spatial_dims",
9797
"in_channels": 8,
9898
"out_channels": 125,
99-
"latent_channels": 4,
99+
"latent_channels": "@latent_channels",
100100
"num_channels": [
101101
32,
102102
64,
@@ -114,13 +114,16 @@
114114
"with_decoder_nonlocal_attn": false,
115115
"use_flash_attention": false,
116116
"use_checkpointing": true,
117-
"use_convtranspose": true
117+
"use_convtranspose": true,
118+
"norm_float16": true,
119+
"num_splits": 8,
120+
"dim_split": 1
118121
},
119122
"mask_generation_diffusion_def": {
120-
"_target_": "generative.networks.nets.DiffusionModelUNet",
121-
"spatial_dims": 3,
122-
"in_channels": 4,
123-
"out_channels": 4,
123+
"_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
124+
"spatial_dims": "@spatial_dims",
125+
"in_channels": "@latent_channels",
126+
"out_channels": "@latent_channels",
124127
"num_channels":[64, 128, 256, 512],
125128
"attention_levels":[false, false, true, true],
126129
"num_head_channels":[0, 0, 32, 32],
@@ -132,15 +135,15 @@
132135
},
133136
"mask_generation_scale_factor": 1.0055984258651733,
134137
"noise_scheduler": {
135-
"_target_": "generative.networks.schedulers.DDPMScheduler",
138+
"_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
136139
"num_train_timesteps": 1000,
137140
"beta_start": 0.0015,
138141
"beta_end": 0.0195,
139142
"schedule": "scaled_linear_beta",
140143
"clip_sample": false
141144
},
142145
"mask_generation_noise_scheduler": {
143-
"_target_": "generative.networks.schedulers.DDPMScheduler",
146+
"_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
144147
"num_train_timesteps": 1000,
145148
"beta_start": 0.0015,
146149
"beta_end": 0.0195,

generative/maisi/maisi_inference_tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@
397397
"controlnet.load_state_dict(checkpoint_controlnet[\"controlnet_state_dict\"], strict=True)\n",
398398
"\n",
399399
"mask_generation_autoencoder = define_instance(args, \"mask_generation_autoencoder_def\").to(device)\n",
400-
"checkpoint_mask_generation_autoencoder = torch.load(args.trained_mask_generation_autoencoder_path)\n",
400+
"checkpoint_mask_generation_autoencoder = load_autoencoder_ckpt(args.trained_mask_generation_autoencoder_path)\n",
401401
"mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder)\n",
402402
"\n",
403403
"mask_generation_diffusion_unet = define_instance(args, \"mask_generation_diffusion_def\").to(device)\n",

0 commit comments

Comments
 (0)