Skip to content

Commit 4d61397

Browse files
committed
update generate
1 parent 334b31e commit 4d61397

File tree

2 files changed

+97
-116
lines changed

2 files changed

+97
-116
lines changed

tests/test_ppo_trainer.py

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def test_peft_training(self):
179179

180180
def test_generate(self):
181181
"""Test various configurations of the generate method in PPOTrainer."""
182+
182183
with tempfile.TemporaryDirectory() as tmp_dir:
183184
# Configure training args
184185
training_args = PPOConfig(
@@ -202,47 +203,35 @@ def test_generate(self):
202203
)
203204

204205
query_txt = "This morning I went to the "
205-
query_tensor = torch.flatten(self.tokenizer.encode(query_txt, return_tensors="pt")).to("cuda")
206-
query_list = list(self.tokenizer.encode(query_txt, return_tensors="pt").to("cuda"))
206+
query_tensor = torch.flatten(self.tokenizer.encode(query_txt, return_tensors="pt")).to(self.model.device)
207+
query_list = list(self.tokenizer.encode(query_txt, return_tensors="pt").to(self.model.device))
207208

208209
test_cases = [
209-
# (input_type, input_data, return_prompt, generate_ref_response)
210-
("tensor", query_tensor, False, False),
211-
("tensor", query_tensor, True, False),
212-
("tensor", query_tensor, False, True),
213-
("tensor", query_tensor, True, True),
214-
("list", query_list, False, False),
215-
("list", query_list, True, False),
216-
("list", query_list, False, True),
217-
("list", query_list, True, True),
210+
# (input_type, input_data, return_logits)
211+
("tensor", query_tensor, False),
212+
("tensor", query_tensor, True),
213+
("list", query_list, False),
214+
("list", query_list, True),
218215
]
219216

220-
for input_type, query, return_prompt, generate_ref_response in test_cases:
221-
with self.subTest(
222-
input_type=input_type, return_prompt=return_prompt, generate_ref_response=generate_ref_response
223-
):
224-
# Run generate with the current configuration
225-
if generate_ref_response:
226-
response, ref_response = trainer.generate(
227-
query, return_prompt=return_prompt, generate_ref_response=generate_ref_response
228-
)
229-
230-
# Verify the reference response
231-
if input_type == "tensor":
232-
self.assertTrue(isinstance(ref_response, torch.Tensor))
233-
self.assertEqual(len(ref_response.shape), 2)
234-
else:
235-
self.assertTrue(isinstance(ref_response, list))
236-
self.assertEqual(len(ref_response), 1)
237-
else:
238-
response = trainer.generate(
239-
query, return_prompt=return_prompt, generate_ref_response=generate_ref_response
240-
)
241-
242-
# Verify the response format based on input type
243-
if input_type == "tensor":
217+
for input_type, query, return_logits in test_cases:
218+
with self.subTest(input_type=input_type, return_logits=return_logits):
219+
try:
220+
response = trainer.generate(queries=query, return_logits=return_logits)
221+
except Exception:
222+
response = trainer.generate(queries=query, return_logits=return_logits)
223+
if input_type == "tensor" and return_logits is False:
244224
self.assertTrue(isinstance(response, torch.Tensor))
245225
self.assertEqual(len(response.shape), 2)
246-
else:
226+
elif input_type == "tensor" and return_logits is True:
227+
self.assertTrue(isinstance(response, tuple))
228+
self.assertEqual(len(response[0].shape), 2)
229+
# equal to vocab size - 1
230+
# self.assertEqual(response[1].shape ==
231+
elif input_type == "list" and return_logits is True:
247232
self.assertTrue(isinstance(response, list))
233+
self.assertTrue(isinstance(response[0], tuple))
248234
self.assertEqual(len(response), 1)
235+
elif input_type == "list" and return_logits is False:
236+
self.assertTrue(isinstance(response, list))
237+
self.assertTrue(isinstance(response[0], torch.Tensor))

trl/trainer/ppo_trainer.py

Lines changed: 72 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
exact_div,
5858
first_true_indices,
5959
forward,
60+
generate,
6061
generate_model_card,
6162
get_comet_experiment_url,
6263
get_reward,
@@ -359,7 +360,7 @@ def repeat_generator():
359360
yield from dataloader
360361

361362
iter_dataloader = iter(repeat_generator())
362-
generation_config = GenerationConfig(
363+
self.generation_config = GenerationConfig(
363364
max_new_tokens=args.response_length,
364365
temperature=(args.temperature + 1e-7),
365366
top_k=0.0,
@@ -428,7 +429,7 @@ def repeat_generator():
428429
queries,
429430
args.local_rollout_forward_batch_size,
430431
processing_class.pad_token_id,
431-
generation_config,
432+
self.generation_config,
432433
)
433434

434435
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
@@ -677,12 +678,81 @@ def repeat_generator():
677678
)
678679
torch.cuda.empty_cache()
679680

681+
torch.cuda.empty_cache()
682+
680683
# HF trainer specifics
681684
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
682685
if self.control.should_save:
683686
self._save_checkpoint(model, trial=None, metrics=None)
684687
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
685688

689+
def generate(
690+
self,
691+
queries: Union[torch.Tensor, list[torch.Tensor]],
692+
generation_config: Optional[GenerationConfig] = None,
693+
return_logits=False,
694+
) -> Union[
695+
torch.Tensor, tuple[torch.Tensor, torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor]
696+
]:
697+
"""
698+
Generates responses to the given queries.
699+
700+
Args:
701+
queries (`torch.Tensor` or `list[torch.Tensor]]`): A batch of query tensors or a list of query tensors.
702+
Each query tensor should be a 1- or 2D tensor of token IDs.
703+
generation_config (`GenerationConfig` or `None`): Generation config, defaults to the one defined
704+
in the PPOConfig or the model's default config.
705+
return_logits (`bool`): Whether to return the logits along with the generated sequences.
706+
Defaults to False.
707+
708+
Returns:
709+
Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor]]:
710+
If `queries` is a batch tensor and `return_logits` is False:
711+
A tensor of generated sequences.
712+
If `queries` is a batch tensor and `return_logits` is True:
713+
A tuple containing the tensor of generated sequences and the tensor of logits.
714+
If `queries` is a list of tensors and `return_logits` is False:
715+
A list of tensors, where each tensor is a generated sequence.
716+
If `queries` is a list of tensors and `return_logits` is True:
717+
A list of tuples, where each tuple contains the generated sequence tensor and the logits tensor.
718+
"""
719+
720+
def _reshape_query(query: torch.Tensor) -> torch.Tensor:
721+
if len(query.shape) == 1:
722+
query = query.reshape(1, -1)
723+
return query
724+
725+
generation_config = generation_config or getattr(
726+
self, "generation_config", self.policy_model.generation_config
727+
)
728+
with unwrap_model_for_generation(
729+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
730+
) as unwrapped_model:
731+
if isinstance(queries, list):
732+
result = []
733+
for query in queries:
734+
query = _reshape_query(query)
735+
single_result = batch_generation(
736+
unwrapped_model.policy,
737+
query,
738+
self.args.local_rollout_forward_batch_size,
739+
self.processing_class.pad_token_id,
740+
generation_config,
741+
)
742+
if return_logits:
743+
result.append(single_result)
744+
else:
745+
result.append(single_result[0])
746+
747+
else:
748+
queries = _reshape_query(queries)
749+
result = generate(
750+
unwrapped_model.policy, queries, self.processing_class.pad_token_id, generation_config
751+
)
752+
if not return_logits:
753+
result = result[0]
754+
return result
755+
686756
def generate_completions(self, sampling: bool = False):
687757
args = self.args
688758
processing_class = self.processing_class
@@ -862,81 +932,3 @@ def _generate_batched(
862932

863933
self.tokenizer.padding_side = padding_side_default
864934
return outputs
865-
866-
def generate(
867-
self,
868-
query_tensor: Union[torch.Tensor, list[torch.Tensor]],
869-
length_sampler: Optional[Callable] = None,
870-
batch_size: int = 4,
871-
return_prompt: bool = True,
872-
generate_ref_response: bool = False,
873-
**generation_kwargs,
874-
):
875-
"""
876-
Generate response with the model given the query tensor.
877-
call the `generate` method of the model.
878-
879-
Args:
880-
query_tensor (`torch.LongTensor`):
881-
A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`).
882-
length_sampler (`Callable`, *optional*):
883-
Callable that returns the number of newly generated tokens.
884-
batch_size (`int`, *optional):
885-
Batch size used for generation, defaults to `4`.
886-
return_prompt (`bool`, *optional*):
887-
If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`.
888-
generate_ref_response (`bool`, *optional*):
889-
If set to `True` the reference response is also generated, defaults to `False`.
890-
generation_kwargs (dict[str, Any]):
891-
Keyword arguments for generation.
892-
893-
Returns:
894-
`torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens.
895-
"""
896-
if generate_ref_response:
897-
ref_model = self.model if self.is_peft_model else self.ref_model
898-
if isinstance(query_tensor, list):
899-
response = self._generate_batched(
900-
self.model,
901-
query_tensor,
902-
length_sampler=length_sampler,
903-
batch_size=batch_size,
904-
return_prompt=return_prompt,
905-
**generation_kwargs,
906-
)
907-
if generate_ref_response:
908-
ref_response = self._generate_batched(
909-
ref_model,
910-
query_tensor,
911-
length_sampler=length_sampler,
912-
batch_size=batch_size,
913-
return_prompt=return_prompt,
914-
**generation_kwargs,
915-
)
916-
917-
else:
918-
if len(query_tensor.shape) == 2:
919-
raise ValueError(
920-
"query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)"
921-
)
922-
923-
if length_sampler is not None:
924-
generation_kwargs["max_new_tokens"] = length_sampler()
925-
926-
with unwrap_model_for_generation(self.policy_model, self.accelerator) as unwrapped_model:
927-
response = unwrapped_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
928-
929-
if generate_ref_response:
930-
with unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_model:
931-
ref_response = unwrapped_model.generate(
932-
input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
933-
)
934-
935-
if not return_prompt and not self.is_encoder_decoder:
936-
response = response[:, query_tensor.shape[0] :]
937-
if generate_ref_response:
938-
ref_response = ref_response[:, query_tensor.shape[0] :]
939-
940-
if generate_ref_response:
941-
return response, ref_response
942-
return response

0 commit comments

Comments
 (0)