Skip to content

Commit c4efd51

Browse files
committed
add test
1 parent 225f6cb commit c4efd51

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

test/pytest/test_transpose_concat.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,37 @@ def test_accuracy(data, keras_model, hls_model):
5454
y_hls4ml = hls_model.predict(X).reshape(y_keras.shape)
5555
# "accuracy" of hls4ml predictions vs keras
5656
np.testing.assert_allclose(y_keras, y_hls4ml, rtol=0, atol=1e-04, verbose=True)
57+
58+
59+
@pytest.fixture(scope='module')
60+
def keras_model_highdim():
61+
inp = Input(shape=(2, 3, 4, 5, 6), name='input_1')
62+
out = Permute((3, 5, 4, 1, 2))(inp)
63+
model = Model(inputs=inp, outputs=out)
64+
return model
65+
66+
67+
@pytest.fixture(scope='module')
68+
def data_highdim():
69+
X = np.random.randint(-128, 127, (100, 2, 3, 4, 5, 6)) / 128
70+
X = X.astype(np.float32)
71+
return X
72+
73+
74+
@pytest.mark.parametrize('io_type', ['io_stream', 'io_parallel'])
75+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis'])
76+
def test_highdim_permute(data_highdim, keras_model_highdim, io_type, backend):
77+
X = data_highdim
78+
model = keras_model_highdim
79+
80+
model_hls = hls4ml.converters.convert_from_keras_model(
81+
model,
82+
io_type=io_type,
83+
backend=backend,
84+
output_dir=str(test_root_path / f'hls4mlprj_highdim_transpose_{backend}_{io_type}'),
85+
)
86+
model_hls.compile()
87+
y_keras = model.predict(X)
88+
y_hls4ml = model_hls.predict(X).reshape(y_keras.shape) # type: ignore
89+
90+
assert np.all(y_keras == y_hls4ml)

0 commit comments

Comments
 (0)