Skip to content

Commit 0a6f962

Browse files
committed
Update discriminator option
1 parent 53fbb51 commit 0a6f962

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

hypernets/experiment/_maker.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def make_experiment(hyper_model_cls,
206206
- nf
207207
optimize_direction : str, optional
208208
Hypernets search reward metric direction, default is detected from reward_metric.
209-
discriminator : instance of hypernets.discriminator.BaseDiscriminator, optional
210-
Discriminator is used to determine whether to continue training
209+
discriminator : instance of hypernets.discriminator.BaseDiscriminator or bool, optional
210+
Discriminator is used to determine whether to continue training, set False to disable it.
211211
hyper_model_options: dict, optional
212212
Options to initlize HyperModel except *reward_metric*, *task*, *callbacks*, *discriminator*.
213213
evaluation_metrics: str, list, or None (default='auto'),
@@ -365,10 +365,14 @@ def append_early_stopping_callbacks(cbs):
365365
report_render = to_report_render_object(report_render, report_render_options)
366366
callbacks.append(MLReportCallback(report_render))
367367

368-
if discriminator is None and cfg.experiment_discriminator is not None and len(cfg.experiment_discriminator) > 0:
368+
if ((discriminator is None or discriminator is True)
369+
and cfg.experiment_discriminator is not None
370+
and len(cfg.experiment_discriminator) > 0):
369371
discriminator = make_discriminator(cfg.experiment_discriminator,
370372
optimize_direction=optimize_direction,
371373
**(cfg.experiment_discriminator_options or {}))
374+
elif discriminator is False:
375+
discriminator = None
372376

373377
if id is None:
374378
hasher = tb.data_hasher()

0 commit comments

Comments
 (0)