Skip to content

Commit 289edd9

Browse files
ganteArthurZucker
authored andcommitted
Generate: can_generate() recursive check (#33718)
* add recursive check and test warnings * missing space * models without can_generate
1 parent c64be31 commit 289edd9

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

src/transformers/modeling_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,6 +1645,12 @@ def can_generate(cls) -> bool:
16451645
# Model class overwrites `generate` (e.g. time series models) -> can generate
16461646
if str(cls.__name__) in str(cls.generate):
16471647
return True
1648+
# The class inherits from a class that can generate (recursive check) -> can generate
1649+
for base in cls.__bases__:
1650+
if not hasattr(base, "can_generate"):
1651+
continue
1652+
if "PreTrainedModel" not in str(base) and base.can_generate():
1653+
return True
16481654
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
16491655
# was how we detected whether a model could generate.
16501656
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):

tests/utils/test_modeling_utils.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,29 +1718,51 @@ def test_isin_mps_friendly(self):
17181718

17191719
def test_can_generate(self):
17201720
"""Tests the behavior of `PreTrainedModel.can_generate` method."""
1721+
logger = logging.get_logger("transformers.modeling_utils")
1722+
logger.warning_once.cache_clear()
1723+
17211724
# 1 - By default, a model CAN'T generate
1722-
self.assertFalse(BertModel.can_generate())
1725+
can_generate = BertModel.can_generate()
1726+
self.assertFalse(can_generate)
17231727

17241728
# 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
17251729
class DummyBertWithMixin(BertModel, GenerationMixin):
17261730
pass
17271731

1728-
self.assertTrue(DummyBertWithMixin.can_generate())
1732+
with CaptureLogger(logger) as cl:
1733+
can_generate = DummyBertWithMixin.can_generate()
1734+
self.assertTrue("" == cl.out)
1735+
self.assertTrue(can_generate)
17291736

17301737
# 3 - Alternatively, a model can implement a `generate` method
17311738
class DummyBertWithGenerate(BertModel):
17321739
def generate(self):
17331740
pass
17341741

1735-
self.assertTrue(DummyBertWithGenerate.can_generate())
1742+
with CaptureLogger(logger) as cl:
1743+
can_generate = DummyBertWithGenerate.can_generate()
1744+
self.assertTrue("" == cl.out)
1745+
self.assertTrue(can_generate)
1746+
1747+
# 4 - Finally, it can inherit from a model that can generate
1748+
class DummyBertWithParent(DummyBertWithMixin):
1749+
pass
1750+
1751+
with CaptureLogger(logger) as cl:
1752+
can_generate = DummyBertWithParent.can_generate()
1753+
self.assertTrue("" == cl.out)
1754+
self.assertTrue(can_generate)
17361755

1737-
# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
1756+
# 5 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
17381757
# `GenerationMixin`)
17391758
class DummyBertWithPrepareInputs(BertModel):
17401759
def prepare_inputs_for_generation(self):
17411760
pass
17421761

1743-
self.assertTrue(DummyBertWithPrepareInputs.can_generate())
1762+
with CaptureLogger(logger) as cl:
1763+
can_generate = DummyBertWithPrepareInputs.can_generate()
1764+
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
1765+
self.assertTrue(can_generate)
17441766

17451767
def test_save_and_load_config_with_custom_generation(self):
17461768
"""

0 commit comments

Comments
 (0)