Open
Description
Reproduction
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, GenerationConfig, AutoModelForCausalLM
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
# from trl.core import respond_to_batch
ds = load_dataset('imdb', split='train')
ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'})
ds = ds.filter(lambda x: len(x["review"])>200, batched=False)
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
value_model = AutoModelForCausalLM.from_pretrained('gpt2')
model_ref = create_reference_model(model)
reward_model = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1,output_dir="./output/")
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
# response_tensor = respond_to_batch(model, query_tensor)
model.generation_config = GenerationConfig(eos_token_id=tokenizer.eos_token_id)
ppo_trainer = PPOTrainer(
args=ppo_config,
processing_class=tokenizer,
model=model,
ref_model=model_ref,
reward_model=reward_model,
train_dataset=ds,
value_model=value_model,
)
response_tensor, values_and_logits = ppo_trainer.generate(
query_tensor,
return_prompt=False,
return_values_and_logits=True,
)
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]
# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
outputs:
Traceback (most recent call last):
File "/workspace/multimodal/RL/ppo_test.py", line 37, in <module>
response_tensor, values_and_logits = ppo_trainer.generate(
AttributeError: 'PPOTrainer' object has no attribute 'generate'
System Info
transformers
Version: 4.48.0.dev0
trl
Version: 0.16.0
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete