@@ -1142,13 +1142,16 @@ def _generate_and_score_completions(
1142
1142
completions = completions_text
1143
1143
1144
1144
rewards_per_func = torch .zeros (len (prompts ), len (self .reward_funcs ), device = device )
1145
+
1146
+ # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations
1147
+ keys = [key for key in inputs [0 ] if key not in ["prompt" , "completion" , "completion_ids" ]]
1148
+ reward_kwargs = {key : [example [key ] for example in inputs ] for key in keys }
1149
+
1145
1150
for i , (reward_func , reward_processing_class , reward_func_name ) in enumerate (
1146
1151
zip (self .reward_funcs , self .reward_processing_classes , self .reward_func_names )
1147
1152
):
1148
1153
with profiling_context (self , reward_func_name ):
1149
- if isinstance (
1150
- reward_func , nn .Module
1151
- ): # Module instead of PretrainedModel for compat with compiled models
1154
+ if isinstance (reward_func , nn .Module ): # Module (no PretrainedModel) for compat with compiled models
1152
1155
if is_conversational (inputs [0 ]):
1153
1156
messages = [{"messages" : p + c } for p , c in zip (prompts , completions )]
1154
1157
texts = [apply_chat_template (x , reward_processing_class )["text" ] for x in messages ]
@@ -1161,10 +1164,6 @@ def _generate_and_score_completions(
1161
1164
with torch .inference_mode ():
1162
1165
rewards_per_func [:, i ] = reward_func (** reward_inputs ).logits [:, 0 ] # Shape (B*G,)
1163
1166
else :
1164
- # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the number
1165
- # of generations
1166
- keys = [key for key in inputs [0 ] if key not in ["prompt" , "completion" , "completion_ids" ]]
1167
- reward_kwargs = {key : [example [key ] for example in inputs ] for key in keys }
1168
1167
output_reward_func = reward_func (
1169
1168
prompts = prompts , completions = completions , completion_ids = completion_ids_list , ** reward_kwargs
1170
1169
)
0 commit comments