Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llama3 and distributed checkpoint support in NeVA #9101

Merged
merged 67 commits into from May 22, 2024
Merged

Conversation

yaoyu-33
Copy link
Collaborator

@yaoyu-33 yaoyu-33 commented May 2, 2024

What does this PR do ?

Add llama3 and distributed checkpoint support in NeVA

Collection: [multimodal]

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

yaoyu-33 and others added 30 commits March 15, 2024 18:27
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
# Conflicts:
#	examples/multimodal/vision_language_foundation/clip/megatron_clip_pretrain.py
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
# Conflicts:
#	nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
yaoyu-33 and others added 8 commits May 9, 2024 19:52
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
# Conflicts:
#	nemo/collections/multimodal/data/neva/neva_dataset.py
#	nemo/collections/nlp/parts/nlp_overrides.py
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33 yaoyu-33 changed the title Add llama3 support in NeVA Add llama3 and distributed checkpoint support in NeVA May 13, 2024
@yaoyu-33 yaoyu-33 requested a review from mikolajblaz May 13, 2024 17:59
@@ -169,7 +169,7 @@ def eval_model(args):
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument("--question-file", type=str, default="tables/question.json")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
parser.add_argument("--conv-mode", type=str, default="llava_v0")
parser.add_argument("--conv-mode", type=str, default="llava_v0") # this flag has no use!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then should we get rid of it..?

@@ -487,6 +544,7 @@ def __init__(self, model):
is_multimodal=self.data_cfg.is_multimodal,
sep_image_conv_front=self.data_cfg.sep_image_conv_front,
conv_template=self.data_cfg.get("conv_template", "nvgpt"),
model_type=self.cfg.mm_cfg.llm.get("model_type", "nvgpt"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we rename this to nemotron?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that will break many old checkpoint, we might just keep this way for a while...

yaoyu-33 and others added 3 commits May 13, 2024 16:24
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@@ -75,19 +80,24 @@ def load_checkpoint(
else:
sharded_strategy = None

if not strict:
for key in list(sharded_state_dict['state_dict'].keys()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is a temporary implementation for a strict flag because we also need to notify the user which keys are skipped.

Also, this will work only with Zarr ckpt format, for PyT Distributed it will be different

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I don't think it's correct, because we should use the sharded key, not the state dict key (also doesn't account for nested dicts).
Can you try something like this (I didn't run this code so there might be errors)?
This should work for all backends, nested dicts and use correct keys.


from megatron.core.dist_checkpointing.dict_utils import extract_matching_values
from megatron.core.dist_checkpointing.mapping import ShardedBase


    if not strict:
        sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)
    ...

    def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]):
        ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path)
        loaded_keys = []
        missing_keys = []
        unexpected_keys = []

        def should_remove_missing_sharded_base(x: Any):
            if isinstance(x, ShardedBase):
                if x.key in ckpt_sharded_metadata:
                    loaded_keys.append(x.key)
                    return False
                else:
                    unexpected_keys.append(x.key)
                    return True
            return False

        _, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base)
        logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}')

        # TODO: compute missing_keys by:
        #  1. all_gather_object of loaded_keys
        #  2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys
        return sharded_state_dict

sharded_state_dict = super().sharded_state_dict(prefix=prefix, sharded_offsets=sharded_offsets, **kwargs)

state_dict = self.state_dict(prefix='', keep_vars=True)
state_dict.pop('weight')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is weight not needed at all?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight already be take care of in super

Comment on lines 249 to 256
for layer_name in state_dict.keys():
tensor = state_dict[layer_name]
layer_key = f'{prefix}{layer_name}'
sharded_state_dict[layer_key] = make_sharded_tensor_for_checkpoint(
tensor,
layer_key,
prepend_offsets=sharded_offsets,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be replaced with

from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
     ...
     sharded_state_dict.update(make_sharded_tensors_for_checkpoint(state_dict))

@yaoyu-33 yaoyu-33 removed the Run CICD label May 16, 2024
yaoyu-33 and others added 4 commits May 16, 2024 10:49
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu-33@users.noreply.github.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
# Conflicts:
#	nemo/utils/callbacks/dist_ckpt_io.py
def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]):
ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path)
loaded_keys = []
missing_keys = []

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable missing_keys is not used.
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'os' is not used.
# Conflicts:
#	nemo/collections/multimodal/parts/utils.py
@yaoyu-33 yaoyu-33 merged commit d7bb403 into main May 22, 2024
133 checks passed
@yaoyu-33 yaoyu-33 deleted the yuya/neva_llama3 branch May 22, 2024 03:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants