diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index 7bc64b444..70852ea17 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -17,11 +17,15 @@ Callable, Dict, Generator, + Generic, List, Mapping, Optional, + Protocol, TypedDict, + TypeVar, cast, + runtime_checkable, ) from langsmith import client, run_trees, utils @@ -223,6 +227,20 @@ def _setup_run( return response_container +R = TypeVar("R", covariant=True) + + +@runtime_checkable +class SupportsLangsmithExtra(Protocol, Generic[R]): + def __call__( + self, + *args: Any, + langsmith_extra: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> R: + ... + + def traceable( run_type: str = "chain", *, @@ -233,7 +251,7 @@ def traceable( client: Optional[client.Client] = None, extra: Optional[Dict] = None, reduce_fn: Optional[Callable] = None, -) -> Callable: +) -> Callable[[Callable[..., R]], SupportsLangsmithExtra[R]]: """Decorator for creating or adding a run to a run tree. Args: diff --git a/python/pyproject.toml b/python/pyproject.toml index 294a55db2..923b26902 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langsmith" -version = "0.0.77" +version = "0.0.78" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." authors = ["LangChain "] license = "MIT" diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index 39079c208..8e03b6ece 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -106,7 +106,7 @@ def test_upload_csv(mock_session_cls: mock.Mock) -> None: assert dataset.description == "Test dataset" -def test_async_methods(): +def test_async_methods() -> None: """For every method defined on the Client, if there is a corresponding async method, then the async method args should be a @@ -140,7 +140,7 @@ def test_async_methods(): ) -def test_get_api_key(): +def test_get_api_key() -> None: assert _get_api_key("provided_api_key") == "provided_api_key" assert _get_api_key("'provided_api_key'") == "provided_api_key" assert _get_api_key('"_provided_api_key"') == "_provided_api_key" @@ -155,7 +155,7 @@ def test_get_api_key(): assert _get_api_key(" ") is None -def test_get_api_url(): +def test_get_api_url() -> None: assert _get_api_url("http://provided.url", "api_key") == "http://provided.url" with patch.dict(os.environ, {"LANGCHAIN_ENDPOINT": "http://env.url"}): @@ -174,7 +174,7 @@ def test_get_api_url(): _get_api_url(" ", "api_key") -def test_create_run_unicode(): +def test_create_run_unicode() -> None: client = Client(api_url="http://localhost:1984", api_key="123") inputs = { "foo": "これは私の友達です", @@ -194,7 +194,7 @@ def test_create_run_unicode(): @pytest.mark.parametrize("source_type", ["api", "model"]) -def test_create_feedback_string_source_type(source_type: str): +def test_create_feedback_string_source_type(source_type: str) -> None: client = Client(api_url="http://localhost:1984", api_key="123") session = mock.Mock() request_object = mock.Mock() @@ -215,7 +215,7 @@ def test_create_feedback_string_source_type(source_type: str): ) -def test_pydantic_serialize(): +def test_pydantic_serialize() -> None: """Test that pydantic objects can be serialized.""" test_uuid = uuid.uuid4() test_time = datetime.now()