Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend multimodal/speech_llm with lhotse, t5 and bestow supports #9169

Open
wants to merge 453 commits into
base: main
Choose a base branch
from

Conversation

zhehuaichen
Copy link
Collaborator

What does this PR do ?

In multimodal/speech_llm, add lhotse dataloader support and two models, SALM-T5 and Bestow-GPT. Include example configs.

Main features under speech_llm

  • Lhotse dataloader support for speech SFT in speech_llm
  • SALM-style architecture with T5 LLM backbone
  • Bestow-style architecture (cross-attention based) with GPT LLM backbone

Minor edit in nlp collection:

  • megatron_base_model.py: handle the case tokenizer.type is not set
  • megatron_lm_encoder_decoder_model.py: hanlde the case encoder_input is used
  • megatron_base_prompt_learning_model.py: group the llm init code under init_model function (follow the pattern from megatron_gpt_prompt_learning_model.py) so that it can be overwritten by subclass when needed
  • megatron/utils.py: in gradient accumulation, handle the case where the batch size from dynamic bucketing is not divisible. This happens when using lhotse dataloader with batch_duration

Collection: [common,nlp,multimodal]

PR Type:

  • New Feature
  • Bugfix
  • Documentation

zhehuaichen and others added 30 commits November 24, 2023 18:45
Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Signed-off-by: stevehuang52 <heh@nvidia.com>
Signed-off-by: stevehuang52 <heh@nvidia.com>
Signed-off-by: stevehuang52 <heh@nvidia.com>
Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
… canary_speechllm1_cross_t5_pr3

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
…ross_t5_pr3

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
Signed-off-by: zhehuaichen <zhehuaichen@users.noreply.github.com>
Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CodeQL found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

@zhehuaichen zhehuaichen marked this pull request as ready for review May 11, 2024 04:03
@@ -0,0 +1,361 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you put this file under conf/bestow/* ?

# See the License for the specific language governing permissions and
# limitations under the License.

name: megatron_audio_gpt_salm_lhotse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you put this file under conf/salm/* ?

# See the License for the specific language governing permissions and
# limitations under the License.

name: megatron_audio_t5_salm_lhotse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you put this file under conf/salm/* ?

# See the License for the specific language governing permissions and
# limitations under the License.

import copy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename file to modular_t5_models.py and the previous one as modular_gpt_models.py, or maybe even change the modular prefix to something else

vectors = collate_vectors_lhotse(items, padding_value=padding_value)
if max_length > vectors.size(1):
vectors = torch.cat(
[vectors, padding_value * torch.ones(vectors.size(0), max_length - vectors.size(1), dtype=vectors.dtype)],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to enforce a static shape with padding for every example here?

return (n + m - 1) // m * m


class TextProcessing:
Copy link
Collaborator

@pzelasko pzelasko May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class needs more documentation on what is it doing, how to use its API, and what are the expected input and output formats. Also, it only has private methods right now, the main API method should be public (no underscore at the beginning).

I'd expect a docstring of kind: this class is used to convert X to Y. in order to do so, it performs A, B, C, and D. the expect format of X is .... the expected format of Y is ...

since it's used to convert text to prompts to token ids, I'd like to see full documentation of the prompt template/schema

the options to init also need documentation, if some are unused/unnecessary they may be removed

return processed_example


def convert_canary_prompt_to_text(prompt, is_canary_tokens_augment):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand why this function is built the way it is but for future experiments let's try to move away from canary special token conversion and design a configurable prompting setup instead.

by Lhotse samplers instead.
"""

def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init args are not documented

tokens_to_generate: int,
pad_to_max_length: bool,
max_seq_length: int,
noise_cuts: Optional = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a left-over from before lhotse nemo dataset was refactored; remove noise_cuts in this class.

conf['manifest_filepath'] = cur_manifest_filepath
question_file_set = data_cfg.get('question_file_set', None)
if question_file_set is not None:
conf['question_file_set'] = [question_file_set[dataset_idx]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless you add question_file_set to LhotseDataLoadingConfig it will be discarded at the beginning of get_lhotse_dataloader_from_config

from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config

# for eval, we need to create separate dataset so as to report splitted numbers
if data_cfg.get('is_tarred', False) or (is_eval == False and is_predict == False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you shouldn't rely on the flag is_tarred for choosing lhotse dataloader, lhotse supports more formats than nemo json and nemo tar and auto-deduces is_tarred

freeze_modality_adapter: False
load_audio_encoder: True

global_batch_size: 128
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do these settings of global/micro batch size work with lhotse dynamic batch sizes?

bucketing_batch_size: null
use_lhotse: True
duration_bins : [2,4,6,8,10,12,14,16,18]
lhotse:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an old config from before we merged lhotse dataloading to main. please update lhotse related options here (and in other configs if needed)

average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
num_classes: null

# test_ds:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

Copy link
Collaborator

@stevehuang52 stevehuang52 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work, please address the CodeQL issues and see the minor comments.

from nemo.utils import logging


def build_salm_dataset(model_instance, data_cfg, is_train):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better not include salm in the function name, so as to make it more general

return (n + m - 1) // m * m


class TextProcessing:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a copy or a modified version of the TextProcessing class in `audio_text_dataset? If it's a modified version, we should inherit from the parent class and only overwrite the necessary functions.

return processed_example


def convert_canary_prompt_to_text(prompt, is_canary_tokens_augment):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. can we make this function directly scalable to more languages, and flexible to adapt to new changes in canary prompt?
  2. where is the format "<|fr|" and "<|transcribe|" defined?

random_context_prob: float = 0.0,
random_context_positive_percent: float = 0.1,
):
from lhotse.dataset import AudioSamples, CutMix
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please put the import to the top of the file

self.random_context_prob = random_context_prob
self.random_context_positive_percent = random_context_positive_percent

def _inject_random_context_into_question(self, cut, random_context_num=8, random_context_positive_percent=0.1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we remove random context? it's usage is only limited to word boosting and might hurt performance when doing multi-task training

*args,
**kwargs,
):
assert input_embeds.shape[-1] == encoder_states.shape[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add docstrings

@@ -155,3 +155,15 @@ def align_feat_seq_list(
new_seq_list.append(new_seq)
new_seq_len_list.append(new_seq_len)
return new_seq_list, new_seq_len_list


def to_cuda(inputs, non_blocking=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -417,8 +417,8 @@ def split_list(inputs, num_chunks):
"""
Split a list into equal sized chunks
"""
# if len(inputs) % chunk_size != 0, round down the chunk size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line if it's not used

@@ -442,8 +442,11 @@ def get_iterator_k_split(batch: Union[Dict, List[torch.Tensor]], num_microbatche

# Split tensor items
items = list(tensor_items.items())
assert items[0][1].shape[0] % num_microbatches == 0, "Issue with batch size configuration!"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need someone from NLP to review the change on removing the constraint assert items[0][1].shape[0] % num_microbatches == 0

@@ -246,12 +246,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
self.use_fsdp = cfg.get('fsdp', False)

def setup_transformer_engine_tp_groups(self):
""" This should be called after model parallel groups have been initialized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please go though the changes and undo anything that's not necessary

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants