diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e86dcb7d..dac9127d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,7 +20,7 @@ python -m venv env && source ./env/bin/activate Install the necessary dependencies, including development and test dependencies: ```shell -pip install -e ."[dev,test]" +pip install -e ."[dev,test,litellm]" ``` ### Start Coding @@ -35,4 +35,4 @@ Before creating a pull request, run `scripts/lint.sh` and `scripts/tests.sh` to ### Code Review After submitting your pull request, be patient and receptive to feedback from reviewers. Address any concerns they raise and collaborate to refine the code. Together, we can enhance the ShellGPT project. -Thank you once again for your contribution! We're excited to have you join us. \ No newline at end of file +Thank you once again for your contribution! We're excited to have you join us. diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index a17d8024..3c7d53ec 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -1,4 +1,5 @@ import json +import re from pathlib import Path from typing import Any, Callable, Dict, Generator, List, Optional @@ -37,6 +38,7 @@ class Handler: def __init__(self, role: SystemRole, markdown: bool) -> None: self.role = role + self.is_shell = role.name == DefaultRoles.SHELL.value api_base_url = cfg.get("API_BASE_URL") self.base_url = None if api_base_url == "default" else api_base_url @@ -45,6 +47,13 @@ def __init__(self, role: SystemRole, markdown: bool) -> None: self.markdown = "APPLY MARKDOWN" in self.role.role and markdown self.code_theme, self.color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR") + self.backticks_start = re.compile(r"(^|[\r\n]+)```\w*[\r\n]+") + end_regex_parts = [r"[\r\n]+", "`", "`", "`", r"([\r\n]+|$)"] + self.backticks_end_prefixes = [ + re.compile("".join(end_regex_parts[: i + 1])) + for i in range(len(end_regex_parts)) + ] + @property def printer(self) -> Printer: return ( @@ -82,6 +91,48 @@ def handle_function_call( yield f"```text\n{result}\n```\n" messages.append({"role": "function", "content": result, "name": name}) + def _matches_end_at(self, text: str) -> tuple[bool, int]: + end_of_match = 0 + for _i, regex in enumerate(self.backticks_end_prefixes): + m = regex.search(text) + if m: + end_of_match = m.end() + else: + return False, end_of_match + return True, m.start() + + def _filter_chunks( + self, chunks: Generator[str, None, None] + ) -> Generator[str, None, None]: + buffer = "" + inside_backticks = False + end_of_beginning = 0 + + for chunk in chunks: + buffer += chunk + if not inside_backticks: + m = self.backticks_start.search(buffer) + if not m: + continue + new_end_of_beginning = m.end() + if new_end_of_beginning > end_of_beginning: + end_of_beginning = new_end_of_beginning + continue + inside_backticks = True + buffer = buffer[end_of_beginning:] + if inside_backticks: + matches_end, index = self._matches_end_at(buffer) + if matches_end: + yield buffer[:index] + return + if index == len(buffer): + continue + else: + yield buffer + buffer = "" + if buffer: + yield buffer + @cache def get_completion( self, @@ -163,4 +214,6 @@ def handle( caching=caching, **kwargs, ) + if self.role.name == DefaultRoles.SHELL.value: + generator = self._filter_chunks(generator) return self.printer(generator, not disable_stream) diff --git a/tests/test_shell.py b/tests/test_shell.py index b78e2c96..d08dd984 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -2,6 +2,8 @@ from pathlib import Path from unittest.mock import patch +import pytest + from sgpt.config import cfg from sgpt.role import DefaultRoles, SystemRole @@ -22,6 +24,54 @@ def test_shell(completion): assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout +@patch("sgpt.handlers.handler.completion") +@pytest.mark.parametrize( + "prefix,suffix", + [ + ("", ""), + ("some text before\n```powershell\n", "\n```" ""), + ("```powershell\n", "\n```\nsome text after" ""), + ("some text before\n```powershell\n", "\n```\nsome text after" ""), + ( + "some text with ``` before\n```powershell\n", + "\n```\nsome text with ``` after" "", + ), + ("```powershell\n", "\n```" ""), + ("```\n", "\n```" ""), + ("```powershell\r\n", "\r\n```" ""), + ("```\r\n", "\r\n```" ""), + ("```powershell\r", "\r```" ""), + ("```\r", "\r```" ""), + ], +) +@pytest.mark.parametrize("group_by_size", range(10)) +def test_shell_no_backticks(completion, prefix: str, suffix: str, group_by_size: int): + expected_output = "Get-Process | \nWhere-Object { $_.Port -eq 9000 }\r\n | Select-Object Id | Text \r\nwith '```' inside" + produced_output = prefix + expected_output + suffix + if group_by_size == 0: + produced_tokens = list(produced_output) + else: + produced_tokens = [ + produced_output[i : i + group_by_size] + for i in range(0, len(produced_output), group_by_size) + ] + assert produced_output == "".join(produced_tokens) + + role = SystemRole.get(DefaultRoles.SHELL.value) + completion.return_value = mock_comp(produced_tokens) + + args = {"prompt": "find pid by port 9000", "--shell": True} + result = runner.invoke(app, cmd_args(**args)) + + completion.assert_called_once_with(**comp_args(role, args["prompt"])) + index = result.stdout.find(expected_output) + assert index >= 0 + rest = result.stdout[index + len(expected_output) :].strip() + assert "`" not in rest + assert result.exit_code == 0 + assert "[E]xecute, [D]escribe, [A]bort:" == rest + + @patch("sgpt.printer.TextPrinter.live_print") @patch("sgpt.printer.MarkdownPrinter.live_print") @patch("sgpt.handlers.handler.completion")