Skip to content

Commit ca8462b

Browse files
FIX low_cpu_mem_usage consolidates devices (#2113)
See: huggingface/diffusers#9510 (comment) Right now, the low_cpu_mem_usage=True option does not consolidate the devices. E.g. when the model is on GPU and the state_dict on CPU, the adapter weight will be on CPU after loading, when it should be GPU. This fix ensures that the devices are consolidated.
1 parent ae297f0 commit ca8462b

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

src/peft/utils/save_and_load.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,10 @@ def renamed_dora_weights(k):
443443
)
444444
if low_cpu_mem_usage:
445445
load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True)
446+
# ensure that the correct device is set
447+
for module in model.modules():
448+
if hasattr(module, "_move_adapter_to_device_of_base_layer"):
449+
module._move_adapter_to_device_of_base_layer(adapter_name)
446450
else:
447451
load_result = model.load_state_dict(peft_model_state_dict, strict=False)
448452

tests/test_gpu_examples.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,11 @@
5555
PromptEncoderConfig,
5656
TaskType,
5757
get_peft_model,
58+
get_peft_model_state_dict,
59+
inject_adapter_in_model,
5860
prepare_model_for_kbit_training,
5961
replace_lora_weights_loftq,
62+
set_peft_model_state_dict,
6063
)
6164
from peft.tuners import boft
6265
from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
@@ -3226,3 +3229,51 @@ def test_p_tuning_exactly_reproducible_after_loading(self, tmp_path):
32263229

32273230
torch.testing.assert_close(output_loaded, output_peft)
32283231
torch.testing.assert_close(gen_loaded, gen_peft)
3232+
3233+
3234+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
3235+
@pytest.mark.single_gpu_tests
3236+
class TestLowCpuMemUsageDifferentDevices:
3237+
"""Test for the low CPU memory usage option for loading PEFT models.
3238+
3239+
There are already tests for this in test_initialization.py but here we want to specifically test diverging devices
3240+
for the model and state_dict.
3241+
3242+
"""
3243+
3244+
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
3245+
3246+
@pytest.mark.parametrize("device_model, device_sd", [("cpu", "cuda"), ("cuda", "cpu")])
3247+
def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, device_model, device_sd):
3248+
inputs = {"input_ids": torch.randint(0, 100, (1, 10)), "attention_mask": torch.ones(1, 10)}
3249+
inputs = {k: v.to(device_model) for k, v in inputs.items()}
3250+
3251+
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model)
3252+
lora_config = LoraConfig(init_lora_weights=False, target_modules="all-linear")
3253+
model = get_peft_model(model, lora_config)
3254+
model.eval()
3255+
logits_not_low_cpu_mem = model(**inputs).logits
3256+
3257+
state_dict = get_peft_model_state_dict(model)
3258+
peft_model_state_dict = {}
3259+
# remap the state dict so that it can be correctly loaded, and move weights to the other device
3260+
prefix = "base_model.model."
3261+
for k, v in state_dict.items():
3262+
k = k[len(prefix) :]
3263+
peft_model_state_dict[k] = v.to(device_sd)
3264+
3265+
del model
3266+
3267+
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model)
3268+
model.eval()
3269+
inject_adapter_in_model(lora_config, model, low_cpu_mem_usage=True)
3270+
load_result = set_peft_model_state_dict(model, peft_model_state_dict, low_cpu_mem_usage=True)
3271+
3272+
# sanity check: all lora keys are matched
3273+
assert not any("lora" in k for k in load_result.missing_keys)
3274+
assert not any("lora" in k for k in load_result.unexpected_keys)
3275+
3276+
logits_low_cpu_mem = model(**inputs).logits
3277+
3278+
assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem)
3279+
assert {p.device.type for p in model.parameters()} == {device_model}

0 commit comments

Comments
 (0)