diff --git a/README.md b/README.md index 304816699..33e3191dd 100644 --- a/README.md +++ b/README.md @@ -414,6 +414,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[Table Transformer](https://huggingface.co/docs/transformers/model_doc/table-transformer)** (from Microsoft Research) released with the paper [PubTables-1M: Towards Comprehensive Table Extraction From Unstructured Documents](https://arxiv.org/abs/2110.00061) by Brandon Smock, Rohith Pesala, Robin Abraham. 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. +1. **Ultravox** (from Fixie.ai) released with the repository [fixie-ai/ultravox](https://github.com/fixie-ai/ultravox) by the Fixie.ai team. 1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang. 1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu. 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index 6efb933ad..41e95a685 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -129,6 +129,7 @@ 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[Table Transformer](https://huggingface.co/docs/transformers/model_doc/table-transformer)** (from Microsoft Research) released with the paper [PubTables-1M: Towards Comprehensive Table Extraction From Unstructured Documents](https://arxiv.org/abs/2110.00061) by Brandon Smock, Rohith Pesala, Robin Abraham. 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. +1. **Ultravox** (from Fixie.ai) released with the repository [fixie-ai/ultravox](https://github.com/fixie-ai/ultravox) by the Fixie.ai team. 1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang. 1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu. 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. diff --git a/src/configs.js b/src/configs.js index fd8ca9188..94f0b31f0 100644 --- a/src/configs.js +++ b/src/configs.js @@ -70,6 +70,7 @@ function getNormalizedConfig(config) { case 'florence2': case 'llava_onevision': case 'idefics3': + case 'ultravox': case 'smolvlm': // @ts-expect-error TS2339 init_normalized_config = getNormalizedConfig(config.text_config); diff --git a/src/models.js b/src/models.js index 24e02864d..976d1c000 100644 --- a/src/models.js +++ b/src/models.js @@ -133,6 +133,7 @@ const MODEL_TYPES = { Musicgen: 7, MultiModality: 8, Phi3V: 9, + AudioTextToText: 10, } ////////////////////////////////////////////////// @@ -549,7 +550,7 @@ async function encoderForward(self, model_inputs) { const dims = encoderFeeds.pixel_values.dims; encoderFeeds.pixel_mask = ones([dims[0], dims[2], dims[3]]); } - + return await sessionRun(session, encoderFeeds); } @@ -587,58 +588,98 @@ async function decoderForward(self, model_inputs, is_encoder_decoder = false) { -function default_merge_input_ids_with_image_features({ - image_token_id, +function default_merge_input_ids_with_features({ + modality_token_id, inputs_embeds, - image_features, + modality_features, input_ids, attention_mask, }) { - const image_tokens = input_ids.tolist().map(ids => + const token_positions = input_ids.tolist().map(ids => ids.reduce((acc, x, idx) => { - if (x == image_token_id) acc.push(idx); + if (x == modality_token_id) acc.push(idx); return acc; }, []) ); - const n_image_tokens = image_tokens.reduce((acc, x) => acc + x.length, 0); - const n_image_features = image_features.dims[0]; - if (n_image_tokens !== n_image_features) { - throw new Error(`Image features and image tokens do not match: tokens: ${n_image_tokens}, features ${n_image_features}`); + const n_tokens = token_positions.reduce((acc, x) => acc + x.length, 0); + const n_features = modality_features.dims[0]; + if (n_tokens !== n_features) { + throw new Error(`Number of tokens and features do not match: tokens: ${n_tokens}, features ${n_features}`); } // Equivalent to performing a masked_scatter let img = 0; - for (let i = 0; i < image_tokens.length; ++i) { - const tokens = image_tokens[i]; + for (let i = 0; i < token_positions.length; ++i) { + const tokens = token_positions[i]; const embeds = inputs_embeds[i]; for (let j = 0; j < tokens.length; ++j) { - embeds[tokens[j]].data.set(image_features[img++].data) + embeds[tokens[j]].data.set(modality_features[img++].data) } } return { inputs_embeds, attention_mask } } -/** - * Forward pass of an image-text-to-text model. - * @param {Object} self The image-text-to-text model model. - * @param {Object} model_inputs The input data to be used for the forward pass. - * @param {Tensor} [model_inputs.input_ids=null] - * @param {Tensor} [model_inputs.attention_mask=null] - * @param {Tensor} [model_inputs.pixel_values=null] - * @param {Tensor} [model_inputs.position_ids=null] - * @param {Tensor} [model_inputs.inputs_embeds=null] - * @param {Tensor} [model_inputs.past_key_values=null] - * @param {Object} [model_inputs.generation_config=null] - * @param {Object} [model_inputs.logits_processor=null] +function default_merge_input_ids_with_image_features({ + image_token_id, + inputs_embeds, + image_features, + input_ids, + attention_mask, +}) { + return default_merge_input_ids_with_features({ + modality_token_id: image_token_id, + inputs_embeds, + modality_features: image_features, + input_ids, + attention_mask, + }) +} + +function default_merge_input_ids_with_audio_features({ + audio_token_id, + inputs_embeds, + audio_features, + input_ids, + attention_mask, +}) { + return default_merge_input_ids_with_features({ + modality_token_id: audio_token_id, + inputs_embeds, + modality_features: audio_features, + input_ids, + attention_mask, + }) +} + +/** + * Abstract forward pass function for image-text-to-text or audio-text-to-text models. + * @param {Object} self The model object. + * @param {Object} params Additional parameters. + * @param {Function} [params.encode_function] The function to encode the modality values. + * @param {Function} [params.merge_function] The function to merge the modality features with the input embeddings. + * @param {string} [params.modality_input_name] The modality input name. + * @param {string} [params.modality_output_name] The modality output name. + * @param {Tensor} [params.input_ids=null] + * @param {Tensor} [params.attention_mask=null] + * @param {Tensor} [params.position_ids=null] + * @param {Tensor} [params.inputs_embeds=null] + * @param {Tensor} [params.past_key_values=null] + * @param {Object} [params.generation_config=null] + * @param {Object} [params.logits_processor=null] * @returns {Promise} The model's output tensor * @private */ -async function imageTextToTextForward(self, { +async function genericTextToTextForward(self, { + // Generic parameters: + encode_function, + merge_function, + modality_input_name, + modality_output_name, + // Produced by the tokenizer/processor: input_ids = null, attention_mask = null, - pixel_values = null, // Used during generation: position_ids = null, @@ -649,27 +690,31 @@ async function imageTextToTextForward(self, { generation_config = null, logits_processor = null, - // TODO: needed? + // Additional parameters ...kwargs }) { - + const modality_values = kwargs[modality_input_name]; if (!inputs_embeds) { - // 1. Extract the input embeddings + // 1. Extract the text embeddings. inputs_embeds = await self.encode_text({ input_ids, ...kwargs }); - // 2. Possibly, merge text and images - if (pixel_values && input_ids.dims[1] !== 1) { - const image_features = await self.encode_image({ pixel_values, ...kwargs }); - - ({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({ - image_features, + // 2. Possibly, merge text and modality values + if (modality_values && input_ids.dims[1] !== 1) { + const modality_features = await encode_function({ + // Pass the modality values under its expected key. + // The caller knows whether this is audio or image. + [modality_input_name]: modality_values, + ...kwargs + }); + ({ inputs_embeds, attention_mask } = merge_function({ + [modality_output_name]: modality_features, inputs_embeds, input_ids, attention_mask, })); - } else if (past_key_values && pixel_values && input_ids.dims[1] === 1) { - // This is the case when we are generating with cache + } else if (past_key_values && modality_values && input_ids.dims[1] === 1) { + // This branch handles the cache case. const target_length = input_ids.dims[1]; // always 1 const past_length = Object.values(past_key_values)[0].dims.at(-2); @@ -690,6 +735,7 @@ async function imageTextToTextForward(self, { } } + // 3. Call the decoder forward using the updated inputs. const outputs = await decoderForward(self, { inputs_embeds, past_key_values, @@ -701,6 +747,40 @@ async function imageTextToTextForward(self, { return outputs; } +/** + * Forward pass of an audio-text-to-text model. + * @param {Object} self The audio-text-to-text model. + * @param {Object} params The inputs for the audio-text-to-text forward pass. + * @returns {Promise} The model's output tensor. + * @private + */ +async function audioTextToTextForward(self, params) { + return await genericTextToTextForward(self, { + ...params, + modality_input_name: 'audio_values', + modality_output_name: 'audio_features', + encode_function: self.encode_audio.bind(self), + merge_function: self._merge_input_ids_with_audio_features.bind(self), + }); +} + +/** + * Forward pass of an image-text-to-text model. + * @param {Object} self The image-text-to-text model. + * @param {Object} params The inputs for the image-text-to-text forward pass. + * @returns {Promise} The model's output tensor. + * @private + */ +async function imageTextToTextForward(self, params) { + return await genericTextToTextForward(self, { + ...params, + modality_input_name: 'pixel_values', + modality_output_name: 'image_features', + encode_function: self.encode_image.bind(self), + merge_function: self._merge_input_ids_with_image_features.bind(self), + }); +} + /** * Helper function to perform the following: * ```python @@ -814,7 +894,7 @@ function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_in }; } -function image_text_to_text_prepare_inputs_for_generation(self, ...args) { +function multimodal_text_to_text_prepare_inputs_for_generation(self, ...args) { if (self.config.is_encoder_decoder) { return encoder_decoder_prepare_inputs_for_generation(self, ...args); } else { @@ -918,11 +998,16 @@ export class PreTrainedModel extends Callable { case MODEL_TYPES.ImageTextToText: this.can_generate = true; this._forward = imageTextToTextForward; - this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation; + this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation; + break; + case MODEL_TYPES.AudioTextToText: + this.can_generate = true; + this._forward = audioTextToTextForward; + this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation; break; case MODEL_TYPES.Phi3V: this.can_generate = true; - this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation; + this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation; break; case MODEL_TYPES.MultiModality: @@ -1061,6 +1146,19 @@ export class PreTrainedModel extends Callable { }, options), ]); + } else if (modelType === MODEL_TYPES.AudioTextToText) { + const sessions = { + embed_tokens: 'embed_tokens', + audio_encoder: 'audio_encoder', + decoder_model_merged: 'decoder_model_merged', + } + info = await Promise.all([ + constructSessions(pretrained_model_name_or_path, sessions, options), + getOptionalConfigs(pretrained_model_name_or_path, { + generation_config: 'generation_config.json', + }, options), + ]); + } else if (modelType === MODEL_TYPES.Musicgen) { info = await Promise.all([ constructSessions(pretrained_model_name_or_path, { @@ -1878,6 +1976,11 @@ export class PreTrainedModel extends Callable { // text_inputs === { input_ids, attention_mask } return (await sessionRun(this.sessions['embed_tokens'], { input_ids })).inputs_embeds; } + + async encode_audio({ audio_values }) { + // audio_inputs === { audio_values } + return (await sessionRun(this.sessions['audio_encoder'], { audio_values })).audio_features; + } } ////////////////////////////////////////////////// @@ -6971,6 +7074,34 @@ export class PatchTSMixerModel extends PatchTSMixerPreTrainedModel { } export class PatchTSMixerForPrediction extends PatchTSMixerPreTrainedModel { } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +export class UltravoxPreTrainedModel extends PreTrainedModel { + forward_params = [ + 'input_ids', + 'attention_mask', + 'position_ids', + 'audio_values', + 'past_key_values', + ]; +} + +export class UltravoxModel extends UltravoxPreTrainedModel { + + _merge_input_ids_with_audio_features(kwargs) { + const audio_hidden_size = kwargs.audio_features.dims.at(-1); + const reshaped_audio_features = kwargs.audio_features.view(-1, audio_hidden_size); + + return default_merge_input_ids_with_audio_features({ + // @ts-ignore + audio_token_id: this.config.ignore_index, + ...kwargs, + audio_features: reshaped_audio_features, + }) + } +} +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// // AutoModels, used to simplify construction of PreTrainedModels @@ -7337,6 +7468,11 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([ ['paligemma', ['PaliGemmaForConditionalGeneration', PaliGemmaForConditionalGeneration]], ]); +const MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES = new Map([ + ['ultravox', ['UltravoxModel', UltravoxModel]], +]); + + const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ ['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]], ]); @@ -7480,6 +7616,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq], [MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, MODEL_TYPES.ImageTextToText], + [MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES, MODEL_TYPES.AudioTextToText], [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], @@ -7771,6 +7908,14 @@ export class AutoModelForImageFeatureExtraction extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]; } +export class AutoModelForImageTextToText extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES]; +} + +export class AutoModelForAudioTextToText extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES]; +} + ////////////////////////////////////////////////// ////////////////////////////////////////////////// diff --git a/src/models/processors.js b/src/models/processors.js index eb9119158..e64273123 100644 --- a/src/models/processors.js +++ b/src/models/processors.js @@ -13,6 +13,7 @@ export * from './qwen2_vl/processing_qwen2_vl.js'; export * from './sam/processing_sam.js'; export * from './smolvlm/processing_smolvlm.js'; export * from './speecht5/processing_speecht5.js'; +export * from './ultravox/processing_ultravox.js'; export * from './wav2vec2/processing_wav2vec2.js'; export * from './wav2vec2_with_lm/processing_wav2vec2_with_lm.js'; export * from './whisper/processing_whisper.js'; diff --git a/src/models/ultravox/processing_ultravox.js b/src/models/ultravox/processing_ultravox.js new file mode 100644 index 000000000..b525b723f --- /dev/null +++ b/src/models/ultravox/processing_ultravox.js @@ -0,0 +1,54 @@ +import { AutoFeatureExtractor } from "../auto/feature_extraction_auto.js" +import { AutoTokenizer } from "../../tokenizers.js" +import { Processor } from "../../base/processing_utils.js" + +/** + * Represents a UltravoxProcessor that extracts features from an audio input. + */ +export class UltravoxProcessor extends Processor { + static tokenizer_class = AutoTokenizer + static feature_extractor_class = AutoFeatureExtractor + static uses_processor_config = true; + + /** + * @param {string} text The text input to process. + * @param {Float32Array} audio The audio input to process. + */ + async _call(text, audio = null, kwargs = {}) { + // TODO: Support batched inputs + if (Array.isArray(text)) { + throw new Error("Batched inputs are not supported yet."); + } + + let audio_inputs = {}; + if (audio) { + const audio_len = audio.length; + const { input_features } = await this.feature_extractor(audio, { + ...kwargs, + max_length: audio_len, + }); + const nb_encoder_frames = Math.round(audio_len / this.config.encoder_ds_factor + 1e-4); + + // NOTE: The python version appears to have an off-by-one error. + const audio_embed_frames = 1 + Math.ceil(nb_encoder_frames / this.config.stack_factor); + audio_inputs["audio_token_len"] = [audio_embed_frames]; + audio_inputs["audio_values"] = input_features; + + const image_token = this.config.audio_placeholder; + if (!text.includes(image_token)) { + throw new Error(`The input text does not contain the image token ${image_token}.`); + } + text = text.replaceAll(image_token, image_token.repeat(audio_embed_frames)); + } + + const text_inputs = this.tokenizer(text, { + add_special_tokens: false, + ...kwargs, + }); + + return { + ...text_inputs, + ...audio_inputs, + } + } +} diff --git a/src/models/whisper/feature_extraction_whisper.js b/src/models/whisper/feature_extraction_whisper.js index 6f7bcdd56..a52a18e33 100644 --- a/src/models/whisper/feature_extraction_whisper.js +++ b/src/models/whisper/feature_extraction_whisper.js @@ -39,7 +39,10 @@ export class WhisperFeatureExtractor extends FeatureExtractor { log_mel: 'log10', // Custom - max_num_frames: this.config.nb_max_frames, // 3000 + max_num_frames: Math.min( + Math.floor(waveform.length / this.config.hop_length), + this.config.nb_max_frames, // 3000 + ) } ) @@ -58,20 +61,25 @@ export class WhisperFeatureExtractor extends FeatureExtractor { * @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array. * @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor. */ - async _call(audio) { + async _call(audio, { + max_length = null, + } = {}) { validate_audio_inputs(audio, 'WhisperFeatureExtractor'); let waveform; - if (audio.length > this.config.n_samples) { - console.warn( - "Attempting to extract features for audio longer than 30 seconds. " + - "If using a pipeline to extract transcript from a long audio clip, " + - "remember to specify `chunk_length_s` and/or `stride_length_s`." - ); - waveform = audio.slice(0, this.config.n_samples); + const length = max_length ?? this.config.n_samples; + if (audio.length > length) { + if (audio.length > this.config.n_samples) { + console.warn( + "Attempting to extract features for audio longer than 30 seconds. " + + "If using a pipeline to extract transcript from a long audio clip, " + + "remember to specify `chunk_length_s` and/or `stride_length_s`." + ); + } + waveform = audio.slice(0, length); } else { // pad with zeros - waveform = new Float32Array(this.config.n_samples); + waveform = new Float32Array(length); waveform.set(audio); } diff --git a/tests/models/whisper/test_feature_extraction_whisper.js b/tests/models/whisper/test_feature_extraction_whisper.js index 20e132ff6..fe817bd8b 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.js +++ b/tests/models/whisper/test_feature_extraction_whisper.js @@ -29,5 +29,39 @@ export default () => { }, MAX_TEST_EXECUTION_TIME, ); + + it( + "max_length (max_length < audio.length < max_num_samples)", + async () => { + const audio = await load_cached_audio("mlk"); + const { input_features } = await feature_extractor(audio, { max_length: 5 * 16000 }); + const { dims, data } = input_features; + expect(dims).toEqual([1, 80, 500]); + expect(input_features.mean().item()).toBeCloseTo(0.20474646985530853); + expect(data[0]).toBeCloseTo(0.33168578147888184); + expect(data[1]).toBeCloseTo(0.30986475944519043); + expect(data[81]).toBeCloseTo(0.10727238655090332); + expect(data[3001]).toBeCloseTo(0.4018087387084961); + expect(data.at(-1)).toBeCloseTo(-0.41003990173339844); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "max_length (audio.length < max_length < max_num_samples)", + async () => { + const audio = await load_cached_audio("mlk"); + const { input_features } = await feature_extractor(audio, { max_length: 25 * 16000 }); + const { dims, data } = input_features; + expect(dims).toEqual([1, 80, 2500]); + expect(input_features.mean().item()).toBeCloseTo(-0.20426231622695923); + expect(data[0]).toBeCloseTo(0.33168578147888184); + expect(data[1]).toBeCloseTo(0.30986475944519043); + expect(data[81]).toBeCloseTo(0.10727238655090332); + expect(data[3001]).toBeCloseTo(0.18040966987609863); + expect(data.at(-1)).toBeCloseTo(-0.6668410897254944); + }, + MAX_TEST_EXECUTION_TIME, + ); }); };