diff --git a/hls4ml/model/optimizer/passes/move_scales.py b/hls4ml/model/optimizer/passes/move_scales.py index 8fba1ec405..03bb0f3b77 100644 --- a/hls4ml/model/optimizer/passes/move_scales.py +++ b/hls4ml/model/optimizer/passes/move_scales.py @@ -12,6 +12,9 @@ from hls4ml.model.layers import ApplyAlpha, Constant, Conv, MatMul, Merge from hls4ml.model.optimizer import OptimizerPass +# These attributes should not be copied. (Should add the output name to this) +_attrs_not_to_copy = ['trace', 'precision', 'scale', 'bias', 'scale_data', 'bias_data'] + class ScaleDownMatMul(OptimizerPass): '''Shift an ApplyAlpha below a MatMul''' @@ -62,7 +65,7 @@ def transform(self, model, node): output = node.get_output_variable() # to remove warning, since these get set again - new_attrs = {k: v for k, v in apply_alpha.attributes.items() if k not in ('trace', 'precision')} + new_attrs = {k: v for k, v in apply_alpha.attributes.items() if k not in _attrs_not_to_copy + apply_alpha.outputs} can_propagate = False if not bias.shape and bias == 0: @@ -258,7 +261,7 @@ def transform(self, model, node): return False # to remove warning, since these get set again - new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')} + new_attrs = {k: v for k, v in in0.attributes.items() if k not in _attrs_not_to_copy + in0.outputs} new_name = in0.name model.remove_node(in0) @@ -305,7 +308,7 @@ def transform(self, model, node): return False # to remove warning, since these get set again - new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')} + new_attrs = {k: v for k, v in in0.attributes.items() if k not in _attrs_not_to_copy + in0.outputs} new_name = in1.name model.remove_node(in1) @@ -329,7 +332,7 @@ def transform(self, model, node): return False # to remove warning, since these get set again - new_attrs = {k: v for k, v in in2.attributes.items() if k not in ('trace', 'precision')} + new_attrs = {k: v for k, v in in2.attributes.items() if k not in _attrs_not_to_copy + in2.outputs} new_name = in2.name model.remove_node(in2) @@ -391,7 +394,7 @@ def transform(self, model, node): return False # to remove warning, since these get set again - new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')} + new_attrs = {k: v for k, v in in0.attributes.items() if k not in _attrs_not_to_copy + in0.outputs} new_name = in1.name model.remove_node(in0) model.remove_node(in1) @@ -415,7 +418,7 @@ def transform(self, model, node): return False # to remove warning, since these get set again - new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')} + new_attrs = {k: v for k, v in in0.attributes.items() if k not in _attrs_not_to_copy + in0.outputs} new_name = in0.name model.remove_node(in0) model.remove_node(in2) @@ -442,7 +445,7 @@ def transform(self, model, node): return False # to remove warning, since these get set again - new_attrs = {k: v for k, v in in1.attributes.items() if k not in ('trace', 'precision')} + new_attrs = {k: v for k, v in in1.attributes.items() if k not in _attrs_not_to_copy + in1.outputs} new_name = in1.name model.remove_node(in1) model.remove_node(in2) @@ -478,7 +481,7 @@ def transform(self, model, node): return False # to remove warning, since these get set again - new_attrs = {k: v for k, v in in0.attributes.items() if k not in ('trace', 'precision')} + new_attrs = {k: v for k, v in in0.attributes.items() if k not in _attrs_not_to_copy + in0.outputs} new_name = in0.name model.remove_node(in0) model.remove_node(in1)