Skip to content

Commit f9e37fe

Browse files
authored
Add spatial size as iput for infer class - UNETR configuration (#707)
* Add spatial size as iput for infer class - UNETR configuration Signed-off-by: Andres <[email protected]> * Change default network to train from main Signed-off-by: Andres <[email protected]>
1 parent 22912ff commit f9e37fe

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

sample-apps/radiology/lib/configs/deepedit.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
6363
download_file(url, self.path[0])
6464

6565
# Network
66-
spatial_size = json.loads(self.conf.get("spatial_size", "[128, 128, 128]"))
66+
self.spatial_size = json.loads(self.conf.get("spatial_size", "[128, 128, 128]"))
6767
if network == "unetr":
6868
self.network = UNETR(
6969
spatial_dims=3,
7070
in_channels=len(self.labels) + 1,
7171
out_channels=len(self.labels),
72-
img_size=spatial_size,
72+
img_size=self.spatial_size,
7373
feature_size=64,
7474
hidden_size=1536,
7575
mlp_dim=3072,
@@ -93,9 +93,15 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
9393

9494
def infer(self) -> Union[InferTask, Dict[str, InferTask]]:
9595
return {
96-
self.name: lib.infers.DeepEdit(path=self.path, network=self.network, labels=self.labels),
96+
self.name: lib.infers.DeepEdit(
97+
path=self.path, network=self.network, labels=self.labels, spatial_size=self.spatial_size
98+
),
9799
f"{self.name}_seg": lib.infers.DeepEdit(
98-
path=self.path, network=self.network, labels=self.labels, type=InferType.SEGMENTATION
100+
path=self.path,
101+
network=self.network,
102+
labels=self.labels,
103+
spatial_size=self.spatial_size,
104+
type=InferType.SEGMENTATION,
99105
),
100106
}
101107

sample-apps/radiology/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ def main():
219219
conf = {
220220
"models": "deepedit",
221221
"use_pretrained_model": "false",
222+
# "network": "unetr",
223+
# "spatial_size": "[128,128,128]",
222224
}
223225

224226
app = MyApp(app_dir, studies, conf)

0 commit comments

Comments
 (0)