diff --git a/lakera_chainguard/lakera_chainguard.py b/lakera_chainguard/lakera_chainguard.py index 446af05..c9acbbc 100644 --- a/lakera_chainguard/lakera_chainguard.py +++ b/lakera_chainguard/lakera_chainguard.py @@ -171,26 +171,25 @@ def _convert_to_lakera_guard_input( if isinstance(prompt, PromptValue): prompt = prompt.to_messages() if isinstance(prompt, List): - formatted_input = [ - {"role": "system", "content": ""}, - {"role": "user", "content": ""}, - {"role": "assistant", "content": ""}, - ] - # For system, human, assistant, we put the last message of each - # type in the guard input + user_message = "" + formatted_input = [] for message in prompt: if not isinstance( message, (HumanMessage, SystemMessage, AIMessage) ) or not isinstance(message.content, str): raise TypeError("Input type not supported by Lakera Guard.") + + role = "assistant" if isinstance(message, SystemMessage): - formatted_input[0]["content"] = message.content + role = "system" elif isinstance(message, HumanMessage): - formatted_input[1]["content"] = message.content - else: # must be AIMessage - formatted_input[2]["content"] = message.content + user_message = message.content + role = "user" + + formatted_input.append({"role": role, "content": message.content}) + if self.endpoint != "prompt_injection": - return formatted_input[1]["content"] + return user_message return formatted_input else: return str(prompt) diff --git a/pyproject.toml b/pyproject.toml index 15c8ce1..00944f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "lakera-chainguard" -version = "0.1.4" +version = "0.1.5" description = "Guard your LangChain applications against prompt injection with Lakera Guard" homepage = "https://lakeraai.github.io/chainguard/" repository = "https://github.com/lakeraai/chainguard"