Skip to content

Commit e85580a

Browse files
authored
Fix 'torch.device' object has no attribute 'gpu_id' issue in trt export (#8019)
Part of #8017 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <[email protected]>
1 parent 4877767 commit e85580a

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

monai/networks/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ def _onnx_trt_compile(
822822
output_names = [] if not output_names else output_names
823823

824824
# set up the TensorRT builder
825-
torch_tensorrt.set_device(device)
825+
torch.cuda.set_device(device)
826826
logger = trt.Logger(trt.Logger.WARNING)
827827
builder = trt.Builder(logger)
828828
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
@@ -931,7 +931,7 @@ def convert_to_trt(
931931
warnings.warn(f"The dynamic batch range sequence should have 3 elements, but got {dynamic_batchsize} elements.")
932932

933933
device = device if device else 0
934-
target_device = torch.device(f"cuda:{device}") if device else torch.device("cuda:0")
934+
target_device = torch.device(f"cuda:{device}")
935935
convert_precision = torch.float32 if precision == "fp32" else torch.half
936936
inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]
937937

@@ -986,7 +986,7 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int):
986986
ir_model,
987987
inputs=input_placeholder,
988988
enabled_precisions=convert_precision,
989-
device=target_device,
989+
device=torch_tensorrt.Device(f"cuda:{device}"),
990990
ir="torchscript",
991991
**kwargs,
992992
)

monai/utils/module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ def load_submodules(
214214
loader = mod_spec.loader
215215
loader.exec_module(mod)
216216
submodules.append(mod)
217-
218217
except OptionalImportError:
219218
pass # could not import the optional deps., they are ignored
220219
except ImportError as e:

0 commit comments

Comments
 (0)