|
55 | 55 | PromptEncoderConfig,
|
56 | 56 | TaskType,
|
57 | 57 | get_peft_model,
|
| 58 | + get_peft_model_state_dict, |
| 59 | + inject_adapter_in_model, |
58 | 60 | prepare_model_for_kbit_training,
|
59 | 61 | replace_lora_weights_loftq,
|
| 62 | + set_peft_model_state_dict, |
60 | 63 | )
|
61 | 64 | from peft.tuners import boft
|
62 | 65 | from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
|
@@ -3226,3 +3229,51 @@ def test_p_tuning_exactly_reproducible_after_loading(self, tmp_path):
|
3226 | 3229 |
|
3227 | 3230 | torch.testing.assert_close(output_loaded, output_peft)
|
3228 | 3231 | torch.testing.assert_close(gen_loaded, gen_peft)
|
| 3232 | + |
| 3233 | + |
| 3234 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") |
| 3235 | +@pytest.mark.single_gpu_tests |
| 3236 | +class TestLowCpuMemUsageDifferentDevices: |
| 3237 | + """Test for the low CPU memory usage option for loading PEFT models. |
| 3238 | +
|
| 3239 | + There are already tests for this in test_initialization.py but here we want to specifically test diverging devices |
| 3240 | + for the model and state_dict. |
| 3241 | +
|
| 3242 | + """ |
| 3243 | + |
| 3244 | + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" |
| 3245 | + |
| 3246 | + @pytest.mark.parametrize("device_model, device_sd", [("cpu", "cuda"), ("cuda", "cpu")]) |
| 3247 | + def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, device_model, device_sd): |
| 3248 | + inputs = {"input_ids": torch.randint(0, 100, (1, 10)), "attention_mask": torch.ones(1, 10)} |
| 3249 | + inputs = {k: v.to(device_model) for k, v in inputs.items()} |
| 3250 | + |
| 3251 | + model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model) |
| 3252 | + lora_config = LoraConfig(init_lora_weights=False, target_modules="all-linear") |
| 3253 | + model = get_peft_model(model, lora_config) |
| 3254 | + model.eval() |
| 3255 | + logits_not_low_cpu_mem = model(**inputs).logits |
| 3256 | + |
| 3257 | + state_dict = get_peft_model_state_dict(model) |
| 3258 | + peft_model_state_dict = {} |
| 3259 | + # remap the state dict so that it can be correctly loaded, and move weights to the other device |
| 3260 | + prefix = "base_model.model." |
| 3261 | + for k, v in state_dict.items(): |
| 3262 | + k = k[len(prefix) :] |
| 3263 | + peft_model_state_dict[k] = v.to(device_sd) |
| 3264 | + |
| 3265 | + del model |
| 3266 | + |
| 3267 | + model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model) |
| 3268 | + model.eval() |
| 3269 | + inject_adapter_in_model(lora_config, model, low_cpu_mem_usage=True) |
| 3270 | + load_result = set_peft_model_state_dict(model, peft_model_state_dict, low_cpu_mem_usage=True) |
| 3271 | + |
| 3272 | + # sanity check: all lora keys are matched |
| 3273 | + assert not any("lora" in k for k in load_result.missing_keys) |
| 3274 | + assert not any("lora" in k for k in load_result.unexpected_keys) |
| 3275 | + |
| 3276 | + logits_low_cpu_mem = model(**inputs).logits |
| 3277 | + |
| 3278 | + assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem) |
| 3279 | + assert {p.device.type for p in model.parameters()} == {device_model} |
0 commit comments