Skip to content

PPOTrainer' object has no attribute 'generate' #3250

Open
@zoeChen119

Description

@zoeChen119

Reproduction

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, GenerationConfig, AutoModelForCausalLM
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
# from trl.core import respond_to_batch

ds = load_dataset('imdb', split='train')
ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'})
ds = ds.filter(lambda x: len(x["review"])>200, batched=False)

model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
value_model = AutoModelForCausalLM.from_pretrained('gpt2')
model_ref = create_reference_model(model)
reward_model = create_reference_model(model)

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

ppo_config = PPOConfig(batch_size=1, mini_batch_size=1,output_dir="./output/")

query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# response_tensor  = respond_to_batch(model, query_tensor)
model.generation_config = GenerationConfig(eos_token_id=tokenizer.eos_token_id)

ppo_trainer = PPOTrainer(
    args=ppo_config, 
    processing_class=tokenizer,
    model=model, 
    ref_model=model_ref, 
    reward_model=reward_model, 
    train_dataset=ds,
    value_model=value_model,
)

response_tensor, values_and_logits = ppo_trainer.generate(
    query_tensor,
    return_prompt=False,
    return_values_and_logits=True,
)
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]

# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

outputs:

Traceback (most recent call last):
  File "/workspace/multimodal/RL/ppo_test.py", line 37, in <module>
    response_tensor, values_and_logits = ppo_trainer.generate(
AttributeError: 'PPOTrainer' object has no attribute 'generate'

System Info

transformers
Version: 4.48.0.dev0

trl
Version: 0.16.0

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 PPORelated to PPO🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions