Skip to content

Commit eac15bc

Browse files
committed
add tests for einsumdense
1 parent 0a2bdfb commit eac15bc

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

test/pytest/test_einsum_dense.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from pathlib import Path
2+
3+
import keras
4+
import numpy as np
5+
import pytest
6+
7+
from hls4ml.converters import convert_from_keras_model
8+
9+
if keras.__version__ < '3.0.0':
10+
pytest.skip('Only keras v3 is supported for now', allow_module_level=True)
11+
12+
from keras.api.layers import EinsumDense, Input
13+
14+
test_root_path = Path(__file__).parent
15+
16+
17+
@pytest.mark.parametrize('strategy', ['latency'])
18+
@pytest.mark.parametrize('io_type', ['io_parallel'])
19+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis'])
20+
@pytest.mark.parametrize(
21+
'operation',
22+
[
23+
# eq, inp, out
24+
('bi,j->bij', (8,), (8, 7), None),
25+
('bi,j->bij', (8,), (8, 7), 'i'),
26+
('bi,j->bij', (8,), (8, 7), 'j'),
27+
('bi,io->bo', (8,), 7, None),
28+
('...i,oi->...o', (4, 3), (5,), None),
29+
('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), None),
30+
('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), 'aeb'),
31+
('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), 'ab'),
32+
('...abcd,bcde->...aeb', (5, 4, 3, 2), (5, 6, 4), 'a'),
33+
],
34+
)
35+
def test_einsum_dense(backend, io_type, strategy, operation):
36+
eq, inp_shape, out_shape, bias_axes = operation
37+
model = keras.Sequential(
38+
[Input(inp_shape), EinsumDense(eq, output_shape=out_shape, bias_axes=bias_axes, name='einsum_dense')]
39+
)
40+
41+
if bias_axes is not None:
42+
layer = model.get_layer('einsum_dense')
43+
layer.bias.assign(keras.ops.convert_to_tensor(np.random.rand(*layer.bias.shape)))
44+
45+
data = np.random.rand(1000, *inp_shape)
46+
eq_name = eq.replace(',', '_').replace('->', '_') + ('' if bias_axes is None else f'_{bias_axes}')
47+
output_dir = str(test_root_path / f'hls4mlprj_einsum_dense_{eq_name}_{backend}_{io_type}_{strategy}')
48+
hls_config = {'Model': {'Precision': 'ap_fixed<32,8>', 'ReuseFactor': 1}, 'Strategy': strategy}
49+
model_hls = convert_from_keras_model(
50+
model, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type
51+
)
52+
53+
model_hls.compile()
54+
r_keras = model.predict(data, verbose=0, batch_size=1000) # type: ignore
55+
r_hls = model_hls.predict(data).reshape(r_keras.shape) # type: ignore
56+
57+
np.testing.assert_allclose(r_hls, r_keras, atol=1e-2, rtol=0)

0 commit comments

Comments
 (0)