Skip to content

Commit

Permalink
[FRONTEND] fix conflicting multithreading and fork management (triton…
Browse files Browse the repository at this point in the history
…-lang#4169)

This PR disables multithreading in MLIR context after compilation ends.
This is done to safely finalize thread pool implemented in MLIRContext.
Not properly finalized thread pool can hang or crash in child process
after fork.

---------

Co-authored-by: Lei Zhang <[email protected]>
  • Loading branch information
2 people authored and bertmaher committed Dec 4, 2024
1 parent 01a897a commit 402075d
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 3 deletions.
10 changes: 7 additions & 3 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,13 @@ void init_triton_ir(py::module &&m) {
.def(py::init<>())
.def("printOpOnDiagnostic",
[](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); })
.def("printStackTraceOnDiagnostic", [](MLIRContext &self, bool v) {
self.printStackTraceOnDiagnostic(v);
});
.def("printStackTraceOnDiagnostic",
[](MLIRContext &self, bool v) {
self.printStackTraceOnDiagnostic(v);
})
.def("disable_multithreading",
[](MLIRContext &self) { self.disableMultithreading(); });

py::class_<SourceMgrDiagnosticHandler>(m, "source_mgr_diag",
py::module_local())
.def(py::init<llvm::SourceMgr &, MLIRContext *>());
Expand Down
51 changes: 51 additions & 0 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,54 @@ def test_compile_in_forked_subproc() -> None:
proc.start()
proc.join()
assert proc.exitcode == 0


def compile_empty_kernel_with_gc(attrs):

@triton.jit
def empty_kernel():
pass

import gc
gc.collect()
src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants=dict())
triton.compile(src=src, target=target)


def test_compile_in_forked_subproc_with_forced_gc() -> None:
'''
Tests that compilation artifacts can safely live in forked process.
Scenario being tested here ("p" stands for parent process, "c" is child process):
1. p compiles a kernel 1, and produces compilation artifacts.
2. p forks the process to create c.
3. c deletes compilation artifacts inherited from p, compiles kernel 2, and terminates.
3. p wait for c and join it.
This is a regression test that ensures thread pool in MLIRContext is released
safely after compilation.
'''
reset_tmp_dir()
import gc
old_gc_state = gc.isenabled()
# disable GC to manage resources manually in the manner described in comment above
gc.disable()

# stage 1.p
config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())
compile_empty_kernel_with_gc(config)

# stage 2.p
reset_tmp_dir()
assert multiprocessing.get_start_method() == 'fork'
proc = multiprocessing.Process(target=compile_empty_kernel_with_gc, args=(config, ))

# stage 3.c
proc.start()
# stage 3.p
proc.join()

# restore gc state
if old_gc_state:
gc.enable()
assert proc.exitcode == 0
5 changes: 5 additions & 0 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ def compile(src, target=None, options=None):
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
binary=False)
fn_cache_manager.put_group(metadata_filename, metadata_group)
# Compilation completed, disabling multithreading in context.
# This is needed to safely finalize threads pool inside context: if current process forks before
# python GC deletes context object, thread pool in child process will be invalid, which could
# lead to child crash or hang.
context.disable_multithreading()
# return handle to compiled kernel
return CompiledKernel(src, metadata_group, hash)

Expand Down

0 comments on commit 402075d

Please sign in to comment.