|
78 | 78 | },
|
79 | 79 | {
|
80 | 80 | "cell_type": "code",
|
81 |
| - "execution_count": 69, |
| 81 | + "execution_count": null, |
82 | 82 | "metadata": {
|
83 | 83 | "colab": {
|
84 | 84 | "base_uri": "https://localhost:8080/"
|
|
94 | 94 | " Compose,\n",
|
95 | 95 | " LoadImageD,\n",
|
96 | 96 | " ScaleIntensityd,\n",
|
97 |
| - " RandGaussianNoiseD, \n", |
98 |
| - " RandGaussianSmoothD, \n", |
99 |
| - " RandAdjustContrastD, \n", |
| 97 | + " RandGaussianNoiseD,\n", |
| 98 | + " RandGaussianSmoothD,\n", |
100 | 99 | ")\n",
|
101 |
| - "import numpy as np\n", |
102 | 100 | "from monai.data import DataLoader, Dataset, CacheDataset\n",
|
103 | 101 | "from monai.config import print_config\n",
|
104 | 102 | "from monai.networks.nets.restormer import Restormer\n",
|
105 | 103 | "from monai.apps import MedNISTDataset\n",
|
106 | 104 | "\n",
|
107 |
| - "import numpy as np\n", |
108 | 105 | "import torch\n",
|
109 | 106 | "from monai.losses import SSIMLoss\n",
|
110 | 107 | "import matplotlib.pyplot as plt\n",
|
111 | 108 | "import os\n",
|
112 | 109 | "import tempfile\n",
|
113 | 110 | "\n",
|
114 |
| - "from tqdm.notebook import tqdm\n", |
115 |
| - "\n", |
116 | 111 | "\n",
|
117 |
| - "#print_config()\n", |
118 |
| - "#set_determinism(42)" |
| 112 | + "print_config()\n", |
| 113 | + "set_determinism(42)" |
119 | 114 | ]
|
120 | 115 | },
|
121 | 116 | {
|
|
361 | 356 | },
|
362 | 357 | {
|
363 | 358 | "cell_type": "code",
|
364 |
| - "execution_count": 70, |
| 359 | + "execution_count": null, |
365 | 360 | "metadata": {
|
366 | 361 | "id": "zHAj8nuHXG-D",
|
367 | 362 | "outputId": "462d37f3-b59e-4d88-ca18-60224f69076d"
|
|
374 | 369 | " device = torch.device(\"mps\")\n",
|
375 | 370 | "else:\n",
|
376 | 371 | " device = torch.device(\"cpu\")\n",
|
377 |
| - " \n", |
| 372 | + "\n", |
378 | 373 | "model = Restormer(\n",
|
379 | 374 | " spatial_dims=2,\n",
|
380 | 375 | " in_channels=1,\n",
|
|
399 | 394 | },
|
400 | 395 | {
|
401 | 396 | "cell_type": "code",
|
402 |
| - "execution_count": 72, |
| 397 | + "execution_count": null, |
403 | 398 | "metadata": {
|
404 | 399 | "id": "eyiL4ccmYsjt"
|
405 | 400 | },
|
|
432 | 427 | }
|
433 | 428 | ],
|
434 | 429 | "source": [
|
435 |
| - "max_epochs = 20\n", |
| 430 | + "max_epochs = 2\n", |
436 | 431 | "epoch_loss_values = []\n",
|
437 | 432 | "\n",
|
438 | 433 | "\n",
|
|
448 | 443 | " moving = batch_data[\"moving_hand\"].to(device)\n",
|
449 | 444 | " fixed = batch_data[\"fixed_hand\"].to(device)\n",
|
450 | 445 | " pred_image = model(moving)\n",
|
451 |
| - " pred_image=torch.sigmoid(pred_image)\n", |
| 446 | + " pred_image = torch.sigmoid(pred_image)\n", |
452 | 447 | "\n",
|
453 | 448 | " loss = image_loss(input=pred_image, target=fixed)\n",
|
454 | 449 | " loss.backward()\n",
|
|
514 | 509 | },
|
515 | 510 | {
|
516 | 511 | "cell_type": "code",
|
517 |
| - "execution_count": 76, |
| 512 | + "execution_count": null, |
518 | 513 | "metadata": {
|
519 | 514 | "colab": {
|
520 | 515 | "base_uri": "https://localhost:8080/"
|
|
534 | 529 | "source": [
|
535 | 530 | "val_ds = CacheDataset(data=training_datadict[2000:2500], transform=train_transforms, cache_rate=1.0, num_workers=0)\n",
|
536 | 531 | "val_loader = DataLoader(val_ds, batch_size=16, num_workers=0)\n",
|
537 |
| - "model.eval() # Set model to evaluation mode\n", |
| 532 | + "model.eval() # Set model to evaluation mode\n", |
538 | 533 | "\n",
|
539 |
| - "with torch.no_grad(): # Disable gradient calculation for inference\n", |
| 534 | + "with torch.no_grad(): # Disable gradient calculation for inference\n", |
540 | 535 | " for batch_data in val_loader:\n",
|
541 | 536 | " moving = batch_data[\"moving_hand\"].to(device)\n",
|
542 | 537 | " fixed = batch_data[\"fixed_hand\"].to(device)\n",
|
543 | 538 | " # Pass only the moving image, consistent with training\n",
|
544 | 539 | " pred_image = model(moving)\n",
|
545 | 540 | " pred_image = torch.sigmoid(pred_image)\n",
|
546 |
| - " break # Process only the first batch for visualization\n", |
| 541 | + " break # Process only the first batch for visualization\n", |
547 | 542 | "\n",
|
548 | 543 | "fixed_image = fixed.detach().cpu().numpy()[:, 0]\n",
|
549 | 544 | "moving_image = moving.detach().cpu().numpy()[:, 0]\n",
|
|
596 | 591 | "plt.axis(\"off\")\n",
|
597 | 592 | "plt.show()"
|
598 | 593 | ]
|
599 |
| - }, |
600 |
| - { |
601 |
| - "cell_type": "code", |
602 |
| - "execution_count": null, |
603 |
| - "metadata": {}, |
604 |
| - "outputs": [], |
605 |
| - "source": [] |
606 | 594 | }
|
607 | 595 | ],
|
608 | 596 | "metadata": {
|
|
0 commit comments