@@ -1718,29 +1718,51 @@ def test_isin_mps_friendly(self):
1718
1718
1719
1719
def test_can_generate (self ):
1720
1720
"""Tests the behavior of `PreTrainedModel.can_generate` method."""
1721
+ logger = logging .get_logger ("transformers.modeling_utils" )
1722
+ logger .warning_once .cache_clear ()
1723
+
1721
1724
# 1 - By default, a model CAN'T generate
1722
- self .assertFalse (BertModel .can_generate ())
1725
+ can_generate = BertModel .can_generate ()
1726
+ self .assertFalse (can_generate )
1723
1727
1724
1728
# 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
1725
1729
class DummyBertWithMixin (BertModel , GenerationMixin ):
1726
1730
pass
1727
1731
1728
- self .assertTrue (DummyBertWithMixin .can_generate ())
1732
+ with CaptureLogger (logger ) as cl :
1733
+ can_generate = DummyBertWithMixin .can_generate ()
1734
+ self .assertTrue ("" == cl .out )
1735
+ self .assertTrue (can_generate )
1729
1736
1730
1737
# 3 - Alternatively, a model can implement a `generate` method
1731
1738
class DummyBertWithGenerate (BertModel ):
1732
1739
def generate (self ):
1733
1740
pass
1734
1741
1735
- self .assertTrue (DummyBertWithGenerate .can_generate ())
1742
+ with CaptureLogger (logger ) as cl :
1743
+ can_generate = DummyBertWithGenerate .can_generate ()
1744
+ self .assertTrue ("" == cl .out )
1745
+ self .assertTrue (can_generate )
1746
+
1747
+ # 4 - Finally, it can inherit from a model that can generate
1748
+ class DummyBertWithParent (DummyBertWithMixin ):
1749
+ pass
1750
+
1751
+ with CaptureLogger (logger ) as cl :
1752
+ can_generate = DummyBertWithParent .can_generate ()
1753
+ self .assertTrue ("" == cl .out )
1754
+ self .assertTrue (can_generate )
1736
1755
1737
- # 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
1756
+ # 5 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
1738
1757
# `GenerationMixin`)
1739
1758
class DummyBertWithPrepareInputs (BertModel ):
1740
1759
def prepare_inputs_for_generation (self ):
1741
1760
pass
1742
1761
1743
- self .assertTrue (DummyBertWithPrepareInputs .can_generate ())
1762
+ with CaptureLogger (logger ) as cl :
1763
+ can_generate = DummyBertWithPrepareInputs .can_generate ()
1764
+ self .assertTrue ("it doesn't directly inherit from `GenerationMixin`" in cl .out )
1765
+ self .assertTrue (can_generate )
1744
1766
1745
1767
def test_save_and_load_config_with_custom_generation (self ):
1746
1768
"""
0 commit comments