diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index 3f664de1b3..59570821b8 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -176,3 +176,62 @@ def test_peft_training(self): self.assertTrue(critic_weights_updated, "Critic weights were not updated during training") self.assertTrue(policy_weights_updated, "Policy LoRA weights were not updated during training") + + def test_generate(self): + """Test various configurations of the generate method in PPOTrainer.""" + + with tempfile.TemporaryDirectory() as tmp_dir: + # Configure training args + training_args = PPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=4, + per_device_eval_batch_size=2, + num_ppo_epochs=2, + report_to="none", + ) + + # Create trainer + trainer = PPOTrainer( + args=training_args, + processing_class=self.tokenizer, + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + value_model=self.value_model, + train_dataset=self.raw_dataset["train"], + eval_dataset=self.raw_dataset["test"], + ) + + query_txt = "This morning I went to the " + query_tensor = torch.flatten(self.tokenizer.encode(query_txt, return_tensors="pt")).to(self.model.device) + query_list = list(self.tokenizer.encode(query_txt, return_tensors="pt").to(self.model.device)) + + test_cases = [ + # (input_type, input_data, return_logits) + ("tensor", query_tensor, False), + ("tensor", query_tensor, True), + ("list", query_list, False), + ("list", query_list, True), + ] + + for input_type, query, return_logits in test_cases: + with self.subTest(input_type=input_type, return_logits=return_logits): + try: + response = trainer.generate(queries=query, return_logits=return_logits) + except Exception: + response = trainer.generate(queries=query, return_logits=return_logits) + if input_type == "tensor" and return_logits is False: + self.assertTrue(isinstance(response, torch.Tensor)) + self.assertEqual(len(response.shape), 2) + elif input_type == "tensor" and return_logits is True: + self.assertTrue(isinstance(response, tuple)) + self.assertEqual(len(response[0].shape), 2) + # equal to vocab size - 1 + # self.assertEqual(response[1].shape == + elif input_type == "list" and return_logits is True: + self.assertTrue(isinstance(response, list)) + self.assertTrue(isinstance(response[0], tuple)) + self.assertEqual(len(response), 1) + elif input_type == "list" and return_logits is False: + self.assertTrue(isinstance(response, list)) + self.assertTrue(isinstance(response[0], torch.Tensor)) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 580e5864b2..3a7a335ba2 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -58,6 +58,7 @@ exact_div, first_true_indices, forward, + generate, generate_model_card, get_comet_experiment_url, get_reward, @@ -166,6 +167,7 @@ def __init__( if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False): peft_module_casting_to_bf16(self.policy_model) + self.is_encoder_decoder = model.config.is_encoder_decoder self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel) self.model_adapter_name = args.model_adapter_name self.ref_adapter_name = args.ref_adapter_name @@ -359,7 +361,7 @@ def repeat_generator(): yield from dataloader iter_dataloader = iter(repeat_generator()) - generation_config = GenerationConfig( + self.generation_config = GenerationConfig( max_new_tokens=args.response_length, temperature=(args.temperature + 1e-7), top_k=0.0, @@ -428,7 +430,7 @@ def repeat_generator(): queries, args.local_rollout_forward_batch_size, processing_class.pad_token_id, - generation_config, + self.generation_config, ) for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): @@ -677,12 +679,81 @@ def repeat_generator(): ) empty_cache() + torch.cuda.empty_cache() + # HF trainer specifics self.control = self.callback_handler.on_train_end(args, self.state, self.control) if self.control.should_save: self._save_checkpoint(model, trial=None, metrics=None) self.control = self.callback_handler.on_save(self.args, self.state, self.control) + def generate( + self, + queries: Union[torch.Tensor, list[torch.Tensor]], + generation_config: Optional[GenerationConfig] = None, + return_logits=False, + ) -> Union[ + torch.Tensor, tuple[torch.Tensor, torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] + ]: + """ + Generates responses to the given queries. + + Args: + queries (`torch.Tensor` or `list[torch.Tensor]]`): A batch of query tensors or a list of query tensors. + Each query tensor should be a 1- or 2D tensor of token IDs. + generation_config (`GenerationConfig` or `None`): Generation config, defaults to the one defined + in the PPOConfig or the model's default config. + return_logits (`bool`): Whether to return the logits along with the generated sequences. + Defaults to False. + + Returns: + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor]]: + If `queries` is a batch tensor and `return_logits` is False: + A tensor of generated sequences. + If `queries` is a batch tensor and `return_logits` is True: + A tuple containing the tensor of generated sequences and the tensor of logits. + If `queries` is a list of tensors and `return_logits` is False: + A list of tensors, where each tensor is a generated sequence. + If `queries` is a list of tensors and `return_logits` is True: + A list of tuples, where each tuple contains the generated sequence tensor and the logits tensor. + """ + + def _reshape_query(query: torch.Tensor) -> torch.Tensor: + if len(query.shape) == 1: + query = query.reshape(1, -1) + return query + + generation_config = generation_config or getattr( + self, "generation_config", self.policy_model.generation_config + ) + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + if isinstance(queries, list): + result = [] + for query in queries: + query = _reshape_query(query) + single_result = batch_generation( + unwrapped_model.policy, + query, + self.args.local_rollout_forward_batch_size, + self.processing_class.pad_token_id, + generation_config, + ) + if return_logits: + result.append(single_result) + else: + result.append(single_result[0]) + + else: + queries = _reshape_query(queries) + result = generate( + unwrapped_model.policy, queries, self.processing_class.pad_token_id, generation_config + ) + if not return_logits: + result = result[0] + return result + def generate_completions(self, sampling: bool = False): args = self.args processing_class = self.processing_class