You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Runing the Pretraining BERT encountered two issues:
the "TransformerEngine only supports softmax compute in FP32". Need to add --attention-softmax-in-fp32 to the model arguments. This applies to Pretraining GPT pretrain_gpt.sh too.
The attention mask is of dimension [B, 1, max_seqlen, max_seqlen]; however, the function get_cu_seqlens expects its shape to be [B, 1, 1, max_seqlen]. The training crashes. See the log below.
To Reproduce
run the example: ./examples/pretrain_bert.sh in the docker image nvcr.io/nvidia/pytorch:24.02-py3 with the main branch of Megatron-LM. The issues was found in the core_r0.6.0 branch too.
Expected behavior
expect the example runs out of box.
Stack trace/logs
[after dataloaders are built] datetime: 2024-04-23 00:29:39
done with setup ...
(min, max) time across ranks (ms):
model-and-optimizer-setup ......................: (5967.29, 5967.29)
train/valid/test-data-iterators-setup ..........: (128.70, 128.70)
training ...
[before the start of training step] datetime: 2024-04-23 00:29:39
torch.Size([4, 1, 512, 512])
Traceback (most recent call last):
File "/pscratch/sd/x/xju/LLMTracking/Megatron-LM/pretrain_bert.py", line 194, in <module>
pretrain(train_valid_test_datasets_provider, model_provider,
File "/pscratch/sd/x/xju/LLMTracking/Megatron-LM/megatron/training/training.py", line 270, in pretrain
iteration, num_floating_point_operations_so_far = train(
File "/pscratch/sd/x/xju/LLMTracking/Megatron-LM/megatron/training/training.py", line 990, in train
train_step(forward_step_func,
File "/pscratch/sd/x/xju/LLMTracking/Megatron-LM/megatron/training/training.py", line 541, in train_step
losses_reduced = forward_backward_func(
File "/pscratch/sd/x/xju/LLMTracking/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 356, in forward_backward_no_pipelining
output_tensor = forward_step(
File "/pscratch/sd/x/xju/LLMTracking/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 192, in forward_step
output_tensor, loss_func = forward_step_func(data_iterator, model)
File "/pscratch/sd/x/xju/LLMTracking/Megatron-LM/pretrain_bert.py", line 139, in forward_step
output_tensor = model(tokens, padding_mask,
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/pscratch/sd/x/xju/LLMTracking/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py", line 179, in forward
return self.module(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/pscratch/sd/x/xju/LLMTracking/Megatron-LM/megatron/legacy/model/module.py", line 190, in forward
outputs = self.module(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/pscratch/sd/x/xju/LLMTracking/Megatron-LM/megatron/legacy/model/bert_model.py", line 182, in forward
lm_output = self.language_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/pscratch/sd/x/xju/LLMTracking/Megatron-LM/megatron/legacy/model/language_model.py", line 493, in forward
encoder_output = self.encoder(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/pscratch/sd/x/xju/LLMTracking/Megatron-LM/megatron/legacy/model/transformer.py", line 1777, in forward
hidden_states = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py", line 625, in forward
self_attention_outputs = self.self_attention(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 3461, in forward
context_layer = self.core_attention(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 2724, in forward
return self.fused_attention(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 417, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 2055, in forward
_cu_seqlens_q = get_cu_seqlens(attention_mask)
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 166, in get_cu_seqlens
cu_seqlens = torch.cat((zero, cu_seqlens))
RuntimeError: Tensors must have same number of dimensions: got 1 and 2
Environment (please complete the following information):
Used the docker image: nvcr.io/nvidia/pytorch:24.02-py3.
Megatron-LM commit ID: ccfeda4
PyTorch version: 2.3.0a0+ebedce2
CUDA version: 12.3
NCCL version 2.20.3
Proposed fix
N/A
Additional context
N/A
The text was updated successfully, but these errors were encountered:
Describe the bug
Runing the Pretraining BERT encountered two issues:
--attention-softmax-in-fp32
to the model arguments. This applies to Pretraining GPTpretrain_gpt.sh
too.[B, 1, max_seqlen, max_seqlen]
; however, the functionget_cu_seqlens
expects its shape to be[B, 1, 1, max_seqlen]
. The training crashes. See the log below.To Reproduce
run the example:
./examples/pretrain_bert.sh
in the docker imagenvcr.io/nvidia/pytorch:24.02-py3
with themain
branch of Megatron-LM. The issues was found in thecore_r0.6.0
branch too.Expected behavior
expect the example runs out of box.
Stack trace/logs
Environment (please complete the following information):
Used the docker image:
nvcr.io/nvidia/pytorch:24.02-py3
.ccfeda4
2.3.0a0+ebedce2
12.3
2.20.3
Proposed fix
N/A
Additional context
N/A
The text was updated successfully, but these errors were encountered: