Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

support the random generator ops and fix the issues on tf.op #453

Merged
merged 4 commits into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 65 additions & 204 deletions keras2onnx/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,166 +7,19 @@
import numbers
import tensorflow
import numpy as np
from typing import Union
from onnx import numpy_helper, mapping

from keras2onnx._consts import TYPES, NCHW_TO_NHWC, NHWC_TO_NCHW, HWCN_TO_NCHW
from onnx import numpy_helper
from .common.utils import count_dynamic_dim
from .common.onnx_ops import apply_identity, apply_reshape, OnnxOperatorBuilder
from .funcbook import converter_func, set_converters
from .proto import keras
from .proto.tfcompat import is_tf2


class TYPES:
# tf-node types:
Identity = 'Identity'
Const = 'Const'
AddN = 'AddN'
Any = 'Any'
All = 'All'
BatchMatMul = 'BatchMatMul'
BatchMatMulV2 = 'BatchMatMulV2'
BatchToSpaceND = 'BatchToSpaceND'
BiasAdd = 'BiasAdd'
BiasAddV1 = 'BiasAddV1'
Cast = 'Cast'
ConcatV2 = 'ConcatV2'
Conv1D = 'Conv1D'
Conv2D = 'Conv2D'
Cumsum = 'Cumsum'
DepthwiseConv2dNative = 'DepthwiseConv2dNative'
Einsum = 'Einsum'
ExpandDims = 'ExpandDims'
Fill = 'Fill'
FloorDiv = 'FloorDiv'
FusedBatchNorm = 'FusedBatchNorm'
FusedBatchNormV2 = 'FusedBatchNormV2'
FusedBatchNormV3 = 'FusedBatchNormV3'
GatherNd = 'GatherNd'
GatherV2 = 'GatherV2'
GreaterEqual = 'GreaterEqual'
LessEqual = 'LessEqual'
LogicalAnd = 'LogicalAnd'
LogicalNot = 'LogicalNot'
LogSoftmax = 'LogSoftmax'
MatMul = 'MatMul'
Max = 'Max'
Maximum = 'Maximum'
Mean = 'Mean'
Min = 'Min'
Minimum = 'Minimum'
NonMaxSuppressionV2 = 'NonMaxSuppressionV2'
NonMaxSuppressionV3 = 'NonMaxSuppressionV3'
NotEqual = 'NotEqual'
OneHot = 'OneHot'
Pack = 'Pack'
Pad = 'Pad'
PadV2 = 'PadV2'
Prod = 'Prod'
Range = 'Range'
ReadVariableOp = 'ReadVariableOp'
Reshape = 'Reshape'
ResizeBilinear = 'ResizeBilinear'
ResizeNearestNeighbor = 'ResizeNearestNeighbor'
Round = 'Round'
Rsqrt = 'Rsqrt'
ScatterNd = 'ScatterNd'
Select = 'Select'
Shape = 'Shape'
Size = 'Size'
Slice = 'Slice'
Softmax = 'Softmax'
SpaceToBatchND = 'SpaceToBatchND'
Split = 'Split'
SplitV = 'SplitV'
Square = 'Square'
SquaredDifference = 'SquaredDifference'
Squeeze = 'Squeeze'
StridedSlice = 'StridedSlice'
Sum = 'Sum'
Tile = 'Tile'
TopKV2 = 'TopKV2'
Transpose = 'Transpose'
Unpack = 'Unpack'
VarHandleOp = 'VarHandleOp'
VariableV2 = 'VariableV2'
Where = 'Where'
ZerosLike = 'ZerosLike'

# converter internal types:
TD_Reshape = '_reshape_timedistributed'


def is_placeholder_node(node):
return len(node.inputs) == 0 and node.type in ['Placeholder', "PlaceholderV2", 'PlaceholderWithDefault'] and \
node.outputs[0].dtype.name != 'resource'


def tsname_to_node(name):
return name.split(':')[0]


NCHW_TO_NHWC = [0, 2, 3, 1]
NHWC_TO_NCHW = [0, 3, 1, 2]
HWCN_TO_NCHW = [3, 2, 0, 1]
NCHW_TO_HWCN = [2, 3, 1, 0]


def _is_nhwc(node):
return node.get_attr('data_format') == b'NHWC'


_MAX_FOLDING_NODE_NUMBER = 15


def _count_input_nodes(tensor): # type: (tensorflow.Tensor)->int
nodes_to_keep = set()
node_inputs = [tensor.op]
while node_inputs:
nd_ = node_inputs[0]
del node_inputs[0]
if nd_ in nodes_to_keep:
continue

if is_placeholder_node(nd_):
return -1
nodes_to_keep.add(nd_)
if len(nodes_to_keep) >= _MAX_FOLDING_NODE_NUMBER:
return -1

node_inputs.extend(in_.op for in_ in nd_.inputs)

return len(nodes_to_keep)


def _cal_tensor_value(tensor): # type: (tensorflow.Tensor)->Union[np.ndarray, None]
if _count_input_nodes(tensor) < 0:
return None

node = tensor.op
if node.type in ["Const", "ConstV2"]:
make_ndarray = tensorflow.make_ndarray
np_arr = make_ndarray(node.get_attr("value"))
return np_arr
else:
try:
cls_sess = tensorflow.Session if hasattr(tensorflow, 'Session') else tensorflow.compat.v1.Session
with cls_sess(graph=node.graph) as sess:
np_arr = sess.run(tensor)
return np_arr
except (ValueError, tensorflow.errors.InvalidArgumentError, tensorflow.errors.OpError):
return None


def _cal_tensor_shape(tensor):
if len(tensor.shape) > 0 and hasattr(tensor.shape[0], 'value'):
return [x.value for x in tensor.shape]
else:
return list(tensor.shape)


def _to_onnx_type(dt_type):
# TensorFlow data types integrate seamlessly with numpy
return mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dt_type.as_numpy_dtype)]
from ._tf_utils import (is_nhwc as _is_nhwc,
tf_attrs_to_onnx as _to_onnx_attrs,
cal_tensor_value as _cal_tensor_value,
cal_tensor_shape as _cal_tensor_shape,
to_onnx_type as _to_onnx_type)


def default_convert(scope, operator, container):
Expand Down Expand Up @@ -456,7 +309,7 @@ def convert_tf_cum_sum(scope, operator, container):
raise ValueError("CumSum op is not supported for opset < 11")
node = operator.raw_operator
oopb = OnnxOperatorBuilder(container, scope)
attrs = {'exclusive': node.get_attr('exclusive'), 'reverse': node.get_attr('reverse') }
attrs = {'exclusive': node.get_attr('exclusive'), 'reverse': node.get_attr('reverse')}
oopb.add_node_with_output('CumSum',
operator.input_full_names,
operator.output_full_names,
Expand Down Expand Up @@ -862,6 +715,7 @@ def convert_tf_fill(scope, operator, container):
name=operator.full_name,
**attrs)


@converter_func(TYPES.FloorDiv)
def convert_tf_floor_div(scope, operator, container):
node = operator.raw_operator
Expand All @@ -878,6 +732,7 @@ def convert_tf_floor_div(scope, operator, container):
operator.outputs[0].full_name,
name=operator.full_name)


@converter_func(TYPES.FusedBatchNorm)
def convert_tf_fused_batch_norm(scope, operator, container):
_convert_tf_fused_batch_norm_core(scope, operator, container)
Expand Down Expand Up @@ -981,7 +836,8 @@ def convert_tf_logsoftmax(scope, operator, container):
oopb = OnnxOperatorBuilder(container, scope)
node = operator.raw_operator
logits_rank = len(_cal_tensor_shape(node.inputs[0]))
axis = node.get_attr('axis') if hasattr(node, 'axis') else -1
attrs = _to_onnx_attrs(node)
axis = attrs['axis'] if hasattr(attrs, 'axis') else -1
if operator.target_opset < 11 and axis < 0:
axis += logits_rank

Expand Down Expand Up @@ -1263,7 +1119,8 @@ def _convert_tf_pad(scope, operator, container):
desired_shape=[-1])[0]
else:
paddings = np.array(_cal_tensor_value(node.inputs[1])).transpose().flatten()
mode = node.get_attr("mode") if hasattr(node, 'mode') else None
attrs = _to_onnx_attrs(node)
mode = attrs["mode"] if hasattr(attrs, 'mode') else None

if mode:
mode = mode.s.decode("utf-8").lower()
Expand Down Expand Up @@ -1749,7 +1606,6 @@ def convert_tf_not_equal(scope, operator, container):
name=operator.full_name + '_not')



@converter_func(TYPES.OneHot)
def convert_tf_one_hot(scope, operator, container):
if operator.target_opset < 9:
Expand Down Expand Up @@ -1850,7 +1706,8 @@ def convert_tf_softmax(scope, operator, container):
oopb = OnnxOperatorBuilder(container, scope)
node = operator.raw_operator
logits_rank = len(_cal_tensor_shape(node.inputs[0]))
axis = node.get_attr('axis') if hasattr(node, 'axis') else -1
attrs = _to_onnx_attrs(node)
axis = attrs['axis'] if hasattr(attrs, 'axis') else -1
if operator.target_opset < 11 and axis < 0:
axis += logits_rank

Expand Down Expand Up @@ -2129,58 +1986,61 @@ def convert_tf_where(scope, operator, container):
name=operator.full_name + '_transpose',
perm=list(reversed(range(len(node.outputs[0].shape)))))


@converter_func(TYPES.ZerosLike)
def convert_tf_zeros_like(scope, operator, container):
node = operator.raw_operator
oopb = OnnxOperatorBuilder(container, scope)
dtype = _to_onnx_type(node.outputs[0].dtype)
oopb.apply_op_with_output('apply_mul',
[ operator.inputs[0].full_name,
('_zero', dtype, np.zeros((), dtype=np.int64)) ],
[operator.inputs[0].full_name,
('_zero', dtype, np.zeros((), dtype=np.int64))],
operator.outputs[0].full_name,
name=operator.full_name)

direct_ops = {"Abs": ("apply_abs",),
"Acos": 7,
"Acosh": 9,
"Add": ("apply_add",),
"AddV2": ("apply_add",),
"Asin": 7,
"Asinh": 9,
"Atan": 7,
"Atanh": 9,
"Ceil": ("apply_ceil",),
"Cos": 7,
"Cosh": 9,
"Div": ("apply_div",),
"Elu": ("apply_elu",),
"Equal": 7,
"Erf": 9,
"Exp": ("apply_exp",),
"Floor": ("apply_floor",),
"Greater": ("apply_greater",),
"Less": ("apply_less",),
"Log": ("apply_log",),
"Mul": ("apply_mul",),
"Neg": ("apply_neg",),
"Pow": ("apply_pow",),
"RealDiv": ("apply_div",),
"Reciprocal": ("apply_reciprocal",),
"Relu": ("apply_relu",),
"Sigmoid": ("apply_sigmoid",),
"Sin": 7,
"Sinh": 9,
"Softplus": 1,
"Softsign": 1,
"Sqrt": ("apply_sqrt",),
"StopGradient": ("apply_identity",),
"Sub": ("apply_sub",),
"Tan": 7,
"Tanh": ("apply_tanh",)
}


def tf_op_convert(scope, operator, container):

direct_ops = {
"Abs": ("apply_abs",),
"Acos": 7,
"Acosh": 9,
"Add": ("apply_add",),
"AddV2": ("apply_add",),
"Asin": 7,
"Asinh": 9,
"Atan": 7,
"Atanh": 9,
"Ceil": ("apply_ceil",),
"Cos": 7,
"Cosh": 9,
"Div": ("apply_div",),
"Elu": ("apply_elu",),
"Equal": 7,
"Erf": 9,
"Exp": ("apply_exp",),
"Floor": ("apply_floor",),
"Greater": ("apply_greater",),
"Less": ("apply_less",),
"Log": ("apply_log",),
"Mul": ("apply_mul",),
"Neg": ("apply_neg",),
"Pow": ("apply_pow",),
"RealDiv": ("apply_div",),
"Reciprocal": ("apply_reciprocal",),
"Relu": ("apply_relu",),
"Sigmoid": ("apply_sigmoid",),
"Sin": 7,
"Sinh": 9,
"Softplus": 1,
"Softsign": 1,
"Sqrt": ("apply_sqrt",),
"StopGradient": ("apply_identity",),
"Sub": ("apply_sub",),
"Tan": 7,
"Tanh": ("apply_tanh",)
}


def direct_tf_op_convert(scope, operator, container):
oopb = OnnxOperatorBuilder(container, scope)
type = operator.raw_operator.type
item = direct_ops[type]
Expand All @@ -2201,4 +2061,5 @@ def tf_op_convert(scope, operator, container):
)


set_converters({k: tf_op_convert for k in direct_ops.keys()})
def register_direct_tf_ops():
set_converters({k: direct_tf_op_convert for k in direct_ops.keys()})
Loading