From 402075d438e765f7aeddd8696dc7304545e79bc9 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Wed, 3 Jul 2024 21:13:21 +0300 Subject: [PATCH] [FRONTEND] fix conflicting multithreading and fork management (#4169) This PR disables multithreading in MLIR context after compilation ends. This is done to safely finalize thread pool implemented in MLIRContext. Not properly finalized thread pool can hang or crash in child process after fork. --------- Co-authored-by: Lei Zhang --- python/src/ir.cc | 10 +++-- python/test/unit/runtime/test_subproc.py | 51 ++++++++++++++++++++++++ python/triton/compiler/compiler.py | 5 +++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/python/src/ir.cc b/python/src/ir.cc index 906438f7c5bc..7cf3e33a3180 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -207,9 +207,13 @@ void init_triton_ir(py::module &&m) { .def(py::init<>()) .def("printOpOnDiagnostic", [](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); }) - .def("printStackTraceOnDiagnostic", [](MLIRContext &self, bool v) { - self.printStackTraceOnDiagnostic(v); - }); + .def("printStackTraceOnDiagnostic", + [](MLIRContext &self, bool v) { + self.printStackTraceOnDiagnostic(v); + }) + .def("disable_multithreading", + [](MLIRContext &self) { self.disableMultithreading(); }); + py::class_(m, "source_mgr_diag", py::module_local()) .def(py::init()); diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 333d1f929126..7d0bb7498798 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -71,3 +71,54 @@ def test_compile_in_forked_subproc() -> None: proc.start() proc.join() assert proc.exitcode == 0 + + +def compile_empty_kernel_with_gc(attrs): + + @triton.jit + def empty_kernel(): + pass + + import gc + gc.collect() + src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants=dict()) + triton.compile(src=src, target=target) + + +def test_compile_in_forked_subproc_with_forced_gc() -> None: + ''' + Tests that compilation artifacts can safely live in forked process. + + Scenario being tested here ("p" stands for parent process, "c" is child process): + 1. p compiles a kernel 1, and produces compilation artifacts. + 2. p forks the process to create c. + 3. c deletes compilation artifacts inherited from p, compiles kernel 2, and terminates. + 3. p wait for c and join it. + + This is a regression test that ensures thread pool in MLIRContext is released + safely after compilation. + ''' + reset_tmp_dir() + import gc + old_gc_state = gc.isenabled() + # disable GC to manage resources manually in the manner described in comment above + gc.disable() + + # stage 1.p + config = triton.compiler.AttrsDescriptor(tuple(range(1)), ()) + compile_empty_kernel_with_gc(config) + + # stage 2.p + reset_tmp_dir() + assert multiprocessing.get_start_method() == 'fork' + proc = multiprocessing.Process(target=compile_empty_kernel_with_gc, args=(config, )) + + # stage 3.c + proc.start() + # stage 3.p + proc.join() + + # restore gc state + if old_gc_state: + gc.enable() + assert proc.exitcode == 0 diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 32bc955826c9..6d3d478aed96 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -303,6 +303,11 @@ def compile(src, target=None, options=None): metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) + # Compilation completed, disabling multithreading in context. + # This is needed to safely finalize threads pool inside context: if current process forks before + # python GC deletes context object, thread pool in child process will be invalid, which could + # lead to child crash or hang. + context.disable_multithreading() # return handle to compiled kernel return CompiledKernel(src, metadata_group, hash)