Skip to content

Commit 70557b9

Browse files
aihao2000sayakpaulyiyixuxuhlky
committed
update (#7067)
* add data_dir parameter to load_dataset --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: hlky <[email protected]>
1 parent af7efa0 commit 70557b9

File tree

3 files changed

+3
-10
lines changed

3 files changed

+3
-10
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,9 +571,6 @@ def parse_args(input_args=None):
571571
if args.dataset_name is None and args.train_data_dir is None:
572572
raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
573573

574-
if args.dataset_name is not None and args.train_data_dir is not None:
575-
raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
576-
577574
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
578575
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
579576

@@ -615,6 +612,7 @@ def make_train_dataset(args, tokenizer, accelerator):
615612
args.dataset_name,
616613
args.dataset_config_name,
617614
cache_dir=args.cache_dir,
615+
data_dir=args.train_data_dir,
618616
)
619617
else:
620618
if args.train_data_dir is not None:

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,9 +598,6 @@ def parse_args(input_args=None):
598598
if args.dataset_name is None and args.train_data_dir is None:
599599
raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
600600

601-
if args.dataset_name is not None and args.train_data_dir is not None:
602-
raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
603-
604601
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
605602
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
606603

@@ -642,6 +639,7 @@ def get_train_dataset(args, accelerator):
642639
args.dataset_name,
643640
args.dataset_config_name,
644641
cache_dir=args.cache_dir,
642+
data_dir=args.train_data_dir,
645643
)
646644
else:
647645
if args.train_data_dir is not None:

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,6 @@ def parse_args(input_args=None):
483483
# Sanity checks
484484
if args.dataset_name is None and args.train_data_dir is None:
485485
raise ValueError("Need either a dataset name or a training folder.")
486-
487486
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
488487
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
489488

@@ -824,9 +823,7 @@ def load_model_hook(models, input_dir):
824823
if args.dataset_name is not None:
825824
# Downloading and loading a dataset from the hub.
826825
dataset = load_dataset(
827-
args.dataset_name,
828-
args.dataset_config_name,
829-
cache_dir=args.cache_dir,
826+
args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
830827
)
831828
else:
832829
data_files = {}

0 commit comments

Comments
 (0)