|
57 | 57 | exact_div,
|
58 | 58 | first_true_indices,
|
59 | 59 | forward,
|
| 60 | + generate, |
60 | 61 | generate_model_card,
|
61 | 62 | get_comet_experiment_url,
|
62 | 63 | get_reward,
|
@@ -359,7 +360,7 @@ def repeat_generator():
|
359 | 360 | yield from dataloader
|
360 | 361 |
|
361 | 362 | iter_dataloader = iter(repeat_generator())
|
362 |
| - generation_config = GenerationConfig( |
| 363 | + self.generation_config = GenerationConfig( |
363 | 364 | max_new_tokens=args.response_length,
|
364 | 365 | temperature=(args.temperature + 1e-7),
|
365 | 366 | top_k=0.0,
|
@@ -428,7 +429,7 @@ def repeat_generator():
|
428 | 429 | queries,
|
429 | 430 | args.local_rollout_forward_batch_size,
|
430 | 431 | processing_class.pad_token_id,
|
431 |
| - generation_config, |
| 432 | + self.generation_config, |
432 | 433 | )
|
433 | 434 |
|
434 | 435 | for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
|
@@ -677,12 +678,81 @@ def repeat_generator():
|
677 | 678 | )
|
678 | 679 | torch.cuda.empty_cache()
|
679 | 680 |
|
| 681 | + torch.cuda.empty_cache() |
| 682 | + |
680 | 683 | # HF trainer specifics
|
681 | 684 | self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
682 | 685 | if self.control.should_save:
|
683 | 686 | self._save_checkpoint(model, trial=None, metrics=None)
|
684 | 687 | self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
685 | 688 |
|
| 689 | + def generate( |
| 690 | + self, |
| 691 | + queries: Union[torch.Tensor, list[torch.Tensor]], |
| 692 | + generation_config: Optional[GenerationConfig] = None, |
| 693 | + return_logits=False, |
| 694 | + ) -> Union[ |
| 695 | + torch.Tensor, tuple[torch.Tensor, torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] |
| 696 | + ]: |
| 697 | + """ |
| 698 | + Generates responses to the given queries. |
| 699 | +
|
| 700 | + Args: |
| 701 | + queries (`torch.Tensor` or `list[torch.Tensor]]`): A batch of query tensors or a list of query tensors. |
| 702 | + Each query tensor should be a 1- or 2D tensor of token IDs. |
| 703 | + generation_config (`GenerationConfig` or `None`): Generation config, defaults to the one defined |
| 704 | + in the PPOConfig or the model's default config. |
| 705 | + return_logits (`bool`): Whether to return the logits along with the generated sequences. |
| 706 | + Defaults to False. |
| 707 | +
|
| 708 | + Returns: |
| 709 | + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor]]: |
| 710 | + If `queries` is a batch tensor and `return_logits` is False: |
| 711 | + A tensor of generated sequences. |
| 712 | + If `queries` is a batch tensor and `return_logits` is True: |
| 713 | + A tuple containing the tensor of generated sequences and the tensor of logits. |
| 714 | + If `queries` is a list of tensors and `return_logits` is False: |
| 715 | + A list of tensors, where each tensor is a generated sequence. |
| 716 | + If `queries` is a list of tensors and `return_logits` is True: |
| 717 | + A list of tuples, where each tuple contains the generated sequence tensor and the logits tensor. |
| 718 | + """ |
| 719 | + |
| 720 | + def _reshape_query(query: torch.Tensor) -> torch.Tensor: |
| 721 | + if len(query.shape) == 1: |
| 722 | + query = query.reshape(1, -1) |
| 723 | + return query |
| 724 | + |
| 725 | + generation_config = generation_config or getattr( |
| 726 | + self, "generation_config", self.policy_model.generation_config |
| 727 | + ) |
| 728 | + with unwrap_model_for_generation( |
| 729 | + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation |
| 730 | + ) as unwrapped_model: |
| 731 | + if isinstance(queries, list): |
| 732 | + result = [] |
| 733 | + for query in queries: |
| 734 | + query = _reshape_query(query) |
| 735 | + single_result = batch_generation( |
| 736 | + unwrapped_model.policy, |
| 737 | + query, |
| 738 | + self.args.local_rollout_forward_batch_size, |
| 739 | + self.processing_class.pad_token_id, |
| 740 | + generation_config, |
| 741 | + ) |
| 742 | + if return_logits: |
| 743 | + result.append(single_result) |
| 744 | + else: |
| 745 | + result.append(single_result[0]) |
| 746 | + |
| 747 | + else: |
| 748 | + queries = _reshape_query(queries) |
| 749 | + result = generate( |
| 750 | + unwrapped_model.policy, queries, self.processing_class.pad_token_id, generation_config |
| 751 | + ) |
| 752 | + if not return_logits: |
| 753 | + result = result[0] |
| 754 | + return result |
| 755 | + |
686 | 756 | def generate_completions(self, sampling: bool = False):
|
687 | 757 | args = self.args
|
688 | 758 | processing_class = self.processing_class
|
@@ -862,81 +932,3 @@ def _generate_batched(
|
862 | 932 |
|
863 | 933 | self.tokenizer.padding_side = padding_side_default
|
864 | 934 | return outputs
|
865 |
| - |
866 |
| - def generate( |
867 |
| - self, |
868 |
| - query_tensor: Union[torch.Tensor, list[torch.Tensor]], |
869 |
| - length_sampler: Optional[Callable] = None, |
870 |
| - batch_size: int = 4, |
871 |
| - return_prompt: bool = True, |
872 |
| - generate_ref_response: bool = False, |
873 |
| - **generation_kwargs, |
874 |
| - ): |
875 |
| - """ |
876 |
| - Generate response with the model given the query tensor. |
877 |
| - call the `generate` method of the model. |
878 |
| -
|
879 |
| - Args: |
880 |
| - query_tensor (`torch.LongTensor`): |
881 |
| - A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`). |
882 |
| - length_sampler (`Callable`, *optional*): |
883 |
| - Callable that returns the number of newly generated tokens. |
884 |
| - batch_size (`int`, *optional): |
885 |
| - Batch size used for generation, defaults to `4`. |
886 |
| - return_prompt (`bool`, *optional*): |
887 |
| - If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`. |
888 |
| - generate_ref_response (`bool`, *optional*): |
889 |
| - If set to `True` the reference response is also generated, defaults to `False`. |
890 |
| - generation_kwargs (dict[str, Any]): |
891 |
| - Keyword arguments for generation. |
892 |
| -
|
893 |
| - Returns: |
894 |
| - `torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens. |
895 |
| - """ |
896 |
| - if generate_ref_response: |
897 |
| - ref_model = self.model if self.is_peft_model else self.ref_model |
898 |
| - if isinstance(query_tensor, list): |
899 |
| - response = self._generate_batched( |
900 |
| - self.model, |
901 |
| - query_tensor, |
902 |
| - length_sampler=length_sampler, |
903 |
| - batch_size=batch_size, |
904 |
| - return_prompt=return_prompt, |
905 |
| - **generation_kwargs, |
906 |
| - ) |
907 |
| - if generate_ref_response: |
908 |
| - ref_response = self._generate_batched( |
909 |
| - ref_model, |
910 |
| - query_tensor, |
911 |
| - length_sampler=length_sampler, |
912 |
| - batch_size=batch_size, |
913 |
| - return_prompt=return_prompt, |
914 |
| - **generation_kwargs, |
915 |
| - ) |
916 |
| - |
917 |
| - else: |
918 |
| - if len(query_tensor.shape) == 2: |
919 |
| - raise ValueError( |
920 |
| - "query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)" |
921 |
| - ) |
922 |
| - |
923 |
| - if length_sampler is not None: |
924 |
| - generation_kwargs["max_new_tokens"] = length_sampler() |
925 |
| - |
926 |
| - with unwrap_model_for_generation(self.policy_model, self.accelerator) as unwrapped_model: |
927 |
| - response = unwrapped_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs) |
928 |
| - |
929 |
| - if generate_ref_response: |
930 |
| - with unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_model: |
931 |
| - ref_response = unwrapped_model.generate( |
932 |
| - input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs |
933 |
| - ) |
934 |
| - |
935 |
| - if not return_prompt and not self.is_encoder_decoder: |
936 |
| - response = response[:, query_tensor.shape[0] :] |
937 |
| - if generate_ref_response: |
938 |
| - ref_response = ref_response[:, query_tensor.shape[0] :] |
939 |
| - |
940 |
| - if generate_ref_response: |
941 |
| - return response, ref_response |
942 |
| - return response |
0 commit comments