Skip to content

Commit 2a95f48

Browse files
guschmueHonry
authored andcommitted
fixes for ort-1.17
1 parent 5ac17bd commit 2a95f48

File tree

4 files changed

+63
-48
lines changed

4 files changed

+63
-48
lines changed

package-lock.json

Lines changed: 22 additions & 35 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@
3838
},
3939
"homepage": "https://github.com/xenova/transformers.js#readme",
4040
"dependencies": {
41-
"onnxruntime-web": "1.14.0",
41+
"onnxruntime-web": "1.17.0",
4242
"sharp": "^0.32.0",
4343
"@huggingface/jinja": "^0.1.0"
4444
},
4545
"optionalDependencies": {
46-
"onnxruntime-node": "1.14.0"
46+
"onnxruntime-node": "1.17.0"
4747
},
4848
"devDependencies": {
4949
"@types/jest": "^29.5.1",

src/models.js

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,16 @@ function validateInputs(session, inputs) {
203203
async function sessionRun(session, inputs) {
204204
const checkedInputs = validateInputs(session, inputs);
205205
try {
206-
// @ts-ignore
207-
let output = await session.run(checkedInputs);
206+
// pass the original ort tensor
207+
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
208+
let output = await session.run(ortFeed);
208209
output = replaceTensors(output);
210+
for (const [name, t] of Object.entries(checkedInputs)) {
211+
// if we use gpu buffers for kv_caches, we own them and need to dispose()
212+
if (name.startsWith('past_key_values')) {
213+
t.dispose();
214+
};
215+
}
209216
return output;
210217
} catch (e) {
211218
// This usually occurs when the inputs are of the wrong type.

src/utils/tensor.js

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import {
1717

1818
const DataTypeMap = Object.freeze({
1919
float32: Float32Array,
20+
float16: Uint16Array,
2021
float64: Float64Array,
2122
string: Array, // string[]
2223
int8: Int8Array,
@@ -39,33 +40,48 @@ const ONNXTensor = ONNX.Tensor;
3940

4041
export class Tensor {
4142
/** @type {number[]} Dimensions of the tensor. */
42-
dims;
43+
get dims() {
44+
// @ts-ignore
45+
return this.ort_tensor.dims;
46+
}
47+
set dims(value) {
48+
// FIXME: ONNXTensor declares dims as readonly so one needs to use the constructor() if dims change.
49+
// @ts-ignore
50+
this.ort_tensor.dims = value;
51+
}
4352

4453
/** @type {DataType} Type of the tensor. */
45-
type;
54+
get type() {
55+
return this.ort_tensor.type;
56+
};
4657

4758
/** @type {DataArray} The data stored in the tensor. */
48-
data;
59+
get data() {
60+
return this.ort_tensor.data;
61+
}
4962

5063
/** @type {number} The number of elements in the tensor. */
51-
size;
64+
get size() {
65+
return this.ort_tensor.size;
66+
};
67+
68+
ort_tensor;
5269

5370
/**
5471
* Create a new Tensor or copy an existing Tensor.
5572
* @param {[DataType, DataArray, number[]]|[import('onnxruntime-common').Tensor]} args
5673
*/
5774
constructor(...args) {
5875
if (args[0] instanceof ONNXTensor) {
59-
// Create shallow copy
60-
Object.assign(this, args[0]);
61-
76+
this.ort_tensor = args[0];
6277
} else {
6378
// Create new tensor
64-
Object.assign(this, new ONNXTensor(
79+
const t = new ONNXTensor(
6580
/** @type {DataType} */(args[0]),
6681
/** @type {Exclude<import('./maths.js').AnyTypedArray, Uint8ClampedArray>} */(args[1]),
6782
args[2]
68-
));
83+
);
84+
this.ort_tensor = t;
6985
}
7086

7187
return new Proxy(this, {
@@ -89,6 +105,11 @@ export class Tensor {
89105
});
90106
}
91107

108+
dispose() {
109+
this.ort_tensor.dispose();
110+
// this.ort_tensor = undefined;
111+
}
112+
92113
/**
93114
* Returns an iterator object for iterating over the tensor data in row-major order.
94115
* If the tensor has more than one dimension, the iterator will yield subarrays.

0 commit comments

Comments
 (0)