1
1
from __future__ import annotations
2
2
3
3
import dataclasses
4
- from typing import Dict , List , Set
4
+ from typing import Dict , List
5
5
6
6
import torch
7
7
from torch .onnx ._internal .fx import _pass , diagnostics
8
8
9
9
10
10
@dataclasses .dataclass
11
11
class UnsupportedFxNodesAnalysisResult (_pass .AnalysisResult ):
12
- unsupported_op_to_target_mapping : Dict [str , Set [str ]]
12
+ unsupported_op_to_target_mapping : Dict [str , Dict [str , None ]]
13
13
14
14
15
15
class UnsupportedFxNodesAnalysis (_pass .Analysis ):
@@ -25,7 +25,7 @@ def _lint(
25
25
return
26
26
27
27
normalized_op_targets_map = {
28
- op : [ str ( target ) for target in targets ]
28
+ op : list ( targets . keys ())
29
29
for op , targets in analysis_result .unsupported_op_to_target_mapping .items ()
30
30
}
31
31
@@ -63,12 +63,12 @@ def analyze(
63
63
except diagnostics .RuntimeErrorWithDiagnostic as e :
64
64
unsupported_nodes .append (node )
65
65
66
- op_to_target_mapping : Dict [str , Set [str ]] = {}
66
+ op_to_target_mapping : Dict [str , Dict [str , None ]] = {}
67
67
68
68
for node in unsupported_nodes :
69
69
op = node .op
70
70
target = node .target
71
- op_to_target_mapping .setdefault (op , set ()). add (str (target ))
71
+ op_to_target_mapping .setdefault (op , {}). setdefault (str (target ), None )
72
72
73
73
analysis_result = UnsupportedFxNodesAnalysisResult (op_to_target_mapping )
74
74
self ._lint (analysis_result , diagnostic_level )
0 commit comments