6
6
import coremltools as ct
7
7
8
8
import logging
9
+ import json
9
10
10
11
logging .basicConfig ()
11
12
logger = logging .getLogger (__name__ )
@@ -21,14 +22,47 @@ class CoreMLModel:
21
22
""" Wrapper for running CoreML models using coremltools
22
23
"""
23
24
24
- def __init__ (self , model_path , compute_unit ):
25
- assert os .path .exists (model_path ) and model_path .endswith (".mlpackage" )
25
+ def __init__ (self , model_path , compute_unit , sources = 'packages' ):
26
26
27
27
logger .info (f"Loading { model_path } " )
28
28
29
29
start = time .time ()
30
- self .model = ct .models .MLModel (
31
- model_path , compute_units = ct .ComputeUnit [compute_unit ])
30
+ if sources == 'packages' :
31
+ assert os .path .exists (model_path ) and model_path .endswith (".mlpackage" )
32
+
33
+ self .model = ct .models .MLModel (
34
+ model_path , compute_units = ct .ComputeUnit [compute_unit ])
35
+ DTYPE_MAP = {
36
+ 65552 : np .float16 ,
37
+ 65568 : np .float32 ,
38
+ 131104 : np .int32 ,
39
+ }
40
+ self .expected_inputs = {
41
+ input_tensor .name : {
42
+ "shape" : tuple (input_tensor .type .multiArrayType .shape ),
43
+ "dtype" : DTYPE_MAP [input_tensor .type .multiArrayType .dataType ],
44
+ }
45
+ for input_tensor in self .model ._spec .description .input
46
+ }
47
+ elif sources == 'compiled' :
48
+ assert os .path .exists (model_path ) and model_path .endswith (".mlmodelc" )
49
+
50
+ self .model = ct .models .CompiledMLModel (model_path , ct .ComputeUnit [compute_unit ])
51
+
52
+ # Grab expected inputs from metadata.json
53
+ with open (os .path .join (model_path , 'metadata.json' ), 'r' ) as f :
54
+ config = json .load (f )[0 ]
55
+
56
+ self .expected_inputs = {
57
+ input_tensor ['name' ]: {
58
+ "shape" : tuple (eval (input_tensor ['shape' ])),
59
+ "dtype" : np .dtype (input_tensor ['dataType' ].lower ()),
60
+ }
61
+ for input_tensor in config ['inputSchema' ]
62
+ }
63
+ else :
64
+ raise ValueError (f'Expected `packages` or `compiled` for sources, received { sources } ' )
65
+
32
66
load_time = time .time () - start
33
67
logger .info (f"Done. Took { load_time :.1f} seconds." )
34
68
@@ -38,21 +72,6 @@ def __init__(self, model_path, compute_unit):
38
72
"The Swift package we provide uses precompiled Core ML models (.mlmodelc) to avoid compile-on-load."
39
73
)
40
74
41
-
42
- DTYPE_MAP = {
43
- 65552 : np .float16 ,
44
- 65568 : np .float32 ,
45
- 131104 : np .int32 ,
46
- }
47
-
48
- self .expected_inputs = {
49
- input_tensor .name : {
50
- "shape" : tuple (input_tensor .type .multiArrayType .shape ),
51
- "dtype" : DTYPE_MAP [input_tensor .type .multiArrayType .dataType ],
52
- }
53
- for input_tensor in self .model ._spec .description .input
54
- }
55
-
56
75
def _verify_inputs (self , ** kwargs ):
57
76
for k , v in kwargs .items ():
58
77
if k in self .expected_inputs :
@@ -72,7 +91,7 @@ def _verify_inputs(self, **kwargs):
72
91
f"Expected shape { expected_shape } , got { v .shape } for input: { k } "
73
92
)
74
93
else :
75
- raise ValueError ("Received unexpected input kwarg: {k}" )
94
+ raise ValueError (f "Received unexpected input kwarg: { k } " )
76
95
77
96
def __call__ (self , ** kwargs ):
78
97
self ._verify_inputs (** kwargs )
@@ -82,21 +101,77 @@ def __call__(self, **kwargs):
82
101
LOAD_TIME_INFO_MSG_TRIGGER = 10 # seconds
83
102
84
103
85
- def _load_mlpackage (submodule_name , mlpackages_dir , model_version ,
86
- compute_unit ):
87
- """ Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
104
+ def get_resource_type (resources_dir : str ) -> str :
105
+ """
106
+ Detect resource type based on filepath extensions.
107
+ returns:
108
+ `packages`: for .mlpackage resources
109
+ 'compiled`: for .mlmodelc resources
88
110
"""
89
- logger . info ( f"Loading { submodule_name } mlpackage" )
111
+ directories = [ f for f in os . listdir ( resources_dir ) if os . path . isdir ( os . path . join ( resources_dir , f ))]
90
112
91
- fname = f"Stable_Diffusion_version_{ model_version } _{ submodule_name } .mlpackage" .replace (
92
- "/" , "_" )
93
- mlpackage_path = os .path .join (mlpackages_dir , fname )
113
+ # consider directories ending with extension
114
+ extensions = set ([os .path .splitext (e )[1 ] for e in directories if os .path .splitext (e )[1 ]])
94
115
95
- if not os .path .exists (mlpackage_path ):
96
- raise FileNotFoundError (
97
- f"{ submodule_name } CoreML model doesn't exist at { mlpackage_path } " )
116
+ # if one extension present we may be able to infer sources type
117
+ if len (set (extensions )) == 1 :
118
+ extension = extensions .pop ()
119
+ else :
120
+ raise ValueError (f'Multiple file extensions found at { resources_dir } .'
121
+ f'Cannot infer resource type from contents.' )
122
+
123
+ if extension == '.mlpackage' :
124
+ sources = 'packages'
125
+ elif extension == '.mlmodelc' :
126
+ sources = 'compiled'
127
+ else :
128
+ raise ValueError (f'Did not find .mlpackage or .mlmodelc at { resources_dir } ' )
129
+
130
+ return sources
131
+
132
+
133
+ def _load_mlpackage (submodule_name ,
134
+ mlpackages_dir ,
135
+ model_version ,
136
+ compute_unit ,
137
+ sources = None ):
138
+ """
139
+ Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
140
+
141
+ """
142
+
143
+ # if sources not provided, attempt to infer `packages` or `compiled` from the
144
+ # resources directory
145
+ if sources is None :
146
+ sources = get_resource_type (mlpackages_dir )
147
+
148
+ if sources == 'packages' :
149
+ logger .info (f"Loading { submodule_name } mlpackage" )
150
+ fname = f"Stable_Diffusion_version_{ model_version } _{ submodule_name } .mlpackage" .replace (
151
+ "/" , "_" )
152
+ mlpackage_path = os .path .join (mlpackages_dir , fname )
153
+
154
+ if not os .path .exists (mlpackage_path ):
155
+ raise FileNotFoundError (
156
+ f"{ submodule_name } CoreML model doesn't exist at { mlpackage_path } " )
157
+
158
+ elif sources == 'compiled' :
159
+ logger .info (f"Loading { submodule_name } mlmodelc" )
160
+
161
+ # FixMe: Submodule names and compiled resources names differ. Can change if names match in the future.
162
+ submodule_names = ["text_encoder" , "text_encoder_2" , "unet" , "vae_decoder" ]
163
+ compiled_names = ['TextEncoder' , 'TextEncoder2' , 'Unet' , 'VAEDecoder' , 'VAEEncoder' ]
164
+ name_map = dict (zip (submodule_names , compiled_names ))
165
+
166
+ cname = name_map [submodule_name ] + '.mlmodelc'
167
+ mlpackage_path = os .path .join (mlpackages_dir , cname )
168
+
169
+ if not os .path .exists (mlpackage_path ):
170
+ raise FileNotFoundError (
171
+ f"{ submodule_name } CoreML model doesn't exist at { mlpackage_path } " )
172
+
173
+ return CoreMLModel (mlpackage_path , compute_unit , sources = sources )
98
174
99
- return CoreMLModel (mlpackage_path , compute_unit )
100
175
101
176
def _load_mlpackage_controlnet (mlpackages_dir , model_version , compute_unit ):
102
177
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
@@ -115,5 +190,6 @@ def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):
115
190
116
191
return CoreMLModel (mlpackage_path , compute_unit )
117
192
193
+
118
194
def get_available_compute_units ():
119
195
return tuple (cu for cu in ct .ComputeUnit ._member_names_ )
0 commit comments