|
60 | 60 | "source": [
|
61 | 61 | "!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel]\"\n",
|
62 | 62 | "!python -c \"import matplotlib\" || pip install -q matplotlib\n",
|
63 |
| - "!pip install -q pytorch-lightning==0.9.0\n", |
| 63 | + "!pip install -q pytorch-lightning==1.4.0\n", |
64 | 64 | "%matplotlib inline"
|
65 | 65 | ]
|
66 | 66 | },
|
|
145 | 145 | "from monai.apps import download_and_extract\n",
|
146 | 146 | "import torch\n",
|
147 | 147 | "import pytorch_lightning\n",
|
148 |
| - "from pytorch_lightning.callbacks.model_checkpoint \\\n", |
149 |
| - " import ModelCheckpoint\n", |
150 | 148 | "import matplotlib.pyplot as plt\n",
|
151 | 149 | "import tempfile\n",
|
152 | 150 | "import shutil\n",
|
|
424 | 422 | "tb_logger = pytorch_lightning.loggers.TensorBoardLogger(\n",
|
425 | 423 | " save_dir=log_dir\n",
|
426 | 424 | ")\n",
|
427 |
| - "checkpoint_callback = ModelCheckpoint(\n", |
428 |
| - " filepath=os.path.join(\n", |
429 |
| - " log_dir, \"{epoch}-{val_loss:.2f}-{val_dice:.2f}\")\n", |
430 |
| - ")\n", |
431 | 425 | "\n",
|
432 | 426 | "# initialise Lightning's trainer.\n",
|
433 | 427 | "trainer = pytorch_lightning.Trainer(\n",
|
434 | 428 | " gpus=[0],\n",
|
435 | 429 | " max_epochs=600,\n",
|
436 | 430 | " logger=tb_logger,\n",
|
437 |
| - " checkpoint_callback=checkpoint_callback,\n", |
| 431 | + " checkpoint_callback=True,\n", |
438 | 432 | " num_sanity_val_steps=1,\n",
|
439 | 433 | ")\n",
|
440 | 434 | "\n",
|
|
710 | 704 | "name": "python",
|
711 | 705 | "nbconvert_exporter": "python",
|
712 | 706 | "pygments_lexer": "ipython3",
|
713 |
| - "version": "3.8.10" |
| 707 | + "version": "3.7.10" |
714 | 708 | }
|
715 | 709 | },
|
716 | 710 | "nbformat": 4,
|
|
0 commit comments