diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index 7d69a3f1ba3c4..021e7b1879a84 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -1496,6 +1496,18 @@ def __init__( self.parent_run_manager = parent_run_manager self.ended = False + def copy(self) -> CallbackManagerForChainGroup: + return self.__class__( + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + parent_run_manager=self.parent_run_manager, + ) + def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: """Run when traced chain group ends. diff --git a/libs/langchain/langchain/callbacks/tracers/root_listeners.py b/libs/langchain/langchain/callbacks/tracers/root_listeners.py new file mode 100644 index 0000000000000..be837de1baad4 --- /dev/null +++ b/libs/langchain/langchain/callbacks/tracers/root_listeners.py @@ -0,0 +1,46 @@ +from typing import Callable, Optional +from uuid import UUID + +from langchain.callbacks.tracers.base import BaseTracer +from langchain.callbacks.tracers.schemas import Run + + +class RootListenersTracer(BaseTracer): + def __init__( + self, + *, + on_start: Optional[Callable[[Run], None]], + on_end: Optional[Callable[[Run], None]], + on_error: Optional[Callable[[Run], None]] + ) -> None: + super().__init__() + + self._arg_on_start = on_start + self._arg_on_end = on_end + self._arg_on_error = on_error + self.root_id: Optional[UUID] = None + + def _persist_run(self, run: Run) -> None: + # This is a legacy method only called once for an entire run tree + # therefore not useful here + pass + + def _on_run_create(self, run: Run) -> None: + if self.root_id is not None: + return + + self.root_id = run.id + + if self._arg_on_start is not None: + self._arg_on_start(run) + + def _on_run_update(self, run: Run) -> None: + if run.id != self.root_id: + return + + if run.error is None: + if self._arg_on_end is not None: + self._arg_on_end(run) + else: + if self._arg_on_error is not None: + self._arg_on_error(run) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 025e0030e9942..e1abeef5e9dab 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -37,6 +37,7 @@ CallbackManagerForChainRun, ) from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch + from langchain.callbacks.tracers.schemas import Run from langchain.schema.runnable.fallbacks import ( RunnableWithFallbacks as RunnableWithFallbacksT, ) @@ -585,6 +586,39 @@ def with_config( kwargs={}, ) + def with_listeners( + self, + *, + on_start: Optional[Callable[[Run], None]] = None, + on_end: Optional[Callable[[Run], None]] = None, + on_error: Optional[Callable[[Run], None]] = None, + ) -> Runnable[Input, Output]: + """ + Bind lifecycle listeners to a Runnable, returning a new Runnable. + + on_start: Called before the runnable starts running, with the Run object. + on_end: Called after the runnable finishes running, with the Run object. + on_error: Called if the runnable throws an error, with the Run object. + + The Run object contains information about the run, including its id, + type, input, output, error, start_time, end_time, and any tags or metadata + added to the run. + """ + from langchain.callbacks.tracers.root_listeners import RootListenersTracer + + return RunnableBinding( + bound=self, + config_factories=[ + lambda: { + "callbacks": [ + RootListenersTracer( + on_start=on_start, on_end=on_end, on_error=on_error + ) + ], + } + ], + ) + def with_types( self, *, @@ -2323,6 +2357,30 @@ def with_config( ) -> RunnableEach[Input, Output]: return RunnableEach(bound=self.bound.with_config(config, **kwargs)) + def with_listeners( + self, + *, + on_start: Optional[Callable[[Run], None]] = None, + on_end: Optional[Callable[[Run], None]] = None, + on_error: Optional[Callable[[Run], None]] = None, + ) -> RunnableEach[Input, Output]: + """ + Bind lifecycle listeners to a Runnable, returning a new Runnable. + + on_start: Called before the runnable starts running, with the Run object. + on_end: Called after the runnable finishes running, with the Run object. + on_error: Called if the runnable throws an error, with the Run object. + + The Run object contains information about the run, including its id, + type, input, output, error, start_time, end_time, and any tags or metadata + added to the run. + """ + return RunnableEach( + bound=self.bound.with_listeners( + on_start=on_start, on_end=on_end, on_error=on_error + ) + ) + def _invoke( self, inputs: List[Input], @@ -2363,10 +2421,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]): bound: Runnable[Input, Output] - kwargs: Mapping[str, Any] + kwargs: Mapping[str, Any] = Field(default_factory=dict) config: RunnableConfig = Field(default_factory=dict) + config_factories: List[Callable[[], RunnableConfig]] = Field(default_factory=list) + # Union[Type[Input], BaseModel] + things like List[str] custom_input_type: Optional[Any] = None # Union[Type[Output], BaseModel] + things like List[str] @@ -2379,8 +2439,9 @@ def __init__( self, *, bound: Runnable[Input, Output], - kwargs: Mapping[str, Any], + kwargs: Optional[Mapping[str, Any]] = None, config: Optional[RunnableConfig] = None, + config_factories: Optional[List[Callable[[], RunnableConfig]]] = None, custom_input_type: Optional[Union[Type[Input], BaseModel]] = None, custom_output_type: Optional[Union[Type[Output], BaseModel]] = None, **other_kwargs: Any, @@ -2397,8 +2458,9 @@ def __init__( ) super().__init__( bound=bound, - kwargs=kwargs, - config=config, + kwargs=kwargs or {}, + config=config or {}, + config_factories=config_factories or [], custom_input_type=custom_input_type, custom_output_type=custom_output_type, **other_kwargs, @@ -2472,6 +2534,43 @@ def with_config( custom_output_type=self.custom_output_type, ) + def with_listeners( + self, + *, + on_start: Optional[Callable[[Run], None]] = None, + on_end: Optional[Callable[[Run], None]] = None, + on_error: Optional[Callable[[Run], None]] = None, + ) -> Runnable[Input, Output]: + """ + Bind lifecycle listeners to a Runnable, returning a new Runnable. + + on_start: Called before the runnable starts running, with the Run object. + on_end: Called after the runnable finishes running, with the Run object. + on_error: Called if the runnable throws an error, with the Run object. + + The Run object contains information about the run, including its id, + type, input, output, error, start_time, end_time, and any tags or metadata + added to the run. + """ + from langchain.callbacks.tracers.root_listeners import RootListenersTracer + + return self.__class__( + bound=self.bound, + kwargs=self.kwargs, + config=self.config, + config_factories=[ + lambda: { + "callbacks": [ + RootListenersTracer( + on_start=on_start, on_end=on_end, on_error=on_error + ) + ], + } + ], + custom_input_type=self.custom_input_type, + custom_output_type=self.custom_output_type, + ) + def with_types( self, input_type: Optional[Union[Type[Input], BaseModel]] = None, @@ -2496,6 +2595,11 @@ def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]: config=self.config, ) + def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: + return merge_configs( + self.config, *(f() for f in self.config_factories), *configs + ) + def invoke( self, input: Input, @@ -2504,7 +2608,7 @@ def invoke( ) -> Output: return self.bound.invoke( input, - merge_configs(self.config, config), + self._merge_configs(config), **{**self.kwargs, **kwargs}, ) @@ -2516,7 +2620,7 @@ async def ainvoke( ) -> Output: return await self.bound.ainvoke( input, - merge_configs(self.config, config), + self._merge_configs(config), **{**self.kwargs, **kwargs}, ) @@ -2531,10 +2635,10 @@ def batch( if isinstance(config, list): configs = cast( List[RunnableConfig], - [merge_configs(self.config, conf) for conf in config], + [self._merge_configs(conf) for conf in config], ) else: - configs = [merge_configs(self.config, config) for _ in range(len(inputs))] + configs = [self._merge_configs(config) for _ in range(len(inputs))] return self.bound.batch( inputs, configs, @@ -2553,10 +2657,10 @@ async def abatch( if isinstance(config, list): configs = cast( List[RunnableConfig], - [merge_configs(self.config, conf) for conf in config], + [self._merge_configs(conf) for conf in config], ) else: - configs = [merge_configs(self.config, config) for _ in range(len(inputs))] + configs = [self._merge_configs(config) for _ in range(len(inputs))] return await self.bound.abatch( inputs, configs, @@ -2572,7 +2676,7 @@ def stream( ) -> Iterator[Output]: yield from self.bound.stream( input, - merge_configs(self.config, config), + self._merge_configs(config), **{**self.kwargs, **kwargs}, ) @@ -2584,7 +2688,7 @@ async def astream( ) -> AsyncIterator[Output]: async for item in self.bound.astream( input, - merge_configs(self.config, config), + self._merge_configs(config), **{**self.kwargs, **kwargs}, ): yield item @@ -2597,7 +2701,7 @@ def transform( ) -> Iterator[Output]: yield from self.bound.transform( input, - merge_configs(self.config, config), + self._merge_configs(config), **{**self.kwargs, **kwargs}, ) @@ -2609,7 +2713,7 @@ async def atransform( ) -> AsyncIterator[Output]: async for item in self.bound.atransform( input, - merge_configs(self.config, config), + self._merge_configs(config), **{**self.kwargs, **kwargs}, ): yield item diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index bded08f7248cc..869c413ebf5c9 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -220,6 +220,51 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: **base.get(key, {}), # type: ignore **(config.get(key) or {}), # type: ignore } + elif key == "callbacks": + base_callbacks = base.get("callbacks") + these_callbacks = config["callbacks"] + # callbacks can be either None, list[handler] or manager + # so merging two callbacks values has 6 cases + if isinstance(these_callbacks, list): + if base_callbacks is None: + base["callbacks"] = these_callbacks + elif isinstance(base_callbacks, list): + base["callbacks"] = base_callbacks + these_callbacks + else: + # base_callbacks is a manager + mngr = base_callbacks.copy() + for callback in these_callbacks: + mngr.add_handler(callback, inherit=True) + base["callbacks"] = mngr + elif these_callbacks is not None: + # these_callbacks is a manager + if base_callbacks is None: + base["callbacks"] = these_callbacks + elif isinstance(base_callbacks, list): + mngr = these_callbacks.copy() + for callback in base_callbacks: + mngr.add_handler(callback, inherit=True) + base["callbacks"] = mngr + else: + # base_callbacks is also a manager + base["callbacks"] = base_callbacks.__class__( + parent_run_id=base_callbacks.parent_run_id + or these_callbacks.parent_run_id, + handlers=base_callbacks.handlers + these_callbacks.handlers, + inheritable_handlers=base_callbacks.inheritable_handlers + + these_callbacks.inheritable_handlers, + tags=list(set(base_callbacks.tags + these_callbacks.tags)), + inheritable_tags=list( + set( + base_callbacks.inheritable_tags + + these_callbacks.inheritable_tags + ) + ), + metadata={ + **base_callbacks.metadata, + **these_callbacks.metadata, + }, + ) else: base[key] = config[key] or base.get(key) # type: ignore return base diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index 6780814829cd5..2cfe9d5cfef46 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -3748,6 +3748,7 @@ ] }, "config": {}, + "config_factories": [], "custom_input_type": null, "custom_output_type": null } diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_config.py b/libs/langchain/tests/unit_tests/schema/runnable/test_config.py new file mode 100644 index 0000000000000..410fee0610856 --- /dev/null +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_config.py @@ -0,0 +1,34 @@ +from langchain.callbacks.manager import CallbackManager +from langchain.callbacks.stdout import StdOutCallbackHandler +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler +from langchain.schema.runnable.config import RunnableConfig, merge_configs + + +def test_merge_config_callbacks() -> None: + manager: RunnableConfig = { + "callbacks": CallbackManager(handlers=[StdOutCallbackHandler()]) + } + handlers: RunnableConfig = {"callbacks": [ConsoleCallbackHandler()]} + other_handlers: RunnableConfig = {"callbacks": [StreamingStdOutCallbackHandler()]} + + merged = merge_configs(manager, handlers)["callbacks"] + + assert isinstance(merged, CallbackManager) + assert len(merged.handlers) == 2 + assert isinstance(merged.handlers[0], StdOutCallbackHandler) + assert isinstance(merged.handlers[1], ConsoleCallbackHandler) + + merged = merge_configs(handlers, manager)["callbacks"] + + assert isinstance(merged, CallbackManager) + assert len(merged.handlers) == 2 + assert isinstance(merged.handlers[0], StdOutCallbackHandler) + assert isinstance(merged.handlers[1], ConsoleCallbackHandler) + + merged = merge_configs(handlers, other_handlers)["callbacks"] + + assert isinstance(merged, list) + assert len(merged) == 2 + assert isinstance(merged[0], ConsoleCallbackHandler) + assert isinstance(merged[1], StreamingStdOutCallbackHandler) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index a0fc7e8823cf7..065af5b6056a4 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -19,7 +19,12 @@ from pytest_mock import MockerFixture from syrupy import SnapshotAssertion -from langchain.callbacks.manager import Callbacks, atrace_as_chain_group, collect_runs +from langchain.callbacks.manager import ( + Callbacks, + atrace_as_chain_group, + collect_runs, + trace_as_chain_group, +) from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch from langchain.callbacks.tracers.schemas import Run @@ -1495,6 +1500,39 @@ def test_prompt_template_params() -> None: prompt.invoke({}) +def test_with_listeners(mocker: MockerFixture) -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + chat = FakeListChatModel(responses=["foo"]) + + chain = prompt | chat + + mock_start = mocker.Mock() + mock_end = mocker.Mock() + + chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke( + {"question": "Who are you?"} + ) + + assert mock_start.call_count == 1 + assert mock_start.call_args[0][0].name == "RunnableSequence" + assert mock_end.call_count == 1 + + mock_start.reset_mock() + mock_end.reset_mock() + + with trace_as_chain_group("hello") as manager: + chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke( + {"question": "Who are you?"}, {"callbacks": manager} + ) + + assert mock_start.call_count == 1 + assert mock_start.call_args[0][0].name == "RunnableSequence" + assert mock_end.call_count == 1 + + @pytest.mark.asyncio @freeze_time("2023-01-01") async def test_prompt_with_chat_model(