Skip to content

Commit aea46ff

Browse files
borisfompre-commit-ci[bot]KumoLiuyiheng-wang-nvbinliunls
authored
Trt compiler fixes (#8064)
Fixes #8061. ### Description Post-merge fixes for trt_compile() ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Boris Fomitchev <[email protected]> Signed-off-by: Yiheng Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Yiheng Wang <[email protected]> Co-authored-by: Yiheng Wang <[email protected]> Co-authored-by: binliunls <[email protected]>
1 parent befb5f6 commit aea46ff

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

monai/networks/trt_compiler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def forward(self, model, argv, kwargs):
342342
self._build_and_save(model, build_args)
343343
# This will reassign input_names from the engine
344344
self._load_engine()
345+
assert self.engine is not None
345346
except Exception as e:
346347
if self.fallback:
347348
self.logger.info(f"Failed to build engine: {e}")
@@ -403,8 +404,10 @@ def _onnx_to_trt(self, onnx_path):
403404

404405
build_args = self.build_args.copy()
405406
build_args["tf32"] = self.precision != "fp32"
406-
build_args["fp16"] = self.precision == "fp16"
407-
build_args["bf16"] = self.precision == "bf16"
407+
if self.precision == "fp16":
408+
build_args["fp16"] = True
409+
elif self.precision == "bf16":
410+
build_args["bf16"] = True
408411

409412
self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}")
410413
network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM])
@@ -502,6 +505,7 @@ def trt_compile(
502505
) -> torch.nn.Module:
503506
"""
504507
Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook.
508+
Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x
505509
Args:
506510
model: module to patch with TrtCompiler object.
507511
base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path.

tests/test_trt_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from monai.handlers import TrtHandler
2121
from monai.networks import trt_compile
2222
from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132
23-
from monai.utils import optional_import
23+
from monai.utils import min_version, optional_import
2424
from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows
2525

26-
trt, trt_imported = optional_import("tensorrt")
26+
trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version)
2727
polygraphy, polygraphy_imported = optional_import("polygraphy")
2828
build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")
2929

0 commit comments

Comments
 (0)