Skip to content

Reintroduce generate method for PPOTrainer #3374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,62 @@ def test_peft_training(self):

self.assertTrue(critic_weights_updated, "Critic weights were not updated during training")
self.assertTrue(policy_weights_updated, "Policy LoRA weights were not updated during training")

def test_generate(self):
"""Test various configurations of the generate method in PPOTrainer."""

with tempfile.TemporaryDirectory() as tmp_dir:
# Configure training args
training_args = PPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=4,
per_device_eval_batch_size=2,
num_ppo_epochs=2,
report_to="none",
)

# Create trainer
trainer = PPOTrainer(
args=training_args,
processing_class=self.tokenizer,
model=self.model,
ref_model=self.ref_model,
reward_model=self.reward_model,
value_model=self.value_model,
train_dataset=self.raw_dataset["train"],
eval_dataset=self.raw_dataset["test"],
)

query_txt = "This morning I went to the "
query_tensor = torch.flatten(self.tokenizer.encode(query_txt, return_tensors="pt")).to(self.model.device)
query_list = list(self.tokenizer.encode(query_txt, return_tensors="pt").to(self.model.device))

test_cases = [
# (input_type, input_data, return_logits)
("tensor", query_tensor, False),
("tensor", query_tensor, True),
("list", query_list, False),
("list", query_list, True),
]

for input_type, query, return_logits in test_cases:
with self.subTest(input_type=input_type, return_logits=return_logits):
try:
response = trainer.generate(queries=query, return_logits=return_logits)
except Exception:
response = trainer.generate(queries=query, return_logits=return_logits)
if input_type == "tensor" and return_logits is False:
self.assertTrue(isinstance(response, torch.Tensor))
self.assertEqual(len(response.shape), 2)
elif input_type == "tensor" and return_logits is True:
self.assertTrue(isinstance(response, tuple))
self.assertEqual(len(response[0].shape), 2)
# equal to vocab size - 1
# self.assertEqual(response[1].shape ==
elif input_type == "list" and return_logits is True:
self.assertTrue(isinstance(response, list))
self.assertTrue(isinstance(response[0], tuple))
self.assertEqual(len(response), 1)
elif input_type == "list" and return_logits is False:
self.assertTrue(isinstance(response, list))
self.assertTrue(isinstance(response[0], torch.Tensor))
75 changes: 73 additions & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
exact_div,
first_true_indices,
forward,
generate,
generate_model_card,
get_comet_experiment_url,
get_reward,
Expand Down Expand Up @@ -166,6 +167,7 @@ def __init__(
if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(self.policy_model)

self.is_encoder_decoder = model.config.is_encoder_decoder
self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
self.model_adapter_name = args.model_adapter_name
self.ref_adapter_name = args.ref_adapter_name
Expand Down Expand Up @@ -359,7 +361,7 @@ def repeat_generator():
yield from dataloader

iter_dataloader = iter(repeat_generator())
generation_config = GenerationConfig(
self.generation_config = GenerationConfig(
max_new_tokens=args.response_length,
temperature=(args.temperature + 1e-7),
top_k=0.0,
Expand Down Expand Up @@ -428,7 +430,7 @@ def repeat_generator():
queries,
args.local_rollout_forward_batch_size,
processing_class.pad_token_id,
generation_config,
self.generation_config,
)

for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
Expand Down Expand Up @@ -677,12 +679,81 @@ def repeat_generator():
)
empty_cache()

torch.cuda.empty_cache()

# HF trainer specifics
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

def generate(
self,
queries: Union[torch.Tensor, list[torch.Tensor]],
generation_config: Optional[GenerationConfig] = None,
return_logits=False,
) -> Union[
torch.Tensor, tuple[torch.Tensor, torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor]
]:
"""
Generates responses to the given queries.

Args:
queries (`torch.Tensor` or `list[torch.Tensor]]`): A batch of query tensors or a list of query tensors.
Each query tensor should be a 1- or 2D tensor of token IDs.
generation_config (`GenerationConfig` or `None`): Generation config, defaults to the one defined
in the PPOConfig or the model's default config.
return_logits (`bool`): Whether to return the logits along with the generated sequences.
Defaults to False.

Returns:
Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor]]:
If `queries` is a batch tensor and `return_logits` is False:
A tensor of generated sequences.
If `queries` is a batch tensor and `return_logits` is True:
A tuple containing the tensor of generated sequences and the tensor of logits.
If `queries` is a list of tensors and `return_logits` is False:
A list of tensors, where each tensor is a generated sequence.
If `queries` is a list of tensors and `return_logits` is True:
A list of tuples, where each tuple contains the generated sequence tensor and the logits tensor.
"""

def _reshape_query(query: torch.Tensor) -> torch.Tensor:
if len(query.shape) == 1:
query = query.reshape(1, -1)
return query

generation_config = generation_config or getattr(
self, "generation_config", self.policy_model.generation_config
)
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
if isinstance(queries, list):
result = []
for query in queries:
query = _reshape_query(query)
single_result = batch_generation(
unwrapped_model.policy,
query,
self.args.local_rollout_forward_batch_size,
self.processing_class.pad_token_id,
generation_config,
)
if return_logits:
result.append(single_result)
else:
result.append(single_result[0])

else:
queries = _reshape_query(queries)
result = generate(
unwrapped_model.policy, queries, self.processing_class.pad_token_id, generation_config
)
if not return_logits:
result = result[0]
return result

def generate_completions(self, sampling: bool = False):
args = self.args
processing_class = self.processing_class
Expand Down