Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wav2Vec2Pretrain (HFTransformersInterface implementation) samples padded values for mask_time_indices and negative_sample_indices #2386

Open
porfirythelaw opened this issue Feb 2, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@porfirythelaw
Copy link

porfirythelaw commented Feb 2, 2024

Describe the bug

I've been using SpeechBrain Wav2Vec2 training recipe (with HF integration) on my own data, and noticed that I get significantly different metrics with the same model on validation dataset depending on the amount of padding in the batch. My hypothesis was that somehow padding is not ignored during indices sampling process, and I think this is what in fact is happening.

mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.mask_prob,
mask_length=self.mask_length,
)

As you can see in this function you don't provide attention mask, so masked indices are drawn from padded values as well.

Same for negative masked indicies, which you take from the whole sequence

negative_sample_indices = torch.tensor(
transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices(
(batch_size, sequence_length),
num_negatives=self.config.num_negatives,
mask_time_indices=full_sentence_indices,
),
device=wav.device,
dtype=torch.long,
)

You provide attention mask in this call to the model

self.model(
wav,
mask_time_indices=torch_mask_time_indices,
sampled_negative_indices=negative_sample_indices,
attention_mask=padding_mask,
),
torch_mask_time_indices,
)

However, if you check hugginface source code it does not affect loss calculation, it only affects encoder self-attention.

I'm not sure if this behavior was intended or not.

Expected behaviour

Padded values should not be influencing model loss / metrics.

To Reproduce

No response

Environment Details

Speechbrain v0.5.16

Relevant Log Output

No response

Additional Context

No response

@porfirythelaw porfirythelaw added the bug Something isn't working label Feb 2, 2024
@Adel-Moumen
Copy link
Collaborator

Hey @TParcollet, could you please have a look?

@porfirythelaw
Copy link
Author

porfirythelaw commented Feb 14, 2024

My local fix is something like this (using features_padding_mask):

   padding_mask = make_padding_masks(wav, wav_len=wav_lens)
   features_padding_mask = self.model._get_feature_vector_attention_mask(
            sequence_length, padding_mask, add_adapter=False
        )

    # 1. Compute the indices that will be masked
    mask_time_indices = _compute_mask_indices(
        (batch_size, sequence_length),
        mask_prob=self.mask_prob,
        mask_length=self.mask_length,
        attention_mask=features_padding_mask

    )
    torch_mask_time_indices = torch.tensor(
        mask_time_indices, device=wav.device, dtype=torch.long,
    )

    # 2. Sample the negative samples from the entire sequence.
    # Fairseq does it only on the masked indices, but this only work if you
    # have long sentences. For more versatily, we sample on the entire sequence.
    # value.
    full_sentence_indices = np.ones((batch_size, sequence_length))

    negative_sample_indices = torch.tensor(
        transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices(
            (batch_size, sequence_length),
            num_negatives=self.config.num_negatives,
            # mask_time_indices=full_sentence_indices,
            mask_time_indices=features_padding_mask.detach().cpu().numpy()

        ),
        device=wav.device,
        dtype=torch.long,
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants