|
29 | 29 | import warnings
|
30 | 30 | from onnx import helper, numpy_helper
|
31 | 31 | from onnx import onnx_pb as onnx_proto
|
| 32 | +import onnxslim.third_party.onnx_graphsurgeon as gs |
32 | 33 |
|
33 | 34 |
|
34 | 35 | FLOAT32 = 1
|
@@ -179,8 +180,12 @@ def make_value_info_from_tensor(tensor):
|
179 | 180 | "Max",
|
180 | 181 | "Upsample",
|
181 | 182 | # NEW:
|
182 |
| - "Cast", |
183 | 183 | "RandomNormalLike",
|
| 184 | + # TODO: Ideally, "Cast" nodes should not be here, for the following reasons: |
| 185 | + # - It breaks the semantics that the default list contains "ops that are not supported for float16 in ONNX Runtime". |
| 186 | + # - When fp32 casts already exist in the model (e.g., for rotary embeddings), this script will insert redundant casts around it. |
| 187 | + # However, without it, the graphs produced are invalid. Eventually, we will resolve this. |
| 188 | + "Cast", |
184 | 189 | ]
|
185 | 190 |
|
186 | 191 |
|
@@ -277,9 +282,14 @@ def convert_float_to_float16(
|
277 | 282 | is_top_level = False # Going to process sub-graph
|
278 | 283 | graph_stack = next_level
|
279 | 284 |
|
280 |
| - sort_topology(model.graph) |
281 | 285 | remove_unnecessary_cast_node(model.graph)
|
282 | 286 |
|
| 287 | + # Topologically sort the graph |
| 288 | + # NOTE: We do not perform another round of optimization as the model is already optimized |
| 289 | + graph = gs.import_onnx(model) |
| 290 | + graph.toposort() |
| 291 | + model = gs.export_onnx(graph) |
| 292 | + |
283 | 293 | return model
|
284 | 294 |
|
285 | 295 |
|
@@ -311,21 +321,26 @@ def process_graph_input(
|
311 | 321 | graph, graph_input.name
|
312 | 322 | )
|
313 | 323 | for d_node in downstream_nodes:
|
314 |
| - cast_node_name = graph_input.name + "_cast_to_" + d_node.name |
315 |
| - cast_node_output_name = graph_input.name + "_cast_to_" + d_node.name |
316 |
| - add_cast_node( |
317 |
| - graph, |
318 |
| - [graph_input.name], |
319 |
| - [cast_node_output_name], |
320 |
| - cast_node_name, |
321 |
| - FLOAT16, |
322 |
| - ) |
323 |
| - add_new_value_info( |
324 |
| - graph, |
325 |
| - graph_input, |
326 |
| - cast_node_output_name, |
327 |
| - onnx_proto.TensorProto.FLOAT16, |
328 |
| - ) |
| 324 | + # More than one node may consume the model input, so we only create |
| 325 | + # a single cast node, and then reuse this node when needed. |
| 326 | + cast_exists = graph_input.name in global_input_name_dict |
| 327 | + if cast_exists: |
| 328 | + cast_node_output_name = global_input_name_dict[graph_input.name] |
| 329 | + else: |
| 330 | + cast_node_output_name = graph_input.name + "_fp16" |
| 331 | + add_cast_node( |
| 332 | + graph, |
| 333 | + [graph_input.name], |
| 334 | + [cast_node_output_name], |
| 335 | + cast_node_output_name, # Set node name same as output name |
| 336 | + FLOAT16, |
| 337 | + ) |
| 338 | + add_new_value_info( |
| 339 | + graph, |
| 340 | + graph_input, |
| 341 | + cast_node_output_name, |
| 342 | + onnx_proto.TensorProto.FLOAT16, |
| 343 | + ) |
329 | 344 | for i, input_name in enumerate(d_node.input):
|
330 | 345 | if input_name == graph_input.name:
|
331 | 346 | d_node.input[i] = (
|
@@ -414,8 +429,7 @@ def process_node_in_block_list(
|
414 | 429 | def insert_cast32_before_node(
|
415 | 430 | graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict
|
416 | 431 | ):
|
417 |
| - for i in range(len(node.input)): |
418 |
| - input_name = node.input[i] |
| 432 | + for i, input_name in enumerate(node.input): |
419 | 433 | for value_info in itertools.chain(graph.value_info, graph.input):
|
420 | 434 | if input_name == value_info.name:
|
421 | 435 | if (
|
@@ -443,8 +457,7 @@ def insert_cast32_before_node(
|
443 | 457 | def insert_cast16_after_node(
|
444 | 458 | graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict
|
445 | 459 | ):
|
446 |
| - for i in range(len(node.output)): |
447 |
| - output_name = node.output[i] |
| 460 | + for i, output_name in enumerate(node.output): |
448 | 461 | for value_info in itertools.chain(graph.value_info, graph.output):
|
449 | 462 | if output_name == value_info.name:
|
450 | 463 | if (
|
@@ -693,56 +706,6 @@ def convert_float_to_float16_model_path(
|
693 | 706 | )
|
694 | 707 |
|
695 | 708 |
|
696 |
| -def sort_graph_node(graph_proto): |
697 |
| - # find the "first" node in Nodes that its input is not any node's output |
698 |
| - def find_first_node(output2node_dict): |
699 |
| - for node in org_nodes: |
700 |
| - is_not_first_node = any(item in output2node_dict for item in node.input) |
701 |
| - if not is_not_first_node: |
702 |
| - return node |
703 |
| - return None |
704 |
| - |
705 |
| - # remove the node from output2node_dict using output as key |
706 |
| - def remove_first_node_from_dict2(first_node): |
707 |
| - for output in first_node.output: |
708 |
| - if output in output2node_dict: |
709 |
| - del output2node_dict[output] |
710 |
| - |
711 |
| - org_nodes = graph_proto.node |
712 |
| - # create a dict to store output as key and node as value |
713 |
| - output2node_dict = {} |
714 |
| - for node in org_nodes: |
715 |
| - for output in node.output: |
716 |
| - output2node_dict[output] = node |
717 |
| - |
718 |
| - # save the final node after sorted |
719 |
| - sorted_node = [] |
720 |
| - # traverse the Nodes to find the first node |
721 |
| - while len(output2node_dict) > 0: |
722 |
| - first_node = find_first_node(output2node_dict) |
723 |
| - sorted_node.append(first_node) |
724 |
| - remove_first_node_from_dict2(first_node) |
725 |
| - # del node from original nodes list to avoid duplicate traverse |
726 |
| - org_nodes.remove(first_node) |
727 |
| - |
728 |
| - for new_node in sorted_node: |
729 |
| - graph_proto.node.extend([new_node]) |
730 |
| - |
731 |
| - |
732 |
| -# The input graph should be mode.graph |
733 |
| -# Recursively sort the topology for each sub-graph |
734 |
| -def sort_topology(graph_proto): |
735 |
| - assert isinstance(graph_proto, onnx_proto.GraphProto) |
736 |
| - sort_graph_node(graph_proto) # sort global graph |
737 |
| - for node in graph_proto.node: |
738 |
| - for attr in node.attribute: |
739 |
| - if isinstance(attr.g, onnx_proto.GraphProto) and len(attr.g.node) > 0: |
740 |
| - sort_topology(attr.g) # sort sub-graph |
741 |
| - for g in attr.graphs: |
742 |
| - if isinstance(g, onnx_proto.GraphProto): |
743 |
| - sort_topology(g) # sort sub-graph |
744 |
| - |
745 |
| - |
746 | 709 | def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto):
|
747 | 710 | # 1. find all cast nodes in the graph
|
748 | 711 | cast_node_list = []
|
@@ -837,8 +800,8 @@ def get_type(name: str) -> Optional[int]:
|
837 | 800 | else:
|
838 | 801 | if (
|
839 | 802 | downstream_node.op_type == "Cast"
|
840 |
| - and cast_node.attribute[0].i == 10 |
841 |
| - and downstream_node.attribute[0].i == 1 |
| 803 | + and cast_node.attribute[0].i == FLOAT16 |
| 804 | + and downstream_node.attribute[0].i == FLOAT32 |
842 | 805 | and downstream_node in cast_node_list
|
843 | 806 | and cast_node in cast_node_list
|
844 | 807 | ):
|
|
0 commit comments