Skip to content

Commit ea68d7c

Browse files
Fixes EMAModel "from_pretrained" method (#9779)
* fix from_pretrained and added test * make style --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent ad754e6 commit ea68d7c

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

src/diffusers/training_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def __init__(
379379

380380
@classmethod
381381
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
382-
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
382+
_, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
383383
model = model_cls.from_pretrained(path)
384384

385385
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)

tests/others/test_ema.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,25 @@ def simulate_backprop(self, unet):
5959
unet.load_state_dict(updated_state_dict)
6060
return unet
6161

62+
def test_from_pretrained(self):
63+
# Save the model parameters to a temporary directory
64+
unet, ema_unet = self.get_models()
65+
with tempfile.TemporaryDirectory() as tmpdir:
66+
ema_unet.save_pretrained(tmpdir)
67+
68+
# Load the EMA model from the saved directory
69+
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False)
70+
71+
# Check that the shadow parameters of the loaded model match the original EMA model
72+
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
73+
assert torch.allclose(original_param, loaded_param, atol=1e-4)
74+
75+
# Verify that the optimization step is also preserved
76+
assert loaded_ema_unet.optimization_step == ema_unet.optimization_step
77+
78+
# Check the decay value
79+
assert loaded_ema_unet.decay == ema_unet.decay
80+
6281
def test_optimization_steps_updated(self):
6382
unet, ema_unet = self.get_models()
6483
# Take the first (hypothetical) EMA step.
@@ -194,6 +213,25 @@ def simulate_backprop(self, unet):
194213
unet.load_state_dict(updated_state_dict)
195214
return unet
196215

216+
def test_from_pretrained(self):
217+
# Save the model parameters to a temporary directory
218+
unet, ema_unet = self.get_models()
219+
with tempfile.TemporaryDirectory() as tmpdir:
220+
ema_unet.save_pretrained(tmpdir)
221+
222+
# Load the EMA model from the saved directory
223+
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True)
224+
225+
# Check that the shadow parameters of the loaded model match the original EMA model
226+
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
227+
assert torch.allclose(original_param, loaded_param, atol=1e-4)
228+
229+
# Verify that the optimization step is also preserved
230+
assert loaded_ema_unet.optimization_step == ema_unet.optimization_step
231+
232+
# Check the decay value
233+
assert loaded_ema_unet.decay == ema_unet.decay
234+
197235
def test_optimization_steps_updated(self):
198236
unet, ema_unet = self.get_models()
199237
# Take the first (hypothetical) EMA step.

0 commit comments

Comments
 (0)