|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +import unittest |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +import torch |
| 18 | +from parameterized import parameterized |
| 19 | + |
| 20 | +from monai.losses.deform import DiffusionLoss |
| 21 | + |
| 22 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 23 | + |
| 24 | +TEST_CASES = [ |
| 25 | + # all first partials are zero, so the diffusion loss is also zero |
| 26 | + [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], |
| 27 | + # all first partials are one, so the diffusion loss is also one |
| 28 | + [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0], |
| 29 | + # before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67 |
| 30 | + [ |
| 31 | + {"normalize": False}, |
| 32 | + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, |
| 33 | + 56.0 / 3.0, |
| 34 | + ], |
| 35 | + # same as the previous case |
| 36 | + [ |
| 37 | + {"normalize": False}, |
| 38 | + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, |
| 39 | + 56.0 / 3.0, |
| 40 | + ], |
| 41 | + # same as the previous case |
| 42 | + [{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], |
| 43 | + # we have shown in the demo notebook that |
| 44 | + # diffusion loss is scale-invariant when the all axes have the same resolution |
| 45 | + [ |
| 46 | + {"normalize": True}, |
| 47 | + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, |
| 48 | + 56.0 / 3.0, |
| 49 | + ], |
| 50 | + [ |
| 51 | + {"normalize": True}, |
| 52 | + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, |
| 53 | + 56.0 / 3.0, |
| 54 | + ], |
| 55 | + [{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], |
| 56 | + # for the following case, consider the following 2D matrix: |
| 57 | + # tensor([[[[0, 1, 2], |
| 58 | + # [1, 2, 3], |
| 59 | + # [2, 3, 4], |
| 60 | + # [3, 4, 5], |
| 61 | + # [4, 5, 6]], |
| 62 | + # [[0, 1, 2], |
| 63 | + # [1, 2, 3], |
| 64 | + # [2, 3, 4], |
| 65 | + # [3, 4, 5], |
| 66 | + # [4, 5, 6]]]]) |
| 67 | + # the first partials wrt x are all ones, and so are the first partials wrt y |
| 68 | + # the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2 |
| 69 | + [{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0], |
| 70 | + # consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook, |
| 71 | + # the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y |
| 72 | + # the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689 |
| 73 | + [ |
| 74 | + {"normalize": True}, |
| 75 | + {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, |
| 76 | + (1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0, |
| 77 | + ], |
| 78 | +] |
| 79 | + |
| 80 | + |
| 81 | +class TestDiffusionLoss(unittest.TestCase): |
| 82 | + @parameterized.expand(TEST_CASES) |
| 83 | + def test_shape(self, input_param, input_data, expected_val): |
| 84 | + result = DiffusionLoss(**input_param).forward(**input_data) |
| 85 | + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) |
| 86 | + |
| 87 | + def test_ill_shape(self): |
| 88 | + loss = DiffusionLoss() |
| 89 | + # not in 3-d, 4-d, 5-d |
| 90 | + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): |
| 91 | + loss.forward(torch.ones((1, 3), device=device)) |
| 92 | + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): |
| 93 | + loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) |
| 94 | + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): |
| 95 | + loss.forward(torch.ones((1, 3, 2, 5, 5), device=device)) |
| 96 | + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): |
| 97 | + loss.forward(torch.ones((1, 3, 5, 2, 5))) |
| 98 | + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): |
| 99 | + loss.forward(torch.ones((1, 3, 5, 5, 2))) |
| 100 | + |
| 101 | + # number of vector components unequal to number of spatial dims |
| 102 | + with self.assertRaisesRegex(ValueError, "Number of vector components"): |
| 103 | + loss.forward(torch.ones((1, 2, 5, 5, 5))) |
| 104 | + with self.assertRaisesRegex(ValueError, "Number of vector components"): |
| 105 | + loss.forward(torch.ones((1, 2, 5, 5, 5))) |
| 106 | + |
| 107 | + def test_ill_opts(self): |
| 108 | + pred = torch.rand(1, 3, 5, 5, 5).to(device=device) |
| 109 | + with self.assertRaisesRegex(ValueError, ""): |
| 110 | + DiffusionLoss(reduction="unknown")(pred) |
| 111 | + with self.assertRaisesRegex(ValueError, ""): |
| 112 | + DiffusionLoss(reduction=None)(pred) |
| 113 | + |
| 114 | + |
| 115 | +if __name__ == "__main__": |
| 116 | + unittest.main() |
0 commit comments