Skip to content

Commit c536857

Browse files
PrafulBxenova
andauthored
Return buffer instead of file_path if cache unavailable for model loading (#1280)
* Return buffer if cache unavailable for model file loading * Minor cleanup * Add back check * Actually, original approach was better * Support custom cache returning response or path --------- Co-authored-by: Joshua Lochner <[email protected]>
1 parent d0fe828 commit c536857

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

src/env.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ const localModelPath = RUNNING_LOCALLY
118118
* @property {string} cacheDir The directory to use for caching files with the file system. By default, it is `./.cache`.
119119
* @property {boolean} useCustomCache Whether to use a custom cache system (defined by `customCache`), defaults to `false`.
120120
* @property {Object} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which
121-
* implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache
121+
* implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache.
122+
* If you wish, you may also return a `Promise<string>` from the `match` function if you'd like to use a file path instead of `Promise<Response>`.
122123
*/
123124

124125
/** @type {TransformersEnvironment} */

src/models.js

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ import { RawImage } from './utils/image.js';
116116
import { dynamic_time_warping, max, medianFilter } from './utils/maths.js';
117117
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
118118
import { LogitsSampler } from './generation/logits_sampler.js';
119-
import { apis } from './env.js';
119+
import { apis, env } from './env.js';
120120

121121
import { WhisperGenerationConfig } from './models/whisper/generation_whisper.js';
122122
import { whisper_language_to_code } from './models/whisper/common_whisper.js';
@@ -248,7 +248,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
248248
);
249249
}
250250

251-
const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, apis.IS_NODE_ENV);
251+
const return_path = apis.IS_NODE_ENV && env.useFSCache;
252+
const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, return_path);
252253

253254
// handle onnx external data files
254255
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
@@ -276,7 +277,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
276277
const path = `${baseName}_data${i === 0 ? '' : '_' + i}`;
277278
const fullPath = `${options.subfolder ?? ''}/${path}`;
278279
externalDataPromises.push(new Promise(async (resolve, reject) => {
279-
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options, apis.IS_NODE_ENV);
280+
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options, return_path);
280281
resolve(data instanceof Uint8Array ? { path, data } : path);
281282
}));
282283
}

src/utils/hub.js

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
638638
});
639639

640640
if (result) {
641-
if (return_path) {
641+
if (!apis.IS_NODE_ENV && return_path) {
642642
throw new Error("Cannot return path in a browser environment.")
643643
}
644644
return result;
@@ -647,12 +647,18 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
647647
return response.filePath;
648648
}
649649

650-
const path = await cache.match(cacheKey);
651-
if (path instanceof FileResponse) {
652-
return path.filePath;
650+
// Otherwise, return the cached response (most likely a `FileResponse`).
651+
// NOTE: A custom cache may return a Response, or a string (file path)
652+
const cachedResponse = await cache?.match(cacheKey);
653+
if (cachedResponse instanceof FileResponse) {
654+
return cachedResponse.filePath;
655+
} else if (cachedResponse instanceof Response) {
656+
return new Uint8Array(await cachedResponse.arrayBuffer());
657+
} else if (typeof cachedResponse === 'string') {
658+
return cachedResponse;
653659
}
654-
throw new Error("Unable to return path for response.");
655660

661+
throw new Error("Unable to get model file path or buffer.");
656662
}
657663

658664
/**

0 commit comments

Comments
 (0)