diff --git a/3d_segmentation/spleen_segmentation_3d_lightning.ipynb b/3d_segmentation/spleen_segmentation_3d_lightning.ipynb index 107d1c6326..225e1929c6 100755 --- a/3d_segmentation/spleen_segmentation_3d_lightning.ipynb +++ b/3d_segmentation/spleen_segmentation_3d_lightning.ipynb @@ -60,7 +60,7 @@ "source": [ "!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel]\"\n", "!python -c \"import matplotlib\" || pip install -q matplotlib\n", - "!pip install -q pytorch-lightning==0.9.0\n", + "!pip install -q pytorch-lightning==1.4.0\n", "%matplotlib inline" ] }, @@ -145,8 +145,6 @@ "from monai.apps import download_and_extract\n", "import torch\n", "import pytorch_lightning\n", - "from pytorch_lightning.callbacks.model_checkpoint \\\n", - " import ModelCheckpoint\n", "import matplotlib.pyplot as plt\n", "import tempfile\n", "import shutil\n", @@ -424,17 +422,13 @@ "tb_logger = pytorch_lightning.loggers.TensorBoardLogger(\n", " save_dir=log_dir\n", ")\n", - "checkpoint_callback = ModelCheckpoint(\n", - " filepath=os.path.join(\n", - " log_dir, \"{epoch}-{val_loss:.2f}-{val_dice:.2f}\")\n", - ")\n", "\n", "# initialise Lightning's trainer.\n", "trainer = pytorch_lightning.Trainer(\n", " gpus=[0],\n", " max_epochs=600,\n", " logger=tb_logger,\n", - " checkpoint_callback=checkpoint_callback,\n", + " checkpoint_callback=True,\n", " num_sanity_val_steps=1,\n", ")\n", "\n", @@ -710,7 +704,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.7.10" } }, "nbformat": 4,