From 0fc51f46c156793b23def2beeb446c1402f72483 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 17 Jun 2023 08:27:22 +0100 Subject: [PATCH] Implemented template extra params using -p name value --- docs/templates.md | 26 ++++++++++++++ llm/__init__.py | 36 +++++++++++++++++++ llm/cli.py | 28 +++++++++------ tests/test_templates.py | 80 ++++++++++++++++++++++++++++++++++++----- 4 files changed, 151 insertions(+), 19 deletions(-) diff --git a/docs/templates.md b/docs/templates.md index 2805826b..76ccdea7 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -90,6 +90,32 @@ system: You speak like an excitable Victorian adventurer prompt: 'Summarize this: $input' ``` +## Additional template variables + +Templates that work against the user's normal input (content that is either piped to the tool via standard input or passed as a command-line argument) use just the `$input` variable. + +You can use additional named variables. These will then need to be provided using the `-p/--param` option when executing the template. + +Here's an example template called `recipe`, created using `llm templates edit recipe`: + +```yaml +prompt: | + Suggest a recipe using ingredients: $ingredients + + It should be based on cuisine from this country: $country +``` +This can be executed like so: + +```bash +llm -t recipe -p ingredients 'sausages, milk' -p country Germany +``` +My output started like this: +> Recipe: German Sausage and Potato Soup +> +> Ingredients: +> - 4 German sausages +> - 2 cups whole milk + ## Setting a default model for a template Templates executed using `llm -t template-name` will execute using the default model that the user has configured for the tool - or `gpt-3.5-turbo` if they have not configured their own default. diff --git a/llm/__init__.py b/llm/__init__.py index bf0ebe9b..194d9cf4 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -1,4 +1,5 @@ from pydantic import BaseModel +import string from typing import Optional @@ -10,3 +11,38 @@ class Template(BaseModel): class Config: extra = "forbid" + + class MissingVariables(Exception): + pass + + def execute(self, input, params=None): + params = params or {} + params["input"] = input + if not self.prompt: + system = self.interpolate(self.system, params) + prompt = input + else: + prompt = self.interpolate(self.prompt, params) + system = self.interpolate(self.system, params) + return prompt, system + + @classmethod + def interpolate(cls, text, params): + if not text: + return text + # Confirm all variables in text are provided + string_template = string.Template(text) + vars = cls.extract_vars(string_template) + missing = [p for p in vars if p not in params] + if missing: + raise cls.MissingVariables( + "Missing variables: {}".format(", ".join(missing)) + ) + return string_template.substitute(**params) + + @staticmethod + def extract_vars(string_template): + return [ + match.group("named") + for match in string_template.pattern.finditer(string_template.template) + ] diff --git a/llm/cli.py b/llm/cli.py index 5c7df203..140bd6ff 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -40,6 +40,13 @@ def cli(): @click.option("--system", help="System prompt to use") @click.option("-m", "--model", help="Model to use") @click.option("-t", "--template", help="Template to use") +@click.option( + "-p", + "--param", + multiple=True, + type=(str, str), + help="Parameters for template", +) @click.option("--no-stream", is_flag=True, help="Do not stream output") @click.option("-n", "--no-log", is_flag=True, help="Don't log to database") @click.option( @@ -57,25 +64,24 @@ def cli(): type=int, ) @click.option("--key", help="API key to use") -def prompt(prompt, system, model, template, no_stream, no_log, _continue, chat_id, key): +def prompt( + prompt, system, model, template, param, no_stream, no_log, _continue, chat_id, key +): "Execute a prompt against on OpenAI model" - if prompt is None: + if prompt is None and not param: # Read from stdin instead prompt = sys.stdin.read() openai.api_key = get_key(key, "openai", "OPENAI_API_KEY") if template: + params = dict(param) # Cannot be used with system if system: - raise click.ClickException("Cannot use --template and --system together") + raise click.ClickException("Cannot use -t/--template and --system together") template_obj = load_template(template) - if not template_obj.prompt: - # It's a system prompt template - system = template_obj.system - else: - # Interpolate our existing prompt - input = prompt - prompt = StringTemplate(template_obj.prompt).substitute(input=input) - system = template_obj.system + try: + prompt, system = template_obj.execute(prompt, params) + except Template.MissingVariables as ex: + raise click.ClickException(str(ex)) if model is None and template_obj.model: model = template_obj.model messages = [] diff --git a/tests/test_templates.py b/tests/test_templates.py index 7bc344c3..e071e3c7 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,10 +1,35 @@ from click.testing import CliRunner +from llm import Template from llm.cli import cli import os from unittest import mock import pytest +@pytest.mark.parametrize( + "prompt,system,params,expected_prompt,expected_system,expected_error", + ( + ("S: $input", None, {}, "S: input", None, None), + ("S: $input", "system", {}, "S: input", "system", None), + ("No vars", None, {}, "No vars", None, None), + ("$one and $two", None, {}, None, None, "Missing variables: one, two"), + ("$one and $two", None, {"one": 1, "two": 2}, "1 and 2", None, None), + ), +) +def test_template_execute( + prompt, system, params, expected_prompt, expected_system, expected_error +): + t = Template(name="t", prompt=prompt, system=system) + if expected_error: + with pytest.raises(Template.MissingVariables) as ex: + prompt, system = t.execute("input", params) + assert ex.value.args[0] == expected_error + else: + prompt, system = t.execute("input", params) + assert prompt == expected_prompt + assert system == expected_system + + def test_templates_list(templates_path): (templates_path / "one.yaml").write_text("template one", "utf-8") (templates_path / "two.yaml").write_text("template two", "utf-8") @@ -23,25 +48,60 @@ def test_templates_list(templates_path): @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"}) @pytest.mark.parametrize( - "template,extra_args,expected_model,expected_input", + "template,extra_args,expected_model,expected_input,expected_error", ( - ("'Summarize this: $input'", [], "gpt-3.5-turbo", "Summarize this: Input text"), + ( + "'Summarize this: $input'", + [], + "gpt-3.5-turbo", + "Summarize this: Input text", + None, + ), ( "prompt: 'Summarize this: $input'\nmodel: gpt-4", [], "gpt-4", "Summarize this: Input text", + None, ), ( "prompt: 'Summarize this: $input'", ["-m", "4"], "gpt-4", "Summarize this: Input text", + None, + ), + ( + "boo", + ["--system", "s"], + None, + None, + "Error: Cannot use -t/--template and --system together", + ), + ( + "prompt: 'Say $hello'", + [], + None, + None, + "Error: Missing variables: hello", + ), + ( + "prompt: 'Say $hello'", + ["-p", "hello", "Blah"], + "gpt-3.5-turbo", + "Say Blah", + None, ), ), ) def test_template_basic( - templates_path, mocked_openai, template, extra_args, expected_model, expected_input + templates_path, + mocked_openai, + template, + extra_args, + expected_model, + expected_input, + expected_error, ): (templates_path / "template.yaml").write_text(template, "utf-8") runner = CliRunner() @@ -50,8 +110,12 @@ def test_template_basic( ["--no-stream", "-t", "template", "Input text"] + extra_args, catch_exceptions=False, ) - assert result.exit_code == 0 - assert mocked_openai.last_request.json() == { - "model": expected_model, - "messages": [{"role": "user", "content": expected_input}], - } + if expected_error is None: + assert result.exit_code == 0 + assert mocked_openai.last_request.json() == { + "model": expected_model, + "messages": [{"role": "user", "content": expected_input}], + } + else: + assert result.exit_code == 1 + assert result.output.strip() == expected_error