Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUG] Using a package with MKL and GPU versions, using python to open a new process will cause an error #14979

Closed
fierceX opened this issue May 17, 2019 · 51 comments · Fixed by #15762
Assignees

Comments

@fierceX
Copy link
Contributor

fierceX commented May 17, 2019

Hardware and version information:

----------Python Info----------
Version : 3.6.8
Compiler : GCC 7.3.0
Build : ('default', 'Dec 30 2018 01:22:34')
Arch : ('64bit', '')
------------Pip Info-----------
Version : 19.1.1
Directory : /home/bird/miniconda3/envs/test/lib/python3.6/site-packages/pip
----------MXNet Info-----------
Version : 1.4.1
Directory : /home/bird/miniconda3/envs/test/lib/python3.6/site-packages/mxnet
Hashtag not found. Not installed from pre-built package.
----------System Info----------
Platform : Linux-4.15.0-50-generic-x86_64-with-debian-buster-sid
system : Linux
node : ctmp
release : 4.15.0-50-generic
version : #54-Ubuntu SMP Mon May 6 18:46:08 UTC 2019
----------Hardware Info----------
machine : x86_64
processor : x86_64
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 8
On-line CPU(s) list: 0-7
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 94
Model name: Intel(R) Core(TM) i7-6700 CPU @ 3.40GHz
Stepping: 3
CPU MHz: 800.218
CPU max MHz: 4000.0000
CPU min MHz: 800.0000
BogoMIPS: 6816.00
Virtualization: VT-x
L1d cache: 32K
L1i cache: 32K
L2 cache: 256K
L3 cache: 8192K
NUMA node0 CPU(s): 0-7
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp md_clear flush_l1d

Python package version

Package        Version 
-------------- --------
certifi        2019.3.9
chardet        3.0.4   
gluonnlp       0.6.0   
graphviz       0.8.4   
idna           2.8     
mxnet-cu100mkl 1.4.1   
numpy          1.14.6  
pip            19.1.1  
requests       2.22.0  
setuptools     41.0.1  
urllib3        1.25.2  
wheel          0.33.4

In a GPU package with MKL, if you create a new process in Python and use multiple processes to load data at the same time, you will get an error.

from multiprocessing import Process
import gluonnlp as nlp
import numpy as np
from gluonnlp.data import SQuAD
from mxnet import nd,gluon
import mxnet as mx
from mxnet.gluon import nn

class Transform(object):
    def __init__(self):
        pass

    def __call__(self, record_index, question_id, question, context, answer_list,
                 answer_start_list):
        return np.ones((100,1)),np.ones((100,3))

def train():
    train_data = SQuAD('train')
    dataloader = gluon.data.DataLoader(train_data.transform(Transform()),batch_size=128, shuffle=True, num_workers=4)
    net = nn.HybridSequential()
    net.add(nn.Dense(10))
    net.initialize(mx.init.Xavier(), ctx=mx.gpu(0))
    print(net)

p = Process(target=train)
p.start()
p.join()
Segmentation fault: 11

Stack trace returned 10 entries:
[bt] (0) /home/bird/miniconda3/envs/test/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x3f935a) [0x7ff39d25735a]
[bt] (1) /home/bird/miniconda3/envs/test/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x3513b36) [0x7ff3a0371b36]
[bt] (2) /lib/x86_64-linux-gnu/libc.so.6(+0x3ef20) [0x7ff3e124ff20]
[bt] (3) /home/bird/miniconda3/envs/test/lib/python3.6/site-packages/mxnet/libiomp5.so(+0xa9ea5) [0x7ff3dce09ea5]
[bt] (4) /home/bird/miniconda3/envs/test/lib/python3.6/site-packages/mxnet/libiomp5.so(+0xa9ba4) [0x7ff3dce09ba4]
[bt] (5) /home/bird/miniconda3/envs/test/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2da4d13) [0x7ff39fc02d13]
[bt] (6) /home/bird/miniconda3/envs/test/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2db56c8) [0x7ff39fc136c8]
[bt] (7) /home/bird/miniconda3/envs/test/lib/python3.6/site-packages/mxnet/libmxnet.so(std::shared_ptr<mxnet::engine::ThreadedEnginePerDevice::ThreadWorkerBlock<(dmlc::ConcurrentQueueType)1> > mxnet::common::LazyAllocArray<mxnet::engine::ThreadedEnginePerDevice::ThreadWorkerBlock<(dmlc::ConcurrentQueueType)1> >::Get<mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#2}>(int, mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#2})+0x251) [0x7ff39fc18501]
[bt] (8) /home/bird/miniconda3/envs/test/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2dbd359) [0x7ff39fc1b359]
[bt] (9) /home/bird/miniconda3/envs/test/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2da9428) [0x7ff39fc07428]

If you change the mxnet version to mxnet-cu100-1.4.1, there will be no errors.
Similarly, mxnet-cu100mkl-1.5.0b20190516 will fail and mxnet-cu100-1.5.0b20190516 will not go wrong.

In addition
Using cpu
Remove the num_workers parameter
Do not create a new process
Nothing in any of the above three cases will go wrong

@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended labels: Installation

@TaoLv
Copy link
Member

TaoLv commented May 17, 2019

@fierceX I'm not sure and don't know why but can you try the magic below? ;)

export KMP_INIT_AT_FORK=false

@fierceX
Copy link
Contributor Author

fierceX commented May 17, 2019

@TaoLv Yes, this will not go wrong, but this should be a bug. What is the specific reason?

@vdantu
Copy link
Contributor

vdantu commented May 19, 2019

@mxnet-label-bot add [question, MKL]

@larroy
Copy link
Contributor

larroy commented May 24, 2019

Looks like an OpenMP related problem. Since the stack trace has libc on it I suspect we are re-entering MXNet on pthread_at_fork handlers due to Python multiprocessing interaction. Since you are using multiprocessing, this could be done above the python level to avoid this situation.

@larroy
Copy link
Contributor

larroy commented May 24, 2019

I would suggest to reproduce with debug symbols as the stack is not including the function names.

@larroy
Copy link
Contributor

larroy commented Jul 16, 2019

ping

@larroy
Copy link
Contributor

larroy commented Aug 1, 2019

I tried in GPU version, also no crash in debug mode.

In [2]: mx.runtime.Features()                                                      
Out[2]: [✔ CUDA, ✔ CUDNN, ✖ NCCL, ✔ CUDA_RTC, ✖ TENSORRT, ✔ CPU_SSE, ✔ CPU_SSE2, ✔ CPU_SSE3, ✔ CPU_SSE4_1, ✔ CPU_SSE4_2, ✖ CPU_SSE4A, ✔ CPU_AVX, ✖ CPU_AVX2, ✔ OPENMP, ✖ SSE, ✔ F16C, ✔ JEMALLOC, ✔ BLAS_OPEN, ✖ BLAS_ATLAS, ✖ BLAS_MKL, ✖ BLAS_APPLE, ✔ LAPACK, ✔ MKLDNN, ✔ OPENCV, ✖ CAFFE, ✖ PROFILER, ✖ DIST_KVSTORE, ✖ CXX14, ✖ INT64_TENSOR_SIZE, ✔ SIGNAL_HANDLER, ✔ DEBUG, ✖ TVM_OP]
              

Revision 9d7fc7c

@larroy
Copy link
Contributor

larroy commented Aug 1, 2019

I could reproduce with binary distribution cu101mkl

(py3_env_pip) piotr@ip-172-31-21-159:0: ~> python test.py
Downloading /home/piotr/.mxnet/datasets/squad/train-v1.1.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/squad/train-v1.1.zip...

Segmentation fault: 11

Stack trace:
  [bt] (0) /home/piotr/py3_env_pip/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2f9cf20) [0x7fd2107a7f20]
  [bt] (1) /lib/x86_64-linux-gnu/libc.so.6(+0x3ef20) [0x7fd25813cf20]
  [bt] (2) /home/piotr/py3_env_pip/lib/python3.6/site-packages/mxnet/libiomp5.so(+0xac19c) [0x7fd253c7119c]
  [bt] (3) /home/piotr/py3_env_pip/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2648ea3) [0x7fd20fe53ea3]
  [bt] (4) /home/piotr/py3_env_pip/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x265c798) [0x7fd20fe67798]
  [bt] (5) /home/piotr/py3_env_pip/lib/python3.6/site-packages/mxnet/libmxnet.so(std::shared_ptr<mxnet::engine::ThreadedEnginePerDevice::ThreadWorkerBlock<(dmlc::ConcurrentQueueType)1> > mxnet::common::LazyAllocArray<mxnet::engine::ThreadedEnginePerDevice::ThreadWorkerBlock<(dmlc::ConcurrentQueueType)1> >::Get<mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#2}>(int, mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#2})+0x241) [0x7fd20fe6d741]
  [bt] (6) /home/piotr/py3_env_pip/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x26650c4) [0x7fd20fe700c4]
  [bt] (7) /home/piotr/py3_env_pip/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x264dc6e) [0x7fd20fe58c6e]
  [bt] (8) /home/piotr/py3_env_pip/lib/python3.6/site-packages/mxnet/libmxnet.so(mxnet::CopyFromTo(mxnet::NDArray const&, mxnet::NDArray const&, int, bool)+0xa39) [0x7fd2100732c9]

@larroy
Copy link
Contributor

larroy commented Aug 1, 2019

Reproduced with a release CMake build
Also reproduced with a release Make build.

(py3_venv) piotr@ip-172-31-21-159:0: ~/mxnet [master]> python ~/test.py 

Segmentation fault: 11

Stack trace:
  [bt] (0) /home/piotr/mxnet/python/mxnet/../../build/libmxnet.so(+0x345d9d9) [0x7fc3e00189d9]
  [bt] (1) /lib/x86_64-linux-gnu/libc.so.6(+0x3ef20) [0x7fc4055e0f20]
  [bt] (2) /home/piotr/mxnet/build/3rdparty/openmp/runtime/src/libomp.so(+0x34250) [0x7fc3b863c250]
  [bt] (3) /home/piotr/mxnet/build/3rdparty/openmp/runtime/src/libomp.so(+0x34d3e) [0x7fc3b863cd3e]
  [bt] (4) /home/piotr/mxnet/python/mxnet/../../build/libmxnet.so(mxnet::engine::OpenMP::set_reserve_cores(int)+0x6d) [0x7fc3dff68d5d]
  [bt] (5) /home/piotr/mxnet/python/mxnet/../../build/libmxnet.so(mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#2}::operator()() const+0x4f) [0x7fc3dff79c0f]
  [bt] (6) /home/piotr/mxnet/python/mxnet/../../build/libmxnet.so(std::shared_ptr<mxnet::engine::ThreadedEnginePerDevice::ThreadWorkerBlock<(dmlc::ConcurrentQueueType)1> > mxnet::common::LazyAllocArray<mxnet::engine::ThreadedEnginePerDevice::ThreadWorkerBlock<(dmlc::ConcurrentQueueType)1> >::Get<mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#2}>(int, mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#2})+0x414) [0x7fc3dff7b0f4]
  [bt] (7) /home/piotr/mxnet/python/mxnet/../../build/libmxnet.so(mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)+0x481) [0x7fc3dff7c871]
  [bt] (8) /home/piotr/mxnet/python/mxnet/../../build/libmxnet.so(mxnet::engine::ThreadedEngine::Push(mxnet::engine::Opr*, mxnet::Context, int, bool)+0x1a8) [0x7fc3dff6d358]

Flags:

USE_CUDA: "ON" # Build with CUDA support
USE_OLDCMAKECUDA: "OFF" # Build with old cmake cuda
USE_NCCL: "OFF" # Use NVidia NCCL with CUDA
USE_OPENCV: "ON" # Build with OpenCV support
USE_OPENMP: "ON" # Build with Openmp support
USE_CUDNN: "ON" # Build with cudnn support) # one could set CUDNN_ROOT for search path
USE_SSE: "ON" # Build with x86 SSE instruction support IF NOT ARM
USE_F16C: "ON" # Build with x86 F16C instruction support) # autodetects support if "ON"
USE_LAPACK: "ON" # Build with lapack support
USE_MKL_IF_AVAILABLE: "ON" # Use MKL if found
USE_MKLML_MKL: "ON" # Use MKLDNN variant of MKL (if MKL found) IF USE_MKL_IF_AVAILABLE AND (NOT APPLE)
USE_MKLDNN: "ON" # Use MKLDNN variant of MKL (if MKL found) IF USE_MKL_IF_AVAILABLE AND (NOT APPLE)
USE_OPERATOR_TUNING: "ON" # Enable auto-tuning of operators IF NOT MSVC
USE_GPERFTOOLS: "ON" # Build with GPerfTools support (if found)
USE_JEMALLOC: "ON" # Build with Jemalloc support
USE_PROFILER: "ON" # Build with Profiler support
USE_DIST_KVSTORE: "OFF" # Build with DIST_KVSTORE support
USE_PLUGINS_WARPCTC: "OFF" # Use WARPCTC Plugins
USE_PLUGIN_CAFFE: "OFF" # Use Caffe Plugin
USE_CPP_PACKAGE: "OFF" # Build C++ Package
USE_MXNET_LIB_NAMING: "ON" # Use MXNet library naming conventions.
USE_GPROF: "OFF" # Compile with gprof (profiling) flag
USE_CXX14_IF_AVAILABLE: "OFF" # Build with C++14 if the compiler supports it
USE_VTUNE: "OFF" # Enable use of Intel Amplifier XE (VTune)) # one could set VTUNE_ROOT for search path
ENABLE_CUDA_RTC: "ON" # Build with CUDA runtime compilation support
BUILD_CPP_EXAMPLES: "ON" # Build cpp examples
INSTALL_EXAMPLES: "OFF" # Install the example source files.
USE_SIGNAL_HANDLER: "ON" # Print stack traces on segfaults.
USE_TENSORRT: "OFF" # Enable infeference optimization with TensorRT.
USE_ASAN: "OFF" # Enable Clang/GCC ASAN sanitizers.
ENABLE_TESTCOVERAGE: "OFF" # Enable compilation with test coverage metric output
CMAKE_BUILD_TYPE: "Release"
CMAKE_CUDA_COMPILER_LAUNCHER: "ccache"
CMAKE_C_COMPILER_LAUNCHER: "ccache"
CMAKE_CXX_COMPILER_LAUNCHER: "ccache"

@larroy
Copy link
Contributor

larroy commented Aug 1, 2019

Can't reproduce with Debug builds.

@larroy
Copy link
Contributor

larroy commented Aug 1, 2019

RelWithDebSymbols:

Assertion failure at kmp_runtime.cpp(6481): __kmp_team_pool == __null.
*** invalid %N$ use detected ***

No crash.

@larroy
Copy link
Contributor

larroy commented Aug 2, 2019

out.txt

@larroy
Copy link
Contributor

larroy commented Aug 2, 2019

@fierceX forking an initial process is not supported in MXNet. The first Process creation should not be done, as the state of the library after fork is inconsitent. The code in the train function is never executed.

With respect to the crash, after investigating this I believe is caused by calling setenv in pthread_at_fork. I will refactor this code so unsafe calls to setenv are not done during forking.

Additionally we can detect that we are in a forked state and emit additional errors in MXNet for example during the use of DataLoader.

larroy added a commit to larroy/mxnet that referenced this issue Aug 6, 2019
larroy added a commit to larroy/mxnet that referenced this issue Aug 6, 2019
larroy added a commit to larroy/mxnet that referenced this issue Aug 6, 2019
larroy added a commit to larroy/mxnet that referenced this issue Aug 6, 2019
larroy added a commit to larroy/mxnet that referenced this issue Aug 6, 2019
larroy added a commit to larroy/mxnet that referenced this issue Aug 6, 2019
larroy added a commit to larroy/mxnet that referenced this issue Aug 8, 2019
larroy added a commit to larroy/mxnet that referenced this issue Aug 8, 2019
zachgk pushed a commit that referenced this issue Aug 9, 2019
…cal concurrency crashes. (#15762)

* Refactor LibraryInitializer so it's thread safe.
Fixes #13438
Fixes #14979

* Refactor around lib loading

* Fix lint

* CR

* Add option to choose between OMP implementations

* Fix bug

* Fix from CR
anirudhacharya pushed a commit to anirudhacharya/mxnet that referenced this issue Aug 20, 2019
…cal concurrency crashes. (apache#15762)

* Refactor LibraryInitializer so it's thread safe.
Fixes apache#13438
Fixes apache#14979

* Refactor around lib loading

* Fix lint

* CR

* Add option to choose between OMP implementations

* Fix bug

* Fix from CR
@leezu leezu reopened this Dec 8, 2019
@leezu
Copy link
Contributor

leezu commented Dec 8, 2019

There are currently two hypotheses about the root cause of this error (#14979 (comment)): a) bug in llvm / intel openmp b) interaction between gomp and llvm / intel openmp.

I did some more investigation and conclude we can rule out option b. In particular, I compile CC=clang-8 CXX=clang++-8 cmake -DUSE_CUDA=1 -DUSE_MKLDNN=1 -DCMAKE_EXPORT_COMPILE_COMMANDS=1 -DBUILD_CYTHON_MODULES=1 -DUSE_OPENCV=0 ...

We can investigate the shared library dependencies of the resulting libmxnet.so:

% readelf -Wa libmxnet.so | grep NEEDED
 0x0000000000000001 (NEEDED)             Shared library: [libnvToolsExt.so.1]
 0x0000000000000001 (NEEDED)             Shared library: [libopenblas.so.0]
 0x0000000000000001 (NEEDED)             Shared library: [librt.so.1]
 0x0000000000000001 (NEEDED)             Shared library: [libjemalloc.so.1]
 0x0000000000000001 (NEEDED)             Shared library: [liblapack.so.3]
 0x0000000000000001 (NEEDED)             Shared library: [libcublas.so.10.0]
 0x0000000000000001 (NEEDED)             Shared library: [libcufft.so.10.0]
 0x0000000000000001 (NEEDED)             Shared library: [libcusolver.so.10.0]
 0x0000000000000001 (NEEDED)             Shared library: [libcurand.so.10.0]
 0x0000000000000001 (NEEDED)             Shared library: [libnvrtc.so.10.0]
 0x0000000000000001 (NEEDED)             Shared library: [libcuda.so.1]
 0x0000000000000001 (NEEDED)             Shared library: [libdl.so.2]
 0x0000000000000001 (NEEDED)             Shared library: [libpthread.so.0]
 0x0000000000000001 (NEEDED)             Shared library: [libomp.so.5]
 0x0000000000000001 (NEEDED)             Shared library: [libstdc++.so.6]
 0x0000000000000001 (NEEDED)             Shared library: [libm.so.6]
 0x0000000000000001 (NEEDED)             Shared library: [libgcc_s.so.1]
 0x0000000000000001 (NEEDED)             Shared library: [libc.so.6]
 0x0000000000000001 (NEEDED)             Shared library: [ld-linux-x86-64.so.2]

among those, libopenblas.so.0 is provided by the system and depends on libgomp.so. (If we would compile with OpenCV, OpenCV would also transitively depend on ligomp.so, so I just disable it for the purpose of this test). We can see it shows up among the transitive shared library dependencies:

% ldd libmxnet.so
        linux-vdso.so.1 (0x00007ffd382ca000)
        libnvToolsExt.so.1 => /usr/local/cuda/lib64/libnvToolsExt.so.1 (0x00007efdc9594000)
        libopenblas.so.0 => /usr/local/lib/libopenblas.so.0 (0x00007efdc85fb000)
        librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007efdc83f3000)
        libjemalloc.so.1 => /usr/lib/x86_64-linux-gnu/libjemalloc.so.1 (0x00007efdc81bd000)
        liblapack.so.3 => /usr/lib/x86_64-linux-gnu/liblapack.so.3 (0x00007efdc78fe000)
        libcublas.so.10.0 => /usr/local/cuda/lib64/libcublas.so.10.0 (0x00007efdc3368000)
        libcufft.so.10.0 => /usr/local/cuda/lib64/libcufft.so.10.0 (0x00007efdbceb4000)
        libcusolver.so.10.0 => /usr/local/cuda/lib64/libcusolver.so.10.0 (0x00007efdb47cd000)
        libcurand.so.10.0 => /usr/local/cuda/lib64/libcurand.so.10.0 (0x00007efdb0666000)
        libnvrtc.so.10.0 => /usr/local/cuda/lib64/libnvrtc.so.10.0 (0x00007efdaf04a000)
        libcuda.so.1 => /usr/lib/x86_64-linux-gnu/libcuda.so.1 (0x00007efdaded3000)
        libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007efdadccf000)
        libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007efdadab0000)
        libomp.so.5 => /usr/lib/x86_64-linux-gnu/libomp.so.5 (0x00007efe411b4000)
        libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007efdad727000)
        libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007efdad389000)
        libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007efdad171000)
        libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007efdacd80000)
        /lib64/ld-linux-x86-64.so.2 (0x00007efe410a8000)
        libgfortran.so.4 => /usr/lib/x86_64-linux-gnu/libgfortran.so.4 (0x00007efdac9a1000)
        libgomp.so.1 => /usr/lib/x86_64-linux-gnu/libgomp.so.1 (0x00007efdac772000)
        libblas.so.3 => /usr/lib/x86_64-linux-gnu/libblas.so.3 (0x00007efdac1b0000)
        libnvidia-fatbinaryloader.so.418.87.01 => /usr/lib/x86_64-linux-gnu/libnvidia-fatbinaryloader.so.418.87.01 (0x00007efdabf62000)
        libquadmath.so.0 => /usr/lib/x86_64-linux-gnu/libquadmath.so.0 (0x00007efdabd22000)

Thus I recompile OpenBLAS with clang. Then we can investigate the transitive dependencies while replacing the system OpenBLAS with the llvm-openmp based OpenBLAS:

% LD_PRELOAD=/home/ubuntu/src/OpenBLAS/libopenblas.so ldd libmxnet.so
        linux-vdso.so.1 (0x00007ffd8eac5000)
        /home/ubuntu/src/OpenBLAS/libopenblas.so (0x00007f06ee33a000)
        libnvToolsExt.so.1 => /usr/local/cuda/lib64/libnvToolsExt.so.1 (0x00007f06ee131000)
        librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007f06edf29000)
        libjemalloc.so.1 => /usr/lib/x86_64-linux-gnu/libjemalloc.so.1 (0x00007f06edcf3000)
        liblapack.so.3 => /usr/lib/x86_64-linux-gnu/liblapack.so.3 (0x00007f06ed434000)
        libcublas.so.10.0 => /usr/local/cuda/lib64/libcublas.so.10.0 (0x00007f06e8e9e000)
        libcufft.so.10.0 => /usr/local/cuda/lib64/libcufft.so.10.0 (0x00007f06e29ea000)
        libcusolver.so.10.0 => /usr/local/cuda/lib64/libcusolver.so.10.0 (0x00007f06da303000)
        libcurand.so.10.0 => /usr/local/cuda/lib64/libcurand.so.10.0 (0x00007f06d619c000)
        libnvrtc.so.10.0 => /usr/local/cuda/lib64/libnvrtc.so.10.0 (0x00007f06d4b80000)
        libcuda.so.1 => /usr/lib/x86_64-linux-gnu/libcuda.so.1 (0x00007f06d3a09000)
        libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007f06d3805000)
        libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f06d35e6000)
        libomp.so.5 => /usr/lib/x86_64-linux-gnu/libomp.so.5 (0x00007f0766c79000)
        libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007f06d325d000)
        libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007f06d2ebf000)
        libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007f06d2ca7000)
        libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f06d28b6000)
        /lib64/ld-linux-x86-64.so.2 (0x00007f0766b6d000)
        libgfortran.so.4 => /usr/lib/x86_64-linux-gnu/libgfortran.so.4 (0x00007f06d24d7000)
        libblas.so.3 => /usr/lib/x86_64-linux-gnu/libblas.so.3 (0x00007f06d1f15000)
        libnvidia-fatbinaryloader.so.418.87.01 => /usr/lib/x86_64-linux-gnu/libnvidia-fatbinaryloader.so.418.87.01 (0x00007f06d1cc7000)
        libquadmath.so.0 => /usr/lib/x86_64-linux-gnu/libquadmath.so.0 (0x00007f06d1a87000)

and you find that libmxnet.so doesn't depend on libgomp.so anymore.

So let's see if the test case by @fierceX still crashes:

LD_PRELOAD=/home/ubuntu/src/OpenBLAS/libopenblas.so python3 ~/test.py

Stack trace:
  [bt] (0) /home/ubuntu/src/mxnet/python/mxnet/../../build/libmxnet.so(+0x186faeb) [0x7f653ffcfaeb]
  [bt] (1) /lib/x86_64-linux-gnu/libc.so.6(+0x3ef20) [0x7f65cf785f20]
  [bt] (2) /usr/lib/x86_64-linux-gnu/libomp.so.5(+0x3d594) [0x7f65cd145594]
  [bt] (3) /home/ubuntu/src/mxnet/python/mxnet/../../build/libmxnet.so(mxnet::engine::OpenMP::set_reserve_cores(int)+0xf5) [0x7f653fed5255]
  [bt] (4) /home/ubuntu/src/mxnet/python/mxnet/../../build/libmxnet.so(mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#2}::operator()() const+0x42) [0x7f653fee8752]
  [bt] (5) /home/ubuntu/src/mxnet/python/mxnet/../../build/libmxnet.so(std::shared_ptr<mxnet::engine::ThreadedEnginePerDevice::ThreadWorkerBlock<(dmlc::ConcurrentQueueType)1> > mxnet::common::LazyAllocArray<mxnet::engine::ThreadedEnginePerDevice::ThreadWorkerBlock<(dmlc::ConcurrentQueueType)1> >::Get<mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#2}>(int, mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)::{lambda()#2})+0x487) [0x7f653fee5b87]
  [bt] (6) /home/ubuntu/src/mxnet/python/mxnet/../../build/libmxnet.so(mxnet::engine::ThreadedEnginePerDevice::PushToExecute(mxnet::engine::OprBlock*, bool)+0x223) [0x7f653fee12f3]
  [bt] (7) /home/ubuntu/src/mxnet/python/mxnet/../../build/libmxnet.so(mxnet::engine::ThreadedEngine::Push(mxnet::engine::Opr*, mxnet::Context, int, bool)+0x1dc) [0x7f653fed625c]
  [bt] (8) /home/ubuntu/src/mxnet/python/mxnet/../../build/libmxnet.so(mxnet::engine::ThreadedEngine::PushAsync(std::function<void (mxnet::RunContext, mxnet::engine::CallbackOnComplete)>, mxnet::Context, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, mxnet::FnProperty, int, char const*, bool)+0x212) [0x7f653fed64d2]

As the crash remains, we can conclude this is due to a bug in libomp.so, ie. the llvm openmp.

As @fierceX's use-case is common and important among the MXNet users, we can thus conclude that we must not default to llvm openmp until this issue is fixed.

On a sidenote, using forking in a multithreaded environment is according to the POSIX standard generally largely undefined (you're only allowed to exec afterwards). So it's not really a bug in llvm-openmp (as it's behavior is undefined). However, as it is an important use-case, and as it works with gomp, I suggest we just use gomp. You can also take a look at https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60035 for some more background.

@cjolivier01 please let me know if you see any issue with this investigation.

PS: To compile with clang, a small change to dmlc-core is required dmlc/dmlc-core@master...leezu:omp

@cjolivier01
Copy link
Member

what is the source file and line number of that crash in libmxnet.so? What’s the line of code crashing?

@TaoLv
Copy link
Member

TaoLv commented Dec 9, 2019

@leezu, not sure if the problem of LLVM OMP is the same one as described at the beginning of this issue. I simply took the original issue as a problem of iomp5 which has been removed from all the binary releases of MXNet. Hence the issue was closed.
If you have further interests, you can reproduce the original issue on a GPU with the code snippet and package provided in the description. And then please try to replace the libiomp5.so under the installation folder with the one released along with Intel compiler 2019.0 update5. I expect the problem can be addressed with the new version of iomp5.

@TaoLv
Copy link
Member

TaoLv commented Dec 9, 2019

The library can be found at: https://software.intel.com/en-us/parallel-studio-xe/choose-download

@leezu
Copy link
Contributor

leezu commented Dec 9, 2019

Is the iomp source code based on llvm? What version of llvm omp would the 2019.0 update correspond to? Or is the source different? I'll try it.

@cjolivier01
Copy link
Member

which line? the stack trace you listed says libmxnet.so at stack level 0 rather than libomp.so, so it wouldn’t be in the omp calls here, correct, is the CHECK_GE failing?

@leezu
Copy link
Contributor

leezu commented Dec 9, 2019

The function referred to above is the bt [3] in the backtrace. CHECK_GE is not failing, because bt [2] is part of libomp.so. But I'm not sure if it is due to the omp_set_num_threads(1); or omp_set_num_threads(omp_thread_max_ - reserve_cores_); call. addr2line just prints ?? if I try to look up the address of bt [3]

libmxnet.so is on bt [0] because it has a segfault handler:

https://github.com/apache/incubator-mxnet/blob/e48ff96ef4375f3d7c505e152b73b1f15a8b7afe/src/initialize.cc#L62-L68

Specifically Line 65.

@cjolivier01
Copy link
Member

thanks now the stack trace makes sense. maybe libomp isn’t built with -O0 in debug mode — assuming this is a debug build.

@cjolivier01
Copy link
Member

this problem is gone with the upgrade?

@cjolivier01 cjolivier01 self-assigned this Dec 9, 2019
@cjolivier01
Copy link
Member

cjolivier01 commented Dec 9, 2019

If it’s gone with the upgrade, then fine.

However, if it’s not, and since it also happened with official dist of libiomp5 (and if still happening, also official llvm dist) then considering that llvm omp is in HUGE distribution globally being part of clang and all, then it seems pretty unlikely to me that it’s a bug in the openmp causing this. Especially since I wrote most of this omp-related stuff in mxnet that is in that stack trace, and I definitely didn’t test it specifically with forking — it wasn’t a use case at the time. in fact, at the time i wrote that, it was known that trying to use omp at all (with libgomp specifically) would hang if attempted in a forked process (there’s an issue+PR in there somewhere fixing the issue by avoiding using a kernel that used omp i seem to recall — it was a long time ago and before llvm openmp was added — it was noted that it didn’t happen in mkl build which used libiomp5 instead). generally this wasn’t a problem because the OMP_NUM_THREADS gets set to 1 in the atfork call by the engine code.

However, if mxnet is loaded after the fork, then that environment variable was never set because the engine code never ran to hook the fork call before then. I think it’s possible there’s a bug in the (my) mxnet omp code since this wasn’t a use-case considered. This would mean it would likely still occur with clang builds (assuming it’s not intermittent and hard to reproduce).

@cjolivier01
Copy link
Member

regarding the libgomp hang i noted above, apparently it’s a known issue with libgomp and forking that I am surprised to see still occurring today. He lists gcc 8.

pytorch/pytorch#17199

@leezu
Copy link
Contributor

leezu commented Dec 9, 2019

Yes, this crash still happens after the upgrade of llvm openmp. It also happens both when compiling with gcc or compiling with llvm.

The only case where the crash does not happen is when compiling with gcc and libgomp instead of the libomp.

The gcc hang you refer to above, is it https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60035 ? I'm aware of that bug report and therefore quite surprised that the crash doesn't occur with gcc libgomp, but rather in all other settings..

@cjolivier01
Copy link
Member

cjolivier01 commented Dec 9, 2019

i will see if i can reproduce this week

they link to the gcc issue in that pytorch issue link. it’s not that one you linked which is a pull request of sorts. (correction: i thought it was the PR one while looking on my phone but now on my PC I see it isn't). Yes, seems it's that issue as well.

it’s weird because i saw that behavior maybe three years ago using gcc 3.x ish, I think, so I assumed that libgomp had been corrected to handle fork properly since then. I am surprised that the same behavior is being reproduced in such a new gcc version. i want to try to reprocess that issue this week as well.

llvm openmp as you can see from the comments in the pytorch is known to handle forking correctly.

Another thing that pytorch issue mentions is cuda after a fork. while it’s reasonable to assume it’s illegal to use cuda and then fork and also use cuda in the forked process as well, i wonder if it works ok if you fork before using cuda for the first time as in this issue.

@TaoLv
Copy link
Member

TaoLv commented Dec 10, 2019

Is the iomp source code based on llvm? What version of llvm omp would the 2019.0 update correspond to? Or is the source different? I'll try it.

It should be different code base.

@cjolivier01
Copy link
Member

cjolivier01 commented Dec 10, 2019

  • I root-caused the the crash mentioned directly above (not the assert). I did this in an actual debugger :), but the answer is actually in the stack trace:
[bt] (2) /usr/lib/x86_64-linux-gnu/libomp.so.5(+0x3d594) [0x7f65cd145594]

This is loading the wrong omp library from the one that it was just built against. That library comes with (on Ubuntu) libomp5.deb package. The proper one would be in cmake-build-debug/3rdparty/openmp at build time. Why it is loading that other one I did not track down, because the problem went away when I linked to the proper library in the libmxnet.so's dir. I also uninstalled the libomp5 package on my machine in the course of testing. It might be getting pulled in because the cython compile uses a different "toolchain" (which may or may not map back to the same compiler, which on my machine is just blindly running x86_64-linux-gnu-gcc in the path). Even if this is not the cause, it should be looked at because with more than one toolchain on a lot of dev boxes these days, this is a recipe for trouble. Since the cython library has libmxnet as a dependency, then it is conceiable that in some use-cases, it gets first stab at loading whatever shared object it wants, and so if not using the same toolchain, this could get pretty nasty (i.e. imagine libmxnet.so is forced, at load time, to link against libstdc++ from gcc 3.6 when mxnet was compiled with gcc 8). I know they have version tags in the symbols, but you get the idea, right? This should be looked into, imho.

btw this is why the location for the omp stack trace was ??? -- no debug info for "/usr/lib/x86_64-linux-gnu/libomp.so.5".

At any rate, there's a number of ways to resolve this just as one would resolve the wrong opencv library being loaded -- it's not rocket science :)

Summary: No evidence found suggesting tha this is a libomp5/libomp bug (the "upgrade" wasn't actually necessary, but doesn't hurt anything, so good to leave it in).

  • The reason libgomp doesn't hang (@leezu 's query) is because the environment gets set to one OMP thread at fork time (atfork is hooked at limxnet.so's static init), so the forked process never tries to use OMP.

@cjolivier01
Copy link
Member

cjolivier01 commented Dec 10, 2019

By the way, I don't care if it's used or not, but on the MXNet cython branch, I did some cython stuff that, in the cmake files, uses the mxnet toolchain to build the cython library. That's one approach, there are other approaches -- there's pros and cons for each.

@leezu
Copy link
Contributor

leezu commented Dec 10, 2019

@cjolivier01 thank you for looking into this. I notice that the crash also happens when using the system llvm openmp at compile time (ie delete 3rdparty/openmp before build). I describe that in #14979 (comment)
Thus it seems the cause you mention isn't the root cause?

BTW, the update of the 3rdparty/openmp is about fixing the debug assertion #10856. It indeed doesn't claim anything about the current issue.

@cjolivier01
Copy link
Member

@cjolivier01 thank you for looking into this. I notice that the crash also happens when using the system llvm openmp at compile time (ie delete 3rdparty/openmp before build). I describe that in #14979 (comment)
Thus it seems the cause you mention isn't the root cause?

BTW, the update of the 3rdparty/openmp is about fixing the debug assertion #10856. It indeed doesn't claim anything about the current issue.

Assert issue is fixed in referenced PR above.

The cython setup script apparently uses forking which is causing the problem during compile.

@cjolivier01
Copy link
Member

@cjolivier01 thank you for looking into this. I notice that the crash also happens when using the system llvm openmp at compile time (ie delete 3rdparty/openmp before build). I describe that in #14979 (comment)
Thus it seems the cause you mention isn't the root cause?

BTW, the update of the 3rdparty/openmp is about fixing the debug assertion #10856. It indeed doesn't claim anything about the current issue.

You can see in the newer version, they set in the atfork handler:

__kmp_atfork_child() 
{
...
__kmp_team_pool = NULL;
...
}

This is why the assert goes away, but the assert remains harmless even in the old version.

@leezu
Copy link
Contributor

leezu commented Dec 11, 2019

Thanks for looking into it. Even when harmless, it's annoying when using debug build. Thus it's good to make it go away.

@leezu
Copy link
Contributor

leezu commented Dec 11, 2019

The cython setup script apparently uses forking which is causing the problem during compile.

When using system llvm openmp instead of 3rdparty/openmp, why does the crash reported in the current issue still happen if it is due to a linking problem? I'm not clear about your reasoning here.

@leezu
Copy link
Contributor

leezu commented Dec 11, 2019

Fixed by #17039 Thanks Chris.

@leezu leezu closed this as completed Dec 11, 2019
@larroy
Copy link
Contributor

larroy commented Feb 4, 2020

This seems fixed now.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
10 participants