Skip to content

Commit

Permalink
Improve Runnable type inference for input_schemas (langchain-ai#12630)
Browse files Browse the repository at this point in the history
- Prefer lambda type annotations over inferred dict schema
- For sequences that start with RunnableAssign infer seq input type as
"input type of 2nd item in sequence - output type of runnable assign"
  • Loading branch information
nfcampos authored and xieqihui committed Nov 21, 2023
1 parent 9856eb9 commit d0aa980
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 0 deletions.
20 changes: 20 additions & 0 deletions libs/langchain/langchain/schema/runnable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,23 @@ def OutputType(self) -> Type[Output]:
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
from langchain.schema.runnable.passthrough import RunnableAssign

if isinstance(self.first, RunnableAssign):
first = cast(RunnableAssign, self.first)
next_ = self.middle[0] if self.middle else self.last
next_input_schema = next_.get_input_schema(config)
if not next_input_schema.__custom_root_type__:
# it's a dict as expected
return create_model( # type: ignore[call-overload]
"RunnableSequenceInput",
**{
k: (v.annotation, v.default)
for k, v in next_input_schema.__fields__.items()
if k not in first.mapper.steps
},
)

return self.first.get_input_schema(config)

def get_output_schema(
Expand Down Expand Up @@ -2152,6 +2169,9 @@ def get_input_schema(
else:
return create_model("RunnableLambdaInput", __root__=(List[Any], None))

if self.InputType != Any:
return super().get_input_schema(config)

if dict_keys := get_function_first_arg_dict_keys(func):
return create_model(
"RunnableLambdaInput",
Expand Down
4 changes: 4 additions & 0 deletions libs/langchain/langchain/schema/runnable/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ def get_output_schema(
for k, v in s.__fields__.items()
},
)
elif not map_output_schema.__custom_root_type__:
# ie. only map output is a dict
# ie. input type is either unknown or inferred incorrectly
return map_output_schema

return super().get_output_schema(config)

Expand Down
85 changes: 85 additions & 0 deletions libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from freezegun import freeze_time
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from typing_extensions import TypedDict

from langchain.callbacks.manager import (
Callbacks,
Expand Down Expand Up @@ -508,6 +509,41 @@ async def typed_async_lambda_impl(x: str) -> int:
}


def test_passthrough_assign_schema() -> None:
retriever = FakeRetriever() # str -> List[Document]
prompt = PromptTemplate.from_template("{context} {question}")
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]

seq_w_assign: Runnable = (
RunnablePassthrough.assign(context=itemgetter("question") | retriever)
| prompt
| fake_llm
)

assert seq_w_assign.input_schema.schema() == {
"properties": {"question": {"title": "Question", "type": "string"}},
"title": "RunnableSequenceInput",
"type": "object",
}
assert seq_w_assign.output_schema.schema() == {
"title": "FakeListLLMOutput",
"type": "string",
}

invalid_seq_w_assign: Runnable = (
RunnablePassthrough.assign(context=itemgetter("question") | retriever)
| fake_llm
)

# fallback to RunnableAssign.input_schema if next runnable doesn't have
# expected dict input_schema
assert invalid_seq_w_assign.input_schema.schema() == {
"properties": {"question": {"title": "Question"}},
"title": "RunnableParallelInput",
"type": "object",
}


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
Expand Down Expand Up @@ -565,6 +601,55 @@ async def aget_values(input): # type: ignore[no-untyped-def]
},
}

class InputType(TypedDict):
variable_name: str
yo: int

class OutputType(TypedDict):
hello: str
bye: str
byebye: int

async def aget_values_typed(input: InputType) -> OutputType:
return {
"hello": input["variable_name"],
"bye": input["variable_name"],
"byebye": input["yo"],
}

assert RunnableLambda(aget_values_typed).input_schema.schema() == { # type: ignore[arg-type]
"title": "RunnableLambdaInput",
"$ref": "#/definitions/InputType",
"definitions": {
"InputType": {
"properties": {
"variable_name": {"title": "Variable " "Name", "type": "string"},
"yo": {"title": "Yo", "type": "integer"},
},
"required": ["variable_name", "yo"],
"title": "InputType",
"type": "object",
}
},
}

assert RunnableLambda(aget_values_typed).output_schema.schema() == { # type: ignore[arg-type]
"title": "RunnableLambdaOutput",
"$ref": "#/definitions/OutputType",
"definitions": {
"OutputType": {
"properties": {
"bye": {"title": "Bye", "type": "string"},
"byebye": {"title": "Byebye", "type": "integer"},
"hello": {"title": "Hello", "type": "string"},
},
"required": ["hello", "bye", "byebye"],
"title": "OutputType",
"type": "object",
}
},
}


def test_with_types_with_type_generics() -> None:
"""Verify that with_types works if we use things like List[int]"""
Expand Down

0 comments on commit d0aa980

Please sign in to comment.