-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add system prompt processor (#186)
- Loading branch information
1 parent
847ff84
commit bbabd9a
Showing
6 changed files
with
127 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
<?php | ||
|
||
use PhpLlm\LlmChain\Bridge\OpenAI\GPT; | ||
use PhpLlm\LlmChain\Bridge\OpenAI\PlatformFactory; | ||
use PhpLlm\LlmChain\Chain; | ||
use PhpLlm\LlmChain\Chain\InputProcessor\SystemPromptInputProcessor; | ||
use PhpLlm\LlmChain\Model\Message\Message; | ||
use PhpLlm\LlmChain\Model\Message\MessageBag; | ||
use Symfony\Component\Dotenv\Dotenv; | ||
|
||
require_once dirname(__DIR__).'/vendor/autoload.php'; | ||
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env'); | ||
|
||
if (empty($_ENV['OPENAI_API_KEY'])) { | ||
echo 'Please set the OPENAI_API_KEY environment variable.'.PHP_EOL; | ||
exit(1); | ||
} | ||
|
||
$platform = PlatformFactory::create($_ENV['OPENAI_API_KEY']); | ||
$llm = new GPT(GPT::GPT_4O_MINI); | ||
|
||
$processor = new SystemPromptInputProcessor('You are Yoda and write like he speaks. But short.'); | ||
|
||
$chain = new Chain($platform, $llm, [$processor]); | ||
$messages = new MessageBag(Message::ofUser('What is the meaning of life?')); | ||
$response = $chain->call($messages); | ||
|
||
echo $response->getContent().PHP_EOL; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace PhpLlm\LlmChain\Chain\InputProcessor; | ||
|
||
use PhpLlm\LlmChain\Chain\Input; | ||
use PhpLlm\LlmChain\Chain\InputProcessor; | ||
use PhpLlm\LlmChain\Model\Message\Message; | ||
use Psr\Log\LoggerInterface; | ||
use Psr\Log\NullLogger; | ||
|
||
final readonly class SystemPromptInputProcessor implements InputProcessor | ||
{ | ||
public function __construct( | ||
private string $systemPrompt, | ||
private LoggerInterface $logger = new NullLogger(), | ||
) { | ||
} | ||
|
||
public function processInput(Input $input): void | ||
{ | ||
$messages = $input->messages; | ||
|
||
if (null !== $messages->getSystemMessage()) { | ||
$this->logger->debug('Skipping system prompt injection since MessageBag already contains a system message.'); | ||
|
||
return; | ||
} | ||
|
||
$input->messages = $messages->prepend(Message::forSystem($this->systemPrompt)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
tests/Chain/InputProcessor/SystemPromptInputProcessorTest.php
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace PhpLlm\LlmChain\Tests\Chain\InputProcessor; | ||
|
||
use PhpLlm\LlmChain\Bridge\OpenAI\GPT; | ||
use PhpLlm\LlmChain\Chain\Input; | ||
use PhpLlm\LlmChain\Chain\InputProcessor\SystemPromptInputProcessor; | ||
use PhpLlm\LlmChain\Model\Message\Message; | ||
use PhpLlm\LlmChain\Model\Message\MessageBag; | ||
use PhpLlm\LlmChain\Model\Message\SystemMessage; | ||
use PhpLlm\LlmChain\Model\Message\UserMessage; | ||
use PHPUnit\Framework\Attributes\CoversClass; | ||
use PHPUnit\Framework\Attributes\Small; | ||
use PHPUnit\Framework\Attributes\Test; | ||
use PHPUnit\Framework\Attributes\UsesClass; | ||
use PHPUnit\Framework\TestCase; | ||
|
||
#[CoversClass(SystemPromptInputProcessor::class)] | ||
#[UsesClass(GPT::class)] | ||
#[UsesClass(Message::class)] | ||
#[UsesClass(MessageBag::class)] | ||
#[Small] | ||
final class SystemPromptInputProcessorTest extends TestCase | ||
{ | ||
#[Test] | ||
public function processInputAddsSystemMessageWhenNoneExists(): void | ||
{ | ||
$processor = new SystemPromptInputProcessor('This is a system prompt'); | ||
|
||
$input = new Input(new GPT(), new MessageBag(Message::ofUser('This is a user message')), []); | ||
$processor->processInput($input); | ||
|
||
$messages = $input->messages->getMessages(); | ||
self::assertCount(2, $messages); | ||
self::assertInstanceOf(SystemMessage::class, $messages[0]); | ||
self::assertInstanceOf(UserMessage::class, $messages[1]); | ||
self::assertSame('This is a system prompt', $messages[0]->content); | ||
} | ||
|
||
#[Test] | ||
public function processInputDoesNotAddSystemMessageWhenOneExists(): void | ||
{ | ||
$processor = new SystemPromptInputProcessor('This is a system prompt'); | ||
|
||
$messages = new MessageBag( | ||
Message::forSystem('This is already a system prompt'), | ||
Message::ofUser('This is a user message'), | ||
); | ||
$input = new Input(new GPT(), $messages, []); | ||
$processor->processInput($input); | ||
|
||
$messages = $input->messages->getMessages(); | ||
self::assertCount(2, $messages); | ||
self::assertInstanceOf(SystemMessage::class, $messages[0]); | ||
self::assertInstanceOf(UserMessage::class, $messages[1]); | ||
self::assertSame('This is already a system prompt', $messages[0]->content); | ||
} | ||
} |