Skip to content

Commit

Permalink
Update @Traceable Generic Typing (#341)
Browse files Browse the repository at this point in the history
Previously, mypy would complain when using strict typing
  • Loading branch information
hinthornw authored Jan 8, 2024
1 parent 9c2b756 commit ed34564
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
20 changes: 19 additions & 1 deletion python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
Callable,
Dict,
Generator,
Generic,
List,
Mapping,
Optional,
Protocol,
TypedDict,
TypeVar,
cast,
runtime_checkable,
)

from langsmith import client, run_trees, utils
Expand Down Expand Up @@ -223,6 +227,20 @@ def _setup_run(
return response_container


R = TypeVar("R", covariant=True)


@runtime_checkable
class SupportsLangsmithExtra(Protocol, Generic[R]):
def __call__(
self,
*args: Any,
langsmith_extra: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> R:
...


def traceable(
run_type: str = "chain",
*,
Expand All @@ -233,7 +251,7 @@ def traceable(
client: Optional[client.Client] = None,
extra: Optional[Dict] = None,
reduce_fn: Optional[Callable] = None,
) -> Callable:
) -> Callable[[Callable[..., R]], SupportsLangsmithExtra[R]]:
"""Decorator for creating or adding a run to a run tree.
Args:
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.0.77"
version = "0.0.78"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <[email protected]>"]
license = "MIT"
Expand Down
12 changes: 6 additions & 6 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_upload_csv(mock_session_cls: mock.Mock) -> None:
assert dataset.description == "Test dataset"


def test_async_methods():
def test_async_methods() -> None:
"""For every method defined on the Client, if there is a
corresponding async method, then the async method args should be a
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_async_methods():
)


def test_get_api_key():
def test_get_api_key() -> None:
assert _get_api_key("provided_api_key") == "provided_api_key"
assert _get_api_key("'provided_api_key'") == "provided_api_key"
assert _get_api_key('"_provided_api_key"') == "_provided_api_key"
Expand All @@ -155,7 +155,7 @@ def test_get_api_key():
assert _get_api_key(" ") is None


def test_get_api_url():
def test_get_api_url() -> None:
assert _get_api_url("http://provided.url", "api_key") == "http://provided.url"

with patch.dict(os.environ, {"LANGCHAIN_ENDPOINT": "http://env.url"}):
Expand All @@ -174,7 +174,7 @@ def test_get_api_url():
_get_api_url(" ", "api_key")


def test_create_run_unicode():
def test_create_run_unicode() -> None:
client = Client(api_url="http://localhost:1984", api_key="123")
inputs = {
"foo": "これは私の友達です",
Expand All @@ -194,7 +194,7 @@ def test_create_run_unicode():


@pytest.mark.parametrize("source_type", ["api", "model"])
def test_create_feedback_string_source_type(source_type: str):
def test_create_feedback_string_source_type(source_type: str) -> None:
client = Client(api_url="http://localhost:1984", api_key="123")
session = mock.Mock()
request_object = mock.Mock()
Expand All @@ -215,7 +215,7 @@ def test_create_feedback_string_source_type(source_type: str):
)


def test_pydantic_serialize():
def test_pydantic_serialize() -> None:
"""Test that pydantic objects can be serialized."""
test_uuid = uuid.uuid4()
test_time = datetime.now()
Expand Down

0 comments on commit ed34564

Please sign in to comment.