Skip to content

Commit

Permalink
Implemented template extra params using -p name value
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Jun 17, 2023
1 parent 8cb02fe commit 0fc51f4
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 19 deletions.
26 changes: 26 additions & 0 deletions docs/templates.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,32 @@ system: You speak like an excitable Victorian adventurer
prompt: 'Summarize this: $input'
```

## Additional template variables

Templates that work against the user's normal input (content that is either piped to the tool via standard input or passed as a command-line argument) use just the `$input` variable.

You can use additional named variables. These will then need to be provided using the `-p/--param` option when executing the template.

Here's an example template called `recipe`, created using `llm templates edit recipe`:

```yaml
prompt: |
Suggest a recipe using ingredients: $ingredients
It should be based on cuisine from this country: $country
```
This can be executed like so:

```bash
llm -t recipe -p ingredients 'sausages, milk' -p country Germany
```
My output started like this:
> Recipe: German Sausage and Potato Soup
>
> Ingredients:
> - 4 German sausages
> - 2 cups whole milk

## Setting a default model for a template

Templates executed using `llm -t template-name` will execute using the default model that the user has configured for the tool - or `gpt-3.5-turbo` if they have not configured their own default.
Expand Down
36 changes: 36 additions & 0 deletions llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseModel
import string
from typing import Optional


Expand All @@ -10,3 +11,38 @@ class Template(BaseModel):

class Config:
extra = "forbid"

class MissingVariables(Exception):
pass

def execute(self, input, params=None):
params = params or {}
params["input"] = input
if not self.prompt:
system = self.interpolate(self.system, params)
prompt = input
else:
prompt = self.interpolate(self.prompt, params)
system = self.interpolate(self.system, params)
return prompt, system

@classmethod
def interpolate(cls, text, params):
if not text:
return text
# Confirm all variables in text are provided
string_template = string.Template(text)
vars = cls.extract_vars(string_template)
missing = [p for p in vars if p not in params]
if missing:
raise cls.MissingVariables(
"Missing variables: {}".format(", ".join(missing))
)
return string_template.substitute(**params)

@staticmethod
def extract_vars(string_template):
return [
match.group("named")
for match in string_template.pattern.finditer(string_template.template)
]
28 changes: 17 additions & 11 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def cli():
@click.option("--system", help="System prompt to use")
@click.option("-m", "--model", help="Model to use")
@click.option("-t", "--template", help="Template to use")
@click.option(
"-p",
"--param",
multiple=True,
type=(str, str),
help="Parameters for template",
)
@click.option("--no-stream", is_flag=True, help="Do not stream output")
@click.option("-n", "--no-log", is_flag=True, help="Don't log to database")
@click.option(
Expand All @@ -57,25 +64,24 @@ def cli():
type=int,
)
@click.option("--key", help="API key to use")
def prompt(prompt, system, model, template, no_stream, no_log, _continue, chat_id, key):
def prompt(
prompt, system, model, template, param, no_stream, no_log, _continue, chat_id, key
):
"Execute a prompt against on OpenAI model"
if prompt is None:
if prompt is None and not param:
# Read from stdin instead
prompt = sys.stdin.read()
openai.api_key = get_key(key, "openai", "OPENAI_API_KEY")
if template:
params = dict(param)
# Cannot be used with system
if system:
raise click.ClickException("Cannot use --template and --system together")
raise click.ClickException("Cannot use -t/--template and --system together")
template_obj = load_template(template)
if not template_obj.prompt:
# It's a system prompt template
system = template_obj.system
else:
# Interpolate our existing prompt
input = prompt
prompt = StringTemplate(template_obj.prompt).substitute(input=input)
system = template_obj.system
try:
prompt, system = template_obj.execute(prompt, params)
except Template.MissingVariables as ex:
raise click.ClickException(str(ex))
if model is None and template_obj.model:
model = template_obj.model
messages = []
Expand Down
80 changes: 72 additions & 8 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
from click.testing import CliRunner
from llm import Template
from llm.cli import cli
import os
from unittest import mock
import pytest


@pytest.mark.parametrize(
"prompt,system,params,expected_prompt,expected_system,expected_error",
(
("S: $input", None, {}, "S: input", None, None),
("S: $input", "system", {}, "S: input", "system", None),
("No vars", None, {}, "No vars", None, None),
("$one and $two", None, {}, None, None, "Missing variables: one, two"),
("$one and $two", None, {"one": 1, "two": 2}, "1 and 2", None, None),
),
)
def test_template_execute(
prompt, system, params, expected_prompt, expected_system, expected_error
):
t = Template(name="t", prompt=prompt, system=system)
if expected_error:
with pytest.raises(Template.MissingVariables) as ex:
prompt, system = t.execute("input", params)
assert ex.value.args[0] == expected_error
else:
prompt, system = t.execute("input", params)
assert prompt == expected_prompt
assert system == expected_system


def test_templates_list(templates_path):
(templates_path / "one.yaml").write_text("template one", "utf-8")
(templates_path / "two.yaml").write_text("template two", "utf-8")
Expand All @@ -23,25 +48,60 @@ def test_templates_list(templates_path):

@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
@pytest.mark.parametrize(
"template,extra_args,expected_model,expected_input",
"template,extra_args,expected_model,expected_input,expected_error",
(
("'Summarize this: $input'", [], "gpt-3.5-turbo", "Summarize this: Input text"),
(
"'Summarize this: $input'",
[],
"gpt-3.5-turbo",
"Summarize this: Input text",
None,
),
(
"prompt: 'Summarize this: $input'\nmodel: gpt-4",
[],
"gpt-4",
"Summarize this: Input text",
None,
),
(
"prompt: 'Summarize this: $input'",
["-m", "4"],
"gpt-4",
"Summarize this: Input text",
None,
),
(
"boo",
["--system", "s"],
None,
None,
"Error: Cannot use -t/--template and --system together",
),
(
"prompt: 'Say $hello'",
[],
None,
None,
"Error: Missing variables: hello",
),
(
"prompt: 'Say $hello'",
["-p", "hello", "Blah"],
"gpt-3.5-turbo",
"Say Blah",
None,
),
),
)
def test_template_basic(
templates_path, mocked_openai, template, extra_args, expected_model, expected_input
templates_path,
mocked_openai,
template,
extra_args,
expected_model,
expected_input,
expected_error,
):
(templates_path / "template.yaml").write_text(template, "utf-8")
runner = CliRunner()
Expand All @@ -50,8 +110,12 @@ def test_template_basic(
["--no-stream", "-t", "template", "Input text"] + extra_args,
catch_exceptions=False,
)
assert result.exit_code == 0
assert mocked_openai.last_request.json() == {
"model": expected_model,
"messages": [{"role": "user", "content": expected_input}],
}
if expected_error is None:
assert result.exit_code == 0
assert mocked_openai.last_request.json() == {
"model": expected_model,
"messages": [{"role": "user", "content": expected_input}],
}
else:
assert result.exit_code == 1
assert result.output.strip() == expected_error

0 comments on commit 0fc51f4

Please sign in to comment.