Skip to content

Commit

Permalink
Add Snowflake backend plugin (flyteorg#202)
Browse files Browse the repository at this point in the history
* Add snowflake plugin

Signed-off-by: Kevin Su <[email protected]>

* Fix test

Signed-off-by: Kevin Su <[email protected]>

* Fix test

Signed-off-by: Kevin Su <[email protected]>

* Fix test

Signed-off-by: Kevin Su <[email protected]>

* Fix test

Signed-off-by: Kevin Su <[email protected]>

* Remove duplicate code

Signed-off-by: Kevin Su <[email protected]>

* Improve test coverage

Signed-off-by: Kevin Su <[email protected]>

* Add integration tests

Signed-off-by: Kevin Su <[email protected]>

* Improve test coverage

Signed-off-by: Kevin Su <[email protected]>

* Improve test coverage

Signed-off-by: Kevin Su <[email protected]>

* Fix lint and tests

Signed-off-by: Kevin Su <[email protected]>

* update proto

Signed-off-by: Kevin Su <[email protected]>

* remove snowflake proto

Signed-off-by: Kevin Su <[email protected]>

* Update idl version

Signed-off-by: Kevin Su <[email protected]>

* fix test

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Sep 7, 2021
1 parent 5a8e120 commit 298278f
Show file tree
Hide file tree
Showing 8 changed files with 615 additions and 5 deletions.
2 changes: 1 addition & 1 deletion flyteplugins/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.0.0
github.com/aws/aws-sdk-go-v2/service/athena v1.0.0
github.com/coocood/freecache v1.1.1
github.com/flyteorg/flyteidl v0.20.0
github.com/flyteorg/flyteidl v0.20.1
github.com/flyteorg/flytestdlib v0.3.33
github.com/go-logr/zapr v0.4.0 // indirect
github.com/go-test/deep v1.0.7
Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv
github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg=
github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM=
github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94=
github.com/flyteorg/flyteidl v0.20.0 h1:g5xGayFfPSzFJxJedgL390WFSEbGYjFiPey+NXAB030=
github.com/flyteorg/flyteidl v0.20.0/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/flyteorg/flyteidl v0.20.1 h1:S+jJBmtRtzUcLNAgXJNsVza/6SRS/cmtKu/zAUvC6+U=
github.com/flyteorg/flyteidl v0.20.1/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220=
github.com/flyteorg/flytestdlib v0.3.33 h1:+oCx3zXUIldL7CWmNMD7PMFPXvGqaPgYkSKn9wB6qvY=
github.com/flyteorg/flytestdlib v0.3.33/go.mod h1:7cDWkY3v7xsoesFcDdu6DSW5Q2U2W5KlHUbUHSwBG1Q=
Expand Down
71 changes: 71 additions & 0 deletions flyteplugins/go/tasks/plugins/webapi/snowflake/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package snowflake

import (
"time"

pluginsConfig "github.com/flyteorg/flyteplugins/go/tasks/config"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi"
"github.com/flyteorg/flytestdlib/config"
)

var (
defaultConfig = Config{
WebAPI: webapi.PluginConfig{
ResourceQuotas: map[core.ResourceNamespace]int{
"default": 1000,
},
ReadRateLimiter: webapi.RateLimiterConfig{
Burst: 100,
QPS: 10,
},
WriteRateLimiter: webapi.RateLimiterConfig{
Burst: 100,
QPS: 10,
},
Caching: webapi.CachingConfig{
Size: 500000,
ResyncInterval: config.Duration{Duration: 30 * time.Second},
Workers: 10,
MaxSystemFailures: 5,
},
ResourceMeta: nil,
},
ResourceConstraints: core.ResourceConstraintsSpec{
ProjectScopeResourceConstraint: &core.ResourceConstraint{
Value: 100,
},
NamespaceScopeResourceConstraint: &core.ResourceConstraint{
Value: 50,
},
},
DefaultWarehouse: "COMPUTE_WH",
TokenKey: "FLYTE_SNOWFLAKE_CLIENT_TOKEN",
}

configSection = pluginsConfig.MustRegisterSubSection("snowflake", &defaultConfig)
)

// Config is config for 'snowflake' plugin
type Config struct {
// WeCreateTaskInfobAPI defines config for the base WebAPI plugin
WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."`

// ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time
ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."`

DefaultWarehouse string `json:"defaultWarehouse" pflag:",Defines the default warehouse to use when running on Snowflake unless overwritten by the task."`

TokenKey string `json:"snowflakeTokenKey" pflag:",Name of the key where to find Snowflake token in the secret manager."`

// snowflakeEndpoint overrides Snowflake client endpoint, only for testing
snowflakeEndpoint string
}

func GetConfig() *Config {
return configSection.GetConfig().(*Config)
}

func SetConfig(cfg *Config) error {
return configSection.SetConfig(cfg)
}
18 changes: 18 additions & 0 deletions flyteplugins/go/tasks/plugins/webapi/snowflake/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package snowflake

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestGetAndSetConfig(t *testing.T) {
cfg := defaultConfig
cfg.DefaultWarehouse = "test-warehouse"
cfg.WebAPI.Caching.Workers = 1
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second
err := SetConfig(&cfg)
assert.NoError(t, err)
assert.Equal(t, &cfg, GetConfig())
}
107 changes: 107 additions & 0 deletions flyteplugins/go/tasks/plugins/webapi/snowflake/integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package snowflake

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/flyteorg/flyteidl/clients/go/coreutils"
coreIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
pluginCoreMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/tests"
"github.com/flyteorg/flytestdlib/contextutils"
"github.com/flyteorg/flytestdlib/promutils"
"github.com/flyteorg/flytestdlib/promutils/labeled"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func TestEndToEnd(t *testing.T) {
server := newFakeSnowflakeServer()
defer server.Close()

iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error {
return nil
}

cfg := defaultConfig
cfg.snowflakeEndpoint = server.URL
cfg.DefaultWarehouse = "test-warehouse"
cfg.WebAPI.Caching.Workers = 1
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second
err := SetConfig(&cfg)
assert.NoError(t, err)

pluginEntry := pluginmachinery.CreateRemotePlugin(newSnowflakeJobTaskPlugin())
plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext())
assert.NoError(t, err)

t.Run("SELECT 1", func(t *testing.T) {
config := make(map[string]string)
config["database"] = "my-database"
config["account"] = "snowflake"
config["schema"] = "my-schema"
config["warehouse"] = "my-warehouse"

inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1})
template := flyteIdlCore.TaskTemplate{
Type: "snowflake",
Config: config,
Target: &coreIdl.TaskTemplate_Sql{Sql: &coreIdl.Sql{Statement: "SELECT 1", Dialect: coreIdl.Sql_ANSI}},
}

phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter)

assert.Equal(t, true, phase.Phase().IsSuccess())
})
}

func newFakeSnowflakeServer() *httptest.Server {
statementHandle := "019e7546-0000-278c-0000-40f10001a082"
return httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if request.URL.Path == "/api/statements" && request.Method == "POST" {
writer.WriteHeader(202)
bytes := []byte(fmt.Sprintf(`{
"statementHandle": "%v",
"message": "Asynchronous execution in progress."
}`, statementHandle))
_, _ = writer.Write(bytes)
return
}

if request.URL.Path == "/api/statements/"+statementHandle && request.Method == "GET" {
writer.WriteHeader(200)
bytes := []byte(fmt.Sprintf(`{
"statementHandle": "%v",
"message": "Statement executed successfully."
}`, statementHandle))
_, _ = writer.Write(bytes)
return
}

if request.URL.Path == "/api/statements/"+statementHandle+"/cancel" && request.Method == "POST" {
writer.WriteHeader(200)
return
}

writer.WriteHeader(500)
}))
}

func newFakeSetupContext() *pluginCoreMocks.SetupContext {
fakeResourceRegistrar := pluginCoreMocks.ResourceRegistrar{}
fakeResourceRegistrar.On("RegisterResourceQuota", mock.Anything, mock.Anything, mock.Anything).Return(nil)
labeled.SetMetricKeys(contextutils.NamespaceKey)

fakeSetupContext := pluginCoreMocks.SetupContext{}
fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test"))
fakeSetupContext.OnResourceRegistrar().Return(&fakeResourceRegistrar)

return &fakeSetupContext
}
Loading

0 comments on commit 298278f

Please sign in to comment.