Skip to content

Commit

Permalink
added Literal type to endpoint. Made typing of type_of_llm, type_of_c…
Browse files Browse the repository at this point in the history
…hat_llm better so that mypy does not complain anymore.
  • Loading branch information
Frawa Vetterli committed Jan 24, 2024
1 parent 37a3c80 commit 202da59
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 30 deletions.
57 changes: 28 additions & 29 deletions lakera_chainguard/lakera_chainguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,37 @@

import os
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal, TypeVar

import requests
from langchain.agents import AgentExecutor
from langchain.schema import BaseMessage, PromptValue
from langchain.tools import BaseTool
from langchain_core.agents import AgentAction, AgentFinish, AgentStep
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.agents import AgentStep
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
CallbackManagerForChainRun,
)
from langchain.schema.agent import AgentFinish, AgentAction
from langchain_core.language_models import BaseChatModel, BaseLLM
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatResult, LLMResult

# from langchain.callbacks.manager import CallbackManagerForChainRun

GuardInput = Union[str, List[BaseMessage], PromptValue]
NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]]
GuardChatMessages = list[dict[str, str]]
Endpoints = Literal[
"prompt_injection",
"moderation",
"pii",
"relevant_language",
"sentiment",
"unknown_links",
]
BaseLLMT = TypeVar("BaseLLMT", bound=BaseLLM)
BaseChatModelT = TypeVar("BaseChatModelT", bound=BaseChatModel)


class LakeraGuardError(RuntimeError):
Expand Down Expand Up @@ -56,7 +72,7 @@ class LakeraChainGuard:
def __init__(
self,
api_key: str = "",
endpoint: str = "prompt_injection",
endpoint: Endpoints = "prompt_injection",
additional_json_properties: dict = dict(),
raise_error: bool = True,
) -> None:
Expand Down Expand Up @@ -132,21 +148,6 @@ def _call_lakera_guard(self, query: Union[str, GuardChatMessages]) -> dict:
)
else:
raise ValueError(str(response_body))
if "code" in response_body:
errormessage = str(response_body)
if self.endpoint not in {
"prompt_injection",
"moderation",
"pii",
"relevant_language",
"sentiment",
"unknown_links",
}:
errormessage += (
f" Provided endpoint {self.endpoint} is not supported "
"by Lakera Guard."
)
raise ValueError(errormessage)
if "results" not in response_body:
raise ValueError(str(response_body))

Expand Down Expand Up @@ -240,7 +241,7 @@ def detect_with_response(self, prompt: GuardInput) -> dict:

return lakera_guard_response

def get_guarded_llm(self, type_of_llm: Type[BaseLLM]) -> Type[BaseLLM]:
def get_guarded_llm(self, type_of_llm: Type[BaseLLMT]) -> Type[BaseLLMT]:
"""
Creates a subclass of type_of_llm where the input to the LLM always gets
checked w.r.t. AI security risk specified in self.endpoint.
Expand All @@ -252,7 +253,7 @@ def get_guarded_llm(self, type_of_llm: Type[BaseLLM]) -> Type[BaseLLM]:
"""
lakera_guard_instance = self

class GuardedLLM(type_of_llm):
class GuardedLLM(type_of_llm): # type: ignore
@property
def _llm_type(self) -> str:
return "guarded_" + super()._llm_type
Expand All @@ -270,8 +271,8 @@ def _generate(
return GuardedLLM

def get_guarded_chat_llm(
self, type_of_chat_llm: Type[BaseChatModel]
) -> Type[BaseChatModel]:
self, type_of_chat_llm: Type[BaseChatModelT]
) -> Type[BaseChatModelT]:
"""
Creates a subclass of type_of_chat_llm in which the input to the ChatLLM always
gets checked w.r.t. AI security risk specified in self.endpoint.
Expand All @@ -283,7 +284,7 @@ def get_guarded_chat_llm(
"""
lakera_guard_instance = self

class GuardedChatLLM(type_of_chat_llm):
class GuardedChatLLM(type_of_chat_llm): # type: ignore
@property
def _llm_type(self) -> str:
return "guarded_" + super()._llm_type
Expand Down Expand Up @@ -318,9 +319,8 @@ def _take_next_step(
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
*args,
**kwargs,
):
run_manager: CallbackManagerForChainRun | None = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
for val in inputs.values():
lakera_guard_instance.detect(val)

Expand All @@ -329,8 +329,7 @@ def _take_next_step(
color_mapping,
inputs,
intermediate_steps,
*args,
**kwargs,
run_manager,
)

for act in intermediate_steps:
Expand Down
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ python = "^3.11"
requests = "^2.31.0"
langchain = "^0.0.354"
langchain-core = "^0.1.5"
types-requests = "^2.31.0.20240106"


[tool.poetry.group.dev.dependencies]
Expand Down

0 comments on commit 202da59

Please sign in to comment.