Skip to content

Commit 8b9b682

Browse files
committed
Minor changes, has_eps=False missing for bnb lion
1 parent 61305cc commit 8b9b682

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

timm/optim/_optim_factory.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,9 @@ def create_optimizer(
270270

271271
if param_group_fn:
272272
# run custom fn to generate param groups from nn.Module
273-
parameters = param_group_fn(model_or_params)
273+
params = param_group_fn(model_or_params)
274274
elif layer_decay is not None:
275-
parameters = param_groups_layer_decay(
275+
params = param_groups_layer_decay(
276276
model_or_params,
277277
weight_decay=weight_decay,
278278
layer_decay=layer_decay,
@@ -281,17 +281,17 @@ def create_optimizer(
281281
)
282282
weight_decay = 0.
283283
elif weight_decay and weight_decay_exclude_1d:
284-
parameters = param_groups_weight_decay(
284+
params = param_groups_weight_decay(
285285
model_or_params,
286286
weight_decay=weight_decay,
287287
no_weight_decay_list=no_weight_decay,
288288
)
289289
weight_decay = 0.
290290
else:
291-
parameters = model_or_params.parameters()
291+
params = model_or_params.parameters()
292292
else:
293293
# pass parameters / parameter groups through to optimizer
294-
parameters = model_or_params
294+
params = model_or_params
295295

296296
# Parse optimizer name
297297
opt_split = opt.lower().split('_')
@@ -330,7 +330,7 @@ def create_optimizer(
330330

331331
# Create optimizer
332332
opt_class = self.get_optimizer_class(opt_info, bind_defaults=False)
333-
optimizer = opt_class(parameters, **opt_args)
333+
optimizer = opt_class(params, **opt_args)
334334

335335
# Apply Lookahead if requested
336336
if use_lookahead:
@@ -685,12 +685,14 @@ def _register_bnb_optimizers(registry: OptimizerRegistry) -> None:
685685
'bnblion',
686686
'bitsandbytes.optim.Lion',
687687
description='bitsandbytes Lion',
688+
has_eps=False,
688689
has_betas=True
689690
),
690691
OptimInfo(
691692
'bnblion8bit',
692693
'bitsandbytes.optim.Lion8bit',
693694
description='bitsandbytes 8-bit Lion with dynamic quantization',
695+
has_eps=False,
694696
has_betas=True
695697
),
696698
OptimInfo(

0 commit comments

Comments
 (0)