Skip to content

Commit

Permalink
Improve onboarding by prompting users through initial configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
djcopley committed Jan 28, 2024
1 parent ef5bdaf commit 3b68ead
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 40 deletions.
12 changes: 8 additions & 4 deletions src/shelloracle/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,24 @@
from importlib.metadata import version

from . import shelloracle
from .bootstrap import bootstrap
from .bootstrap import bootstrap_shelloracle


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--init", help="install %(prog)s keybindings", action="store_true")
parser.add_argument('--version', action='version', version=f'%(prog)s {version(__package__)}')

subparsers = parser.add_subparsers()
init_subparser = subparsers.add_parser("init", help="install %(prog)s keybindings")
init_subparser.set_defaults(subparser=init_subparser, func=bootstrap_shelloracle)

return parser.parse_args()


def main() -> None:
args = parse_args()
if args.init:
bootstrap()
if func := args.func:
func()
exit(0)

shelloracle.cli()
Expand Down
79 changes: 63 additions & 16 deletions src/shelloracle/bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import inspect
import shutil
from pathlib import Path

from prompt_toolkit import print_formatted_text
from prompt_toolkit.application import create_app_session_from_tty
from prompt_toolkit import print_formatted_text, prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.shortcuts import confirm

from .config import Setting
from .providers import list_providers, get_provider, Provider


def print_info(info: str) -> None:
print_formatted_text(FormattedText([("ansiblue", info)]))


def print_error(error: str) -> None:
print_formatted_text(FormattedText([("ansired", error)]))
def print_warning(warning: str) -> None:
print_formatted_text(FormattedText([("ansiyellow", warning)]))


def replace_home_with_tilde(path: Path) -> Path:
Expand Down Expand Up @@ -74,17 +78,60 @@ def update_rc(shell: str) -> None:
print_info(f"Successfully updated {replace_home_with_tilde(rc_path)}")


def bootstrap() -> None:
with create_app_session_from_tty():
if not (shells := get_installed_shells()):
print_error(f"No compatible shells found. Supported shells: {', '.join(supported_shells)}")
return
if confirm("Enable terminal keybindings and update rc?", suffix=" ([y]/n) ") is False:
return
for shell in shells:
write_script_home(shell)
update_rc(shell)
def get_settings(provider: Provider) -> list[tuple[str, Setting]]:
settings = inspect.getmembers(provider, predicate=lambda p: isinstance(p, Setting))
return settings


def write_shelloracle_config(provider, settings):
configuration = (
"[shelloracle]\n"
"provider = %(provider)s\n"
"\n"
"[provider.%(provider)s]\n"
) % {"provider": provider.name}

for setting in settings:
s = "%(name)s = %(value)s\n" % {"name": setting[0], "value": setting[1]}
configuration += s

path = Path.home() / ".shelloracle" / "config.toml"
path.write_text(configuration)


def bootstrap_shelloracle() -> None:
"""Bootstrap shelloracle"""
provider = user_select_provider()
settings = user_configure_settings(provider)
write_shelloracle_config(provider, settings)
install_keybindings()


def install_keybindings() -> None:
if not (shells := get_installed_shells()):
print_warning("Cannot install keybindings: no compatible shells found. "
f"Supported shells: {', '.join(supported_shells)}")
return
if confirm("Enable terminal keybindings and update rc?", suffix=" ([y]/n) ") is False:
return
for shell in shells:
write_script_home(shell)
update_rc(shell)


def user_configure_settings(provider: Provider) -> list[tuple[str, str]]:
settings = []
for name, setting in get_settings(provider):
value = prompt(f"{name}: ", default=str(setting.default))
settings.append((name, value))
return settings


if __name__ == '__main__':
bootstrap()
def user_select_provider() -> Provider:
providers = list_providers()
completer = WordCompleter(providers, ignore_case=True)
selected_provider = prompt(f"Choose your LLM provider ({', '.join(providers)}): ", completer=completer)
case_insensitive_map = {p.lower(): p for p in providers}
selected_provider = case_insensitive_map[selected_provider.lower()]
provider = get_provider(selected_provider)
return provider
19 changes: 2 additions & 17 deletions src/shelloracle/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,12 @@
data_home = Path.home() / ".shelloracle"


def _default_config() -> tomlkit.TOMLDocument:
config = tomlkit.document()
shor_table = tomlkit.table()
shor_table.add("provider", "Ollama")
config.add("shelloracle", shor_table)
return config


class Configuration(MutableMapping):
filepath = data_home / "config.toml"

def __init__(self) -> None:
self._ensure_config_exists()
if not self.filepath.exists():
raise FileNotFoundError("Configuration file does not exist: Run `python -m shelloracle init` to initialize")

def __getitem__(self, item: str) -> dict:
with self.filepath.open("r") as file:
Expand Down Expand Up @@ -53,14 +46,6 @@ def __len__(self) -> int:
config = tomlkit.load(file)
return len(config)

def _ensure_config_exists(self) -> None:
if self.filepath.exists():
return
data_home.mkdir(exist_ok=True)
config = _default_config()
with self.filepath.open("w") as file:
tomlkit.dump(config, file)

@property
def provider(self) -> str:
return self["shelloracle"]["provider"]
Expand Down
4 changes: 1 addition & 3 deletions src/shelloracle/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from collections.abc import AsyncIterator
from typing import Protocol, runtime_checkable
from typing import Protocol

system_prompt = (
"Based on the following user description, generate a corresponding Bash command. Focus solely "
Expand All @@ -16,7 +16,6 @@ class ProviderError(Exception):
"""LLM providers raise this error to gracefully indicate something has gone wrong."""


@runtime_checkable
class Provider(Protocol):
"""
LLM Provider Protocol
Expand Down Expand Up @@ -57,7 +56,6 @@ def get_provider(name: str) -> type[Provider]:
:param name: the provider name
:return: the requested provider
"""

return _providers()[name]


Expand Down

0 comments on commit 3b68ead

Please sign in to comment.