Skip to content

Commit 9c74971

Browse files
lucas-tuckerlucast2021
and
lucast2021
authored
[mypy] Forward pass function type hints in lora (#11740)
Signed-off-by: lucast2021 <[email protected]> Co-authored-by: lucast2021 <[email protected]>
1 parent 022c5c6 commit 9c74971

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

vllm/lora/layers.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,9 @@ def __init__(self, base_layer: ReplicatedLinear) -> None:
405405
self.output_size = self.base_layer.output_size
406406
self.n_slices = 1
407407

408-
def forward(self, input_):
408+
def forward(
409+
self, input_: torch.Tensor
410+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
409411
"""Forward of ReplicatedLinearWithLoRA
410412
411413
Args:
@@ -496,7 +498,9 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
496498
bias = bias[start_idx:end_idx]
497499
return bias
498500

499-
def forward(self, input_):
501+
def forward(
502+
self, input_: torch.Tensor
503+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
500504
"""Forward of ColumnParallelLinear
501505
502506
Args:
@@ -833,7 +837,9 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
833837
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
834838
return bias
835839

836-
def forward(self, input_):
840+
def forward(
841+
self, input_: torch.Tensor
842+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
837843
"""Forward of RowParallelLinear
838844
839845
Args:

vllm/lora/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import re
66
from dataclasses import dataclass, field
7-
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
7+
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
88

99
import safetensors.torch
1010
import torch
@@ -219,6 +219,7 @@ def from_local_checkpoint(
219219

220220
config["vllm_max_position_embeddings"] = max_position_embeddings
221221
peft_helper = PEFTHelper.from_dict(config)
222+
unexpected_modules: List[Union[list[str], str]]
222223
if os.path.isfile(lora_tensor_path):
223224
tensors: Dict[str, torch.Tensor] = {}
224225
# Find unexpected modules.

vllm/model_executor/layers/linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
238238
assert param.size() == loaded_weight.size()
239239
param.data.copy_(loaded_weight)
240240

241-
def forward(self, x: torch.Tensor) -> torch.Tensor:
241+
def forward(
242+
self, x: torch.Tensor
243+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
242244
bias = self.bias if not self.skip_bias_add else None
243245
assert self.quant_method is not None
244246
output = self.quant_method.apply(self, x, bias)

0 commit comments

Comments
 (0)