Skip to content

Commit 0886650

Browse files
committed
har_trees: Specify dataset config via a YAML file
So do not have to modify har_train.py
1 parent c766066 commit 0886650

File tree

2 files changed

+19
-71
lines changed

2 files changed

+19
-71
lines changed

examples/har_trees/har_train.py

Lines changed: 18 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import json
99
import itertools
1010

11+
import yaml
1112
import pandas
1213
import numpy
1314
import structlog
@@ -285,76 +286,23 @@ def export_model(path, out):
285286
cmodel.save(name='harmodel', format='csv', file=out)
286287

287288

289+
def load_config(file_path):
290+
291+
with open(file_path, 'r') as f:
292+
data = yaml.safe_load(f)
293+
return data
294+
288295
def run_pipeline(run, hyperparameters, dataset,
296+
config,
289297
data_dir,
290298
out_dir,
291299
model_settings=dict(),
292300
n_splits=5,
293301
features='timebased',
294302
):
295303

296-
dataset_config = {
297-
'uci_har': dict(
298-
groups=['subject', 'experiment'],
299-
data_columns = ['acc_x', 'acc_y', 'acc_z'],
300-
classes = [
301-
#'STAND_TO_LIE',
302-
#'SIT_TO_LIE',
303-
#'LIE_TO_SIT',
304-
#'STAND_TO_SIT',
305-
#'LIE_TO_STAND',
306-
#'SIT_TO_STAND',
307-
'STANDING', 'LAYING', 'SITTING',
308-
'WALKING', 'WALKING_UPSTAIRS', 'WALKING_DOWNSTAIRS',
309-
],
310-
),
311-
'pamap2': dict(
312-
groups=['subject'],
313-
data_columns = ['hand_acceleration_16g_x', 'hand_acceleration_16g_y', 'hand_acceleration_16g_z'],
314-
classes = [
315-
#'transient',
316-
'walking', 'ironing', 'lying', 'standing',
317-
'Nordic_walking', 'sitting', 'vacuum_cleaning',
318-
'cycling', 'ascending_stairs', 'descending_stairs',
319-
'running', 'rope_jumping',
320-
],
321-
),
322-
'har_exercise_1': dict(
323-
groups=['file'],
324-
data_columns = ['x', 'y', 'z'],
325-
classes = [
326-
#'mixed',
327-
'squat', 'jumpingjack', 'lunge', 'other',
328-
],
329-
),
330-
'toothbrush_hussain2021': dict(
331-
groups=['subject'],
332-
label_column = 'is_brushing',
333-
time_column = 'elapsed',
334-
data_columns = ['acc_x', 'acc_y', 'acc_z'],
335-
#data_columns = ['gravity_x', 'gravity_y', 'gravity_z'],
336-
#data_columns = ['motion_x', 'motion_y', 'motion_z'],
337-
classes = [
338-
#'mixed',
339-
'True', 'False',
340-
],
341-
),
342-
'toothbrush_jonnor': dict(
343-
groups=['session'],
344-
label_column = 'is_brushing',
345-
time_column = 'time',
346-
data_columns = ['x', 'y', 'z'],
347-
#data_columns = ['gravity_x', 'gravity_y', 'gravity_z'],
348-
#data_columns = ['motion_x', 'motion_y', 'motion_z'],
349-
classes = [
350-
#'mixed',
351-
'True', 'False',
352-
],
353-
),
354-
}
304+
dataset_config = load_config(config)
355305

356-
if not dataset in dataset_config.keys():
357-
raise ValueError(f"Unknown dataset {dataset}")
358306

359307
if not os.path.exists(out_dir):
360308
os.makedirs(out_dir)
@@ -368,12 +316,12 @@ def run_pipeline(run, hyperparameters, dataset,
368316
#print(data.index.names)
369317
#print(data.columns)
370318

371-
groups = dataset_config[dataset]['groups']
372-
data_columns = dataset_config[dataset]['data_columns']
373-
enabled_classes = dataset_config[dataset]['classes']
374-
label_column = dataset_config[dataset].get('label_column', 'activity')
375-
time_column = dataset_config[dataset].get('time_column', 'time')
376-
sensitivity = dataset_config[dataset].get('sensitivity', 4.0)
319+
groups = dataset_config['groups']
320+
data_columns = dataset_config['data_columns']
321+
enabled_classes = dataset_config['classes']
322+
label_column = dataset_config.get('label_column', 'activity')
323+
time_column = dataset_config.get('time_column', 'time')
324+
sensitivity = dataset_config.get('sensitivity', 4.0)
377325

378326
data[label_column] = data[label_column].astype(str)
379327

@@ -486,6 +434,8 @@ def parse():
486434

487435
parser.add_argument('--dataset', type=str, default='uci_har',
488436
help='Which dataset to use')
437+
parser.add_argument('--config', type=str, default='data/configurations/uci_har.yaml',
438+
help='Which dataset/training config to use')
489439
parser.add_argument('--data-dir', metavar='DIRECTORY', type=str, default='./data/processed',
490440
help='Where the input data is stored')
491441
parser.add_argument('--out-dir', metavar='DIRECTORY', type=str, default='./',
@@ -506,9 +456,6 @@ def parse():
506456
def main():
507457

508458
args = parse()
509-
dataset = args.dataset
510-
out_dir = args.out_dir
511-
data_dir = args.data_dir
512459

513460
run_id = uuid.uuid4().hex.upper()[0:6]
514461

@@ -524,6 +471,7 @@ def main():
524471
}
525472

526473
results = run_pipeline(dataset=args.dataset,
474+
config=args.config,
527475
out_dir=args.out_dir,
528476
data_dir=args.data_dir,
529477
run=run_id,

examples/har_trees/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ setuptools
77
structlog
88
matplotlib
99
mpremote
10-
10+
pyyaml

0 commit comments

Comments
 (0)