Skip to content

Commit

Permalink
Add Runnable.with_listeners() (#12549)
Browse files Browse the repository at this point in the history
- This binds start/end/error listeners to a runnable, which will be
called with the Run object
  • Loading branch information
nfcampos authored Oct 31, 2023
1 parent bcc62d6 commit 2f563ce
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 15 deletions.
12 changes: 12 additions & 0 deletions libs/langchain/langchain/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,18 @@ def __init__(
self.parent_run_manager = parent_run_manager
self.ended = False

def copy(self) -> CallbackManagerForChainGroup:
return self.__class__(
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
parent_run_manager=self.parent_run_manager,
)

def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
"""Run when traced chain group ends.
Expand Down
46 changes: 46 additions & 0 deletions libs/langchain/langchain/callbacks/tracers/root_listeners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Callable, Optional
from uuid import UUID

from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run


class RootListenersTracer(BaseTracer):
def __init__(
self,
*,
on_start: Optional[Callable[[Run], None]],
on_end: Optional[Callable[[Run], None]],
on_error: Optional[Callable[[Run], None]]
) -> None:
super().__init__()

self._arg_on_start = on_start
self._arg_on_end = on_end
self._arg_on_error = on_error
self.root_id: Optional[UUID] = None

def _persist_run(self, run: Run) -> None:
# This is a legacy method only called once for an entire run tree
# therefore not useful here
pass

def _on_run_create(self, run: Run) -> None:
if self.root_id is not None:
return

self.root_id = run.id

if self._arg_on_start is not None:
self._arg_on_start(run)

def _on_run_update(self, run: Run) -> None:
if run.id != self.root_id:
return

if run.error is None:
if self._arg_on_end is not None:
self._arg_on_end(run)
else:
if self._arg_on_error is not None:
self._arg_on_error(run)
132 changes: 118 additions & 14 deletions libs/langchain/langchain/schema/runnable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
CallbackManagerForChainRun,
)
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.callbacks.tracers.schemas import Run
from langchain.schema.runnable.fallbacks import (
RunnableWithFallbacks as RunnableWithFallbacksT,
)
Expand Down Expand Up @@ -585,6 +586,39 @@ def with_config(
kwargs={},
)

def with_listeners(
self,
*,
on_start: Optional[Callable[[Run], None]] = None,
on_end: Optional[Callable[[Run], None]] = None,
on_error: Optional[Callable[[Run], None]] = None,
) -> Runnable[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the runnable starts running, with the Run object.
on_end: Called after the runnable finishes running, with the Run object.
on_error: Called if the runnable throws an error, with the Run object.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
from langchain.callbacks.tracers.root_listeners import RootListenersTracer

return RunnableBinding(
bound=self,
config_factories=[
lambda: {
"callbacks": [
RootListenersTracer(
on_start=on_start, on_end=on_end, on_error=on_error
)
],
}
],
)

def with_types(
self,
*,
Expand Down Expand Up @@ -2323,6 +2357,30 @@ def with_config(
) -> RunnableEach[Input, Output]:
return RunnableEach(bound=self.bound.with_config(config, **kwargs))

def with_listeners(
self,
*,
on_start: Optional[Callable[[Run], None]] = None,
on_end: Optional[Callable[[Run], None]] = None,
on_error: Optional[Callable[[Run], None]] = None,
) -> RunnableEach[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the runnable starts running, with the Run object.
on_end: Called after the runnable finishes running, with the Run object.
on_error: Called if the runnable throws an error, with the Run object.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
return RunnableEach(
bound=self.bound.with_listeners(
on_start=on_start, on_end=on_end, on_error=on_error
)
)

def _invoke(
self,
inputs: List[Input],
Expand Down Expand Up @@ -2363,10 +2421,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):

bound: Runnable[Input, Output]

kwargs: Mapping[str, Any]
kwargs: Mapping[str, Any] = Field(default_factory=dict)

config: RunnableConfig = Field(default_factory=dict)

config_factories: List[Callable[[], RunnableConfig]] = Field(default_factory=list)

# Union[Type[Input], BaseModel] + things like List[str]
custom_input_type: Optional[Any] = None
# Union[Type[Output], BaseModel] + things like List[str]
Expand All @@ -2379,8 +2439,9 @@ def __init__(
self,
*,
bound: Runnable[Input, Output],
kwargs: Mapping[str, Any],
kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[RunnableConfig] = None,
config_factories: Optional[List[Callable[[], RunnableConfig]]] = None,
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None,
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None,
**other_kwargs: Any,
Expand All @@ -2397,8 +2458,9 @@ def __init__(
)
super().__init__(
bound=bound,
kwargs=kwargs,
config=config,
kwargs=kwargs or {},
config=config or {},
config_factories=config_factories or [],
custom_input_type=custom_input_type,
custom_output_type=custom_output_type,
**other_kwargs,
Expand Down Expand Up @@ -2472,6 +2534,43 @@ def with_config(
custom_output_type=self.custom_output_type,
)

def with_listeners(
self,
*,
on_start: Optional[Callable[[Run], None]] = None,
on_end: Optional[Callable[[Run], None]] = None,
on_error: Optional[Callable[[Run], None]] = None,
) -> Runnable[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the runnable starts running, with the Run object.
on_end: Called after the runnable finishes running, with the Run object.
on_error: Called if the runnable throws an error, with the Run object.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
from langchain.callbacks.tracers.root_listeners import RootListenersTracer

return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=self.config,
config_factories=[
lambda: {
"callbacks": [
RootListenersTracer(
on_start=on_start, on_end=on_end, on_error=on_error
)
],
}
],
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)

def with_types(
self,
input_type: Optional[Union[Type[Input], BaseModel]] = None,
Expand All @@ -2496,6 +2595,11 @@ def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
config=self.config,
)

def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
return merge_configs(
self.config, *(f() for f in self.config_factories), *configs
)

def invoke(
self,
input: Input,
Expand All @@ -2504,7 +2608,7 @@ def invoke(
) -> Output:
return self.bound.invoke(
input,
merge_configs(self.config, config),
self._merge_configs(config),
**{**self.kwargs, **kwargs},
)

Expand All @@ -2516,7 +2620,7 @@ async def ainvoke(
) -> Output:
return await self.bound.ainvoke(
input,
merge_configs(self.config, config),
self._merge_configs(config),
**{**self.kwargs, **kwargs},
)

Expand All @@ -2531,10 +2635,10 @@ def batch(
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
[merge_configs(self.config, conf) for conf in config],
[self._merge_configs(conf) for conf in config],
)
else:
configs = [merge_configs(self.config, config) for _ in range(len(inputs))]
configs = [self._merge_configs(config) for _ in range(len(inputs))]
return self.bound.batch(
inputs,
configs,
Expand All @@ -2553,10 +2657,10 @@ async def abatch(
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
[merge_configs(self.config, conf) for conf in config],
[self._merge_configs(conf) for conf in config],
)
else:
configs = [merge_configs(self.config, config) for _ in range(len(inputs))]
configs = [self._merge_configs(config) for _ in range(len(inputs))]
return await self.bound.abatch(
inputs,
configs,
Expand All @@ -2572,7 +2676,7 @@ def stream(
) -> Iterator[Output]:
yield from self.bound.stream(
input,
merge_configs(self.config, config),
self._merge_configs(config),
**{**self.kwargs, **kwargs},
)

Expand All @@ -2584,7 +2688,7 @@ async def astream(
) -> AsyncIterator[Output]:
async for item in self.bound.astream(
input,
merge_configs(self.config, config),
self._merge_configs(config),
**{**self.kwargs, **kwargs},
):
yield item
Expand All @@ -2597,7 +2701,7 @@ def transform(
) -> Iterator[Output]:
yield from self.bound.transform(
input,
merge_configs(self.config, config),
self._merge_configs(config),
**{**self.kwargs, **kwargs},
)

Expand All @@ -2609,7 +2713,7 @@ async def atransform(
) -> AsyncIterator[Output]:
async for item in self.bound.atransform(
input,
merge_configs(self.config, config),
self._merge_configs(config),
**{**self.kwargs, **kwargs},
):
yield item
Expand Down
45 changes: 45 additions & 0 deletions libs/langchain/langchain/schema/runnable/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,51 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
**base.get(key, {}), # type: ignore
**(config.get(key) or {}), # type: ignore
}
elif key == "callbacks":
base_callbacks = base.get("callbacks")
these_callbacks = config["callbacks"]
# callbacks can be either None, list[handler] or manager
# so merging two callbacks values has 6 cases
if isinstance(these_callbacks, list):
if base_callbacks is None:
base["callbacks"] = these_callbacks
elif isinstance(base_callbacks, list):
base["callbacks"] = base_callbacks + these_callbacks
else:
# base_callbacks is a manager
mngr = base_callbacks.copy()
for callback in these_callbacks:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
elif these_callbacks is not None:
# these_callbacks is a manager
if base_callbacks is None:
base["callbacks"] = these_callbacks
elif isinstance(base_callbacks, list):
mngr = these_callbacks.copy()
for callback in base_callbacks:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
else:
# base_callbacks is also a manager
base["callbacks"] = base_callbacks.__class__(
parent_run_id=base_callbacks.parent_run_id
or these_callbacks.parent_run_id,
handlers=base_callbacks.handlers + these_callbacks.handlers,
inheritable_handlers=base_callbacks.inheritable_handlers
+ these_callbacks.inheritable_handlers,
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
inheritable_tags=list(
set(
base_callbacks.inheritable_tags
+ these_callbacks.inheritable_tags
)
),
metadata={
**base_callbacks.metadata,
**these_callbacks.metadata,
},
)
else:
base[key] = config[key] or base.get(key) # type: ignore
return base
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3748,6 +3748,7 @@
]
},
"config": {},
"config_factories": [],
"custom_input_type": null,
"custom_output_type": null
}
Expand Down
Loading

0 comments on commit 2f563ce

Please sign in to comment.