New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch/distributed] Bugfix: wait for all child procs to exit before c… #125969
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125969
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 0bd79bd with merge base 7f1d5ab (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
f406897
to
72b7500
Compare
# At this point workers finished running the user function | ||
# But the child process might still have not exited. Wait for them. | ||
# pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits. | ||
while not self._pc.join(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we have a timeout on this? wondering what happens if we have a dead/hung worker process?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you reach this line, we've already validated that either:
- the entrypoint function actually ran and returned a result
- -- or -- at least one of the child procs have failed (and a SIGTERM was sent to the rest)
We're waiting for the spawned child proc to exit after the user-provided function has already returned.
This potentially could hang but we were waiting for _pc.join()
indefinitely before this change as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it -- sgtm
# At this point workers finished running the user function | ||
# But the child process might still have not exited. Wait for them. | ||
# pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits. | ||
while not self._pc.join(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it -- sgtm
…losing torch.distributed.elastic.multiprocessing.api.ProcessContext
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
pytorch#125969) Observed Problem --------------------- When `torchrun` has finished running the main trainer function (aka entrypoint/user function) successfully, I noticed that sometimes it SIGTERMS the child processes. Then `torchrun` exits successfully. This results in misleading warning log messages towards the end of the job like the one below: ``` W0510 14:52:48.185934 672413 api.py:513] Closing process 675171 via signal SIGTERM W0510 14:52:48.185984 672413 api.py:513] Closing process 675172 via signal SIGTERM W0510 14:52:48.186013 672413 api.py:513] Closing process 675174 via signal SIGTERM # <---- ^^^ ??? everything runs successfully but child still SIGTERM'ed? ^^^ ---> I0510 14:52:48.229119 672413 api.py:877] [main] worker group successfully finished. Waiting 300 seconds for other agents to finish. I0510 14:52:48.229161 672413 api.py:922] Local worker group finished (WorkerState.SUCCEEDED). Waiting 300 seconds for other agents to finish I0510 14:52:48.229395 672413 api.py:936] Done waiting for other agents. Elapsed: 0.0001709461212158203 seconds I0510 14:52:48.257544 672413 dynamic_rendezvous.py:1131] The node 'localhost_672413_0' has closed the rendezvous 'torchrun_qpfd'. I0510 14:52:48.568198 672413 distributed.py:200] Deleting temp log directory: /tmp/torchrun_udgp8zoq I0510 14:52:48.568989 672413 distributed.py:202] Finished running `main` ``` Root Cause ------------------ I noticed that this was due to the incorrect usage of `torch.multiprocessing.ProcessContext.join()` in `torch.distributed.elastic.multiprocessing.api.MultiprocessingContext`. `torch.multiprocessing.ProcessContext.join()` does not actually wait for ALL child procs to exit, but rather waits for **at-least-one** child proc to exit. If only a subset of the child procs have exited, it returns `False` and if all child procs have exited it returns `True`. `torch.distributed.elastic.multiprocessing.api.MultiprocessingContext` was assuming that `torch.multiprocessing.ProcessContext.join()` blocks indefinitely until all child procs have exited. Fix --------- The fix is simple, just loop, while continuing to call `pc.join()` until it returns `True` > **NOTE**: that the indefinite blocking is NOT an issue since by the time `torch.distributed.elastic.multiprocessing.api.MultiprocessingContext` calls `pc.join()` it already did all the checking to validate that the entrypoint functions either return successfully or that one of them has failed. So we are really just waiting for the unix process to exit after running the entrypoint function. > **NOTE**: since `pc.join()` already blocks until at-least-one child proc exits, there is no need to add a polling interval in the body of the loop and the debug logging will show at most `nproc_per_node` times so no log spamming is observed. Pull Request resolved: pytorch#125969 Approved by: https://github.com/d4l3k
Observed Problem
When
torchrun
has finished running the main trainer function (aka entrypoint/user function) successfully, I noticed that sometimes it SIGTERMS the child processes. Thentorchrun
exits successfully.This results in misleading warning log messages towards the end of the job like the one below:
Root Cause
I noticed that this was due to the incorrect usage of
torch.multiprocessing.ProcessContext.join()
intorch.distributed.elastic.multiprocessing.api.MultiprocessingContext
.torch.multiprocessing.ProcessContext.join()
does not actually wait for ALL child procs to exit, but rather waits for at-least-one child proc to exit. If only a subset of the child procs have exited, it returnsFalse
and if all child procs have exited it returnsTrue
.torch.distributed.elastic.multiprocessing.api.MultiprocessingContext
was assuming thattorch.multiprocessing.ProcessContext.join()
blocks indefinitely until all child procs have exited.Fix
The fix is simple, just loop, while continuing to call
pc.join()
until it returnsTrue
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k