From 202da59dae00efaed0db353a9f3d6017ad1d7a28 Mon Sep 17 00:00:00 2001 From: Frawa Vetterli Date: Wed, 24 Jan 2024 10:31:20 +0100 Subject: [PATCH] added Literal type to endpoint. Made typing of type_of_llm, type_of_chat_llm better so that mypy does not complain anymore. --- lakera_chainguard/lakera_chainguard.py | 57 +++++++++++++------------- poetry.lock | 16 +++++++- pyproject.toml | 1 + 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/lakera_chainguard/lakera_chainguard.py b/lakera_chainguard/lakera_chainguard.py index 065ca4d..446af05 100644 --- a/lakera_chainguard/lakera_chainguard.py +++ b/lakera_chainguard/lakera_chainguard.py @@ -2,21 +2,37 @@ import os import warnings -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal, TypeVar import requests from langchain.agents import AgentExecutor from langchain.schema import BaseMessage, PromptValue from langchain.tools import BaseTool -from langchain_core.agents import AgentAction, AgentFinish, AgentStep -from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.agents import AgentStep +from langchain.callbacks.manager import ( + CallbackManagerForLLMRun, + CallbackManagerForChainRun, +) +from langchain.schema.agent import AgentFinish, AgentAction from langchain_core.language_models import BaseChatModel, BaseLLM from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.outputs import ChatResult, LLMResult +# from langchain.callbacks.manager import CallbackManagerForChainRun + GuardInput = Union[str, List[BaseMessage], PromptValue] NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]] GuardChatMessages = list[dict[str, str]] +Endpoints = Literal[ + "prompt_injection", + "moderation", + "pii", + "relevant_language", + "sentiment", + "unknown_links", +] +BaseLLMT = TypeVar("BaseLLMT", bound=BaseLLM) +BaseChatModelT = TypeVar("BaseChatModelT", bound=BaseChatModel) class LakeraGuardError(RuntimeError): @@ -56,7 +72,7 @@ class LakeraChainGuard: def __init__( self, api_key: str = "", - endpoint: str = "prompt_injection", + endpoint: Endpoints = "prompt_injection", additional_json_properties: dict = dict(), raise_error: bool = True, ) -> None: @@ -132,21 +148,6 @@ def _call_lakera_guard(self, query: Union[str, GuardChatMessages]) -> dict: ) else: raise ValueError(str(response_body)) - if "code" in response_body: - errormessage = str(response_body) - if self.endpoint not in { - "prompt_injection", - "moderation", - "pii", - "relevant_language", - "sentiment", - "unknown_links", - }: - errormessage += ( - f" Provided endpoint {self.endpoint} is not supported " - "by Lakera Guard." - ) - raise ValueError(errormessage) if "results" not in response_body: raise ValueError(str(response_body)) @@ -240,7 +241,7 @@ def detect_with_response(self, prompt: GuardInput) -> dict: return lakera_guard_response - def get_guarded_llm(self, type_of_llm: Type[BaseLLM]) -> Type[BaseLLM]: + def get_guarded_llm(self, type_of_llm: Type[BaseLLMT]) -> Type[BaseLLMT]: """ Creates a subclass of type_of_llm where the input to the LLM always gets checked w.r.t. AI security risk specified in self.endpoint. @@ -252,7 +253,7 @@ def get_guarded_llm(self, type_of_llm: Type[BaseLLM]) -> Type[BaseLLM]: """ lakera_guard_instance = self - class GuardedLLM(type_of_llm): + class GuardedLLM(type_of_llm): # type: ignore @property def _llm_type(self) -> str: return "guarded_" + super()._llm_type @@ -270,8 +271,8 @@ def _generate( return GuardedLLM def get_guarded_chat_llm( - self, type_of_chat_llm: Type[BaseChatModel] - ) -> Type[BaseChatModel]: + self, type_of_chat_llm: Type[BaseChatModelT] + ) -> Type[BaseChatModelT]: """ Creates a subclass of type_of_chat_llm in which the input to the ChatLLM always gets checked w.r.t. AI security risk specified in self.endpoint. @@ -283,7 +284,7 @@ def get_guarded_chat_llm( """ lakera_guard_instance = self - class GuardedChatLLM(type_of_chat_llm): + class GuardedChatLLM(type_of_chat_llm): # type: ignore @property def _llm_type(self) -> str: return "guarded_" + super()._llm_type @@ -318,9 +319,8 @@ def _take_next_step( color_mapping: Dict[str, str], inputs: Dict[str, str], intermediate_steps: List[Tuple[AgentAction, str]], - *args, - **kwargs, - ): + run_manager: CallbackManagerForChainRun | None = None, + ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: for val in inputs.values(): lakera_guard_instance.detect(val) @@ -329,8 +329,7 @@ def _take_next_step( color_mapping, inputs, intermediate_steps, - *args, - **kwargs, + run_manager, ) for act in intermediate_steps: diff --git a/poetry.lock b/poetry.lock index e1087db..f48e20e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2618,6 +2618,20 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "types-requests" +version = "2.31.0.20240106" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.31.0.20240106.tar.gz", hash = "sha256:0e1c731c17f33618ec58e022b614a1a2ecc25f7dc86800b36ef341380402c612"}, + {file = "types_requests-2.31.0.20240106-py3-none-any.whl", hash = "sha256:da997b3b6a72cc08d09f4dba9802fdbabc89104b35fe24ee588e674037689354"}, +] + +[package.dependencies] +urllib3 = ">=2" + [[package]] name = "typing-extensions" version = "4.9.0" @@ -2836,4 +2850,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "58bd50182189cfb83f5fec52d8fb2fd46e0c4670cd73c865647b3952c6e647b9" +content-hash = "ef1dc5b39966aa62a360014b10fcf45a5ae4ad22f3eb489e644162faf1cc9cd2" diff --git a/pyproject.toml b/pyproject.toml index fa0d9e4..9f77426 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ python = "^3.11" requests = "^2.31.0" langchain = "^0.0.354" langchain-core = "^0.1.5" +types-requests = "^2.31.0.20240106" [tool.poetry.group.dev.dependencies]