Skip to content

Add support for optional conditioning in PatchInferer, SliceInferer, and SlidingWindowInferer #8400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 88 additions & 12 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,36 @@ def __call__(
supports callables such as ``lambda x: my_torch_model(x, additional_config)``
args: optional args to be passed to ``network``.
kwargs: optional keyword args to be passed to ``network``.
condition (torch.Tensor, optional): If provided via `**kwargs`,
this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
The resulting segments will be passed to the model together with the corresponding input segments.

"""
# check if there is a conditioning signal
condition = kwargs.pop("condition", None)
# shape check for condition
if condition is not None:
if isinstance(inputs, torch.Tensor) and isinstance(condition, torch.Tensor):
if condition.shape != inputs.shape:
raise ValueError(
f"`condition` must match shape of `inputs` ({inputs.shape}), but got {condition.shape}"
)
elif isinstance(inputs, list) and isinstance(condition, list):
if len(inputs) != len(condition):
raise ValueError(
f"Length of `condition` must match `inputs`. Got {len(inputs)} and {len(condition)}."
)
for (in_patch, _), (cond_patch, _) in zip(inputs, condition):
if cond_patch.shape != in_patch.shape:
raise ValueError(
"Each `condition` patch must match the shape of the corresponding input patch. "
f"Got {cond_patch.shape} and {in_patch.shape}."
)
else:
raise ValueError(
"`condition` and `inputs` must be of the same type (both Tensor or both list of patches)."
)

patches_locations: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor
if self.splitter is None:
# handle situations where the splitter is not provided
Expand All @@ -344,20 +372,39 @@ def __call__(
f"The provided inputs type is {type(inputs)}."
)
patches_locations = inputs
if condition is not None:
condition_locations = condition
else:
# apply splitter
patches_locations = self.splitter(inputs)
if condition is not None:
# apply splitter to condition
condition_locations = self.splitter(condition)

ratios: list[float] = []
mergers: list[Merger] = []
for patches, locations, batch_size in self._batch_sampler(patches_locations):
# run inference
outputs = self._run_inference(network, patches, *args, **kwargs)
# initialize the mergers
if not mergers:
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
# aggregate outputs
self._aggregate(outputs, locations, batch_size, mergers, ratios)
if condition is not None:
for (patches, locations, batch_size), (condition_patches, _, _) in zip(
self._batch_sampler(patches_locations), self._batch_sampler(condition_locations)
):
# add patched condition to kwargs
kwargs["condition"] = condition_patches
# run inference
outputs = self._run_inference(network, patches, *args, **kwargs)
# initialize the mergers
if not mergers:
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
# aggregate outputs
self._aggregate(outputs, locations, batch_size, mergers, ratios)
else:
for patches, locations, batch_size in self._batch_sampler(patches_locations):
# run inference
outputs = self._run_inference(network, patches, *args, **kwargs)
# initialize the mergers
if not mergers:
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
# aggregate outputs
self._aggregate(outputs, locations, batch_size, mergers, ratios)

# finalize the mergers and get the results
merged_outputs = [merger.finalize() for merger in mergers]
Expand Down Expand Up @@ -519,8 +566,14 @@ def __call__(
supports callables such as ``lambda x: my_torch_model(x, additional_config)``
args: optional args to be passed to ``network``.
kwargs: optional keyword args to be passed to ``network``.

condition (torch.Tensor, optional): If provided via `**kwargs`,
this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
The resulting segments will be passed to the model together with the corresponding input segments.
"""
# shape check for condition
condition = kwargs.get("condition", None)
if condition is not None and condition.shape != inputs.shape:
raise ValueError(f"`condition` must match shape of `inputs` ({inputs.shape}), but got {condition.shape}")

device = kwargs.pop("device", self.device)
buffer_steps = kwargs.pop("buffer_steps", self.buffer_steps)
Expand Down Expand Up @@ -728,7 +781,9 @@ def __call__(
network: 2D model to execute inference on slices in the 3D input
args: optional args to be passed to ``network``.
kwargs: optional keyword args to be passed to ``network``.
"""
condition (torch.Tensor, optional): If provided via `**kwargs`,
this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
The resulting segments will be passed to the model together with the corresponding input segments."""
if self.spatial_dim > 2:
raise ValueError("`spatial_dim` can only be `0, 1, 2` with `[H, W, D]` respectively.")

Expand All @@ -742,12 +797,28 @@ def __call__(
f"Currently, only 2D `roi_size` ({self.orig_roi_size}) with 3D `inputs` tensor (shape={inputs.shape}) is supported."
)

return super().__call__(inputs=inputs, network=lambda x: self.network_wrapper(network, x, *args, **kwargs))
# shape check for condition
condition = kwargs.get("condition", None)
if condition is not None and condition.shape != inputs.shape:
raise ValueError(f"`condition` must match shape of `inputs` ({inputs.shape}), but got {condition.shape}")

# check if there is a conditioning signal
if condition is not None:
return super().__call__(
inputs=inputs,
network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs),
condition=condition,
)
else:
return super().__call__(
inputs=inputs, network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs)
)

def network_wrapper(
self,
network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],
x: torch.Tensor,
condition: torch.Tensor | None = None,
*args: Any,
**kwargs: Any,
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
Expand All @@ -756,7 +827,12 @@ def network_wrapper(
"""
# Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.
x = x.squeeze(dim=self.spatial_dim + 2)
out = network(x, *args, **kwargs)

if condition is not None:
condition = condition.squeeze(dim=self.spatial_dim + 2)
out = network(x, condition, *args, **kwargs)
else:
out = network(x, *args, **kwargs)

# Unsqueeze the network output so it is [N, C, D, H, W] as expected by
# the default SlidingWindowInferer class
Expand Down
16 changes: 13 additions & 3 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def sliding_window_inference(
device = device or inputs.device
sw_device = sw_device or inputs.device

condition = kwargs.pop("condition", None)

temp_meta = None
if isinstance(inputs, MetaTensor):
temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False)
Expand All @@ -168,6 +170,8 @@ def sliding_window_inference(
pad_size.extend([half, diff - half])
if any(pad_size):
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
if condition is not None:
condition = F.pad(condition, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)

# Store all slices
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
Expand Down Expand Up @@ -220,13 +224,19 @@ def sliding_window_inference(
]
if sw_batch_size > 1:
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
if condition is not None:
win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device)
kwargs["condition"] = win_condition
else:
win_data = inputs[unravel_slice[0]].to(sw_device)
if condition is not None:
win_condition = condition[unravel_slice[0]].to(sw_device)
kwargs["condition"] = win_condition

if with_coord:
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) # batched patch
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs)
else:
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch

seg_prob_out = predictor(win_data, *args, **kwargs)
# convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.
dict_keys, seg_tuple = _flatten_struct(seg_prob_out)
if process_fn:
Expand Down
Loading
Loading