Skip to content

Commit 5489803

Browse files
committed
keras v3 converter clean-up
1 parent 1fb23b9 commit 5489803

File tree

4 files changed

+53
-19
lines changed

4 files changed

+53
-19
lines changed

hls4ml/converters/keras_v3/_base.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import typing
22
from types import FunctionType
3-
from typing import Any, Callable, Sequence, TypedDict
3+
from typing import Any, Callable, Sequence, TypedDict, overload
44

55

66
class DefaultConfig(TypedDict, total=False):
@@ -26,6 +26,14 @@ class DefaultConfig(TypedDict, total=False):
2626
registry: dict[str, T_kv3_handler] = {}
2727

2828

29+
@overload
30+
def register(cls: type) -> type: ...
31+
32+
33+
@overload
34+
def register(cls: str) -> Callable[[T_kv3_handler], T_kv3_handler]: ...
35+
36+
2937
def register(cls: str | type):
3038
"""Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class.
3139
@@ -51,11 +59,13 @@ def my_layer_handler(layer, inp_tensors, out_tensors):
5159
```
5260
"""
5361

54-
def deco(func: T_kv3_handler):
62+
def deco(func):
5563
if isinstance(cls, str):
5664
registry[cls] = func
5765
for k in getattr(func, 'handles', ()):
5866
registry[k] = func
67+
if isinstance(cls, type):
68+
return cls
5969
return func
6070

6171
if isinstance(cls, type):
@@ -79,7 +89,7 @@ def __call__(
7989
layer: 'keras.Layer',
8090
in_tensors: Sequence['KerasTensor'],
8191
out_tensors: Sequence['KerasTensor'],
82-
):
92+
) -> tuple[dict[str, Any], ...]:
8393
"""Handle a keras layer. Return a tuple of dictionaries, each
8494
dictionary representing a layer (module) in the HLS model. One
8595
layer may correspond one or more dictionaries (e.g., layers with
@@ -114,8 +124,7 @@ def __call__(
114124
dict[str, Any] | tuple[dict[str, Any], ...]
115125
layer configuration(s) for the HLS model to be consumed by
116126
the ModelGraph constructor
117-
""" # noqa: E501
118-
import keras
127+
"""
119128

120129
name = layer.name
121130
class_name = layer.__class__.__name__
@@ -150,12 +159,23 @@ def __call__(
150159
ret = (config,)
151160

152161
# If activation exists, append it
162+
163+
act_config, intermediate_tensor_name = self.maybe_get_activation_config(layer, out_tensors)
164+
if act_config is not None:
165+
ret[0]['output_keras_tensor_names'] = [intermediate_tensor_name]
166+
ret = *ret, act_config
167+
168+
return ret
169+
170+
def maybe_get_activation_config(self, layer, out_tensors):
171+
import keras
172+
153173
activation = getattr(layer, 'activation', None)
174+
name = layer.name
154175
if activation not in (keras.activations.linear, None):
155176
assert len(out_tensors) == 1, f"Layer {name} has more than one output, but has an activation function"
156177
assert isinstance(activation, FunctionType), f"Activation function for layer {name} is not a function"
157178
intermediate_tensor_name = f'{out_tensors[0].name}_activation'
158-
ret[0]['output_keras_tensor_names'] = [intermediate_tensor_name]
159179
act_cls_name = activation.__name__
160180
act_config = {
161181
'class_name': 'Activation',
@@ -164,9 +184,8 @@ def __call__(
164184
'input_keras_tensor_names': [intermediate_tensor_name],
165185
'output_keras_tensor_names': [out_tensors[0].name],
166186
}
167-
ret = *ret, act_config
168-
169-
return ret
187+
return act_config, intermediate_tensor_name
188+
return None, None
170189

171190
def handle(
172191
self,
@@ -175,3 +194,22 @@ def handle(
175194
out_tensors: Sequence['KerasTensor'],
176195
) -> dict[str, Any] | tuple[dict[str, Any], ...]:
177196
return {}
197+
198+
def load_weight(self, layer: 'keras.Layer', key: str):
199+
"""Load a weight from a layer.
200+
201+
Parameters
202+
----------
203+
layer : keras.Layer
204+
The layer to load the weight from.
205+
key : str
206+
The key of the weight to load.
207+
208+
Returns
209+
-------
210+
np.ndarray
211+
The weight.
212+
"""
213+
import keras
214+
215+
return keras.ops.convert_to_numpy(getattr(layer, key))

hls4ml/converters/keras_v3/conv.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from math import ceil
33
from typing import Sequence
44

5-
import numpy as np
6-
75
from ._base import KerasV3LayerHandler, register
86

97
if typing.TYPE_CHECKING:
@@ -40,9 +38,9 @@ def handle(
4038
assert all(isinstance(x, int) for x in in_shape), f"Layer {layer.name} has non-fixed size input: {in_shape}"
4139
assert all(isinstance(x, int) for x in out_shape), f"Layer {layer.name} has non-fixed size output: {out_shape}"
4240

43-
kernel = np.array(layer.kernel)
41+
kernel = self.load_weight(layer, 'kernel')
4442
if layer.use_bias:
45-
bias = np.array(layer.bias)
43+
bias = self.load_weight(layer, 'bias')
4644
else:
4745
bias = None
4846

@@ -113,7 +111,7 @@ def handle(
113111
config['depth_multiplier'] = layer.depth_multiplier
114112
elif isinstance(layer, BaseSeparableConv):
115113
config['depthwise_data'] = kernel
116-
config['pointwise_data'] = np.array(layer.pointwise_kernel)
114+
config['pointwise_data'] = self.load_weight(layer, 'pointwise_kernel')
117115
config['depth_multiplier'] = layer.depth_multiplier
118116
elif isinstance(layer, BaseConv):
119117
config['weight_data'] = kernel

hls4ml/converters/keras_v3/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def handle(
2828
config = {
2929
'data_format': 'channels_last',
3030
'weight_data': kernel,
31-
'bias_data': np.array(layer.bias) if layer.use_bias else None,
31+
'bias_data': self.load_weight(layer, 'bias') if layer.use_bias else None,
3232
'n_out': kernel.shape[1],
3333
'n_in': kernel.shape[0],
3434
}

hls4ml/converters/keras_v3/einsum_dense.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ def handle(
3939
in_tensors: Sequence['KerasTensor'],
4040
out_tensors: Sequence['KerasTensor'],
4141
):
42-
import keras
43-
4442
assert len(in_tensors) == 1, 'EinsumDense layer must have exactly one input tensor'
4543
assert len(out_tensors) == 1, 'EinsumDense layer must have exactly one output tensor'
4644

@@ -56,11 +54,11 @@ def handle(
5654

5755
equation = strip_batch_dim(layer.equation)
5856

59-
kernel = keras.ops.convert_to_numpy(layer.kernel)
57+
kernel = self.load_weight(layer, 'kernel')
6058

6159
bias = None
6260
if layer.bias_axes:
63-
bias = keras.ops.convert_to_numpy(layer.bias)
61+
bias = self.load_weight(layer, 'bias')
6462

6563
return {
6664
'class_name': 'EinsumDense',

0 commit comments

Comments
 (0)