diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index 53b362f57983..3a404902a0f7 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -18,6 +18,7 @@ with the distributed runtime. """ +import logging import os import pickle @@ -397,7 +398,19 @@ def _configure_structlog(self) -> None: except ImportError: return - config = pickle.dumps(structlog.get_config()) + root_logger = logging.getLogger() + if len(root_logger.handlers) == 1 and isinstance( + root_logger.handlers[0].formatter, structlog.stdlib.ProcessorFormatter + ): + stdlib_formatter = root_logger.handlers[0].formatter + else: + stdlib_formatter = None + + stdlib_level = root_logger.level + + full_config = (structlog.get_config(), stdlib_formatter, stdlib_level) + + config = pickle.dumps(full_config) func = self.get_global_func("runtime.disco._configure_structlog") func(config, os.getpid()) @@ -423,8 +436,18 @@ def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None: import structlog # pylint: disable=import-outside-toplevel - config = pickle.loads(pickled_config) - structlog.configure(**config) + full_config = pickle.loads(pickled_config) + structlog_config, stdlib_formatter, stdlib_level = full_config + + root_logger = logging.getLogger() + + root_logger.setLevel(stdlib_level) + if stdlib_formatter is not None: + handler = logging.StreamHandler() + handler.setFormatter(stdlib_formatter) + root_logger.addHandler(handler) + + structlog.configure(**structlog_config) @register_func("runtime.disco._import_python_module")