Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit b3064c5

Browse files
xinyu-intelpengzhao-intel
authored andcommitted
[MKLDNN]Enhance Quantization APIs and Tutorial (#15448)
* enhance api and new tutorial * Update MKLDNN_QUANTIZATION.md update * fix lint * modify pics * skip test * add quantize layer in graph * update * remove center css flag * change requantize color * fix markdown pics * change to use png * Update MKLDNN_QUANTIZATION.md update * enable ipython script * fix png * fix lint * Update MKLDNN_QUANTIZATION.md * change title * trigger * use lower case * some typo * some typo * use dmlc web data * trigger * trigger
1 parent 773f4dc commit b3064c5

File tree

6 files changed

+657
-108
lines changed

6 files changed

+657
-108
lines changed

docs/tutorials/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ Select API: 
105105
* [Module to Gluon API](/tutorials/python/module_to_gluon.html)
106106
* [Gluon end to end from training to inference](/tutorials/gluon/gluon_from_experiment_to_deployment.html)
107107
* [Automatic Mixed Precision in Gluon](/tutorials/amp/amp_tutorial.html)
108+
* [How to build and install MXNet with MKL-DNN backend](/tutorials/mkldnn/MKLDNN_README.html)
109+
* [How to quantize custom models with MKL-DNN backend](/tutorials/mkldnn/mkldnn_quantization.html)<span style="color:red"> (new!) </span>
108110
* API Guides
109111
* Core APIs
110112
* NDArray
@@ -157,7 +159,6 @@ Select API:&nbsp;
157159
* [Large-Scale Multi-Host Multi-GPU Image Classification](/tutorials/vision/large_scale_classification.html)
158160
* [Importing an ONNX model into MXNet](/tutorials/onnx/super_resolution.html)
159161
* [Optimizing Deep Learning Computation Graphs with TensorRT](/tutorials/tensorrt/inference_with_trt.html)
160-
* [How to build and install MXNet with MKL-DNN backend](/tutorials/mkldnn/MKLDNN_README.html)
161162
* API Guides
162163
* Core APIs
163164
* NDArray
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
2+
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
3+
<!--- or more contributor license agreements. See the NOTICE file -->
4+
<!--- distributed with this work for additional information -->
5+
<!--- regarding copyright ownership. The ASF licenses this file -->
6+
<!--- to you under the Apache License, Version 2.0 (the -->
7+
<!--- "License"); you may not use this file except in compliance -->
8+
<!--- with the License. You may obtain a copy of the License at -->
9+
10+
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
11+
12+
<!--- Unless required by applicable law or agreed to in writing, -->
13+
<!--- software distributed under the License is distributed on an -->
14+
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
15+
<!--- KIND, either express or implied. See the License for the -->
16+
<!--- specific language governing permissions and limitations -->
17+
<!--- under the License. -->
18+
19+
# Quantize custom models with MKL-DNN backend
20+
21+
This document is to introduce how to quantize the customer models from FP32 to INT8 with Apache/MXNet toolkit and APIs under Intel CPU.
22+
23+
If you are not familiar with Apache/MXNet quantization flow, please reference [quantization blog](https://medium.com/apache-mxnet/model-quantization-for-production-level-neural-network-inference-f54462ebba05) first, and the performance data is shown in [Apache/MXNet C++ interface](https://github.com/apache/incubator-mxnet/tree/master/cpp-package/example/inference) and [GluonCV](https://gluon-cv.mxnet.io/build/examples_deployment/int8_inference.html).
24+
25+
## Installation and Prerequisites
26+
27+
Installing MXNet with MKLDNN backend is an easy and essential process. You can follow [How to build and install MXNet with MKL-DNN backend](https://mxnet.incubator.apache.org/tutorials/mkldnn/MKLDNN_README.html) to build and install MXNet from source. Also, you can install the release or nightly version via PyPi and pip directly by running:
28+
29+
```
30+
# release version
31+
pip install mxnet-mkl
32+
# nightly version
33+
pip install mxnet-mkl --pre
34+
```
35+
36+
## Image Classification Demo
37+
38+
A quantization script [imagenet_gen_qsym_mkldnn.py](https://github.com/apache/incubator-mxnet/blob/master/example/quantization/imagenet_gen_qsym_mkldnn.py) has been designed to launch quantization for image-classification models. This script is integrated with [Gluon-CV modelzoo](https://gluon-cv.mxnet.io/model_zoo/classification.html), so that all pre-trained models can be downloaded from Gluon-CV and then converted for quantization. For details, you can refer [Model Quantization with Calibration Examples](https://github.com/apache/incubator-mxnet/blob/master/example/quantization/README.md).
39+
40+
## Integrate Quantization Flow to Your Project
41+
42+
Quantization flow works for both symbolic and Gluon models. If you're using Gluon, you can first refer [Saving and Loading Gluon Models](https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/save_load_params.html) to hybridize your computation graph and export it as a symbol before running quantization.
43+
44+
In general, the quantization flow includes 4 steps. The user can get the acceptable accuracy from step 1 to 3 with minimum effort. Most of thing in this stage is out-of-box and the data scientists and researchers only need to focus on how to represent data and layers in their model. After a quantized model is generated, you may want to deploy it online and the performance will be the next key point. Thus, step 4, calibration, can improve the performance a lot by reducing lots of runtime calculation.
45+
46+
![quantization flow](https://github.com/dmlc/web-data/raw/master/mxnet/tutorials/mkldnn/quantization/quantization.png)
47+
48+
Now, we are going to take Gluon ResNet18 as an example to show how each step work.
49+
50+
### Initialize Model
51+
52+
```python
53+
import logging
54+
import mxnet as mx
55+
from mxnet.gluon.model_zoo import vision
56+
from mxnet.contrib.quantization import *
57+
58+
logging.basicConfig()
59+
logger = logging.getLogger('logger')
60+
logger.setLevel(logging.INFO)
61+
62+
batch_shape = (1, 3, 224, 224)
63+
resnet18 = vision.resnet18_v1(pretrained=True)
64+
resnet18.hybridize()
65+
resnet18.forward(mx.nd.zeros(batch_shape))
66+
resnet18.export('resnet18_v1')
67+
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet18_v1', 0)
68+
# (optional) visualize float32 model
69+
mx.viz.plot_network(sym)
70+
```
71+
First, we download resnet18-v1 model from gluon modelzoo and export it as a symbol. You can visualize float32 model. Below is a raw residual block.
72+
73+
![float32 model](https://github.com/dmlc/web-data/raw/master/mxnet/tutorials/mkldnn/quantization/fp32_raw.png)
74+
75+
#### Model Fusion
76+
77+
```python
78+
sym = sym.get_backend_symbol('MKLDNN_QUANTIZE')
79+
# (optional) visualize fused float32 model
80+
mx.viz.plot_network(sym)
81+
```
82+
It's important to add this line to enable graph fusion before quantization to get better performance. Below is a fused residual block. Batchnorm, Activation and elemwise_add are fused into Convolution.
83+
84+
![float32 fused model](https://github.com/dmlc/web-data/raw/master/mxnet/tutorials/mkldnn/quantization/fp32_fusion.png)
85+
86+
### Quantize Model
87+
88+
A python interface `quantize_graph` is provided for the user. Thus, it is very flexible for the data scientist to construct the expected models based on different requirements in a real deployment.
89+
90+
```python
91+
# quantize configs
92+
# set exclude layers
93+
excluded_names = []
94+
# set calib mode.
95+
calib_mode = 'none'
96+
# set calib_layer
97+
calib_layer = None
98+
# set quantized_dtype
99+
quantized_dtype = 'auto'
100+
logger.info('Quantizing FP32 model Resnet18-V1')
101+
qsym, qarg_params, aux_params, collector = quantize_graph(sym=sym, arg_params=arg_params, aux_params=aux_params,
102+
excluded_sym_names=excluded_names,
103+
calib_mode=calib_mode, calib_layer=calib_layer,
104+
quantized_dtype=quantized_dtype, logger=logger)
105+
# (optional) visualize quantized model
106+
mx.viz.plot_network(qsym)
107+
# save quantized model
108+
mx.model.save_checkpoint('quantized-resnet18_v1', 0, qsym, qarg_params, aux_params)
109+
```
110+
111+
By applying `quantize_graph` to the symbolic model, a new quantized model can be generated, named `qsym` along with its parameters. We can see `_contrib_requantize` operators are inserted after `Convolution` to convert the INT32 output to FP32.
112+
113+
![none calibrated model](https://github.com/dmlc/web-data/raw/master/mxnet/tutorials/mkldnn/quantization/none_calib.png)
114+
115+
Below table gives some descriptions.
116+
117+
| param | type | description|
118+
|--------------------|-----------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
119+
| excluded_sym_names | list of strings | A list of strings representing the names of the symbols that users want to excluding from being quantized.|
120+
| calib_mode | str | If calib_mode='none', no calibration will be used and the thresholds for requantization after the corresponding layers will be calculated at runtime by calling min and max operators. The quantized models generated in this mode are normally 10-20% slower than those with calibrations during inference.<br>If calib_mode='naive', the min and max values of the layer outputs from a calibration dataset will be directly taken as the thresholds for quantization.<br>If calib_mode='entropy', the thresholds for quantization will be derived such that the KL divergence between the distributions of FP32 layer outputs and quantized layer outputs is minimized based upon the calibration dataset. |
121+
| calib_layer | function | Given a layer's output name in string, return True or False for deciding whether to calibrate this layer.<br>If yes, the statistics of the layer's output will be collected; otherwise, no information of the layer's output will be collected.<br>If not provided, all the layers' outputs that need requantization will be collected.|
122+
| quantized_dtype | str | The quantized destination type for input data. Currently support 'int8', 'uint8' and 'auto'.<br>'auto' means automatically select output type according to calibration result.|
123+
124+
### Evaluate & Tune
125+
126+
Now, you get a pair of quantized symbol and params file for inference. For Gluon inference, only difference is to load model and params by a SymbolBlock as below example:
127+
128+
```python
129+
quantized_net = mx.gluon.SymbolBlock.imports('quantized-resnet18_v1-symbol.json', 'data', 'quantized-resnet18_v1-0000.params')
130+
quantized_net.hybridize(static_shape=True, static_alloc=True)
131+
batch_size = 1
132+
data = mx.nd.ones((batch_size,3,224,224))
133+
quantized_net(data)
134+
```
135+
136+
Now, you can get the accuracy from a quantized network. Furthermore, you can try to select different layers or OPs to be quantized by `excluded_sym_names` parameter and figure out an acceptable accuracy.
137+
138+
### Calibrate Model (optional for performance)
139+
140+
The quantized model generated in previous steps can be very slow during inference since it will calculate min and max at runtime. We recommend using offline calibration for better performance by setting `calib_mode` to `naive` or `entropy`. And then calling `set_monitor_callback` api to collect layer information with a subset of the validation datasets before int8 inference.
141+
142+
```python
143+
# quantization configs
144+
# set exclude layers
145+
excluded_names = []
146+
# set calib mode.
147+
calib_mode = 'naive'
148+
# set calib_layer
149+
calib_layer = None
150+
# set quantized_dtype
151+
quantized_dtype = 'auto'
152+
logger.info('Quantizing FP32 model resnet18-V1')
153+
cqsym, cqarg_params, aux_params, collector = quantize_graph(sym=sym, arg_params=arg_params, aux_params=aux_params,
154+
excluded_sym_names=excluded_names,
155+
calib_mode=calib_mode, calib_layer=calib_layer,
156+
quantized_dtype=quantized_dtype, logger=logger)
157+
158+
# download imagenet validation dataset
159+
mx.test_utils.download('http://data.mxnet.io/data/val_256_q90.rec', 'dataset.rec')
160+
# set rgb info for data
161+
mean_std = {'mean_r': 123.68, 'mean_g': 116.779, 'mean_b': 103.939, 'std_r': 58.393, 'std_g': 57.12, 'std_b': 57.375}
162+
# set batch size
163+
batch_size = 16
164+
# create DataIter
165+
data = mx.io.ImageRecordIter(path_imgrec='dataset.rec', batch_size=batch_size, data_shape=batch_shape[1:], rand_crop=False, rand_mirror=False, **mean_std)
166+
# create module
167+
mod = mx.mod.Module(symbol=sym, label_names=None, context=mx.cpu())
168+
mod.bind(for_training=False, data_shapes=data.provide_data, label_shapes=None)
169+
mod.set_params(arg_params, aux_params)
170+
171+
# calibration configs
172+
# set num_calib_batches
173+
num_calib_batches = 5
174+
max_num_examples = num_calib_batches * batch_size
175+
# monitor FP32 Inference
176+
mod._exec_group.execs[0].set_monitor_callback(collector.collect, monitor_all=True)
177+
num_batches = 0
178+
num_examples = 0
179+
for batch in data:
180+
mod.forward(data_batch=batch, is_train=False)
181+
num_batches += 1
182+
num_examples += batch_size
183+
if num_examples >= max_num_examples:
184+
break
185+
if logger is not None:
186+
logger.info("Collected statistics from %d batches with batch_size=%d"
187+
% (num_batches, batch_size))
188+
```
189+
190+
After that, layer information will be filled into the `collector` returned by `quantize_graph` api. Then, you need to write the layer information into int8 model by calling `calib_graph` api.
191+
192+
193+
```python
194+
# write scaling factor into quantized symbol
195+
cqsym, cqarg_params, aux_params = calib_graph(qsym=cqsym, arg_params=arg_params, aux_params=aux_params,
196+
collector=collector, calib_mode=calib_mode,
197+
quantized_dtype=quantized_dtype, logger=logger)
198+
# (optional) visualize quantized model
199+
mx.viz.plot_network(cqsym)
200+
```
201+
202+
Below is a quantized residual block with naive calibration. We can see `min_calib_range` and `max_calib_range` are written into `_contrib_requantize` operators.
203+
204+
![naive calibrated model](https://github.com/dmlc/web-data/raw/master/mxnet/tutorials/mkldnn/quantization/naive_calib.png)
205+
206+
When you get a quantized model with calibration, keeping sure to call fusion api again since this can fuse some `requantize` or `dequantize` operators for further performance improvement.
207+
208+
```python
209+
# perform post-quantization fusion
210+
cqsym = cqsym.get_backend_symbol('MKLDNN_QUANTIZE')
211+
# (optional) visualize post-quantized model
212+
mx.viz.plot_network(cqsym)
213+
# save quantized model
214+
mx.model.save_checkpoint('quantized-resnet18_v1', 0, cqsym, cqarg_params, aux_params)
215+
```
216+
217+
Below is a post-quantized residual block. We can see `_contrib_requantize` operators are fused into `Convolution` operators.
218+
219+
![post-quantized model](https://github.com/dmlc/web-data/raw/master/mxnet/tutorials/mkldnn/quantization/post_quantize.png)
220+
221+
BTW, You can also modify the `min_calib_range` and `max_calib_range` in the JSON file directly.
222+
223+
```
224+
{
225+
"op": "_sg_mkldnn_conv",
226+
"name": "quantized_sg_mkldnn_conv_bn_act_6",
227+
"attrs": {
228+
"max_calib_range": "3.562147",
229+
"min_calib_range": "0.000000",
230+
"quantized": "true",
231+
"with_act": "true",
232+
"with_bn": "true"
233+
},
234+
......
235+
```
236+
237+
### Tips for Model Calibration
238+
239+
#### Accuracy Tuning
240+
241+
- Try to use `entropy` calib mode;
242+
243+
- Try to exclude some layers which may cause obvious accuracy drop;
244+
245+
- Change calibration dataset by setting different `num_calib_batches` or shuffle your validation dataset;
246+
247+
#### Performance Tuning
248+
249+
- Keep sure to perform graph fusion before quantization;
250+
251+
- If lots of `requantize` layers exist, keep sure to perform post-quantization fusion after calibration;
252+
253+
- Compare the MXNet profile or `MKLDNN_VERBOSE` of float32 and int8 inference;
254+
255+
## Deploy with Python/C++
256+
257+
MXNet also supports deploy quantized models with C++. Refer [MXNet C++ Package](https://github.com/apache/incubator-mxnet/blob/master/cpp-package/README.md) for more details.
258+
259+
<!-- INSERT SOURCE DOWNLOAD BUTTONS -->

0 commit comments

Comments
 (0)