Skip to content

Commit d5585c3

Browse files
kvtttKumoLiu
andauthored
7263 add diffusion loss (#7272)
Fixes #7263. ### Description Add diffusion loss. I also made a [demo notebook](https://github.com/kvttt/deep-atlas/blob/main/diffusion_loss_scale_test.ipynb) to provide some explanations and analyses of diffusion loss. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: kaibo <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent 88f8dd2 commit d5585c3

File tree

4 files changed

+204
-1
lines changed

4 files changed

+204
-1
lines changed

docs/source/losses.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ Registration Losses
9696
.. autoclass:: BendingEnergyLoss
9797
:members:
9898

99+
`DiffusionLoss`
100+
~~~~~~~~~~~~~~~
101+
.. autoclass:: DiffusionLoss
102+
:members:
103+
99104
`LocalNormalizedCrossCorrelationLoss`
100105
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
101106
.. autoclass:: LocalNormalizedCrossCorrelationLoss

monai/losses/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .adversarial_loss import PatchAdversarialLoss
1515
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
1616
from .contrastive import ContrastiveLoss
17-
from .deform import BendingEnergyLoss
17+
from .deform import BendingEnergyLoss, DiffusionLoss
1818
from .dice import (
1919
Dice,
2020
DiceCELoss,

monai/losses/deform.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,85 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor:
116116
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
117117

118118
return energy
119+
120+
121+
class DiffusionLoss(_Loss):
122+
"""
123+
Calculate the diffusion based on first-order differentiation of pred using central finite difference.
124+
For the original paper, please refer to
125+
VoxelMorph: A Learning Framework for Deformable Medical Image Registration,
126+
Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca
127+
IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231.
128+
129+
Adapted from:
130+
VoxelMorph (https://github.com/voxelmorph/voxelmorph)
131+
"""
132+
133+
def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None:
134+
"""
135+
Args:
136+
normalize:
137+
Whether to divide out spatial sizes in order to make the computation roughly
138+
invariant to image scale (i.e. vector field sampling resolution). Defaults to False.
139+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
140+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
141+
142+
- ``"none"``: no reduction will be applied.
143+
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
144+
- ``"sum"``: the output will be summed.
145+
"""
146+
super().__init__(reduction=LossReduction(reduction).value)
147+
self.normalize = normalize
148+
149+
def forward(self, pred: torch.Tensor) -> torch.Tensor:
150+
"""
151+
Args:
152+
pred:
153+
Predicted dense displacement field (DDF) with shape BCH[WD],
154+
where C is the number of spatial dimensions.
155+
Note that diffusion loss can only be calculated
156+
when the sizes of the DDF along all spatial dimensions are greater than 2.
157+
158+
Raises:
159+
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
160+
ValueError: When ``pred`` is not 3-d, 4-d or 5-d.
161+
ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2.
162+
ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions.
163+
164+
"""
165+
if pred.ndim not in [3, 4, 5]:
166+
raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}")
167+
for i in range(pred.ndim - 2):
168+
if pred.shape[-i - 1] <= 2:
169+
raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}")
170+
if pred.shape[1] != pred.ndim - 2:
171+
raise ValueError(
172+
f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, "
173+
f"does not match number of spatial dimensions, {pred.ndim - 2}"
174+
)
175+
176+
# first order gradient
177+
first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)]
178+
179+
# spatial dimensions in a shape suited for broadcasting below
180+
if self.normalize:
181+
spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,))
182+
183+
diffusion = torch.tensor(0)
184+
for dim_1, g in enumerate(first_order_gradient):
185+
dim_1 += 2
186+
if self.normalize:
187+
# We divide the partial derivative for each vector component at each voxel by the spatial size
188+
# corresponding to that component relative to the spatial size of the vector component with respect
189+
# to which the partial derivative is taken.
190+
g *= pred.shape[dim_1] / spatial_dims
191+
diffusion = diffusion + g**2
192+
193+
if self.reduction == LossReduction.MEAN.value:
194+
diffusion = torch.mean(diffusion) # the batch and channel average
195+
elif self.reduction == LossReduction.SUM.value:
196+
diffusion = torch.sum(diffusion) # sum over the batch and channel dims
197+
elif self.reduction != LossReduction.NONE.value:
198+
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
199+
200+
return diffusion

tests/test_diffusion_loss.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import numpy as np
17+
import torch
18+
from parameterized import parameterized
19+
20+
from monai.losses.deform import DiffusionLoss
21+
22+
device = "cuda" if torch.cuda.is_available() else "cpu"
23+
24+
TEST_CASES = [
25+
# all first partials are zero, so the diffusion loss is also zero
26+
[{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0],
27+
# all first partials are one, so the diffusion loss is also one
28+
[{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0],
29+
# before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67
30+
[
31+
{"normalize": False},
32+
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
33+
56.0 / 3.0,
34+
],
35+
# same as the previous case
36+
[
37+
{"normalize": False},
38+
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
39+
56.0 / 3.0,
40+
],
41+
# same as the previous case
42+
[{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],
43+
# we have shown in the demo notebook that
44+
# diffusion loss is scale-invariant when the all axes have the same resolution
45+
[
46+
{"normalize": True},
47+
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
48+
56.0 / 3.0,
49+
],
50+
[
51+
{"normalize": True},
52+
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
53+
56.0 / 3.0,
54+
],
55+
[{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],
56+
# for the following case, consider the following 2D matrix:
57+
# tensor([[[[0, 1, 2],
58+
# [1, 2, 3],
59+
# [2, 3, 4],
60+
# [3, 4, 5],
61+
# [4, 5, 6]],
62+
# [[0, 1, 2],
63+
# [1, 2, 3],
64+
# [2, 3, 4],
65+
# [3, 4, 5],
66+
# [4, 5, 6]]]])
67+
# the first partials wrt x are all ones, and so are the first partials wrt y
68+
# the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2
69+
[{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0],
70+
# consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook,
71+
# the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y
72+
# the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689
73+
[
74+
{"normalize": True},
75+
{"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)},
76+
(1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0,
77+
],
78+
]
79+
80+
81+
class TestDiffusionLoss(unittest.TestCase):
82+
@parameterized.expand(TEST_CASES)
83+
def test_shape(self, input_param, input_data, expected_val):
84+
result = DiffusionLoss(**input_param).forward(**input_data)
85+
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)
86+
87+
def test_ill_shape(self):
88+
loss = DiffusionLoss()
89+
# not in 3-d, 4-d, 5-d
90+
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
91+
loss.forward(torch.ones((1, 3), device=device))
92+
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
93+
loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device))
94+
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
95+
loss.forward(torch.ones((1, 3, 2, 5, 5), device=device))
96+
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
97+
loss.forward(torch.ones((1, 3, 5, 2, 5)))
98+
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
99+
loss.forward(torch.ones((1, 3, 5, 5, 2)))
100+
101+
# number of vector components unequal to number of spatial dims
102+
with self.assertRaisesRegex(ValueError, "Number of vector components"):
103+
loss.forward(torch.ones((1, 2, 5, 5, 5)))
104+
with self.assertRaisesRegex(ValueError, "Number of vector components"):
105+
loss.forward(torch.ones((1, 2, 5, 5, 5)))
106+
107+
def test_ill_opts(self):
108+
pred = torch.rand(1, 3, 5, 5, 5).to(device=device)
109+
with self.assertRaisesRegex(ValueError, ""):
110+
DiffusionLoss(reduction="unknown")(pred)
111+
with self.assertRaisesRegex(ValueError, ""):
112+
DiffusionLoss(reduction=None)(pred)
113+
114+
115+
if __name__ == "__main__":
116+
unittest.main()

0 commit comments

Comments
 (0)