@@ -18,6 +18,7 @@ def view_data(
18
18
args ,
19
19
neox_args ,
20
20
batch_fn : callable = None ,
21
+ save_path : str = None ,
21
22
):
22
23
# fake MPU setup (needed to init dataloader without actual GPUs or parallelism)
23
24
mpu .mock_model_parallel ()
@@ -37,12 +38,14 @@ def view_data(
37
38
38
39
if args .mode == "save" :
39
40
# save full batches for each step in the range (WARNING: this may consume lots of storage!)
40
- np .save (f"./dump_data/batch{ i } _bs{ neox_args .train_micro_batch_size_per_gpu } " , batch )
41
+ filename = f"batch{ i } _bs{ neox_args .train_micro_batch_size_per_gpu } "
42
+ np .save (os .path .join (save_path , filename ), batch )
41
43
elif args .mode == "custom" :
42
44
# dump user_defined statistic to a jsonl file (save_fn must return a dict)
43
45
log = batch_fn (batch , i )
44
46
45
- with open ("./dump_data/stats.jsonl" , "w+" ) as f :
47
+ filename = "stats.jsonl"
48
+ with open (os .path .join (save_path , filename ), "w+" ) as f :
46
49
f .write (json .dumps (log ) + "\n " )
47
50
else :
48
51
raise ValueError (f'mode={ mode } not acceptable--please pass either "save" or "custom" !' )
@@ -74,6 +77,12 @@ def view_data(
74
77
choices = ["save" , "custom" ],
75
78
help = "Choose mode: 'save' to log all batches, and 'custom' to use user-defined statistic"
76
79
)
80
+ parser .add_argument (
81
+ "--save_path" ,
82
+ type = str ,
83
+ default = 0 ,
84
+ help = "Save path for files"
85
+ )
77
86
args = parser .parse_known_args ()[0 ]
78
87
79
88
# init neox args
@@ -86,10 +95,11 @@ def save_fn(batch: np.array, iteration: int):
86
95
# define your own logic here
87
96
return {"iteration" : iteration , "text" : None }
88
97
98
+ os .makedirs (args .save_path , exist_ok = True )
89
99
90
100
view_data (
91
101
args ,
92
102
neox_args ,
93
103
batch_fn = save_fn ,
104
+ save_path = args .save_path ,
94
105
)
95
-
0 commit comments