Skip to content

Commit 445d750

Browse files
KumoLiuericspod
andauthored
Give more useful exception when batch is considered during matrix multiplication (#7326)
Fixes #7323 ### Description Give more useful exception when batch is considered during matrix multiplication. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent 8fa6931 commit 445d750

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

monai/transforms/inverse.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,17 @@ def track_transform_meta(
185185
# not lazy evaluation, directly update the metatensor affine (don't push to the stack)
186186
orig_affine = data_t.peek_pending_affine()
187187
orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0]
188-
affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64)
188+
try:
189+
affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64)
190+
except RuntimeError as e:
191+
if orig_affine.ndim > 2:
192+
if data_t.is_batch:
193+
msg = "Transform applied to batched tensor, should be applied to instances only"
194+
else:
195+
msg = "Mismatch affine matrix, ensured that the batch dimension is not included in the calculation."
196+
raise RuntimeError(msg) from e
197+
else:
198+
raise
189199
out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"), dtype=torch.float64)
190200

191201
if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)):

0 commit comments

Comments
 (0)