Skip to content

[PEFT] Support low_cpu_mem_usage option for PEFT loading adapters #33725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import warnings
from typing import Any, Dict, List, Optional, Union

from packaging import version

from ..utils import (
check_peft_version,
find_adapter_config_file,
Expand Down Expand Up @@ -77,6 +80,7 @@ def load_adapter(
offload_index: Optional[int] = None,
peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
low_cpu_mem_usage: bool = False,
adapter_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Expand Down Expand Up @@ -129,12 +133,27 @@ def load_adapter(
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
dicts
low_cpu_mem_usage (`bool`, *optional*, defaults to `False`):
Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process.
Requires PEFT version 0.13.0 or higher.
adapter_kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
`find_adapter_config_file` method.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

# peft only supports low_cpu_mem_usage starting from v0.13.0
peft_load_kwargs = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, not sure we need all the kwargs as you are not taking them from the kwargs of this function 😉

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I get you, do you mean this should be merged with adapter_kwargs? That wouldn't work, as these are kwargs used for a different purpose.

Or that using a dict for a single argument is overkill?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or that using a dict for a single argument is overkill?
this 🤗

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this more readable and extensible than having:

if low_cpu_mem_usage:
    inject_adapter_in_model(peft_config, self, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
else:
    inject_adapter_in_model(peft_config, self, adapter_name)

(same with the set_peft_model_state_dict call)

If you want me to change it to this instead or have an alternative idea, let me know and I'll change it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries

if low_cpu_mem_usage:
min_version_lcmu = "0.13.0"
if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu):
peft_load_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
else:
raise ValueError(
"The version of PEFT you are using does not support `low_cpu_mem_usage` yet, "
f"please install PEFT >= {min_version_lcmu}."
)

adapter_name = adapter_name if adapter_name is not None else "default"
if adapter_kwargs is None:
adapter_kwargs = {}
Expand Down Expand Up @@ -192,7 +211,7 @@ def load_adapter(
)

# Create and add fresh new adapters into the model.
inject_adapter_in_model(peft_config, self, adapter_name)
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)

if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
Expand All @@ -211,7 +230,9 @@ def load_adapter(
processed_adapter_state_dict[new_key] = value

# Load state dict
incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name)
incompatible_keys = set_peft_model_state_dict(
self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs
)

if incompatible_keys is not None:
# check only for unexpected keys
Expand Down
44 changes: 44 additions & 0 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import tempfile
import unittest

from huggingface_hub import hf_hub_download
from packaging import version

from transformers import AutoModelForCausalLM, OPTForCausalLM
from transformers.testing_utils import (
Expand Down Expand Up @@ -478,6 +480,48 @@ def test_peft_add_adapter_with_state_dict(self):
# dummy generation
_ = model.generate(input_ids=dummy_input)

def test_peft_add_adapter_with_state_dict_low_cpu_mem_usage(self):
"""
Check the usage of low_cpu_mem_usage, which is supported in PEFT >= 0.13.0
"""
from peft import LoraConfig

min_version_lcmu = "0.13.0"
is_lcmu_supported = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu)

for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)

peft_config = LoraConfig()
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
dummy_state_dict = torch.load(state_dict_path)

# this should always work
model.load_adapter(
adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False
)

if is_lcmu_supported:
# if supported, this should not raise an error
model.load_adapter(
adapter_state_dict=dummy_state_dict,
adapter_name="other",
peft_config=peft_config,
low_cpu_mem_usage=True,
)
# after loading, no meta device should be remaining
self.assertFalse(any((p.device.type == "meta") for p in model.parameters()))
else:
err_msg = r"The version of PEFT you are using does not support `low_cpu_mem_usage` yet"
with self.assertRaisesRegex(ValueError, err_msg):
model.load_adapter(
adapter_state_dict=dummy_state_dict,
adapter_name="other",
peft_config=peft_config,
low_cpu_mem_usage=True,
)

def test_peft_from_pretrained_hub_kwargs(self):
"""
Tests different combinations of PEFT model + from_pretrained + hub kwargs
Expand Down
Loading