Skip to content

Commit f3a2124

Browse files
authored
Xl python inference (#261)
* Updated pipeline.py for XL inference * cleaned up * Add shape handling for UNET time_ids shape * added support for loading from CompiledMLModel
1 parent 94dfc6b commit f3a2124

File tree

2 files changed

+415
-171
lines changed

2 files changed

+415
-171
lines changed

python_coreml_stable_diffusion/coreml_model.py

Lines changed: 107 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import coremltools as ct
77

88
import logging
9+
import json
910

1011
logging.basicConfig()
1112
logger = logging.getLogger(__name__)
@@ -21,14 +22,47 @@ class CoreMLModel:
2122
""" Wrapper for running CoreML models using coremltools
2223
"""
2324

24-
def __init__(self, model_path, compute_unit):
25-
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")
25+
def __init__(self, model_path, compute_unit, sources='packages'):
2626

2727
logger.info(f"Loading {model_path}")
2828

2929
start = time.time()
30-
self.model = ct.models.MLModel(
31-
model_path, compute_units=ct.ComputeUnit[compute_unit])
30+
if sources == 'packages':
31+
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")
32+
33+
self.model = ct.models.MLModel(
34+
model_path, compute_units=ct.ComputeUnit[compute_unit])
35+
DTYPE_MAP = {
36+
65552: np.float16,
37+
65568: np.float32,
38+
131104: np.int32,
39+
}
40+
self.expected_inputs = {
41+
input_tensor.name: {
42+
"shape": tuple(input_tensor.type.multiArrayType.shape),
43+
"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],
44+
}
45+
for input_tensor in self.model._spec.description.input
46+
}
47+
elif sources == 'compiled':
48+
assert os.path.exists(model_path) and model_path.endswith(".mlmodelc")
49+
50+
self.model = ct.models.CompiledMLModel(model_path, ct.ComputeUnit[compute_unit])
51+
52+
# Grab expected inputs from metadata.json
53+
with open(os.path.join(model_path, 'metadata.json'), 'r') as f:
54+
config = json.load(f)[0]
55+
56+
self.expected_inputs = {
57+
input_tensor['name']: {
58+
"shape": tuple(eval(input_tensor['shape'])),
59+
"dtype": np.dtype(input_tensor['dataType'].lower()),
60+
}
61+
for input_tensor in config['inputSchema']
62+
}
63+
else:
64+
raise ValueError(f'Expected `packages` or `compiled` for sources, received {sources}')
65+
3266
load_time = time.time() - start
3367
logger.info(f"Done. Took {load_time:.1f} seconds.")
3468

@@ -38,21 +72,6 @@ def __init__(self, model_path, compute_unit):
3872
"The Swift package we provide uses precompiled Core ML models (.mlmodelc) to avoid compile-on-load."
3973
)
4074

41-
42-
DTYPE_MAP = {
43-
65552: np.float16,
44-
65568: np.float32,
45-
131104: np.int32,
46-
}
47-
48-
self.expected_inputs = {
49-
input_tensor.name: {
50-
"shape": tuple(input_tensor.type.multiArrayType.shape),
51-
"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],
52-
}
53-
for input_tensor in self.model._spec.description.input
54-
}
55-
5675
def _verify_inputs(self, **kwargs):
5776
for k, v in kwargs.items():
5877
if k in self.expected_inputs:
@@ -72,7 +91,7 @@ def _verify_inputs(self, **kwargs):
7291
f"Expected shape {expected_shape}, got {v.shape} for input: {k}"
7392
)
7493
else:
75-
raise ValueError("Received unexpected input kwarg: {k}")
94+
raise ValueError(f"Received unexpected input kwarg: {k}")
7695

7796
def __call__(self, **kwargs):
7897
self._verify_inputs(**kwargs)
@@ -82,21 +101,77 @@ def __call__(self, **kwargs):
82101
LOAD_TIME_INFO_MSG_TRIGGER = 10 # seconds
83102

84103

85-
def _load_mlpackage(submodule_name, mlpackages_dir, model_version,
86-
compute_unit):
87-
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
104+
def get_resource_type(resources_dir: str) -> str:
105+
"""
106+
Detect resource type based on filepath extensions.
107+
returns:
108+
`packages`: for .mlpackage resources
109+
'compiled`: for .mlmodelc resources
88110
"""
89-
logger.info(f"Loading {submodule_name} mlpackage")
111+
directories = [f for f in os.listdir(resources_dir) if os.path.isdir(os.path.join(resources_dir, f))]
90112

91-
fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(
92-
"/", "_")
93-
mlpackage_path = os.path.join(mlpackages_dir, fname)
113+
# consider directories ending with extension
114+
extensions = set([os.path.splitext(e)[1] for e in directories if os.path.splitext(e)[1]])
94115

95-
if not os.path.exists(mlpackage_path):
96-
raise FileNotFoundError(
97-
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
116+
# if one extension present we may be able to infer sources type
117+
if len(set(extensions)) == 1:
118+
extension = extensions.pop()
119+
else:
120+
raise ValueError(f'Multiple file extensions found at {resources_dir}.'
121+
f'Cannot infer resource type from contents.')
122+
123+
if extension == '.mlpackage':
124+
sources = 'packages'
125+
elif extension == '.mlmodelc':
126+
sources = 'compiled'
127+
else:
128+
raise ValueError(f'Did not find .mlpackage or .mlmodelc at {resources_dir}')
129+
130+
return sources
131+
132+
133+
def _load_mlpackage(submodule_name,
134+
mlpackages_dir,
135+
model_version,
136+
compute_unit,
137+
sources=None):
138+
"""
139+
Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
140+
141+
"""
142+
143+
# if sources not provided, attempt to infer `packages` or `compiled` from the
144+
# resources directory
145+
if sources is None:
146+
sources = get_resource_type(mlpackages_dir)
147+
148+
if sources == 'packages':
149+
logger.info(f"Loading {submodule_name} mlpackage")
150+
fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(
151+
"/", "_")
152+
mlpackage_path = os.path.join(mlpackages_dir, fname)
153+
154+
if not os.path.exists(mlpackage_path):
155+
raise FileNotFoundError(
156+
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
157+
158+
elif sources == 'compiled':
159+
logger.info(f"Loading {submodule_name} mlmodelc")
160+
161+
# FixMe: Submodule names and compiled resources names differ. Can change if names match in the future.
162+
submodule_names = ["text_encoder", "text_encoder_2", "unet", "vae_decoder"]
163+
compiled_names = ['TextEncoder', 'TextEncoder2', 'Unet', 'VAEDecoder', 'VAEEncoder']
164+
name_map = dict(zip(submodule_names, compiled_names))
165+
166+
cname = name_map[submodule_name] + '.mlmodelc'
167+
mlpackage_path = os.path.join(mlpackages_dir, cname)
168+
169+
if not os.path.exists(mlpackage_path):
170+
raise FileNotFoundError(
171+
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
172+
173+
return CoreMLModel(mlpackage_path, compute_unit, sources=sources)
98174

99-
return CoreMLModel(mlpackage_path, compute_unit)
100175

101176
def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):
102177
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
@@ -115,5 +190,6 @@ def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):
115190

116191
return CoreMLModel(mlpackage_path, compute_unit)
117192

193+
118194
def get_available_compute_units():
119195
return tuple(cu for cu in ct.ComputeUnit._member_names_)

0 commit comments

Comments
 (0)