Skip to content

Commit

Permalink
feat: support image uploads via local path or URL
Browse files Browse the repository at this point in the history
  • Loading branch information
kardolus committed Jan 1, 2025
1 parent fde40af commit 5c6bd63
Show file tree
Hide file tree
Showing 9 changed files with 424 additions and 44 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ Azure, featuring streaming capabilities and extensive configuration options.
- [Persistent Autocompletion](#persistent-autocompletion)
- [Markdown Rendering](#markdown-rendering)
- [Development](#development)
- [Using the Makefile](#using-the-makefile)
- [Testing the CLI](#testing-the-cli)
- [Using the Makefile](#using-the-makefile)
- [Testing the CLI](#testing-the-cli)
- [Reporting Issues and Contributing](#reporting-issues-and-contributing)
- [Uninstallation](#uninstallation)
- [Useful Links](#useful-links)
Expand All @@ -62,6 +62,8 @@ Azure, featuring streaming capabilities and extensive configuration options.
* **Custom context from any source**: You can provide the GPT model with a custom context during conversation. This
context can be piped in from any source, such as local files, standard input, or even another program. This
flexibility allows the model to adapt to a wide range of conversational scenarios.
* **Support for images**: Upload an image or provide an image URL using the `--image` flag. Note that image support may
not be available for all models.
* **Model listing**: Access a list of available models using the `-l` or `--list-models` flag.
* **Thread listing**: Display a list of active threads using the `--list-threads` flag.
* **Advanced configuration options**: The CLI supports a layered configuration system where settings can be specified
Expand Down Expand Up @@ -503,6 +505,7 @@ To start developing, set the `OPENAI_API_KEY` environment variable to
your [ChatGPT secret key](https://platform.openai.com/account/api-keys).

### Using the Makefile

The Makefile simplifies development tasks by providing several targets for testing, building, and deployment.

* **all-tests**: Run all tests, including linting, formatting, and go mod tidy.
Expand All @@ -523,6 +526,7 @@ The Makefile simplifies development tasks by providing several targets for testi
```

For more available commands, use:

```shell
make help
```
Expand All @@ -534,6 +538,7 @@ make help
```

### Testing the CLI

1. After a successful build, test the application with the following command:

```shell
Expand Down
172 changes: 144 additions & 28 deletions api/client/client.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
package client

import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/google/uuid"
"github.com/kardolus/chatgpt-cli/api"
"github.com/kardolus/chatgpt-cli/api/http"
"github.com/kardolus/chatgpt-cli/config"
"net/url"
"os"
"strings"
"time"
"unicode/utf8"

"github.com/kardolus/chatgpt-cli/history"
stdhttp "net/http"
)

const (
Expand All @@ -24,6 +28,11 @@ const (
InteractiveThreadPrefix = "int_"
gptPrefix = "gpt"
o1Prefix = "o1"
imageURLType = "image_url"
imageContent = "data:%s;base64,%s"
httpScheme = "http"
httpsScheme = "https"
bufferSize = 512
)

type Timer interface {
Expand All @@ -37,15 +46,39 @@ func (r *RealTime) Now() time.Time {
return time.Now()
}

type ImageReader interface {
ReadFile(name string) ([]byte, error)
ReadBufferFromFile(file *os.File) ([]byte, error)
Open(name string) (*os.File, error)
}

type RealImageReader struct{}

func (r *RealImageReader) Open(name string) (*os.File, error) {
return os.Open(name)
}

func (r *RealImageReader) ReadFile(name string) ([]byte, error) {
return os.ReadFile(name)
}

func (r *RealImageReader) ReadBufferFromFile(file *os.File) ([]byte, error) {
buffer := make([]byte, bufferSize)
_, err := file.Read(buffer)

return buffer, err
}

type Client struct {
Config config.Config
History []history.History
caller http.Caller
historyStore history.Store
timer Timer
reader ImageReader
}

func New(callerFactory http.CallerFactory, hs history.Store, t Timer, cfg config.Config, interactiveMode bool) *Client {
func New(callerFactory http.CallerFactory, hs history.Store, t Timer, r ImageReader, cfg config.Config, interactiveMode bool) *Client {
caller := callerFactory(cfg)

if interactiveMode && cfg.AutoCreateNewThread {
Expand All @@ -59,6 +92,7 @@ func New(callerFactory http.CallerFactory, hs history.Store, t Timer, cfg config
caller: caller,
historyStore: hs,
timer: t,
reader: r,
}
}

Expand Down Expand Up @@ -163,9 +197,9 @@ func (c *Client) Query(input string) (string, int, error) {
return "", response.Usage.TotalTokens, errors.New("no responses returned")
}

c.updateHistory(response.Choices[0].Message.Content)
c.updateHistory(response.Choices[0].Message.Content.(string))

return response.Choices[0].Message.Content, response.Usage.TotalTokens, nil
return response.Choices[0].Message.Content.(string), response.Usage.TotalTokens, nil
}

// Stream sends a query to the API and processes the response as a stream.
Expand Down Expand Up @@ -219,6 +253,45 @@ func (c *Client) createBody(stream bool) ([]byte, error) {
Stream: stream,
}

if c.Config.Image != "" {
var content api.ImageContent

if isValidURL(c.Config.Image) {
content = api.ImageContent{
Type: imageURLType,
ImageURL: struct {
URL string `json:"url"`
}{
URL: c.Config.Image,
},
}
} else {
mime, err := c.getMimeTypeFromFileContent(c.Config.Image)
if err != nil {
return nil, err
}

image, err := c.base64EncodeImage(c.Config.Image)
if err != nil {
return nil, err
}

content = api.ImageContent{
Type: imageURLType,
ImageURL: struct {
URL string `json:"url"`
}{
URL: fmt.Sprintf(imageContent, mime, image),
},
}
}

body.Messages = append(body.Messages, api.Message{
Role: UserRole,
Content: []api.ImageContent{content},
})
}

return json.Marshal(body)
}

Expand Down Expand Up @@ -314,33 +387,13 @@ func (c *Client) updateHistory(response string) {
}
}

func calculateEffectiveContextWindow(window int, bufferPercentage int) int {
adjustedPercentage := 100 - bufferPercentage
effectiveContextWindow := (window * adjustedPercentage) / 100
return effectiveContextWindow
}

func countTokens(entries []history.History) (int, []int) {
var result int
var rolling []int

for _, entry := range entries {
charCount, wordCount := 0, 0
words := strings.Fields(entry.Content)
wordCount += len(words)

for _, word := range words {
charCount += utf8.RuneCountInString(word)
}

// This is a simple approximation; actual token count may differ.
// You can adjust this based on your language and the specific tokenizer used by the model.
tokenCountForMessage := (charCount + wordCount) / 2
result += tokenCountForMessage
rolling = append(rolling, tokenCountForMessage)
func (c *Client) base64EncodeImage(path string) (string, error) {
imageData, err := c.reader.ReadFile(path)
if err != nil {
return "", err
}

return result, rolling
return base64.StdEncoding.EncodeToString(imageData), nil
}

func (c *Client) createHistoryEntriesFromString(input string) []history.History {
Expand Down Expand Up @@ -369,6 +422,23 @@ func (c *Client) createHistoryEntriesFromString(input string) []history.History
return result
}

func (c *Client) getMimeTypeFromFileContent(path string) (string, error) {
file, err := c.reader.Open(path)
if err != nil {
return "", err
}
defer file.Close()

buffer, err := c.reader.ReadBufferFromFile(file)
if err != nil {
return "", err
}

mimeType := stdhttp.DetectContentType(buffer)

return mimeType, nil
}

func (c *Client) printRequestDebugInfo(endpoint string, body []byte) {
fmt.Printf("\nGenerated cURL command:\n\n")
method := "POST"
Expand All @@ -395,3 +465,49 @@ func GenerateUniqueSlug(prefix string) string {
guid := uuid.New()
return prefix + guid.String()[:4]
}

func calculateEffectiveContextWindow(window int, bufferPercentage int) int {
adjustedPercentage := 100 - bufferPercentage
effectiveContextWindow := (window * adjustedPercentage) / 100
return effectiveContextWindow
}

func countTokens(entries []history.History) (int, []int) {
var result int
var rolling []int

for _, entry := range entries {
charCount, wordCount := 0, 0
words := strings.Fields(entry.Content.(string))
wordCount += len(words)

for _, word := range words {
charCount += utf8.RuneCountInString(word)
}

// This is a simple approximation; actual token count may differ.
// You can adjust this based on your language and the specific tokenizer used by the model.
tokenCountForMessage := (charCount + wordCount) / 2
result += tokenCountForMessage
rolling = append(rolling, tokenCountForMessage)
}

return result, rolling
}

func isValidURL(input string) bool {
parsedURL, err := url.ParseRequestURI(input)
if err != nil {
return false
}

// Ensure that the URL has a valid scheme
schemes := []string{httpScheme, httpsScheme}
for _, scheme := range schemes {
if strings.HasPrefix(parsedURL.Scheme, scheme) {
return true
}
}

return false
}
Loading

0 comments on commit 5c6bd63

Please sign in to comment.