Skip to content

Commit f63ad84

Browse files
committed
[DO NOT MERGE] Code example for testing ORT-Web WebNN EP
This is a very rough example to enable WebNN in transformers.js. I just add some hard codes to make the "Image classification w/ google/vite-base-patch16-224" fp32 model work with ORT Web WebNN EP. This PR depends on huggingface#596
1 parent 2a95f48 commit f63ad84

File tree

6 files changed

+38
-16
lines changed

6 files changed

+38
-16
lines changed

examples/demo-site/src/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ <h2 class="fw-bolder">Demo</h2>
8686
<label>Task: </label>
8787
<div class="col-12 mt-1">
8888
<select id="task" class="form-select">
89-
<option value="translation" selected>
89+
<option value="translation">
9090
Translation w/ t5-small (78 MB)
9191
</option>
9292
<option value="text-generation">
@@ -119,7 +119,7 @@ <h2 class="fw-bolder">Demo</h2>
119119
<option value="image-to-text">
120120
Image to text w/ vit-gpt2-image-captioning (246 MB)
121121
</option>
122-
<option value="image-classification">
122+
<option value="image-classification" selected>
123123
Image classification w/ google/vit-base-patch16-224 (88 MB)
124124
</option>
125125
<option value="zero-shot-image-classification">

examples/demo-site/src/worker.js

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,15 @@ self.addEventListener('message', async (event) => {
4545
class PipelineFactory {
4646
static task = null;
4747
static model = null;
48+
static quantized = true;
4849

4950
// NOTE: instance stores a promise that resolves to the pipeline
5051
static instance = null;
5152

52-
constructor(tokenizer, model) {
53+
constructor(tokenizer, model, quantized) {
5354
this.tokenizer = tokenizer;
5455
this.model = model;
56+
this.quantized = quantized;
5557
}
5658

5759
/**
@@ -65,7 +67,8 @@ class PipelineFactory {
6567
}
6668
if (this.instance === null) {
6769
this.instance = pipeline(this.task, this.model, {
68-
progress_callback: progressCallback
70+
progress_callback: progressCallback,
71+
quantized: this.quantized,
6972
});
7073
}
7174

@@ -131,6 +134,7 @@ class ImageToTextPipelineFactory extends PipelineFactory {
131134
class ImageClassificationPipelineFactory extends PipelineFactory {
132135
static task = 'image-classification';
133136
static model = 'Xenova/vit-base-patch16-224';
137+
static quantized = false;
134138
}
135139

136140

package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@
3838
},
3939
"homepage": "https://github.com/xenova/transformers.js#readme",
4040
"dependencies": {
41-
"onnxruntime-web": "1.17.0",
41+
"onnxruntime-web": "1.18.0-dev.20240130-9f68a27c7a",
4242
"sharp": "^0.32.0",
4343
"@huggingface/jinja": "^0.1.0"
4444
},
4545
"optionalDependencies": {
46-
"onnxruntime-node": "1.17.0"
46+
"onnxruntime-node": "1.18.0-dev.20240130-9f68a27c7a"
4747
},
4848
"devDependencies": {
4949
"@types/jest": "^29.5.1",

src/backends/onnx.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
// NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
2020
// In either case, we select the default export if it exists, otherwise we use the named export.
2121
import * as ONNX_NODE from 'onnxruntime-node';
22-
import * as ONNX_WEB from 'onnxruntime-web';
22+
import * as ONNX_WEB from 'onnxruntime-web/experimental';
2323

2424
/** @type {import('onnxruntime-web')} The ONNX runtime module. */
2525
export let ONNX;
2626

2727
export const executionProviders = [
28+
// 'webnn',
2829
// 'webgpu',
2930
'wasm'
3031
];

src/env.js

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ const localModelPath = RUNNING_LOCALLY
5959
// In practice, users should probably self-host the necessary .wasm files.
6060
onnx_env.wasm.wasmPaths = RUNNING_LOCALLY
6161
? path.join(__dirname, '/dist/')
62-
: `https://cdn.jsdelivr.net/npm/@xenova/transformers@${VERSION}/dist/`;
63-
62+
// : `https://cdn.jsdelivr.net/npm/@xenova/transformers@${VERSION}/dist/`;
63+
// Copy ort-web wasm files to examples/demo-site/src/dist/
64+
: location.origin + '/dist/';
6465

6566
/**
6667
* Global variable used to control execution. This provides users a simple way to configure Transformers.js.

src/models.js

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,25 @@ async function constructSession(pretrained_model_name_or_path, fileName, options
123123
let buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options);
124124

125125
try {
126-
return await InferenceSession.create(buffer, {
127-
executionProviders,
128-
});
126+
let sessionOptions = { executionProviders };
127+
if (pretrained_model_name_or_path == 'Xenova/vit-base-patch16-224') {
128+
// Hard code example to use webnn for Xenova/vit-base-patch16-224
129+
sessionOptions = {
130+
executionProviders: [{
131+
name: "webnn",
132+
deviceType: "gpu",
133+
}],
134+
// input name: pixel_values, tensor: float32[batch_size,num_channels,height,width]
135+
// WebNN only supports static shape model, use freeDimensionOverrides option to fix the input shape.
136+
freeDimensionOverrides: {
137+
batch_size: 1,
138+
num_channels: 3,
139+
height: 224,
140+
width: 224,
141+
},
142+
}
143+
}
144+
return await InferenceSession.create(buffer, sessionOptions);
129145
} catch (err) {
130146
// If the execution provided was only wasm, throw the error
131147
if (executionProviders.length === 1 && executionProviders[0] === 'wasm') {
@@ -205,13 +221,13 @@ async function sessionRun(session, inputs) {
205221
try {
206222
// pass the original ort tensor
207223
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
208-
let output = await session.run(ortFeed);
224+
let output = await session.run(ortFeed);
209225
output = replaceTensors(output);
210226
for (const [name, t] of Object.entries(checkedInputs)) {
211227
// if we use gpu buffers for kv_caches, we own them and need to dispose()
212-
if (name.startsWith('past_key_values')) {
213-
t.dispose();
214-
};
228+
// if (name.startsWith('past_key_values')) {
229+
// t.dispose();
230+
// };
215231
}
216232
return output;
217233
} catch (e) {

0 commit comments

Comments
 (0)