Skip to content

EMAModel class bug: "from_pretrained" method bug #9764

Closed
@wangyanhui666

Description

@wangyanhui666

Describe the bug

@classmethod
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
    _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
    model = model_cls.from_pretrained(path)

    ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)

    ema_model.load_state_dict(ema_kwargs)
    return ema_model

This is the from_pretrained method of EMAModel class.
first line ", ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)" will always return a empty dict "ema_kwargs"
I think this line should be "
, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)" this will return the ema_kwargs correctly.

Reproduction

can use official example have ema.
https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py

Logs

kwargs=model_cls.load_config(path, return_unused_kwargs=True)
print(kwargs)

({'_class_name': 'DiTTransformer2DModel', '_diffusers_version': '0.30.2', 'activation_fn': 'gelu-approximate', 'attention_bias': True, 'attention_head_dim': 72, 'decay': 0.9999, 'dropout': 0.0, 'in_channels': 4, 'inv_gamma': 1.0, 'min_decay': 0.0, 'norm_elementwise_affine': False, 'norm_eps': 1e-05, 'norm_num_groups': 32, 'norm_type': 'ada_norm_zero', 'num_attention_heads': 16, 'num_embeds_ada_norm': 1000, 'num_layers': 28, 'optimization_step': 280000, 'out_channels': 4, ...}, {})


as you can see second kwargs always empty. ema related config ('optimization_step': 280000,'inv_gamma': 1.0, 'min_decay': 0.0,'decay': 0.9999) is in first kwarys.


However if use 
kwargs=model_cls.from_config(path, return_unused_kwargs=True)
print(kwargs)
(<DiTTransformer2DModel>, {'decay': 0.9999, 'inv_gamma': 1.0, 'min_decay': 0.0, 'optimization_step': 280000, 'power': 0.75, 'update_after_step': 0, 'use_ema_warmup': True, '_class_name': 'DiTTransformer2DModel', '_diffusers_version': '0.30.2'})

second kwargs is ema config! that is correct.

System Info

diffusers 0.30.2

Who can help?

@DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions