Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Batched Whisper] ValueError on input mel features #30740

Closed
2 of 4 tasks
kerem0comert opened this issue May 10, 2024 · 3 comments
Closed
2 of 4 tasks

[Batched Whisper] ValueError on input mel features #30740

kerem0comert opened this issue May 10, 2024 · 3 comments
Assignees
Labels

Comments

@kerem0comert
Copy link

kerem0comert commented May 10, 2024

System Info

  • transformers version: 4.36.2
  • Platform: Linux-6.5.0-27-generic-x86_64-with-glibc2.35
  • Python version: 3.11.5
  • Huggingface_hub version: 0.20.2
  • Safetensors version: 0.4.0
  • Accelerate version: 0.26.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: NA

Who can help?

Hello,

I am using a finetuned Whisper model for transcription, and it works well. However I get the warning:

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
  warnings.warn(

and as such I would like to take advantage of batching given that I run this on a GPU.
As such I implemented the code that I shared in the Reproduction section.
I wanted to do it via this fork, but I see that in its README, it is recommended that I follow this instead. In my code snippet, self.model is an instance of:
<class 'transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration'>
and I have two problems:

  • If I include the flags in the generate() call like so:
self.model.generate(
            **(processed_inputs.to(self.device, torch.float16)), condition_on_prev_tokens=False, temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), logprob_threshold=-1.0, compression_ratio_threshold=1.35, return_timestamps=True
        ) 

I get:

ValueError: The following `model_kwargs` are not used by the model: ['condition_on_prev_tokens', 'logprob_threshold', 'compression_ratio_threshold'] (note: typos in the generate arguments will also show up in this list)

If I exclude the flags (which I do not mind) like so:

self.model.generate(
            **(processed_inputs.to(self.device, torch.float16)),
        ) 

This time I get:

ValueError: Whisper expects the mel input features to be of length 3000, but found 2382. Make sure to pad the input mel features to 3000.

for which I could not find a satisfactory explanation yet, so any help would be much appreciated. Thanks!

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

def predict_batch(
        self, df_batch: pd.DataFrame, column_to_transcribe_into: str
    ) -> pd.DataFrame:
        inputs: list[np.ndarray] = df_batch[COLUMN_AUDIO_DATA].tolist()
        processed_inputs: BatchFeature = self.processor(
            inputs,
            return_tensors="pt",
            truncation=False,
            padding="longest",
            return_attention_mask=True,
            sampling_rate=self.asr_params.sampling_rate,
        )
        # However at this line, we get the error:
        # ValueError: Whisper expects the mel input features to be of length 3000, but 
        # found 2382. Make sure to pad the input mel features to 3000.
        # To which I could not find a solution yet, hence this branch should remain unmerged to master.
        results = self.model.generate(
            **(processed_inputs.to(self.device, torch.float16)),
        ) 
        df_batch[column_to_transcribe_into] = [str(r["text"]).strip() for r in results]
        return df_batch

Expected behavior

Transcribed results of Whisper, ideally with timestamps

@amyeroberts
Copy link
Collaborator

cc @sanchit-gandhi

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented May 16, 2024

Hey @kerem0comert! Thanks for the detailed reproducer. There are three paradigms for Whisper generation:

  1. Short-form generation: transcribe an audio segment less than 30-seconds by padding/truncating it to 30-seconds, and passing it to the model in one go (this is your standard generation strategy)
  2. Chunked long-form generation: chunk a long audio (>30-seconds) into 30-second segments, and pass each chunk to the model in parallel (i.e. through batching)
  3. Sequential long-form generation: take a long audio file, and generate for the first 30-seconds. Use the last predicted timestamp to slide your window forward, and predict another 30-second segment. Repeat this until you have generated for your full audio

I explain these decoding strategies in full in this YouTube video: https://www.youtube.com/live/92xX-E2y4GQ?si=GBxyimNo9-4z1tx_&t=1919

The corresponding code for each of these generation strategies are detailed on the Distil-Whisper model card: https://huggingface.co/distil-whisper/distil-large-v3#transformers-usage

It's not apparent from the documentation how to use each of these strategies -> this is something we should definitely highlight better on the docs.

The problem you're facing is that you have a short-form audio (<30-seconds), but are not padding/truncating it to 30-seconds before passing it to the model. This throws an error, since Whisper expects fixed inputs of 30-seconds. You can remedy this quickly by changing your args to the feature extractor:

def predict_batch(
        self, df_batch: pd.DataFrame, column_to_transcribe_into: str
    ) -> pd.DataFrame:
        inputs: list[np.ndarray] = df_batch[COLUMN_AUDIO_DATA].tolist()
        processed_inputs: BatchFeature = self.processor(
            inputs,
            return_tensors="pt",
-           truncation=False,
-           padding="longest",
            return_attention_mask=True,
            sampling_rate=self.asr_params.sampling_rate,
        )

        results = self.model.generate(
            **(processed_inputs.to(self.device, torch.float16)),
        ) 
        df_batch[column_to_transcribe_into] = [str(r["text"]).strip() for r in results]
        return df_batch

Your code will now work for short-form generation, but not sequential long-form! To handle both automatically, I suggest you use the following code:

def predict_batch(
        self, df_batch: pd.DataFrame, column_to_transcribe_into: str
    ) -> pd.DataFrame:
        inputs: list[np.ndarray] = df_batch[COLUMN_AUDIO_DATA].tolist()
        # assume we have long-form audios
        processed_inputs: BatchFeature = self.processor(
            inputs,
            return_tensors="pt",
            truncation=False,
            padding="longest",
            return_attention_mask=True,
            sampling_rate=self.asr_params.sampling_rate,
        )
        if processed_inputs.input_features.shape[-1] < 3000:
            # we in-fact have short-form -> pre-process accordingly
            processed_inputs: BatchFeature = self.processor(
            inputs,
            return_tensors="pt",
            sampling_rate=self.asr_params.sampling_rate,
        )
        results = self.model.generate(
            **(processed_inputs.to(self.device, torch.float16)),
        ) 
        df_batch[column_to_transcribe_into] = [str(r["text"]).strip() for r in results]
        return df_batch

What we're doing is first assuming we have a long-audio segment. If we compute the log-mel features and in-fact find we have a short-form audio, then we re-compute the log-mel with padding and truncation to 30-seconds.

You can use a similar logic to pass the long-form kwargs condition_on_prev_tokens, logprob_threshold and compression_ratio_threshold to the model if you're doing long-form generation.

@kerem0comert
Copy link
Author

Thanks for your very detailed response, this was indeed it!

Just for completeness, I had to make a small change - since your version returned the tokens but I need the detokenized text version:

def predict_batch(
        self, df_batch: pd.DataFrame, column_to_transcribe_into: str
    ) -> pd.DataFrame:
        inputs: list[np.ndarray] = df_batch[COLUMN_AUDIO_DATA].tolist()
        # assume we have long-form audios
        processed_inputs: BatchFeature = self.processor(
            inputs,
            return_tensors="pt",
            truncation=False,
            padding="longest",
            return_attention_mask=True,
            sampling_rate=self.asr_params.sampling_rate,
        )
        if processed_inputs.input_features.shape[-1] < 3000:
            # we in-fact have short-form -> pre-process accordingly
            processed_inputs = self.processor(
                inputs,
                return_tensors="pt",
                sampling_rate=self.asr_params.sampling_rate,
            )
        result_tokens = self.model.generate(
            **(processed_inputs.to(self.device, torch.float16)),
        )
        results_texts = self.processor.batch_decode(
            result_tokens, skip_special_tokens=True
        )
        df_batch[column_to_transcribe_into] = [
            str(result_text).strip() for result_text in results_texts
        ]
        return df_batch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants