Skip to content

Commit d1174ad

Browse files
🛠️ Initialize reward_kwargs to prevent UnboundLocalError in GRPOTrainer (#3459)
Co-authored-by: Quentin GallouĂ©dec <[email protected]>
1 parent cd83841 commit d1174ad

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

‎trl/trainer/grpo_trainer.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,13 +1142,16 @@ def _generate_and_score_completions(
11421142
completions = completions_text
11431143

11441144
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
1145+
1146+
# Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations
1147+
keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
1148+
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
1149+
11451150
for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
11461151
zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
11471152
):
11481153
with profiling_context(self, reward_func_name):
1149-
if isinstance(
1150-
reward_func, nn.Module
1151-
): # Module instead of PretrainedModel for compat with compiled models
1154+
if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
11521155
if is_conversational(inputs[0]):
11531156
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
11541157
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
@@ -1161,10 +1164,6 @@ def _generate_and_score_completions(
11611164
with torch.inference_mode():
11621165
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
11631166
else:
1164-
# Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the number
1165-
# of generations
1166-
keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
1167-
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
11681167
output_reward_func = reward_func(
11691168
prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
11701169
)

0 commit comments

Comments
 (0)