Skip to content

Commit

Permalink
tested code
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeadie committed Jun 20, 2023
1 parent 93fa067 commit a2aa813
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 20 deletions.
6 changes: 3 additions & 3 deletions gpt_engineer/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def start(self, initial_conversation: Messages) -> Messages:

def next(self, messages: Messages, user_prompt: Optional[str] = None) -> Messages:
if user_prompt:
messages.append(Message(user_prompt, Role.USER))
messages.messages.append(Message(user_prompt, Role.USER))

response = openai.ChatCompletion.create(
messages=[self._format_message(m) for m in messages],
messages=[self._format_message(m) for m in messages.messages],
stream=True,
**self.kwargs,
)
Expand All @@ -84,7 +84,7 @@ def next(self, messages: Messages, user_prompt: Optional[str] = None) -> Message
print(msg, end="")
chat.append(msg)

messages.append(Message("".join(chat), Role.ASSISTANT))
messages.messages.append(Message("".join(chat), Role.ASSISTANT))
return messages

def _format_message(self, msg: Message) -> Dict[str, str]:
Expand Down
5 changes: 3 additions & 2 deletions gpt_engineer/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import List

from dataclasses_json import dataclass_json

Expand All @@ -21,7 +22,7 @@ class Message:
@dataclass_json
@dataclass
class Messages:
messages: list[Message]
messages: List[Message]

def last_message_content(self) -> str:
return self.messages[-1][1].content
return self.messages[-1].content
25 changes: 12 additions & 13 deletions gpt_engineer/steps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import re
import subprocess

Expand Down Expand Up @@ -51,7 +50,7 @@ def clarify(ai: AI, dbs: DBs) -> Messages:
def simple_gen(ai: AI, dbs: DBs) -> Messages:
"""Run the AI on the main prompt and save the results"""
messages = ai.start(
Message(
Messages(
[
Message(role=Role.SYSTEM, content=setup_sys_prompt(dbs)),
Message(role=Role.USER, content=dbs.input["main_prompt"]),
Expand All @@ -68,9 +67,9 @@ def run_clarified(ai: AI, dbs: DBs) -> Messages:
"""
# get the messages from previous step
messages = ai.next(
Message(
Messages(
[Message(role=Role.SYSTEM, content=setup_sys_prompt(dbs))]
+ [Message(m) for m in json.loads(dbs.logs[clarify.__name__])[1:]]
+ Messages.from_json(dbs.logs[clarify.__name__]).messages[1:]
),
user_prompt=dbs.identity["use_qa"],
)
Expand All @@ -93,14 +92,14 @@ def gen_spec(ai: AI, dbs: DBs) -> Messages:
)
messages = ai.next(messages, user_prompt=dbs.identity["spec"])

dbs.memory["specification"] = messages[-1][1].content
dbs.memory["specification"] = messages.last_message_content()

return messages


def respec(ai: AI, dbs: DBs) -> Messages:
messages = Messages(
[Message(role=Role.SYSTEM, content=log) for log in dbs.logs[gen_spec.__name__]]
Messages.from_json(dbs.logs[gen_spec.__name__]).messages
+ [Message(role=Role.SYSTEM, content=dbs.identity["respec"])]
)

Expand All @@ -115,7 +114,7 @@ def respec(ai: AI, dbs: DBs) -> Messages:
),
)

dbs.memory["specification"] = messages[-1][1].content
dbs.memory["specification"] = messages.last_message_content()
return messages


Expand All @@ -135,7 +134,7 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> Messages:

messages = ai.next(messages, user_prompt=dbs.identity["unit_tests"])

dbs.memory["unit_tests"] = messages[-1][1].content
dbs.memory["unit_tests"] = messages.last_message_content()
to_files(dbs.memory["unit_tests"], dbs.workspace)

return messages
Expand All @@ -147,7 +146,7 @@ def gen_clarified_code(ai: AI, dbs: DBs) -> Messages:
messages = ai.next(
Messages(
[Message(role=Role.SYSTEM, content=setup_sys_prompt(dbs))]
+ [Message.from_dict(m) for m in json.loads(dbs.logs[clarify.__name__])[1:]]
+ Messages.from_json(dbs.logs[clarify.__name__]).messages[1:]
),
user_prompt=dbs.identity["use_qa"],
)
Expand Down Expand Up @@ -192,11 +191,11 @@ def execute_entrypoint(ai, dbs) -> Messages:
print()
if input() != "":
print("Ok, not executing the code.")
return []
return Messages(messages=[])
print("Executing the code...")
print()
subprocess.run("bash run.sh", shell=True, cwd=dbs.workspace.path)
return []
return Messages(messages=[])


def gen_entrypoint(ai, dbs) -> Messages:
Expand Down Expand Up @@ -228,7 +227,7 @@ def gen_entrypoint(ai, dbs) -> Messages:
)
print()
regex = r"```\S*\n(.+?)```"
matches = re.finditer(regex, messages[-1][1].content, re.DOTALL)
matches = re.finditer(regex, messages.last_message_content(), re.DOTALL)
dbs.workspace["run.sh"] = "\n".join(match.group(1) for match in matches)
return messages

Expand All @@ -248,7 +247,7 @@ def use_feedback(ai: AI, dbs: DBs) -> Messages:


def fix_code(ai: AI, dbs: DBs) -> Messages:
code_output = json.loads(dbs.logs[gen_code.__name__])[-1]["content"]
code_output = Messages.from_json(dbs.logs[gen_code.__name__]).last_message_content()
messages = Messages(
[
Message(role=Role.SYSTEM, content=setup_sys_prompt(dbs)),
Expand Down
2 changes: 1 addition & 1 deletion scripts/rerun_edited_message_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import typer

from gpt_engineer.ai.ai import AI
from gpt_engineer.ai import AI
from gpt_engineer.chat_to_files import to_files

app = typer.Typer()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from gpt_engineer.ai.ai import AI
from gpt_engineer.ai import AI


@pytest.mark.xfail(reason="Constructor assumes API access")
Expand Down

0 comments on commit a2aa813

Please sign in to comment.