Skip to content

Commit aaeebd6

Browse files
authored
2635 enhance UNet doc for the typical use case (#2659)
* [DLMED] enhance doc-string Signed-off-by: Nic Ma <[email protected]> * [DLMED] enhance the sanity check Signed-off-by: Nic Ma <[email protected]> * [DLMED] update according to comments Signed-off-by: Nic Ma <[email protected]>
1 parent 8d4f45f commit aaeebd6

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

monai/networks/nets/unet.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,21 @@ def __init__(
6464
Note: The acceptable spatial size of input data depends on the parameters of the network,
6565
to set appropriate spatial size, please check the tutorial for more details:
6666
https://github.com/Project-MONAI/tutorials/blob/master/modules/UNet_input_size_constrains.ipynb.
67-
Typically, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data.
67+
Typically, when using a stride of 2 in down / up sampling, the output dimensions are either half of the
68+
input when downsampling, or twice when upsampling. In this case with N numbers of layers in the network,
69+
the inputs must have spatial dimensions that are all multiples of 2^N.
70+
Usually, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data.
6871
6972
"""
7073
super().__init__()
7174

7275
if len(channels) < 2:
7376
raise ValueError("the length of `channels` should be no less than 2.")
74-
delta = len(strides) - len(channels)
75-
if delta < -1:
77+
delta = len(strides) - (len(channels) - 1)
78+
if delta < 0:
7679
raise ValueError("the length of `strides` should equal to `len(channels) - 1`.")
77-
if delta >= 0:
78-
warnings.warn(f"`len(strides) >= len(channels)`, the last {delta + 1} values of strides will not be used.")
80+
if delta > 0:
81+
warnings.warn(f"`len(strides) > len(channels) - 1`, the last {delta} values of strides will not be used.")
7982
if isinstance(kernel_size, Sequence):
8083
if len(kernel_size) != dimensions:
8184
raise ValueError("the length of `kernel_size` should equal to `dimensions`.")

0 commit comments

Comments
 (0)