1
1
import typing
2
2
from types import FunctionType
3
- from typing import Any , Callable , Sequence , TypedDict
3
+ from typing import Any , Callable , Sequence , TypedDict , overload
4
4
5
5
6
6
class DefaultConfig (TypedDict , total = False ):
@@ -26,6 +26,14 @@ class DefaultConfig(TypedDict, total=False):
26
26
registry : dict [str , T_kv3_handler ] = {}
27
27
28
28
29
+ @overload
30
+ def register (cls : type ) -> type : ...
31
+
32
+
33
+ @overload
34
+ def register (cls : str ) -> Callable [[T_kv3_handler ], T_kv3_handler ]: ...
35
+
36
+
29
37
def register (cls : str | type ):
30
38
"""Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class.
31
39
@@ -51,11 +59,13 @@ def my_layer_handler(layer, inp_tensors, out_tensors):
51
59
```
52
60
"""
53
61
54
- def deco (func : T_kv3_handler ):
62
+ def deco (func ):
55
63
if isinstance (cls , str ):
56
64
registry [cls ] = func
57
65
for k in getattr (func , 'handles' , ()):
58
66
registry [k ] = func
67
+ if isinstance (cls , type ):
68
+ return cls
59
69
return func
60
70
61
71
if isinstance (cls , type ):
@@ -79,7 +89,7 @@ def __call__(
79
89
layer : 'keras.Layer' ,
80
90
in_tensors : Sequence ['KerasTensor' ],
81
91
out_tensors : Sequence ['KerasTensor' ],
82
- ):
92
+ ) -> tuple [ dict [ str , Any ], ...] :
83
93
"""Handle a keras layer. Return a tuple of dictionaries, each
84
94
dictionary representing a layer (module) in the HLS model. One
85
95
layer may correspond one or more dictionaries (e.g., layers with
@@ -114,8 +124,7 @@ def __call__(
114
124
dict[str, Any] | tuple[dict[str, Any], ...]
115
125
layer configuration(s) for the HLS model to be consumed by
116
126
the ModelGraph constructor
117
- """ # noqa: E501
118
- import keras
127
+ """
119
128
120
129
name = layer .name
121
130
class_name = layer .__class__ .__name__
@@ -150,12 +159,23 @@ def __call__(
150
159
ret = (config ,)
151
160
152
161
# If activation exists, append it
162
+
163
+ act_config , intermediate_tensor_name = self .maybe_get_activation_config (layer , out_tensors )
164
+ if act_config is not None :
165
+ ret [0 ]['output_keras_tensor_names' ] = [intermediate_tensor_name ]
166
+ ret = * ret , act_config
167
+
168
+ return ret
169
+
170
+ def maybe_get_activation_config (self , layer , out_tensors ):
171
+ import keras
172
+
153
173
activation = getattr (layer , 'activation' , None )
174
+ name = layer .name
154
175
if activation not in (keras .activations .linear , None ):
155
176
assert len (out_tensors ) == 1 , f"Layer { name } has more than one output, but has an activation function"
156
177
assert isinstance (activation , FunctionType ), f"Activation function for layer { name } is not a function"
157
178
intermediate_tensor_name = f'{ out_tensors [0 ].name } _activation'
158
- ret [0 ]['output_keras_tensor_names' ] = [intermediate_tensor_name ]
159
179
act_cls_name = activation .__name__
160
180
act_config = {
161
181
'class_name' : 'Activation' ,
@@ -164,9 +184,8 @@ def __call__(
164
184
'input_keras_tensor_names' : [intermediate_tensor_name ],
165
185
'output_keras_tensor_names' : [out_tensors [0 ].name ],
166
186
}
167
- ret = * ret , act_config
168
-
169
- return ret
187
+ return act_config , intermediate_tensor_name
188
+ return None , None
170
189
171
190
def handle (
172
191
self ,
@@ -175,3 +194,22 @@ def handle(
175
194
out_tensors : Sequence ['KerasTensor' ],
176
195
) -> dict [str , Any ] | tuple [dict [str , Any ], ...]:
177
196
return {}
197
+
198
+ def load_weight (self , layer : 'keras.Layer' , key : str ):
199
+ """Load a weight from a layer.
200
+
201
+ Parameters
202
+ ----------
203
+ layer : keras.Layer
204
+ The layer to load the weight from.
205
+ key : str
206
+ The key of the weight to load.
207
+
208
+ Returns
209
+ -------
210
+ np.ndarray
211
+ The weight.
212
+ """
213
+ import keras
214
+
215
+ return keras .ops .convert_to_numpy (getattr (layer , key ))
0 commit comments