You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# we have no information about whether the segments follow on sequentially
# so we just ensure the same speaker as we concatenate across files
audio_sample=np.append(audio_sample, audio[idx])
# extra spaces in the text transcription don't matter, since we only use it for the WER computation
text_sample+=" "+text[idx]
else:
# speakers do not follow sequentially, save the audio and start looping again
concatenated_audio.append(audio_sample)
concatenated_text.append(text_sample)
concatenated_speaker.append(speaker)
condition_on_prev.append(0)
audio_sample=audio[idx]
text_sample=text[idx]
else:
# concatenated audio exceeds max length, save the audio and start looping again
concatenated_audio.append(audio_sample)
concatenated_text.append(text_sample)
concatenated_speaker.append(speaker)
condition_on_prev.append(1)
audio_sample=audio[idx]
text_sample=text[idx]
From my understanding, the logic in the for loop is
If either:
Adding the current utterance to audio_sample exceeds 30s
The current speaker is different from previous (prev_speaker)
Then save the concatenation up to the previous utterance (audio_sample), excluding the current utterance.
Since the concatenated sample does not contain the current utterance, we have:
The appended speaker should be previous_speaker rather than speaker
condition_on_prev signifies continuity at the start of current utterance, so this should be shifted to the right by 1 (e.g. initialize as condition_on_prev = [0])
Meanwhile, it seems that the very last accumulated sample in each batch did not get appended, i.e. when the for loop exits, there will be a (audio_sample, text_sample) pair that is <= 30s which should've been appended but didn't.
These may not seem significant, but when finetuning on custom dataset with diverse speakers, and condition_on_prev is expected to be true alot, it will cause wrongful training signals.
The text was updated successfully, but these errors were encountered:
In
concatenate_dataset()
:distil-whisper/training/run_pseudo_labelling.py
Lines 644 to 671 in 66ac8dd
From my understanding, the logic in the for loop is
audio_sample
exceeds 30sspeaker
is different from previous (prev_speaker
)audio_sample
), excluding the current utterance.Since the concatenated sample does not contain the current utterance, we have:
previous_speaker
rather thanspeaker
condition_on_prev
signifies continuity at the start of current utterance, so this should be shifted to the right by 1 (e.g. initialize ascondition_on_prev = [0]
)Meanwhile, it seems that the very last accumulated sample in each batch did not get appended, i.e. when the for loop exits, there will be a
(audio_sample, text_sample)
pair that is <= 30s which should've been appended but didn't.These may not seem significant, but when finetuning on custom dataset with diverse speakers, and condition_on_prev is expected to be true alot, it will cause wrongful training signals.
The text was updated successfully, but these errors were encountered: