File tree Expand file tree Collapse file tree 1 file changed +2
-1
lines changed Expand file tree Collapse file tree 1 file changed +2
-1
lines changed Original file line number Diff line number Diff line change @@ -324,7 +324,8 @@ def get_torch_context_manager_or_global_device():
324
324
is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
325
325
"""
326
326
device_in_context = torch .tensor ([]).device
327
- default_device = torch .get_default_device ()
327
+ # `get_default_device` was only introduced in torch>=2.3 - use cpu otherwise to align the behavior
328
+ default_device = torch .get_default_device () if is_torch_greater_or_equal ("2.3" ) else torch .device ("cpu" )
328
329
# This case means no context manager was used -> we still check if the default that was potentially set is not cpu
329
330
if device_in_context == default_device :
330
331
if default_device != torch .device ("cpu" ):
You can’t perform that action at this time.
0 commit comments