Skip to content

Commit

Permalink
feat(data): add support for default value in instill Go tag (#891)
Browse files Browse the repository at this point in the history
Because

- Users may not always provide required data, we can set default values
for certain fields when not provided.

This commit

- Adds support for default values in the instill Go tag.

Note

- The current tag requires manual maintenance. In the future, io.go will
be auto-generated by compogen based on the tasks.json file, and default
values will also be auto-generated.
  • Loading branch information
donch1989 authored Nov 30, 2024
1 parent af8f412 commit b9c2d05
Show file tree
Hide file tree
Showing 15 changed files with 232 additions and 59 deletions.
8 changes: 4 additions & 4 deletions pkg/component/ai/anthropic/v0/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ package anthropic

type MessagesInput struct {
ChatHistory []ChatMessage `instill:"chat-history"`
MaxNewTokens int `instill:"max-new-tokens"`
MaxNewTokens int `instill:"max-new-tokens,default=50"`
ModelName string `instill:"model-name"`
Prompt string `instill:"prompt"`
PromptImages []string `instill:"prompt-images"`
Seed int `instill:"seed"`
SystemMsg string `instill:"system-message"`
Temperature float32 `instill:"temperature"`
TopK int `instill:"top-k"`
SystemMsg string `instill:"system-message,default=You are a helpful assistant."`
Temperature float32 `instill:"temperature,default=0.7"`
TopK int `instill:"top-k,default=10"`
}

type ChatMessage struct {
Expand Down
11 changes: 6 additions & 5 deletions pkg/component/ai/mistralai/v0/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ type ChatMessage struct {
Role string `instill:"role"`
Content []MultiModalContent `instill:"content"`
}

type URL struct {
URL string `instill:"url"`
}
Expand All @@ -16,15 +17,15 @@ type MultiModalContent struct {

type TextGenerationInput struct {
ChatHistory []ChatMessage `instill:"chat-history"`
MaxNewTokens int `instill:"max-new-tokens"`
MaxNewTokens int `instill:"max-new-tokens,default=50"`
ModelName string `instill:"model-name"`
Prompt string `instill:"prompt"`
PromptImages []string `instill:"prompt-images"`
Seed int `instill:"seed"`
SystemMsg string `instill:"system-message"`
Temperature float64 `instill:"temperature"`
TopK int `instill:"top-k"`
TopP float64 `instill:"top-p"`
SystemMsg string `instill:"system-message,default=You are a helpful assistant."`
Temperature float64 `instill:"temperature,default=0.7"`
TopK int `instill:"top-k,default=10"`
TopP float64 `instill:"top-p,default=0.5"`
Safe bool `instill:"safe"`
}

Expand Down
28 changes: 14 additions & 14 deletions pkg/component/ai/openai/v0/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ type taskTextGenerationInput struct {
ChatHistory []*textMessage `instill:"chat-history"`
Model string `instill:"model"`
SystemMessage *string `instill:"system-message"`
Temperature *float32 `instill:"temperature"`
TopP *float32 `instill:"top-p"`
N *int `instill:"n"`
Temperature *float32 `instill:"temperature,default=1"`
TopP *float32 `instill:"top-p,default=1"`
N *int `instill:"n,default=1"`
Stop *string `instill:"stop"`
MaxTokens *int `instill:"max-tokens"`
PresencePenalty *float32 `instill:"presence-penalty"`
FrequencyPenalty *float32 `instill:"frequency-penalty"`
PresencePenalty *float32 `instill:"presence-penalty,default=0"`
FrequencyPenalty *float32 `instill:"frequency-penalty,default=0"`
ResponseFormat *responseFormatInputStruct `instill:"response-format"`
}

Expand Down Expand Up @@ -55,7 +55,7 @@ type taskSpeechRecognitionInput struct {
Audio format.Audio `instill:"audio"`
Model string `instill:"model"`
Prompt *string `instill:"prompt"`
Temperature *float32 `instill:"temperature"`
Temperature *float32 `instill:"temperature,default=0"`
Language *string `instill:"language"`
}

Expand All @@ -66,10 +66,10 @@ type taskSpeechRecognitionOutput struct {

type taskTextToSpeechInput struct {
Text string `instill:"text"`
Model string `instill:"model"`
Voice string `instill:"voice"`
ResponseFormat *string `instill:"response-format"`
Speed *float32 `instill:"speed"`
Model string `instill:"model,default=tts-1"`
Voice string `instill:"voice,default=alloy"`
ResponseFormat *string `instill:"response-format,default=mp3"`
Speed *float32 `instill:"speed,default=1"`
}

type taskTextToSpeechOutput struct {
Expand All @@ -79,10 +79,10 @@ type taskTextToSpeechOutput struct {
type taskTextToImageInput struct {
Prompt string `instill:"prompt"`
Model string `instill:"model"`
N *int `instill:"n"`
Quality *string `instill:"quality"`
Size *string `instill:"size"`
Style *string `instill:"style"`
N *int `instill:"n,default=1"`
Quality *string `instill:"quality,default=standard"`
Size *string `instill:"size,default=1024x1024"`
Style *string `instill:"style,default=vivid"`
}

type taskTextToImageOutput struct {
Expand Down
14 changes: 7 additions & 7 deletions pkg/component/ai/perplexityai/v0/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ type Content struct {
// Parameter contains the input parameter.
type Parameter struct {
// MaxTokens is the maximum number of tokens to generate.
MaxTokens int `instill:"max-tokens"`
MaxTokens int `instill:"max-tokens,default=50"`
// Temperature is the temperature of the model.
Temperature float64 `instill:"temperature"`
Temperature float64 `instill:"temperature,default=0.2"`
// TopP is the top-p value of the model.
TopP float64 `instill:"top-p"`
TopP float64 `instill:"top-p,default=0.9"`
// Stream is whether to stream the output.
Stream bool `instill:"stream"`
Stream bool `instill:"stream,default=false"`
// SearchDomainFilter gives the list of domains,
// limit the citations used by the online model to URLs from the specified
// domains. Currently limited to only 3 domains for whitelisting and
Expand All @@ -59,17 +59,17 @@ type Parameter struct {
// - does not apply to images. Values include `month`, `week`, `day`, `hour`."
SearchRecencyFilter string `instill:"search-recency-filter"`
// TopK is the top-k value of the model.
TopK int `instill:"top-k"`
TopK int `instill:"top-k,default=0"`
// PresencePenalty is a value between -2.0 and 2.0. Positive values penalize new
// tokens based on whether they appear in the text so far, increasing the
// model's likelihood to talk about new topics. Incompatible with
// `frequency_penalty`.
PresencePenalty float64 `instill:"presence-penalty"`
PresencePenalty float64 `instill:"presence-penalty,default=0"`
// FrequencyPenalty is a multiplicative penalty greater than 0. Values greater
// than 1.0 penalize new tokens based on their existing frequency in the text so
// far, decreasing the model's likelihood to repeat the same line verbatim. A
// value of 1.0 means no penalty. Incompatible with `presence_penalty`.
FrequencyPenalty float64 `instill:"frequency-penalty"`
FrequencyPenalty float64 `instill:"frequency-penalty,default=1"`
}

// TextChatOutput is the output for the TASK_CHAT task.
Expand Down
20 changes: 10 additions & 10 deletions pkg/component/application/github/v0/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ type getPullRequestOutput struct {

type listIssuesInput struct {
RepoInfo
State string `instill:"state"`
Sort string `instill:"sort"`
Direction string `instill:"direction"`
State string `instill:"state,default=open"`
Sort string `instill:"sort,default=created"`
Direction string `instill:"direction,default=desc"`
Since string `instill:"since"`
NoPullRequest bool `instill:"no-pull-request"`
PageOptions
Expand All @@ -81,9 +81,9 @@ type listIssuesOutput struct {
type listPullRequestsInput struct {
RepoInfo
PageOptions
State string `instill:"state"`
Sort string `instill:"sort"`
Direction string `instill:"direction"`
State string `instill:"state,default=open"`
Sort string `instill:"sort,default=created"`
Direction string `instill:"direction,default=desc"`
}

type listPullRequestsOutput struct {
Expand All @@ -92,10 +92,10 @@ type listPullRequestsOutput struct {

type listReviewCommentsInput struct {
RepoInfo
PRNumber int `instill:"pr-number"`
Sort string `instill:"sort"`
Direction string `instill:"direction"`
Since string `instill:"since"`
PRNumber int `instill:"pr-number,default=0"`
Sort string `instill:"sort,default=created"`
Direction string `instill:"direction,default=desc"`
Since string `instill:"since,default=2021-01-01"`
PageOptions
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/component/application/leadiq/v0/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type FindProspectsInput struct {
// Company is the company information to find prospects for.
Company Company `instill:"company"`
// Limit is the maximum number of prospects to return.
Limit int `instill:"limit"`
Limit int `instill:"limit,default=10"`
// FilterBy is the filter to apply to the prospects.
FilterBy FilterBy `instill:"filter-by"`
}
Expand Down
1 change: 1 addition & 0 deletions pkg/component/data/googlesheets/v0/config/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@
"description": "The starting row number to retrieve (1-based index).",
"instillFormat": "number",
"minimum": 1,
"default": 2,
"title": "Start Row",
"type": "integer",
"instillUIOrder": 0
Expand Down
2 changes: 1 addition & 1 deletion pkg/component/data/googlesheets/v0/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ type taskDeleteSpreadsheetColumnOutput struct {
type taskListRowsInput struct {
SharedLink string `instill:"shared-link"`
SheetName string `instill:"sheet-name"`
StartRow *int `instill:"start-row"`
StartRow *int `instill:"start-row,default=2"`
EndRow *int `instill:"end-row"`
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/component/generic/collection/v0/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type intersectionOutput struct {

type splitInput struct {
Data format.Value `instill:"data"`
Size int `instill:"size"`
Size int `instill:"size,default=1"`
}

type splitOutput struct {
Expand Down
4 changes: 2 additions & 2 deletions pkg/component/operator/audio/v0/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ type segmentData struct {

type detectActivityInput struct {
Audio format.Audio `instill:"audio"`
MinSilenceDuration int `instill:"min-silence-duration"`
SpeechPad int `instill:"speech-pad"`
MinSilenceDuration int `instill:"min-silence-duration,default=100"`
SpeechPad int `instill:"speech-pad,default=30"`
}

type detectActivityOutput struct {
Expand Down
8 changes: 4 additions & 4 deletions pkg/component/operator/document/v0/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import "github.com/instill-ai/pipeline-backend/pkg/data/format"

type ConvertDocumentToMarkdownInput struct {
Document format.Document `instill:"document"`
DisplayImageTag bool `instill:"display-image-tag"`
DisplayImageTag bool `instill:"display-image-tag,default=false"`
Filename string `instill:"filename"`
DisplayAllPageImage bool `instill:"display-all-page-image"`
Resolution int `instill:"resolution"`
DisplayAllPageImage bool `instill:"display-all-page-image,default=false"`
Resolution int `instill:"resolution,default=300"`
}

type ConvertDocumentToMarkdownOutput struct {
Expand All @@ -22,7 +22,7 @@ type ConvertDocumentToMarkdownOutput struct {
type ConvertDocumentToImagesInput struct {
Document format.Document `instill:"document"`
Filename string `instill:"filename"`
Resolution int `instill:"resolution"`
Resolution int `instill:"resolution,default=300"`
}

type ConvertDocumentToImagesOutput struct {
Expand Down
85 changes: 85 additions & 0 deletions pkg/data/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"

"github.com/instill-ai/pipeline-backend/pkg/data/format"
Expand Down Expand Up @@ -32,13 +33,15 @@ import (
// FirstName string `instill:"first_name"` // Will use "first_name" as the key
// LastName string // Will use "LastName" as the key
// Avatar format.Image `instill:"photo,image/png"` // Will use "photo" as key and convert to PNG
// Age *int `instill:"age,default=18"` // Will use 18 as default if nil
// }
//
// The format portion of the tag supports:
// - For Image: "image/png", "image/jpeg", etc
// - For Video: "video/mp4", "video/webm", etc
// - For Audio: "audio/mpeg", "audio/wav", etc
// - For Document: "application/pdf", "text/plain", etc
// - For pointers: "default=value" to specify default value when nil

// Marshaler is used to marshal a struct into a Map.
type Marshaler struct {
Expand Down Expand Up @@ -126,6 +129,20 @@ func (u *Unmarshaler) unmarshalStruct(ctx context.Context, m Map, v reflect.Valu
fieldName := u.getFieldName(field)
val, ok := m[fieldName]
if !ok {
// Check for default value if field is nil pointer or zero value
if (fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil()) ||
fieldValue.IsZero() {
tag := field.Tag.Get("instill")
parts := strings.Split(tag, ",")
for _, part := range parts {
if strings.HasPrefix(part, "default=") {
defaultVal := strings.TrimPrefix(part, "default=")
if err := u.setDefaultValue(fieldValue, defaultVal); err != nil {
return fmt.Errorf("error setting default value for field %s: %w", fieldName, err)
}
}
}
}
continue
}

Expand All @@ -136,6 +153,74 @@ func (u *Unmarshaler) unmarshalStruct(ctx context.Context, m Map, v reflect.Valu
return nil
}

// setDefaultValue sets the default value for a nil pointer field
func (u *Unmarshaler) setDefaultValue(field reflect.Value, defaultVal string) error {
// Handle format.Value types first
if field.Type().Implements(reflect.TypeOf((*format.Value)(nil)).Elem()) {
elemType := field.Type()
if elemType == reflect.TypeOf((*format.String)(nil)).Elem() {
field.Set(reflect.ValueOf(NewString(defaultVal)))
return nil
} else if elemType == reflect.TypeOf((*format.Number)(nil)).Elem() {
f, err := strconv.ParseFloat(defaultVal, 64)
if err != nil {
return err
}
field.Set(reflect.ValueOf(NewNumberFromFloat(f)))
return nil
} else if elemType == reflect.TypeOf((*format.Boolean)(nil)).Elem() {
b, err := strconv.ParseBool(defaultVal)
if err != nil {
return err
}
field.Set(reflect.ValueOf(NewBoolean(b)))
return nil
}
return fmt.Errorf("unsupported format.Value type: %v", elemType)
}

// Handle pointer types
if field.Kind() == reflect.Ptr {
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
field = field.Elem()
}

switch field.Kind() {
case reflect.String:
field.SetString(defaultVal)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i, err := strconv.ParseInt(defaultVal, 10, 64)
if err != nil {
return err
}
field.SetInt(i)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
i, err := strconv.ParseUint(defaultVal, 10, 64)
if err != nil {
return err
}
field.SetUint(i)
case reflect.Float32, reflect.Float64:
f, err := strconv.ParseFloat(defaultVal, 64)
if err != nil {
return err
}
field.SetFloat(f)
case reflect.Bool:
b, err := strconv.ParseBool(defaultVal)
if err != nil {
return err
}
field.SetBool(b)
default:
return fmt.Errorf("unsupported default value type: %v", field.Kind())
}

return nil
}

// unmarshalValue dispatches to type-specific unmarshal functions based on the value type.
func (u *Unmarshaler) unmarshalValue(ctx context.Context, val format.Value, field reflect.Value, structField reflect.StructField) error {
switch v := val.(type) {
Expand Down
Loading

0 comments on commit b9c2d05

Please sign in to comment.