@@ -59,6 +59,25 @@ def simulate_backprop(self, unet):
59
59
unet .load_state_dict (updated_state_dict )
60
60
return unet
61
61
62
+ def test_from_pretrained (self ):
63
+ # Save the model parameters to a temporary directory
64
+ unet , ema_unet = self .get_models ()
65
+ with tempfile .TemporaryDirectory () as tmpdir :
66
+ ema_unet .save_pretrained (tmpdir )
67
+
68
+ # Load the EMA model from the saved directory
69
+ loaded_ema_unet = EMAModel .from_pretrained (tmpdir , model_cls = UNet2DConditionModel , foreach = False )
70
+
71
+ # Check that the shadow parameters of the loaded model match the original EMA model
72
+ for original_param , loaded_param in zip (ema_unet .shadow_params , loaded_ema_unet .shadow_params ):
73
+ assert torch .allclose (original_param , loaded_param , atol = 1e-4 )
74
+
75
+ # Verify that the optimization step is also preserved
76
+ assert loaded_ema_unet .optimization_step == ema_unet .optimization_step
77
+
78
+ # Check the decay value
79
+ assert loaded_ema_unet .decay == ema_unet .decay
80
+
62
81
def test_optimization_steps_updated (self ):
63
82
unet , ema_unet = self .get_models ()
64
83
# Take the first (hypothetical) EMA step.
@@ -194,6 +213,25 @@ def simulate_backprop(self, unet):
194
213
unet .load_state_dict (updated_state_dict )
195
214
return unet
196
215
216
+ def test_from_pretrained (self ):
217
+ # Save the model parameters to a temporary directory
218
+ unet , ema_unet = self .get_models ()
219
+ with tempfile .TemporaryDirectory () as tmpdir :
220
+ ema_unet .save_pretrained (tmpdir )
221
+
222
+ # Load the EMA model from the saved directory
223
+ loaded_ema_unet = EMAModel .from_pretrained (tmpdir , model_cls = UNet2DConditionModel , foreach = True )
224
+
225
+ # Check that the shadow parameters of the loaded model match the original EMA model
226
+ for original_param , loaded_param in zip (ema_unet .shadow_params , loaded_ema_unet .shadow_params ):
227
+ assert torch .allclose (original_param , loaded_param , atol = 1e-4 )
228
+
229
+ # Verify that the optimization step is also preserved
230
+ assert loaded_ema_unet .optimization_step == ema_unet .optimization_step
231
+
232
+ # Check the decay value
233
+ assert loaded_ema_unet .decay == ema_unet .decay
234
+
197
235
def test_optimization_steps_updated (self ):
198
236
unet , ema_unet = self .get_models ()
199
237
# Take the first (hypothetical) EMA step.
0 commit comments