diff --git a/docs/help.md b/docs/help.md index 1af54d3f..99e17e23 100644 --- a/docs/help.md +++ b/docs/help.md @@ -75,17 +75,18 @@ Usage: llm prompt [OPTIONS] [PROMPT] Documentation: https://llm.datasette.io/en/stable/usage.html Options: - -s, --system TEXT System prompt to use - -m, --model TEXT Model to use - -t, --template TEXT Template to use - -p, --param ... Parameters for template - --no-stream Do not stream output - -n, --no-log Don't log to database - -c, --continue Continue the most recent conversation. - --chat INTEGER Continue the conversation with the given chat ID. - --key TEXT API key to use - --save TEXT Save prompt with this template name - --help Show this message and exit. + -s, --system TEXT System prompt to use + -m, --model TEXT Model to use + -o, --option ... key/value options for the model + -t, --template TEXT Template to use + -p, --param ... Parameters for template + --no-stream Do not stream output + -n, --no-log Don't log to database + -c, --continue Continue the most recent conversation. + --chat INTEGER Continue the conversation with the given chat ID. + --key TEXT API key to use + --save TEXT Save prompt with this template name + --help Show this message and exit. ``` ### llm init-db --help ``` diff --git a/llm/cli.py b/llm/cli.py index 90e96fdd..4d0feac3 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -58,6 +58,14 @@ def cli(): @click.argument("prompt", required=False) @click.option("-s", "--system", help="System prompt to use") @click.option("model_id", "-m", "--model", help="Model to use") +@click.option( + "options", + "-o", + "--option", + type=(str, str), + multiple=True, + help="key/value options for the model", +) @click.option("-t", "--template", help="Template to use") @click.option( "-p", @@ -88,6 +96,7 @@ def prompt( prompt, system, model_id, + options, template, param, no_stream, @@ -116,17 +125,17 @@ def prompt( if save: # We are saving their prompt/system/etc to a new template # Fields to save: prompt, system, model - and more in the future - bad_options = [] + disallowed_options = [] for option, var in ( ("--template", template), ("--continue", _continue), ("--chat", chat_id), ): if var: - bad_options.append(option) - if bad_options: + disallowed_options.append(option) + if disallowed_options: raise click.ClickException( - "--save cannot be used with {}".format(", ".join(bad_options)) + "--save cannot be used with {}".format(", ".join(disallowed_options)) ) path = template_dir() / f"{save}.yaml" to_save = {} @@ -197,13 +206,26 @@ def prompt( if model.needs_key and not model.key: model.key = get_key(key, model.needs_key, model.key_env_var) + # Validate options + validated_options = {} + if options: + # Validate with pydantic + try: + validated_options = dict( + (key, value) + for key, value in model.Options(**dict(options)) + if value is not None + ) + except pydantic.ValidationError as ex: + raise click.ClickException(str(ex)) + should_stream = model.can_stream and not no_stream if should_stream: method = model.stream else: method = model.prompt - response = method(prompt, system) + response = method(prompt, system, **validated_options) if should_stream: for chunk in response: diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index fcfa981b..629251b1 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -2,10 +2,11 @@ from llm.errors import NeedsKeyException from llm.utils import dicts_to_table_string import click -from dataclasses import asdict import datetime import openai +from pydantic import field_validator import requests +from typing import Optional, Union import json @@ -76,6 +77,7 @@ def iter_prompt(self): model=self.prompt.model.model_id, messages=messages, stream=True, + **not_nulls(self.prompt.options), ): self._debug["model"] = chunk.model content = chunk["choices"][0].get("delta", {}).get("content") @@ -96,8 +98,10 @@ def to_log(self) -> LogMessage: model=self.prompt.model.model_id, prompt=self.prompt.prompt, system=self.prompt.system, - options=dict(self.prompt.options), - prompt_json=json.dumps(asdict(self.prompt), default=repr), + options=not_nulls(self.prompt.options), + prompt_json=json.dumps(self.prompt.prompt_json) + if self.prompt.prompt_json + else None, response=self.text(), response_json={}, chat_id=None, # TODO @@ -110,6 +114,40 @@ class Chat(Model): key_env_var = "OPENAI_API_KEY" can_stream: bool = True + class Options(Model.Options): + temperature: Optional[float] = None + max_tokens: Optional[int] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + stop: Optional[str] = None + logit_bias: Optional[Union[dict, str]] = None + + @field_validator("logit_bias") + def validate_logit_bias(cls, logit_bias): + if logit_bias is None: + return None + + if isinstance(logit_bias, str): + try: + logit_bias = json.loads(logit_bias) + except json.JSONDecodeError: + raise ValueError("Invalid JSON in logit_bias string") + + validated_logit_bias = {} + for key, value in logit_bias.items(): + try: + int_key = int(key) + int_value = int(value) + if -100 <= int_value <= 100: + validated_logit_bias[int_key] = int_value + else: + raise ValueError("Value must be between -100 and 100") + except ValueError: + raise ValueError("Invalid key-value pair in logit_bias dictionary") + + return validated_logit_bias + def __init__(self, model_id, key=None): self.model_id = model_id self.key = key @@ -124,3 +162,7 @@ def execute(self, prompt: Prompt, stream: bool = True) -> ChatResponse: def __str__(self): return "OpenAI Chat: {}".format(self.model_id) + + +def not_nulls(data) -> dict: + return {key: value for key, value in data if value is not None} diff --git a/llm/models.py b/llm/models.py index b250469d..0b99f0ab 100644 --- a/llm/models.py +++ b/llm/models.py @@ -33,7 +33,7 @@ class LogMessage: prompt: str # Simplified string version of prompt system: Optional[str] # Simplified string of system prompt options: Dict[str, Any] # Any options e.g. temperature - prompt_json: str # Detailed JSON of prompt + prompt_json: Optional[str] # Detailed JSON of prompt response: str # Simplified string version of response response_json: Dict[str, Any] # Detailed JSON of response chat_id: Optional[int] # ID of chat, if this is part of one