Skip to content

Commit 86a0de8

Browse files
CyrilvallezArthurZucker
authored andcommitted
Protect get_default_device for torch<2.3 (#38376)
* Update modeling_utils.py * CIs
1 parent f5d15e6 commit 86a0de8

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,8 @@ def get_torch_context_manager_or_global_device():
319319
is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
320320
"""
321321
device_in_context = torch.tensor([]).device
322-
default_device = torch.get_default_device()
322+
# `get_default_device` was only introduced in torch>=2.3 - use cpu otherwise to align the behavior
323+
default_device = torch.get_default_device() if is_torch_greater_or_equal("2.3") else torch.device("cpu")
323324
# This case means no context manager was used -> we still check if the default that was potentially set is not cpu
324325
if device_in_context == default_device:
325326
if default_device != torch.device("cpu"):

0 commit comments

Comments
 (0)