Skip to content

Commit

Permalink
Merge branch 'improvement-prompt' of https://github.com/UmerHA/gpt-en…
Browse files Browse the repository at this point in the history
…gineer into improvement-prompt
  • Loading branch information
UmerHA committed Sep 20, 2023
2 parents a50c772 + 89f19fa commit b73f793
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 31 deletions.
21 changes: 10 additions & 11 deletions gpt_engineer/chat_to_files.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from pathlib import Path
import re

from dataclasses import dataclass
from typing import List, Tuple

from gpt_engineer.db import DBs, DB
from gpt_engineer.db import DB, DBs
from gpt_engineer.file_selector import FILE_LIST_NAME


Expand Down Expand Up @@ -73,9 +74,9 @@ def to_files(chat: str, workspace: DB):
workspace[file_name] = file_content


def overwrite_files(chat, dbs: DBs):
def overwrite_files(chat: str, dbs: DBs) -> None:
"""
Replace the AI files with the older local files.
Parse the chat and overwrite all files in the workspace.
Parameters
----------
Expand All @@ -84,19 +85,17 @@ def overwrite_files(chat, dbs: DBs):
dbs : DBs
The database containing the workspace.
"""
dbs.workspace["all_output.txt"] = chat # TODO store this in memory db instead
dbs.memory["all_output_overwrite.txt"] = chat

files = parse_chat(chat)
for file_name, file_content in files:
if file_name == "README.md":
dbs.workspace[
"LAST_MODIFICATION_README.md"
] = file_content # TODO store this in memory db instead
dbs.memory["LAST_MODIFICATION_README.md"] = file_content
else:
dbs.workspace[file_name] = file_content


def get_code_strings(input: DB) -> dict[str, str]:
def get_code_strings(workspace_path: Path, metadata_db: DB) -> dict[str, str]:
"""
Read file_list.txt and return file names and their content.
Expand All @@ -110,13 +109,13 @@ def get_code_strings(input: DB) -> dict[str, str]:
dict[str, str]
A dictionary mapping file names to their content.
"""
files_paths = input[FILE_LIST_NAME].strip().split("\n")
files_paths = metadata_db[FILE_LIST_NAME].strip().split("\n")
files_dict = {}
for full_file_path in files_paths:
with open(full_file_path, "r") as file:
file_data = file.read()
if file_data:
file_name = os.path.relpath(full_file_path, input.path)
file_name = os.path.relpath(full_file_path, workspace_path)
files_dict[file_name] = file_data
return files_dict

Expand Down Expand Up @@ -146,7 +145,7 @@ def format_file_to_input(file_name: str, file_content: str) -> str:
return file_str


def overwrite_files_with_edits(chat, dbs: DBs):
def overwrite_files_with_edits(chat: str, dbs: DBs):
edits = parse_edits(chat)
apply_edits(edits, dbs.workspace)

Expand Down
1 change: 1 addition & 0 deletions gpt_engineer/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class DBs:
input: DB
workspace: DB
archive: DB
project_metadata: DB


def archive(dbs: DBs) -> None:
Expand Down
16 changes: 9 additions & 7 deletions gpt_engineer/file_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import tkinter.filedialog as fd

from pathlib import Path
from typing import List, Union
from typing import List, Mapping, Union

from gpt_engineer.db import DB, DBs

IGNORE_FOLDERS = {"site-packages", "node_modules", "venv"}
FILE_LIST_NAME = "file_list.txt"
Expand Down Expand Up @@ -231,7 +233,7 @@ def is_in_ignoring_extensions(path: Path) -> bool:
return is_hidden and is_pycache


def ask_for_files(db_input) -> None:
def ask_for_files(metadata_db: DB, workspace_db: DB) -> None:
"""
Ask user to select files to improve.
It can be done by terminal, gui, or using the old selection.
Expand All @@ -240,10 +242,10 @@ def ask_for_files(db_input) -> None:
dict[str, str]: Dictionary where key = file name and value = file path
"""
use_last_string = ""
if FILE_LIST_NAME in db_input:
if FILE_LIST_NAME in metadata_db:
use_last_string = (
"3. Use previous file list (available at "
+ f"{os.path.join(db_input.path, FILE_LIST_NAME)})\n"
+ f"{os.path.join(metadata_db.path, FILE_LIST_NAME)})\n"
)
selection_number = 3
else:
Expand All @@ -270,10 +272,10 @@ def ask_for_files(db_input) -> None:

if selection_number == 1:
# Open GUI selection
file_path_list = gui_file_selector(db_input.path)
file_path_list = gui_file_selector(workspace_db.path)
elif selection_number == 2:
# Open terminal selection
file_path_list = terminal_file_selector(db_input.path)
file_path_list = terminal_file_selector(workspace_db.path)
if (
selection_number <= 0
or selection_number > 3
Expand All @@ -283,7 +285,7 @@ def ask_for_files(db_input) -> None:
sys.exit(1)

if not selection_number == 3:
db_input[FILE_LIST_NAME] = "\n".join(file_path_list)
metadata_db[FILE_LIST_NAME] = "\n".join(file_path_list)


def gui_file_selector(input_path: str) -> List[str]:
Expand Down
6 changes: 4 additions & 2 deletions gpt_engineer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ def main(
)

input_path = Path(project_path).absolute()
memory_path = input_path / "memory"
workspace_path = input_path / "workspace"
archive_path = input_path / "archive"
project_metadata_path = input_path / ".gpteng"
memory_path = project_metadata_path / "memory"
archive_path = project_metadata_path / "archive"

dbs = DBs(
memory=DB(memory_path),
Expand All @@ -76,6 +77,7 @@ def main(
Path(__file__).parent / "preprompts"
), # Loads preprompts from the preprompts directory
archive=DB(archive_path),
project_metadata=DB(project_metadata_path),
)

if steps_config not in [
Expand Down
15 changes: 9 additions & 6 deletions gpt_engineer/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def setup_sys_prompt_existing_code(dbs: DBs) -> str:
def curr_fn() -> str:
"""
Get the name of the current function
NOTE: This will be the name of the function that called this function,
This will be the name of the function that called this function,
so it serves to ensure we don't hardcode the function name in the step,
but allow the step names to be refactored
"""
Expand Down Expand Up @@ -295,16 +296,16 @@ def use_feedback(ai: AI, dbs: DBs):

def set_improve_filelist(ai: AI, dbs: DBs):
"""Sets the file list for files to work with in existing code mode."""
ask_for_files(dbs.input) # stores files as full paths.
ask_for_files(dbs.project_metadata, dbs.input) # stores files as full paths.
return []


def assert_files_ready(ai: AI, dbs: DBs):
"""Checks that the required files are present for headless
improve code execution."""
assert (
"file_list.txt" in dbs.input
), "For auto_mode file_list.txt need to be in your project folder."
"file_list.txt" in dbs.project_metadata
), "For auto_mode file_list.txt need to be in your .gpteng folder."
assert "prompt" in dbs.input, "For auto_mode a prompt file must exist."
return []

Expand All @@ -324,7 +325,7 @@ def get_improve_prompt(ai: AI, dbs: DBs):
"-----------------------------",
"The following files will be used in the improvement process:",
f"{FILE_LIST_NAME}:",
str(dbs.input["file_list.txt"]),
str(dbs.project_metadata["file_list.txt"]),
"",
"The inserted prompt is the following:",
f"'{dbs.input['prompt']}'",
Expand All @@ -346,7 +347,9 @@ def improve_existing_code(ai: AI, dbs: DBs):
to sent the formatted prompt to the LLM.
"""

files_info = get_code_strings(dbs.input) # this only has file names not paths
files_info = get_code_strings(
dbs.input.path, dbs.project_metadata
) # this has file names relative to the workspace path

messages = [
ai.fsystem(setup_sys_prompt_existing_code(dbs)),
Expand Down
22 changes: 20 additions & 2 deletions tests/steps/test_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,16 @@ def setup_dbs(tmp_path, dir_names):

def test_archive(tmp_path, monkeypatch):
dbs = setup_dbs(
tmp_path, ["memory", "logs", "preprompts", "input", "workspace", "archive"]
tmp_path,
[
"memory",
"logs",
"preprompts",
"input",
"workspace",
"archive",
"project_metadata",
],
)
freeze_at(monkeypatch, datetime.datetime(2020, 12, 25, 17, 5, 55))
archive(dbs)
Expand All @@ -33,7 +42,16 @@ def test_archive(tmp_path, monkeypatch):
assert os.path.isdir(tmp_path / "archive" / "20201225_170555")

dbs = setup_dbs(
tmp_path, ["memory", "logs", "preprompts", "input", "workspace", "archive"]
tmp_path,
[
"memory",
"logs",
"preprompts",
"input",
"workspace",
"archive",
"project_metadata",
],
)
freeze_at(monkeypatch, datetime.datetime(2022, 8, 14, 8, 5, 12))
archive(dbs)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def test_collect_learnings(monkeypatch):
model = "test_model"
temperature = 0.5
steps = [gen_code_after_unit_tests]
dbs = DBs(DB("/tmp"), DB("/tmp"), DB("/tmp"), DB("/tmp"), DB("/tmp"), DB("/tmp"))
dbs = DBs(
DB("/tmp"), DB("/tmp"), DB("/tmp"), DB("/tmp"), DB("/tmp"), DB("/tmp"), DB("/tmp")
)
dbs.input = {
"prompt": "test prompt\n with newlines",
"feedback": "test feedback",
Expand Down
23 changes: 21 additions & 2 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@ def test_DB_operations(tmp_path):


def test_DBs_initialization(tmp_path):
dir_names = ["memory", "logs", "preprompts", "input", "workspace", "archive"]
dir_names = [
"memory",
"logs",
"preprompts",
"input",
"workspace",
"archive",
"project_metadata",
]
directories = [tmp_path / name for name in dir_names]

# Create DB objects
Expand All @@ -34,6 +42,7 @@ def test_DBs_initialization(tmp_path):
assert isinstance(dbs_instance.input, DB)
assert isinstance(dbs_instance.workspace, DB)
assert isinstance(dbs_instance.archive, DB)
assert isinstance(dbs_instance.project_metadata, DB)


def test_invalid_path():
Expand Down Expand Up @@ -97,7 +106,15 @@ def test_error_messages(tmp_path):


def test_DBs_dataclass_attributes(tmp_path):
dir_names = ["memory", "logs", "preprompts", "input", "workspace", "archive"]
dir_names = [
"memory",
"logs",
"preprompts",
"input",
"workspace",
"archive",
"project_metadata",
]
directories = [tmp_path / name for name in dir_names]

# Create DB objects
Expand All @@ -111,3 +128,5 @@ def test_DBs_dataclass_attributes(tmp_path):
assert dbs_instance.preprompts == dbs[2]
assert dbs_instance.input == dbs[3]
assert dbs_instance.workspace == dbs[4]
assert dbs_instance.archive == dbs[5]
assert dbs_instance.project_metadata == dbs[6]

0 comments on commit b73f793

Please sign in to comment.