Skip to content

Commit c1e0ed3

Browse files
authored
Optimize encoder-decoder exports (including Whisper) (#1218)
* Use non-buggy onnx-graphsurgeon via onnxslim for toposort * fp16 conversion improvements * Formatting * Remove unnecessary lines
1 parent cdce6e8 commit c1e0ed3

File tree

3 files changed

+38
-81
lines changed

3 files changed

+38
-81
lines changed

scripts/float16.py

Lines changed: 36 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import warnings
3030
from onnx import helper, numpy_helper
3131
from onnx import onnx_pb as onnx_proto
32+
import onnxslim.third_party.onnx_graphsurgeon as gs
3233

3334

3435
FLOAT32 = 1
@@ -179,8 +180,12 @@ def make_value_info_from_tensor(tensor):
179180
"Max",
180181
"Upsample",
181182
# NEW:
182-
"Cast",
183183
"RandomNormalLike",
184+
# TODO: Ideally, "Cast" nodes should not be here, for the following reasons:
185+
# - It breaks the semantics that the default list contains "ops that are not supported for float16 in ONNX Runtime".
186+
# - When fp32 casts already exist in the model (e.g., for rotary embeddings), this script will insert redundant casts around it.
187+
# However, without it, the graphs produced are invalid. Eventually, we will resolve this.
188+
"Cast",
184189
]
185190

186191

@@ -277,9 +282,14 @@ def convert_float_to_float16(
277282
is_top_level = False # Going to process sub-graph
278283
graph_stack = next_level
279284

280-
sort_topology(model.graph)
281285
remove_unnecessary_cast_node(model.graph)
282286

287+
# Topologically sort the graph
288+
# NOTE: We do not perform another round of optimization as the model is already optimized
289+
graph = gs.import_onnx(model)
290+
graph.toposort()
291+
model = gs.export_onnx(graph)
292+
283293
return model
284294

285295

@@ -311,21 +321,26 @@ def process_graph_input(
311321
graph, graph_input.name
312322
)
313323
for d_node in downstream_nodes:
314-
cast_node_name = graph_input.name + "_cast_to_" + d_node.name
315-
cast_node_output_name = graph_input.name + "_cast_to_" + d_node.name
316-
add_cast_node(
317-
graph,
318-
[graph_input.name],
319-
[cast_node_output_name],
320-
cast_node_name,
321-
FLOAT16,
322-
)
323-
add_new_value_info(
324-
graph,
325-
graph_input,
326-
cast_node_output_name,
327-
onnx_proto.TensorProto.FLOAT16,
328-
)
324+
# More than one node may consume the model input, so we only create
325+
# a single cast node, and then reuse this node when needed.
326+
cast_exists = graph_input.name in global_input_name_dict
327+
if cast_exists:
328+
cast_node_output_name = global_input_name_dict[graph_input.name]
329+
else:
330+
cast_node_output_name = graph_input.name + "_fp16"
331+
add_cast_node(
332+
graph,
333+
[graph_input.name],
334+
[cast_node_output_name],
335+
cast_node_output_name, # Set node name same as output name
336+
FLOAT16,
337+
)
338+
add_new_value_info(
339+
graph,
340+
graph_input,
341+
cast_node_output_name,
342+
onnx_proto.TensorProto.FLOAT16,
343+
)
329344
for i, input_name in enumerate(d_node.input):
330345
if input_name == graph_input.name:
331346
d_node.input[i] = (
@@ -414,8 +429,7 @@ def process_node_in_block_list(
414429
def insert_cast32_before_node(
415430
graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict
416431
):
417-
for i in range(len(node.input)):
418-
input_name = node.input[i]
432+
for i, input_name in enumerate(node.input):
419433
for value_info in itertools.chain(graph.value_info, graph.input):
420434
if input_name == value_info.name:
421435
if (
@@ -443,8 +457,7 @@ def insert_cast32_before_node(
443457
def insert_cast16_after_node(
444458
graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict
445459
):
446-
for i in range(len(node.output)):
447-
output_name = node.output[i]
460+
for i, output_name in enumerate(node.output):
448461
for value_info in itertools.chain(graph.value_info, graph.output):
449462
if output_name == value_info.name:
450463
if (
@@ -693,56 +706,6 @@ def convert_float_to_float16_model_path(
693706
)
694707

695708

696-
def sort_graph_node(graph_proto):
697-
# find the "first" node in Nodes that its input is not any node's output
698-
def find_first_node(output2node_dict):
699-
for node in org_nodes:
700-
is_not_first_node = any(item in output2node_dict for item in node.input)
701-
if not is_not_first_node:
702-
return node
703-
return None
704-
705-
# remove the node from output2node_dict using output as key
706-
def remove_first_node_from_dict2(first_node):
707-
for output in first_node.output:
708-
if output in output2node_dict:
709-
del output2node_dict[output]
710-
711-
org_nodes = graph_proto.node
712-
# create a dict to store output as key and node as value
713-
output2node_dict = {}
714-
for node in org_nodes:
715-
for output in node.output:
716-
output2node_dict[output] = node
717-
718-
# save the final node after sorted
719-
sorted_node = []
720-
# traverse the Nodes to find the first node
721-
while len(output2node_dict) > 0:
722-
first_node = find_first_node(output2node_dict)
723-
sorted_node.append(first_node)
724-
remove_first_node_from_dict2(first_node)
725-
# del node from original nodes list to avoid duplicate traverse
726-
org_nodes.remove(first_node)
727-
728-
for new_node in sorted_node:
729-
graph_proto.node.extend([new_node])
730-
731-
732-
# The input graph should be mode.graph
733-
# Recursively sort the topology for each sub-graph
734-
def sort_topology(graph_proto):
735-
assert isinstance(graph_proto, onnx_proto.GraphProto)
736-
sort_graph_node(graph_proto) # sort global graph
737-
for node in graph_proto.node:
738-
for attr in node.attribute:
739-
if isinstance(attr.g, onnx_proto.GraphProto) and len(attr.g.node) > 0:
740-
sort_topology(attr.g) # sort sub-graph
741-
for g in attr.graphs:
742-
if isinstance(g, onnx_proto.GraphProto):
743-
sort_topology(g) # sort sub-graph
744-
745-
746709
def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto):
747710
# 1. find all cast nodes in the graph
748711
cast_node_list = []
@@ -837,8 +800,8 @@ def get_type(name: str) -> Optional[int]:
837800
else:
838801
if (
839802
downstream_node.op_type == "Cast"
840-
and cast_node.attribute[0].i == 10
841-
and downstream_node.attribute[0].i == 1
803+
and cast_node.attribute[0].i == FLOAT16
804+
and downstream_node.attribute[0].i == FLOAT32
842805
and downstream_node in cast_node_list
843806
and cast_node in cast_node_list
844807
):

scripts/quantize.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from onnxruntime.quantization.registry import IntegerOpsRegistry
1515
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
1616
from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer
17-
import onnx_graphsurgeon as gs
1817

1918
from . import float16
2019
from .utils import check_and_save_model
@@ -221,10 +220,6 @@ def quantize_fp16(
221220
disable_shape_infer=disable_shape_infer,
222221
op_block_list=blocked_ops,
223222
)
224-
225-
graph = gs.import_onnx(model_fp16)
226-
graph.toposort()
227-
model_fp16 = gs.export_onnx(graph)
228223
check_and_save_model(model_fp16, save_path)
229224

230225

scripts/requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
transformers[torch]==4.48.3
1+
transformers[torch]==4.49.0
22
onnxruntime==1.20.1
3-
optimum@git+https://github.com/huggingface/optimum.git@ce533cf1a9e144d4040581947f301dc3f454b279
3+
optimum@git+https://github.com/huggingface/optimum.git@b04feaea78cda58d79b8da67dca3fd0c4ab33435
44
onnx==1.17.0
55
tqdm==4.67.1
66
onnxslim==0.1.48
7-
onnx-graphsurgeon==0.5.5

0 commit comments

Comments
 (0)