Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG FIX] Fix world_size bug in QuickStart Example #747

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Mr-Philo
Copy link

BUG Description

When I entered the developer guide https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start, and running the given example python file run_simple_mcore_train_loop.py, the terminal didn't respond for nearly an hour and throw an exception:

/data/Megatron-LM/examples/# torchrun --nproc-per-node=2 simple_megatron_transformer.py                            
[2024-03-22 03:18:07,974] torch.distributed.run: [WARNING]                                                                                              
[2024-03-22 03:18:07,974] torch.distributed.run: [WARNING] *****************************************                                                    
[2024-03-22 03:18:07,974] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid yo
ur system being overloaded, please further tune the variable for optimal performance in your application as needed.                                     
[2024-03-22 03:18:07,974] torch.distributed.run: [WARNING] *****************************************                                                    
Traceback (most recent call last):                                                                                                                      
  File "/data/Megatron-LM/examples/run_simple_mcore_train_loop.py", line 102, in <module>                                          
    initialize_distributed(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)                                                                
  File "/data/Megatron-LM/examples/run_simple_mcore_train_loop.py", line 25, in initialize_distributed                             
    torch.distributed.init_process_group(world_size=world_size, rank=rank)                                                                              
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py", line 74, in wrapper                                                  
    func_return = func(*args, **kwargs)                                     
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 1153, in init_process_group                                
    default_pg, _ = _new_process_group_helper(                              
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 1269, in _new_process_group_helper                         
    backend_class = ProcessGroupGloo(backend_prefix_store, group_rank, group_size, timeout=timeout)                                                     
RuntimeError: Socket Timeout

BUG Reason

initialize_distributed() function in Megatron-LM/examples/run_simple_mcore_train_loop.py, the world_size is set to torch.cuda.device_count(). However, this activity is actually wrong if user is running this script on a 8-gpus node, but setting torchrun --nproc-per-node to any number that is not 8. This will cause the world_size is not consistent with the gpus actually used in the script. Even worse, it may cause the terminal not responding for a very long time.

BUG Reproduce

Whenever you set the number of torchrun --nproc-per-node not consistent with the total number of gpus on this machine you're using.

BUG Fix

-   world_size = torch.cuda.device_count()
+   world_size = int(os.environ["WORLD_SIZE"])

This change will fix this bug and avoid causing an exception without the need of changing the running command. Meanwhile, this change is adopted both in script file examples/run_simple_mcore_train_loop.py and doc md file megatron/core/QuickStart.md

Copy link

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label May 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale No activity in 60 days on issue or PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant