Skip to content

Commit

Permalink
-o/--option, implemented for OpenAI models - closes #63
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Jul 3, 2023
1 parent 499a43b commit d649230
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 20 deletions.
23 changes: 12 additions & 11 deletions docs/help.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,18 @@ Usage: llm prompt [OPTIONS] [PROMPT]
Documentation: https://llm.datasette.io/en/stable/usage.html
Options:
-s, --system TEXT System prompt to use
-m, --model TEXT Model to use
-t, --template TEXT Template to use
-p, --param <TEXT TEXT>... Parameters for template
--no-stream Do not stream output
-n, --no-log Don't log to database
-c, --continue Continue the most recent conversation.
--chat INTEGER Continue the conversation with the given chat ID.
--key TEXT API key to use
--save TEXT Save prompt with this template name
--help Show this message and exit.
-s, --system TEXT System prompt to use
-m, --model TEXT Model to use
-o, --option <TEXT TEXT>... key/value options for the model
-t, --template TEXT Template to use
-p, --param <TEXT TEXT>... Parameters for template
--no-stream Do not stream output
-n, --no-log Don't log to database
-c, --continue Continue the most recent conversation.
--chat INTEGER Continue the conversation with the given chat ID.
--key TEXT API key to use
--save TEXT Save prompt with this template name
--help Show this message and exit.
```
### llm init-db --help
```
Expand Down
32 changes: 27 additions & 5 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def cli():
@click.argument("prompt", required=False)
@click.option("-s", "--system", help="System prompt to use")
@click.option("model_id", "-m", "--model", help="Model to use")
@click.option(
"options",
"-o",
"--option",
type=(str, str),
multiple=True,
help="key/value options for the model",
)
@click.option("-t", "--template", help="Template to use")
@click.option(
"-p",
Expand Down Expand Up @@ -88,6 +96,7 @@ def prompt(
prompt,
system,
model_id,
options,
template,
param,
no_stream,
Expand Down Expand Up @@ -116,17 +125,17 @@ def prompt(
if save:
# We are saving their prompt/system/etc to a new template
# Fields to save: prompt, system, model - and more in the future
bad_options = []
disallowed_options = []
for option, var in (
("--template", template),
("--continue", _continue),
("--chat", chat_id),
):
if var:
bad_options.append(option)
if bad_options:
disallowed_options.append(option)
if disallowed_options:
raise click.ClickException(
"--save cannot be used with {}".format(", ".join(bad_options))
"--save cannot be used with {}".format(", ".join(disallowed_options))
)
path = template_dir() / f"{save}.yaml"
to_save = {}
Expand Down Expand Up @@ -197,13 +206,26 @@ def prompt(
if model.needs_key and not model.key:
model.key = get_key(key, model.needs_key, model.key_env_var)

# Validate options
validated_options = {}
if options:
# Validate with pydantic
try:
validated_options = dict(
(key, value)
for key, value in model.Options(**dict(options))
if value is not None
)
except pydantic.ValidationError as ex:
raise click.ClickException(str(ex))

should_stream = model.can_stream and not no_stream
if should_stream:
method = model.stream
else:
method = model.prompt

response = method(prompt, system)
response = method(prompt, system, **validated_options)

if should_stream:
for chunk in response:
Expand Down
48 changes: 45 additions & 3 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from llm.errors import NeedsKeyException
from llm.utils import dicts_to_table_string
import click
from dataclasses import asdict
import datetime
import openai
from pydantic import field_validator
import requests
from typing import Optional, Union
import json


Expand Down Expand Up @@ -76,6 +77,7 @@ def iter_prompt(self):
model=self.prompt.model.model_id,
messages=messages,
stream=True,
**not_nulls(self.prompt.options),
):
self._debug["model"] = chunk.model
content = chunk["choices"][0].get("delta", {}).get("content")
Expand All @@ -96,8 +98,10 @@ def to_log(self) -> LogMessage:
model=self.prompt.model.model_id,
prompt=self.prompt.prompt,
system=self.prompt.system,
options=dict(self.prompt.options),
prompt_json=json.dumps(asdict(self.prompt), default=repr),
options=not_nulls(self.prompt.options),
prompt_json=json.dumps(self.prompt.prompt_json)
if self.prompt.prompt_json
else None,
response=self.text(),
response_json={},
chat_id=None, # TODO
Expand All @@ -110,6 +114,40 @@ class Chat(Model):
key_env_var = "OPENAI_API_KEY"
can_stream: bool = True

class Options(Model.Options):
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
stop: Optional[str] = None
logit_bias: Optional[Union[dict, str]] = None

@field_validator("logit_bias")
def validate_logit_bias(cls, logit_bias):
if logit_bias is None:
return None

if isinstance(logit_bias, str):
try:
logit_bias = json.loads(logit_bias)
except json.JSONDecodeError:
raise ValueError("Invalid JSON in logit_bias string")

validated_logit_bias = {}
for key, value in logit_bias.items():
try:
int_key = int(key)
int_value = int(value)
if -100 <= int_value <= 100:
validated_logit_bias[int_key] = int_value
else:
raise ValueError("Value must be between -100 and 100")
except ValueError:
raise ValueError("Invalid key-value pair in logit_bias dictionary")

return validated_logit_bias

def __init__(self, model_id, key=None):
self.model_id = model_id
self.key = key
Expand All @@ -124,3 +162,7 @@ def execute(self, prompt: Prompt, stream: bool = True) -> ChatResponse:

def __str__(self):
return "OpenAI Chat: {}".format(self.model_id)


def not_nulls(data) -> dict:
return {key: value for key, value in data if value is not None}
2 changes: 1 addition & 1 deletion llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class LogMessage:
prompt: str # Simplified string version of prompt
system: Optional[str] # Simplified string of system prompt
options: Dict[str, Any] # Any options e.g. temperature
prompt_json: str # Detailed JSON of prompt
prompt_json: Optional[str] # Detailed JSON of prompt
response: str # Simplified string version of response
response_json: Dict[str, Any] # Detailed JSON of response
chat_id: Optional[int] # ID of chat, if this is part of one
Expand Down

0 comments on commit d649230

Please sign in to comment.