Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchrun c10d backend doesn't seem to work with python 3.12, giving segmentation fault because of calling obmalloc without holding GIL #125990

Open
TanyaAdams1 opened this issue May 11, 2024 · 7 comments
Assignees
Labels
high priority oncall: distributed Add this issue/PR to distributed oncall triage queue triage review

Comments

@TanyaAdams1
Copy link

TanyaAdams1 commented May 11, 2024

馃悰 Describe the bug

TLDR: It seems like Python 3.12 updated the way GIL works, and now using torch distributed (especially c10d rdzv backend) will trigger a segmentation fault. After debugging, I believe that this error was triggered by calling object allocation function without holding GIL.

To reproduce this bug, first create any new conda environment: conda create -n torch, then follow the installation instruction on torch website: conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia. During this step, conda will by default download a very new version of python (which is python 3.12.3 for me), then run torchrun with any random script name: torchrun --standalone --nproc-per-node 4 random_name.py (because the program will crash even before launching the script!) Here's the error message I got:

[2024-05-10 22:43:34,776] torch.distributed.run: [WARNING] *****************************************
[2024-05-10 22:43:34,776] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-05-10 22:43:34,776] torch.distributed.run: [WARNING] *****************************************
Fatal Python error: Segmentation fault

Current thread 0x00002b7933234740 (most recent call first):
  File ".../lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 113 in _call_store
  File ".../lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 64 in __init__
  File ".../lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 253 in create_backend
  File ".../lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/registry.py", line 36 in _create_c10d_handler
  File ".../lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/api.py", line 258 in create_handler
  File ".../lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/registry.py", line 66 in get_rendezvous_handler
  File ".../lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 238 in launch_agent
  File ".../lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 135 in __call__
  File ".../lib/python3.12/site-packages/torch/distributed/run.py", line 803 in run
  File ".../lib/python3.12/site-packages/torch/distributed/run.py", line 812 in main
  File ".../lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347 in wrapper
  File ".../bin/torchrun", line 33 in <module>

Extension modules: mkl._mklinit, mkl._py_mkl_service, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special (total: 22)
Segmentation fault (core dumped)

I tried to debug this using gdb: gdb --args python -m torch.distributed.launch --standalone --nproc-per-node 4 random_name.py, and here's the output:

0x000055555574498d in _PyInterpreterState_GET () at /usr/local/src/conda/python-3.12.3/Include/internal/pycore_pystate.h:133
warning: 133    /usr/local/src/conda/python-3.12.3/Include/internal/pycore_pystate.h: No such file or directory
(gdb) bt
#0  0x000055555574498d in _PyInterpreterState_GET () at /usr/local/src/conda/python-3.12.3/Include/internal/pycore_pystate.h:133
#1  get_state () at /usr/local/src/conda/python-3.12.3/Objects/obmalloc.c:866
#2  _PyObject_Malloc (nbytes=45, ctx=<optimized out>) at /usr/local/src/conda/python-3.12.3/Objects/obmalloc.c:1563
#3  PyObject_Malloc (size=45) at /usr/local/src/conda/python-3.12.3/Objects/obmalloc.c:801
#4  0x000055555575d125 in _PyBytes_FromSize (use_calloc=0, size=12) at /usr/local/src/conda/python-3.12.3/Objects/bytesobject.c:102
#5  PyBytes_FromStringAndSize (str=0x5555582ba040 "Y2FuaW1hZGFtUU", size=12) at /usr/local/src/conda/python-3.12.3/Objects/bytesobject.c:134
#6  0x00002aaab4aeda35 in pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(c10d::Store&, std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)#28}, pybind11::bytes, c10d::Store&, std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::call_guard<pybind11::gil_scoped_release>, char [888]>(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(c10d::Store&, std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)#28}&&, pybind11::bytes (*)(c10d::Store&, std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::call_guard<pybind11::gil_scoped_release> const&, char const (&) [888])::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) ()
   from .../lib/python3.12/site-packages/torch/lib/libtorch_python.so
#7  0x00002aaab42a7123 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) ()
   from .../lib/python3.12/site-packages/torch/lib/libtorch_python.so
...

Downgrading python back to 3.10 solves the problem for me now, but given that 3.12.3 is downloaded by conda by default, updating how pytorch handles GIL should be the right way to go.

Versions

Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.35

Python version: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.114.2.el7.x86_64-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB

Nvidia driver version: 550.54.14
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Address sizes:       48 bits physical, 48 bits virtual
Byte Order:          Little Endian
CPU(s):              128
On-line CPU(s) list: 0-127
Vendor ID:           AuthenticAMD
Model name:          AMD EPYC 7763 64-Core Processor
CPU family:          25
Model:               1
Thread(s) per core:  1
Core(s) per socket:  64
Socket(s):           2
Stepping:            1
BogoMIPS:            4890.76
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc art rep_good nopl nonstop_tsc extd_apicid aperfmperf eagerfpu pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_l2 cpb cat_l3 cdp_l3 invpcid_single hw_pstate sme ssbd rsb_ctxsw ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq overflow_recov succor smca
Virtualization:      AMD-V
L1d cache:           4 MiB (128 instances)
L1i cache:           4 MiB (128 instances)
L2 cache:            64 MiB (128 instances)
L3 cache:            512 MiB (16 instances)
NUMA node(s):        8
NUMA node0 CPU(s):   0-15
NUMA node1 CPU(s):   16-31
NUMA node2 CPU(s):   32-47
NUMA node3 CPU(s):   48-63
NUMA node4 CPU(s):   64-79
NUMA node5 CPU(s):   80-95
NUMA node6 CPU(s):   96-111
NUMA node7 CPU(s):   112-127

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.1
[pip3] torchaudio==2.2.1
[pip3] torchmetrics==1.3.1
[pip3] torchnet==0.0.4
[pip3] torchvision==0.17.1
[conda] blas                      1.0                         mkl  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mkl-service               2.4.0           py312h5eee18b_1  
[conda] mkl_fft                   1.3.8           py312h5eee18b_0  
[conda] mkl_random                1.2.4           py312hdb19cb5_0  
[conda] numpy                     1.26.4          py312hc5e2394_0  
[conda] numpy-base                1.26.4          py312h0da6c21_0  
[conda] pytorch                   2.2.1           py3.12_cuda11.8_cudnn8.7.0_0    pytorch
[conda] pytorch-cuda              11.8                 h7e8668a_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.2.1               py312_cu118    pytorch
[conda] torchmetrics              1.3.1                    pypi_0    pypi
[conda] torchnet                  0.0.4                    pypi_0    pypi
[conda] torchvision               0.17.1              py312_cu118    pytorch

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

@tringwald tringwald added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 12, 2024
@tringwald
Copy link
Collaborator

Thank you for your bug report. I can reproduce the crash in a clean Python 3.12 environment.

@wconstab
Copy link
Contributor

@kurman is this bug specific to one rendezvous method? I'm not sure why that would be the case but if so I wonder if we are planning to keep this rendezvous method after cleaning up/consolidation work?

@kurman
Copy link
Contributor

kurman commented May 13, 2024

is this bug specific to one rendezvous method?

I believe @XilunWu was able to isolate to segfault in TCPStore: #116423. If so, could be a larger issue.

@wconstab
Copy link
Contributor

wonder if this issue can be reproduced when specifying USE_LIBUV=1 env?

@c-p-i-o
Copy link
Contributor

c-p-i-o commented May 13, 2024

wonder if this issue can be reproduced when specifying USE_LIBUV=1 env?

Issue still reproduces with USE_LIBUV=1. Same core.

USE_LIBUV=1 torchrun --standalone --nproc-per-node 4 random_name.py
OR
export USE_LIBUV=1 && torchrun --standalone --nproc-per-node 4 random_name.py
OR
(torch-3.12) [cpio@devvm17556.vll0 ~]$ env |grep LIBUV
USE_LIBUV=1
(torch-3.12) [cpio@devvm17556.vll0 ~]$ torchrun --standalone --nproc-per-node 4 random_name.py

W0513 14:57:46.757000 140207180518464 torch/distributed/run.py:757] *****************************************
Fatal Python error: Segmentation fault

Current thread 0x00007f8487309440 (most recent call first):
  File "/home/cpio/.conda/envs/torch-3.12/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py", line 113 in _call_store

@kurman
Copy link
Contributor

kurman commented May 13, 2024

Tried isolating Store type using single test and all of them are segfaulting:

pytest test/distributed/test_store.py -k "FileStoreTest and test_compare_set"
pytest test/distributed/test_store.py -k "HashStoreTest and test_compare_set"
pytest test/distributed/test_store.py -k "PrefixFileStoreTest and test_compare_set"
pytest test/distributed/test_store.py -k "TCPStoreTest and test_compare_set"
pytest test/distributed/test_store.py -k "LibUvTCPStoreTest and test_compare_set"
pytest test/distributed/test_store.py -k "PrefixTCPStoreTest and test_compare_set"

@kurman
Copy link
Contributor

kurman commented May 13, 2024

Basic repro on TCP store (both libuv and non-libuv):

import torch.distributed as dist
from datetime import timedelta
store = dist.TCPStore("localhost", 0, 1, True, timeout=timedelta(seconds=2))
store.compare_set('k', 'v1', 'v2')
Segmentation fault (core dumped)

GDB:

Thread 1 "pt_main_thread" received signal SIGSEGV, Segmentation fault.
0x00000000005042c9 in _PyInterpreterState_GET () at /usr/local/src/conda/python-3.12.0/Include/internal/pycore_pystate.h:118
118     /usr/local/src/conda/python-3.12.0/Include/internal/pycore_pystate.h: No such file or directory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority oncall: distributed Add this issue/PR to distributed oncall triage queue triage review
Projects
None yet
Development

No branches or pull requests

7 participants