diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 737ac5c4be..f90492e39d 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -289,9 +289,12 @@ def _forward_func( # 1st element is the total prob, rest are the target tokens # add a leading dim for batch even we only support single instance for now if self.include_per_token_attr: - target_log_probs = torch.stack( - [total_log_prob, *log_prob_list], dim=0 # type: ignore - ).unsqueeze(0) + try: + target_log_probs = torch.stack( + [total_log_prob, *log_prob_list], dim=0 # type: ignore + ).unsqueeze(0) + except TypeError: + raise TypeError("Try using the skip_bos argument.") else: target_log_probs = total_log_prob # type: ignore target_probs = torch.exp(target_log_probs) @@ -325,6 +328,10 @@ def attribute( inp: InterpretableInput, target: Union[str, torch.Tensor, None] = None, num_trials: int = 1, + skip_bos: bool = True, + # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use + # `typing.Dict[, ]` to avoid runtime subscripting + # errors. gen_args: Optional[Dict[str, Any]] = None, use_cached_outputs: bool = True, # internal callback hook can be used for logging @@ -375,8 +382,11 @@ def attribute( assert gen_args is None, "gen_args must be None when target is given" if type(target) is str: - # exclude sos - target_tokens = self.tokenizer.encode(target)[1:] + # exclude sos / bos + if skip_bos: + target_tokens = self.tokenizer.encode(target)[1:] + else: + target_tokens = self.tokenizer.encode(target) target_tokens = torch.tensor(target_tokens) elif type(target) is torch.Tensor: target_tokens = target