From 566573c1e93b5211850287f00aa5e4f86620aba8 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 1 Jul 2023 14:01:29 -0700 Subject: [PATCH] llm models default command, plus refactored env variables Closes #76 Closes #31 --- docs/help.md | 12 +++++++- docs/logging.md | 2 -- docs/setup.md | 14 ++++++++++ docs/usage.md | 37 +++++++++++++++++++++++++ llm/cli.py | 69 +++++++++++++++++++++++++++++++--------------- llm/plugins.py | 12 ++++++++ tests/conftest.py | 25 +++++++---------- tests/test_keys.py | 21 ++++++++------ tests/test_llm.py | 16 ++++++----- 9 files changed, 152 insertions(+), 56 deletions(-) diff --git a/docs/help.md b/docs/help.md index 9ae8d571..1af54d3f 100644 --- a/docs/help.md +++ b/docs/help.md @@ -179,7 +179,8 @@ Options: --help Show this message and exit. Commands: - list List available models + default Show or set the default model + list List available models ``` #### llm models list --help ``` @@ -187,6 +188,15 @@ Usage: llm models list [OPTIONS] List available models +Options: + --help Show this message and exit. +``` +#### llm models default --help +``` +Usage: llm models default [OPTIONS] [MODEL] + + Show or set the default model + Options: --help Show this message and exit. ``` diff --git a/docs/logging.md b/docs/logging.md index 45bddf31..6c4e35c7 100644 --- a/docs/logging.md +++ b/docs/logging.md @@ -18,8 +18,6 @@ On my Mac that outputs: ``` This will differ for other operating systems. -(You can customize the location of this file by setting a path in the `LLM_LOG_PATH` environment variable.) - Once that SQLite database has been created any prompts you run will be logged to that database. To avoid logging a prompt, pass `--no-log` or `-n` to the command: diff --git a/docs/setup.md b/docs/setup.md index 62dd1c2c..f5812f1d 100644 --- a/docs/setup.md +++ b/docs/setup.md @@ -89,3 +89,17 @@ The environment variable will be used only if no `--key` option is passed to the If no environment variable is found, the tool will fall back to checking `keys.json`. You can force the tool to use the key from `keys.json` even if an environment variable has also been set using `llm "prompt" --key openai`. + +## Custom directory location + +This tool stores various files - prompt templates, stored keys, preferences, a database of logs - in a directory on your computer. + +On macOS this is `~/Library/Application Support/io.datasette.llm/`. + +On Linux it may be something like `~/.config/io.datasette.llm/`. + +You can set a custom location for this directory by setting the `LLM_USER_PATH` environment variable: + +```bash +export LLM_USER_PATH=/path/to/my/custom/directory +``` diff --git a/docs/usage.md b/docs/usage.md index 8b418f8a..54a4d7a5 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -61,3 +61,40 @@ This is useful for piping content to standard input, for example: curl -s 'https://simonwillison.net/2023/May/15/per-interpreter-gils/' | \ llm -s 'Suggest topics for this post as a JSON array' + +## Listing available models + +The `llm models list` command lists every model that can be used with LLM, along with any aliases: + +``` +llm models list +``` +Example output: +``` +OpenAI Chat: gpt-3.5-turbo (aliases: 3.5, chatgpt) +OpenAI Chat: gpt-3.5-turbo-16k (aliases: chatgpt-16k, 3.5-16k) +OpenAI Chat: gpt-4 (aliases: 4, gpt4) +OpenAI Chat: gpt-4-32k (aliases: 4-32k) +PaLM 2: chat-bison-001 (aliases: palm, palm2) +``` +You can use pass the full model name or any of the aliases to the `-m/--model` option: + +``` +llm -m chatgpt-16k 'As many names for cheesecakes as you can think of, with detailed descriptions' +``` +Models that have been installed using plugins will be shown here as well. + +## Setting a custom model + +The model used when calling `llm` without the `-m/--model` option defaults to `gpt-3.5-turbo` - the fastest and least expensive OpenAI model, and the same model family that powers ChatGPT. + +You can use the `llm models default` command to set a different default model. For GPT-4 (slower and more expensive, but more capable) run this: + +```bash +llm models default gpt-4 +``` +You can view the current model by running this: +``` +llm models default +``` +Any of the supported aliases for a model can be passed to this command. \ No newline at end of file diff --git a/llm/cli.py b/llm/cli.py index 4481da32..d6ee580c 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -4,7 +4,13 @@ import json from llm import Template from .migrations import migrate -from .plugins import pm, get_plugins, get_model_aliases, get_models_with_aliases +from .plugins import ( + pm, + get_plugins, + get_model, + get_model_aliases, + get_models_with_aliases, +) import openai import os import pathlib @@ -180,7 +186,7 @@ def prompt( # Figure out which model we are using if model_id is None: - model_id = history_model or DEFAULT_MODEL + model_id = history_model or get_default_model() # Now resolve the model try: @@ -255,7 +261,7 @@ def init_db(): All subsequent prompts will be logged to this database. """ - path = log_db_path() + path = logs_db_path() if path.exists(): return # Ensure directory exists @@ -276,11 +282,7 @@ def keys_path_command(): def keys_path(): - llm_keys_path = os.environ.get("LLM_KEYS_PATH") - if llm_keys_path: - return pathlib.Path(llm_keys_path) - else: - return user_dir() / "keys.json" + return user_dir() / "keys.json" @keys.command(name="set") @@ -321,7 +323,7 @@ def logs(): @logs.command(name="path") def logs_path(): "Output the path to the logs.db file" - click.echo(log_db_path()) + click.echo(logs_db_path()) @logs.command(name="list") @@ -340,7 +342,7 @@ def logs_path(): @click.option("-t", "--truncate", is_flag=True, help="Truncate long strings in output") def logs_list(count, path, truncate): "Show recent logged prompts and their responses" - path = pathlib.Path(path or log_db_path()) + path = pathlib.Path(path or logs_db_path()) if not path.exists(): raise click.ClickException("No log database found at {}".format(path)) db = sqlite_utils.Database(path) @@ -369,6 +371,21 @@ def models_list(): click.echo(output) +@models.command(name="default") +@click.argument("model", required=False) +def models_default(model): + "Show or set the default model" + if not model: + click.echo(get_default_model()) + return + # Validate it is a known model + try: + model = get_model(model) + set_default_model(model.model_id) + except KeyError: + raise click.ClickException("Unknown model: {}".format(model)) + + @cli.group() def templates(): "Manage stored prompt templates" @@ -473,11 +490,7 @@ def uninstall(packages, yes): def template_dir(): - llm_templates_path = os.environ.get("LLM_TEMPLATES_PATH") - if llm_templates_path: - path = pathlib.Path(llm_templates_path) - else: - path = user_dir() / "templates" + path = user_dir() / "templates" path.mkdir(parents=True, exist_ok=True) return path @@ -514,15 +527,27 @@ def load_keys(): def user_dir(): + llm_user_path = os.environ.get("LLM_USER_PATH") + if llm_user_path: + return pathlib.Path(llm_user_path) return pathlib.Path(click.get_app_dir("io.datasette.llm")) -def log_db_path(): - llm_log_path = os.environ.get("LLM_LOG_PATH") - if llm_log_path: - return pathlib.Path(llm_log_path) +def get_default_model(): + path = user_dir() / "default_model.txt" + if path.exists(): + return path.read_text().strip() else: - return user_dir() / "logs.db" + return DEFAULT_MODEL + + +def set_default_model(model): + path = user_dir() / "default_model.txt" + path.write_text(model) + + +def logs_db_path(): + return user_dir() / "logs.db" def log(no_log, system, prompt, response, model, chat_id=None, debug=None, start=None): @@ -532,7 +557,7 @@ def log(no_log, system, prompt, response, model, chat_id=None, debug=None, start duration_ms = int((end - start) * 1000) if no_log: return - log_path = log_db_path() + log_path = logs_db_path() if not log_path.exists(): return db = sqlite_utils.Database(log_path) @@ -574,7 +599,7 @@ def load_template(name): def get_history(chat_id): if chat_id is None: return None, [] - log_path = log_db_path() + log_path = logs_db_path() if not log_path.exists(): raise click.ClickException( "This feature requires logging. Run `llm init-db` to create logs.db" diff --git a/llm/plugins.py b/llm/plugins.py index 309f6932..45ad634a 100644 --- a/llm/plugins.py +++ b/llm/plugins.py @@ -52,3 +52,15 @@ def get_model_aliases() -> Dict[str, Model]: model_aliases[alias] = model_with_aliases.model model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model return model_aliases + + +class UnknownModelError(KeyError): + pass + + +def get_model(name): + aliases = get_model_aliases() + try: + return aliases[name] + except KeyError: + raise UnknownModelError(name) diff --git a/tests/conftest.py b/tests/conftest.py index 78d3afa5..06f0f6e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,27 +8,22 @@ def pytest_configure(config): @pytest.fixture -def log_path(tmpdir): - return tmpdir / "logs.db" +def user_path(tmpdir): + dir = tmpdir / "llm.datasette.io" + dir.mkdir() + return dir @pytest.fixture -def keys_path(tmpdir): - return tmpdir / "keys.json" - - -@pytest.fixture -def templates_path(tmpdir): - path = tmpdir / "templates" - path.mkdir() - return path +def templates_path(user_path): + dir = user_path / "templates" + dir.mkdir() + return dir @pytest.fixture(autouse=True) -def env_setup(monkeypatch, log_path, keys_path, templates_path): - monkeypatch.setenv("LLM_KEYS_PATH", str(keys_path)) - monkeypatch.setenv("LLM_LOG_PATH", str(log_path)) - monkeypatch.setenv("LLM_TEMPLATES_PATH", str(templates_path)) +def env_setup(monkeypatch, user_path): + monkeypatch.setenv("LLM_USER_PATH", str(user_path)) @pytest.fixture diff --git a/tests/test_keys.py b/tests/test_keys.py index 0d84de9c..a36fac70 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -1,30 +1,31 @@ from click.testing import CliRunner import json from llm.cli import cli +import pathlib import pytest -@pytest.mark.parametrize("env", ({}, {"LLM_KEYS_PATH": "/tmp/foo.json"})) -def test_keys_path(monkeypatch, env, keys_path): +@pytest.mark.parametrize("env", ({}, {"LLM_USER_PATH": "/tmp/llm-keys-test"})) +def test_keys_in_user_path(monkeypatch, env, user_path): for key, value in env.items(): monkeypatch.setenv(key, value) runner = CliRunner() result = runner.invoke(cli, ["keys", "path"]) assert result.exit_code == 0 if env: - expected = env["LLM_KEYS_PATH"] + expected = env["LLM_USER_PATH"] + "/keys.json" else: - expected = keys_path + expected = user_path + "/keys.json" assert result.output.strip() == expected def test_keys_set(monkeypatch, tmpdir): - keys_path = str(tmpdir / "keys.json") - monkeypatch.setenv("LLM_KEYS_PATH", keys_path) + user_path = str(tmpdir / "user/keys") + monkeypatch.setenv("LLM_USER_PATH", user_path) runner = CliRunner() result = runner.invoke(cli, ["keys", "set", "openai"], input="foo") assert result.exit_code == 0 - content = open(keys_path).read() + content = open(user_path + "/keys.json").read() assert json.loads(content) == { "// Note": "This file stores secret API credentials. Do not share!", "openai": "foo", @@ -32,7 +33,9 @@ def test_keys_set(monkeypatch, tmpdir): def test_uses_correct_key(mocked_openai, monkeypatch, tmpdir): - keys_path = tmpdir / "keys.json" + user_dir = tmpdir / "user-dir" + pathlib.Path(user_dir).mkdir() + keys_path = user_dir / "keys.json" keys_path.write_text( json.dumps( { @@ -42,7 +45,7 @@ def test_uses_correct_key(mocked_openai, monkeypatch, tmpdir): ), "utf-8", ) - monkeypatch.setenv("LLM_KEYS_PATH", str(keys_path)) + monkeypatch.setenv("LLM_USER_PATH", str(user_dir)) monkeypatch.setenv("OPENAI_API_KEY", "from-env") def assert_key(key): diff --git a/tests/test_llm.py b/tests/test_llm.py index 73bd3c3f..c0de9215 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -17,8 +17,9 @@ def test_version(): @pytest.mark.parametrize("n", (None, 0, 2)) -def test_logs(n, log_path): - db = sqlite_utils.Database(str(log_path)) +def test_logs(n, user_path): + log_path = str(user_path / "logs.db") + db = sqlite_utils.Database(log_path) migrate(db) db["log"].insert_all( { @@ -45,24 +46,25 @@ def test_logs(n, log_path): assert len(logs) == expected_length -@pytest.mark.parametrize("env", ({}, {"LLM_LOG_PATH": "/tmp/logs.db"})) -def test_logs_path(monkeypatch, env, log_path): +@pytest.mark.parametrize("env", ({}, {"LLM_USER_PATH": "/tmp/llm-user-path"})) +def test_logs_path(monkeypatch, env, user_path): for key, value in env.items(): monkeypatch.setenv(key, value) runner = CliRunner() result = runner.invoke(cli, ["logs", "path"]) assert result.exit_code == 0 if env: - expected = env["LLM_LOG_PATH"] + expected = env["LLM_USER_PATH"] + "/logs.db" else: - expected = str(log_path) + expected = str(user_path) + "/logs.db" assert result.output.strip() == expected @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"}) @pytest.mark.parametrize("use_stdin", (True, False)) -def test_llm_default_prompt(mocked_openai, use_stdin, log_path): +def test_llm_default_prompt(mocked_openai, use_stdin, user_path): # Reset the log_path database + log_path = user_path / "logs.db" log_db = sqlite_utils.Database(str(log_path)) log_db["log"].delete_where() runner = CliRunner()