-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Initial commit -- Adding calibration loss specific to segmentation #7819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 10 commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
8fbec82
Initial commit -- Adding calibration loss specific to segmentation
Bala93 23b897b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b2ec62b
Update __init__.py
Bala93 93ee114
Update segcalib.py
Bala93 42e732b
Update segcalib.py
Bala93 187053d
Update segcalib.py
Bala93 1d27ec5
Update segcalib.py
Bala93 d499134
Update segcalib.py
Bala93 1e3f755
Update segcalib.py
Bala93 9dedfba
Update segcalib.py
Bala93 59959ce
Update monai/losses/segcalib.py
Bala93 cf1d044
Update monai/losses/segcalib.py
Bala93 0926851
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5317706
Update segcalib.py
Bala93 3155433
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7c121a0
Add specific to gaussian for both 2d and 3d
Bala93 24efd85
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0067953
Merge branch 'Project-MONAI:dev' into model-calibration
Bala93 dccde47
Add mean loss and resolve formatting
Bala93 44e8065
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 57686d7
Merge branch 'dev' into model-calibration
Bala93 5cd9a33
Update segcalib.py
Bala93 b547c4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 42a0215
Update segcalib.py
Bala93 7e36ca1
Update segcalib.py
Bala93 6dbd53d
Update segcalib.py
Bala93 354056c
Update segcalib.py
Bala93 7eb911f
Update segcalib.py
Bala93 0b1209b
Update segcalib.py
Bala93 035c92e
Update segcalib.py
Bala93 c1de5f1
Rename segcalib.py to nacl_loss.py
Bala93 91dd1b9
Update __init__.py
Bala93 9702c02
Update test_nacl_loss.py
Bala93 4462379
Update nacl_loss.py
Bala93 c4f8283
Update test_nacl_loss.py
Bala93 bc6b995
Update test_nacl_loss.py
Bala93 51e15fe
Added missing parameters in doc
Bala93 3a00aec
Formatting check with monai
Bala93 818b42b
Update test_nacl_loss.py
Bala93 6647708
Added mypy fixes
Bala93 7e579dd
DCO Remediation Commit for bala93 <[email protected]>
Bala93 4f8abf1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b72e478
Update docs/source/losses.rst
Bala93 747681d
* Include test cases covering more cases
Bala93 3b15554
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 877139c
Update monai/losses/nacl_loss.py
Bala93 4679456
Update monai/losses/nacl_loss.py
Bala93 7c5217e
* Add docstring with better explanations
Bala93 d33f435
* Maintain the dimension consistency.
Bala93 7deb2cc
Update nacl_loss.py
Bala93 91ce50b
Update nacl_loss.py
Bala93 7f87e0c
Merge branch 'model-calibration' of https://github.com/Bala93/MONAI i…
Bala93 0e880a8
Modify docstring
Bala93 db9daeb
Merge branch 'dev' into model-calibration
KumoLiu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Bala93 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import math | ||
import warnings | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn.modules.loss import _Loss | ||
|
||
from monai.utils import pytorch_after | ||
|
||
|
||
def get_gaussian_kernel_2d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: | ||
x_grid = torch.arange(ksize).repeat(ksize).view(ksize, ksize) | ||
y_grid = x_grid.t() | ||
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() | ||
mean = (ksize - 1) / 2.0 | ||
variance = sigma**2.0 | ||
gaussian_kernel = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp( | ||
-torch.sum((xy_grid - mean) ** 2.0, dim=-1) / (2 * variance + 1e-16) | ||
) | ||
return gaussian_kernel / torch.sum(gaussian_kernel) | ||
|
||
|
||
class GaussianFilter(torch.nn.Module): | ||
Bala93 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, ksize: int = 3, sigma: float = 1.0, channels: int = 0) -> torch.Tensor: | ||
super(GaussianFilter, self).__init__() | ||
gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma) | ||
neighbors_sum = (1 - gkernel[1, 1]) + 1e-16 | ||
gkernel[int(ksize / 2), int(ksize / 2)] = neighbors_sum | ||
self.svls_kernel = gkernel / neighbors_sum | ||
svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize) | ||
svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1) | ||
padding = int(ksize / 2) | ||
self.svls_layer = torch.nn.Conv2d( | ||
in_channels=channels, | ||
out_channels=channels, | ||
kernel_size=ksize, | ||
groups=channels, | ||
bias=False, | ||
padding=padding, | ||
padding_mode="replicate", | ||
) | ||
self.svls_layer.weight.data = svls_kernel_2d | ||
self.svls_layer.weight.requires_grad = False | ||
|
||
def forward(self, x): | ||
return self.svls_layer(x) / self.svls_kernel.sum() | ||
|
||
|
||
class NACLLoss(_Loss): | ||
""" | ||
Murugesan, Balamurali, et al. | ||
Bala93 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Trust your neighbours: Penalty-based constraints for model calibration." | ||
International Conference on Medical Image Computing and Computer-Assisted Intervention, 2023. | ||
https://arxiv.org/abs/2303.06268 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
classes, | ||
kernel_size: int = 3, | ||
kernel_ops: str = "mean", | ||
distance_type: str = "l1", | ||
alpha: float = 0.1, | ||
sigma: float = 1.0, | ||
) -> torch.Tensor: | ||
""" | ||
Args: | ||
classes: number of classes | ||
kernel_size: size of the spatial kernel | ||
kenel_ops: type of kernel operation (mean/gaussian) | ||
distance_type: l1/l2 distance between spatial kernel and predicted logits | ||
alpha: weightage between cross entropy and logit constraint | ||
sigma: sigma if the kernel type is gaussian | ||
""" | ||
|
||
super().__init__() | ||
|
||
if kernel_ops not in ["mean", "gaussian"]: | ||
Bala93 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError("Kernel ops must be either mean or gaussian") | ||
|
||
if distance_type not in ["l1", "l2"]: | ||
raise ValueError("Distance type must be either L1 or L2") | ||
|
||
self.kernel_ops = kernel_ops | ||
self.distance_type = distance_type | ||
self.alpha = alpha | ||
|
||
self.nc = classes | ||
self.ks = kernel_size | ||
self.cross_entropy = nn.CrossEntropyLoss() | ||
|
||
if kernel_ops == "gaussian": | ||
self.svls_layer = GaussianFilter(ksize=kernel_size, sigma=sigma, channels=classes) | ||
|
||
self.old_pt_ver = not pytorch_after(1, 10) | ||
|
||
def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Compute CrossEntropy loss for the input logits and target. | ||
Will remove the channel dim according to PyTorch CrossEntropyLoss: | ||
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss. | ||
|
||
""" | ||
n_pred_ch, n_target_ch = input.shape[1], target.shape[1] | ||
if n_pred_ch != n_target_ch and n_target_ch == 1: | ||
target = torch.squeeze(target, dim=1) | ||
target = target.long() | ||
elif self.old_pt_ver: | ||
warnings.warn( | ||
f"Multichannel targets are not supported in this older Pytorch version {torch.__version__}. " | ||
"Using argmax (as a workaround) to convert target to a single channel." | ||
) | ||
target = torch.argmax(target, dim=1) | ||
elif not torch.is_floating_point(target): | ||
target = target.to(dtype=input.dtype) | ||
|
||
return self.cross_entropy(input, target) # type: ignore[no-any-return] | ||
|
||
def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: | ||
mask = mask.unsqueeze(1) # unfold works for 4d. | ||
|
||
bs, _, h, w = mask.shape | ||
unfold = torch.nn.Unfold(kernel_size=(self.ks, self.ks), padding=self.ks // 2) | ||
|
||
rmask = [] | ||
|
||
if self.kernel_ops == "mean": | ||
umask = unfold(mask.float()) | ||
|
||
for ii in range(self.nc): | ||
rmask.append(torch.sum(umask == ii, 1) / self.ks**2) | ||
|
||
if self.kernel_ops == "gaussian": | ||
oh_labels = ( | ||
F.one_hot(mask[:, 0].to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() | ||
) | ||
rmask = self.svls_layer(oh_labels) | ||
|
||
return rmask | ||
|
||
rmask = torch.stack(rmask, dim=1) | ||
rmask = rmask.reshape(bs, self.nc, h, w) | ||
|
||
return rmask | ||
|
||
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | ||
loss_ce = self.ce(inputs, targets) | ||
|
||
utargets = self.get_constr_target(targets) | ||
|
||
if self.distance_type == "l1": | ||
loss_conf = torch.abs(utargets - inputs).mean() | ||
|
||
if self.distance_type == "l2": | ||
loss_conf = (torch.abs(utargets - inputs) ** 2).mean() | ||
Bala93 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
loss = loss_ce + self.alpha * loss_conf | ||
|
||
return loss # , loss_ce, loss_conf | ||
Bala93 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
import torch | ||
from parameterized import parameterized | ||
|
||
from monai.losses import NACLLoss | ||
|
||
TEST_CASES = [ | ||
[ # shape: (2, 2, 3), (2, 2, 3) | ||
{"classes": 2}, | ||
{ | ||
"inputs": torch.tensor( | ||
[ | ||
[[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]], | ||
[[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]], | ||
] | ||
), | ||
"targets": torch.tensor( | ||
[ | ||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], | ||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], | ||
] | ||
), | ||
}, | ||
3.3611, # the result equals to -1 + np.log(1 + np.exp(1)) | ||
], | ||
[ # shape: (2, 2, 3), (2, 2, 3) | ||
{"classes": 2, "kernel_ops": "gaussian"}, | ||
{ | ||
"inputs": torch.tensor( | ||
[ | ||
[[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]], | ||
[[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]], | ||
] | ||
), | ||
"targets": torch.tensor( | ||
[ | ||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], | ||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], | ||
] | ||
), | ||
Bala93 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}, | ||
3.3963, # the result equals to -1 + np.log(1 + np.exp(1)) | ||
], | ||
[ # shape: (2, 2, 3), (2, 2, 3) | ||
{"classes": 2, "distance_type": "l2"}, | ||
{ | ||
"inputs": torch.tensor( | ||
[ | ||
[[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]], | ||
[[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]], | ||
] | ||
), | ||
"targets": torch.tensor( | ||
[ | ||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], | ||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], | ||
] | ||
), | ||
}, | ||
3.3459, # the result equals to -1 + np.log(1 + np.exp(1)) | ||
], | ||
[ # shape: (2, 2, 3), (2, 2, 3) | ||
{"classes": 2, "alpha": 0.2}, | ||
{ | ||
"inputs": torch.tensor( | ||
[ | ||
[[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]], | ||
[[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]], | ||
] | ||
), | ||
"targets": torch.tensor( | ||
[ | ||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], | ||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], | ||
] | ||
), | ||
}, | ||
3.3836, # the result equals to -1 + np.log(1 + np.exp(1)) | ||
], | ||
] | ||
|
||
|
||
class TestNACLLoss(unittest.TestCase): | ||
|
||
@parameterized.expand(TEST_CASES) | ||
def test_result(self, input_param, input_data, expected_val): | ||
loss = NACLLoss(**input_param) | ||
result = loss(**input_data) | ||
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.