diff --git a/examples/delete_model/main.go b/examples/delete_models/main.go similarity index 94% rename from examples/delete_model/main.go rename to examples/delete_models/main.go index 0fb0566..7d9bcb9 100644 --- a/examples/delete_model/main.go +++ b/examples/delete_models/main.go @@ -34,7 +34,7 @@ func run(opts *Options) error { if err != nil { return err } - rsp, err := client.DeleteModel(opts.Database, opts.Engine, opts.Model) + rsp, err := client.DeleteModels(opts.Database, opts.Engine, []string{opts.Model}) if err != nil { return err } diff --git a/examples/get_model/main.go b/examples/get_model/main.go index 5f5fcc8..14cd342 100644 --- a/examples/get_model/main.go +++ b/examples/get_model/main.go @@ -34,7 +34,7 @@ func run(opts *Options) error { if err != nil { return err } - rsp, err := client.ListModels(opts.Database, opts.Engine) + rsp, err := client.GetModel(opts.Database, opts.Engine, opts.Model) if err != nil { return err } diff --git a/examples/list_model_names/main.go b/examples/list_model_names/main.go deleted file mode 100644 index 72127fb..0000000 --- a/examples/list_model_names/main.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2022 RelationalAI, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "log" - "os" - - "github.com/jessevdk/go-flags" - "github.com/relationalai/rai-sdk-go/rai" -) - -type Options struct { - Database string `short:"d" long:"database" required:"true" description:"database name"` - Engine string `short:"e" long:"engine" required:"true" description:"engine name"` - Profile string `long:"profile" default:"default" description:"config profile"` -} - -func run(opts *Options) error { - client, err := rai.NewClientFromConfig(opts.Profile) - if err != nil { - return err - } - rsp, err := client.ListModelNames(opts.Database, opts.Engine) - if err != nil { - return err - } - rai.ShowJSON(rsp, 4) - return nil -} - -func main() { - var opts Options - if _, err := flags.ParseArgs(&opts, os.Args); err != nil { - os.Exit(1) - } - if err := run(&opts); err != nil { - log.Fatal(err) - } -} diff --git a/examples/load_model/main.go b/examples/load_models/main.go similarity index 88% rename from examples/load_model/main.go rename to examples/load_models/main.go index 8451f75..3042e07 100644 --- a/examples/load_model/main.go +++ b/examples/load_models/main.go @@ -15,6 +15,7 @@ package main import ( + "fmt" "log" "os" "path/filepath" @@ -41,16 +42,22 @@ func run(opts *Options) error { if err != nil { return err } - r, err := os.Open(opts.File) + + value, err := os.ReadFile(opts.File) if err != nil { return err } - name := sansext(opts.File) - rsp, err := client.LoadModel(opts.Database, opts.Engine, name, r) + + model := map[string]string{ + sansext(opts.File): string(value), + } + + rsp, err := client.LoadModels("hnr-db", "hnr-engine", model) if err != nil { - return err + return nil } - rsp.Show() + + fmt.Println(rsp) return nil } diff --git a/rai/client.go b/rai/client.go index 6c6035f..002fb33 100644 --- a/rai/client.go +++ b/rai/client.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "io/ioutil" + "math/rand" "mime" "mime/multipart" "net/http" @@ -669,110 +670,115 @@ func (c *Client) ListOAuthClients() ([]OAuthClient, error) { // Models // -func (c *Client) DeleteModel( - database, engine, name string, -) (*TransactionResult, error) { - return c.DeleteModels(database, engine, []string{name}) +func (c *Client) LoadModels( + database, engine string, models map[string]string, +) (*TransactionResponse, error) { + randUint := rand.Uint32() + queries := make([]string, 0) + queryInputs := make(map[string]string) + + index := 0 + for name, value := range models { + index++ + inputName := fmt.Sprintf("input_%d_%d", randUint, index) + queries = append(queries, + fmt.Sprintf(` + def delete:rel:catalog:model["%s"] = rel:catalog:model["%s"] + def insert:rel:catalog:model["%s"] = %s + `, name, name, name, inputName, + ), + ) + queryInputs[inputName] = value + } + + return c.Execute(database, engine, strings.Join(queries, "\n"), queryInputs, false) } -func (c *Client) DeleteModels( - database, engine string, models []string, -) (*TransactionResult, error) { - var result TransactionResult - tx := TransactionV1{ - Region: c.Region, - Database: database, - Engine: engine, - Mode: "OPEN", - Readonly: false} - data := tx.Payload(makeDeleteModelsAction(models)) - err := c.Post(PathTransaction, tx.QueryArgs(), data, &result) - if err != nil { - return nil, err +func (c *Client) LoadModelsAsync( + database, engine string, models map[string]string, +) (*TransactionResponse, error) { + randUint := rand.Uint32() + queries := make([]string, 0) + queryInputs := make(map[string]string) + + index := 0 + for name, value := range models { + inputName := fmt.Sprintf("input_%d_%d", randUint, index) + queries = append(queries, + fmt.Sprintf(` + def delete:rel:catalog:model["%s"] = rel:catalog:model["%s"] + def insert:rel:catallog:model["%s"] = %s + `, name, name, name, inputName, + ), + ) + queryInputs[inputName] = value + index++ } - return &result, err + + return c.ExecuteAsync(database, engine, strings.Join(queries, "\n"), queryInputs, false) } -func (c *Client) GetModel(database, engine, model string) (*Model, error) { - var result listModelsResponse - tx := NewTransaction(c.Region, database, engine, "OPEN") - data := tx.Payload(makeListModelsAction()) - err := c.Post(PathTransaction, tx.QueryArgs(), data, &result) +// Returns a list of model names for the given database. +func (c *Client) ListModels(database, engine string) ([]string, error) { + outName := fmt.Sprintf("models_%d", rand.Uint32()) + query := fmt.Sprintf("def output:%s[name] = rel:catalog:model(name, _)", outName) + resp, err := c.Execute(database, engine, query, nil, true) if err != nil { return nil, err } - // assert len(result.Actions) == 1 - for _, item := range result.Actions[0].Result.Models { - if item.Name == model { - return &item, nil + + models := make([]string, 0) + + rc := resp.Relations("output", outName) + if len(rc) > 0 { + c := rc.Union().Column(2) + for i := 0; i < c.NumRows(); i++ { + models = append(models, c.String(i)) } + return models, nil } - return nil, ErrNotFound -} -func (c *Client) LoadModel( - database, engine, name string, r io.Reader, -) (*TransactionResult, error) { - return c.LoadModels(database, engine, map[string]io.Reader{name: r}) + return nil, errors.Errorf("output:%s relation is empty", outName) } -func (c *Client) LoadModels( - database, engine string, models map[string]io.Reader, -) (*TransactionResult, error) { - var result TransactionResult - tx := TransactionV1{ - Region: c.Region, - Database: database, - Engine: engine, - Mode: "OPEN", - Readonly: false} - actions := []DbAction{} - for name, r := range models { - model, err := ioutil.ReadAll(r) - if err != nil { - return nil, err - } - action := makeLoadModelAction(name, string(model)) - actions = append(actions, action) - } - data := tx.Payload(actions...) - err := c.Post(PathTransaction, tx.QueryArgs(), data, &result) +func (c *Client) GetModel(database, engine, model string) (*Model, error) { + outName := fmt.Sprintf("model_%d", rand.Uint32()) + query := fmt.Sprintf(`def output:%s = rel:catalog:model["%s"]`, outName, model) + resp, err := c.Execute(database, engine, query, nil, true) + if err != nil { return nil, err } - return &result, nil -} -// Returns a list of model names for the given database. -func (c *Client) ListModelNames(database, engine string) ([]string, error) { - var models listModelsResponse - tx := NewTransaction(c.Region, database, engine, "OPEN") - data := tx.Payload(makeListModelsAction()) - err := c.Post(PathTransaction, tx.QueryArgs(), data, &models) - if err != nil { - return nil, err + rc := resp.Relations("output", outName) + if len(rc) > 0 { + value := rc.Union().Column(2).String(0) + return &Model{model, value}, nil } - actions := models.Actions - // assert len(actions) == 1 - result := []string{} - for _, model := range actions[0].Result.Models { - result = append(result, model.Name) + + return nil, ErrNotFound +} + +func (c *Client) DeleteModels( + database, engine string, models []string, +) (*TransactionResponse, error) { + queries := make([]string, 0) + for _, model := range models { + queries = append(queries, fmt.Sprintf(`def delete:rel:catalog:model["%s"] = rel:catalog:model["%s"]`, model, model)) } - return result, nil + + return c.Execute(database, engine, strings.Join(queries, "\n"), nil, false) } -// Returns the names of models installed in the given database. -func (c *Client) ListModels(database, engine string) ([]Model, error) { - var models listModelsResponse - tx := NewTransaction(c.Region, database, engine, "OPEN") - data := tx.Payload(makeListModelsAction()) - err := c.Post(PathTransaction, tx.QueryArgs(), data, &models) - if err != nil { - return nil, err +func (c *Client) DeleteModelsAsync( + database, engine string, models []string, +) (*TransactionResponse, error) { + queries := make([]string, 0) + for _, model := range models { + queries = append(queries, fmt.Sprintf(`def delete:rel:catalog:model["%s"] = rel:catalog:model["%s"]`, model, model)) } - actions := models.Actions - // assert len(actions) == 1 - return actions[0].Result.Models, nil + + return c.ExecuteAsync(database, engine, strings.Join(queries, "\n"), nil, false) } // diff --git a/rai/client_test.go b/rai/client_test.go index 0d63d9c..51fd6ae 100644 --- a/rai/client_test.go +++ b/rai/client_test.go @@ -49,15 +49,6 @@ func findEDB(edbs []EDB, name string) *EDB { return nil } -func findModel(models []Model, name string) *Model { - for _, model := range models { - if model.Name == name { - return &model - } - } - return nil -} - // Test database management APIs. func TestDatabase(t *testing.T) { client := test.client @@ -104,19 +95,17 @@ func TestDatabase(t *testing.T) { edb := findEDB(edbs, "rel") assert.NotNil(t, edb) - modelNames, err := client.ListModelNames(test.databaseName, test.engineName) + modelNames, err := client.ListModels(test.databaseName, test.engineName) assert.Nil(t, err) assert.True(t, len(modelNames) > 0) assert.True(t, contains(modelNames, "rel/stdlib")) - models, err := client.ListModels(test.databaseName, test.engineName) + modelNames, err = client.ListModels(test.databaseName, test.engineName) assert.Nil(t, err) - assert.True(t, len(models) > 0) - model := findModel(models, "rel/stdlib") - assert.NotNil(t, model) - assert.True(t, len(model.Value) > 0) + assert.True(t, len(modelNames) > 0) + assert.True(t, contains(modelNames, "rel/stdlib")) - model, err = client.GetModel(test.databaseName, test.engineName, "rel/stdlib") + model, err := client.GetModel(test.databaseName, test.engineName, "rel/stdlib") assert.Nil(t, err) assert.NotNil(t, model) assert.True(t, len(model.Value) > 0) @@ -515,44 +504,33 @@ func TestLoadJSON(t *testing.T) { func TestModels(t *testing.T) { client := test.client - const testModel = "def R = \"hello\", \"world\"" + testModel := map[string]string{"test_model": "def R = \"hello\", \"world\""} - r := strings.NewReader(testModel) - rsp, err := client.LoadModel(test.databaseName, test.engineName, "test_model", r) + rsp, err := client.LoadModels(test.databaseName, test.engineName, testModel) assert.Nil(t, err) - assert.Equal(t, false, rsp.Aborted) - assert.Equal(t, 0, len(rsp.Output)) + assert.Equal(t, TransactionState("COMPLETED"), rsp.Transaction.State) assert.Equal(t, 0, len(rsp.Problems)) model, err := client.GetModel(test.databaseName, test.engineName, "test_model") assert.Nil(t, err) + assert.Equal(t, testModel["test_model"], model.Value) assert.Equal(t, "test_model", model.Name) - modelNames, err := client.ListModelNames(test.databaseName, test.engineName) + modelNames, err := client.ListModels(test.databaseName, test.engineName) assert.Nil(t, err) assert.True(t, contains(modelNames, "test_model")) - models, err := client.ListModels(test.databaseName, test.engineName) + deleteResp, err := client.DeleteModels(test.databaseName, test.engineName, []string{"test_model"}) assert.Nil(t, err) - model = findModel(models, "test_model") - assert.NotNil(t, model) - - rsp, err = client.DeleteModel(test.databaseName, test.engineName, "test_model") - assert.Equal(t, false, rsp.Aborted) - assert.Equal(t, 0, len(rsp.Output)) - assert.Equal(t, 0, len(rsp.Problems)) + assert.Equal(t, TransactionState("COMPLETED"), deleteResp.Transaction.State) + assert.Equal(t, 0, len(deleteResp.Problems)) _, err = client.GetModel(test.databaseName, test.engineName, "test_model") assert.True(t, isErrNotFound(err)) - modelNames, err = client.ListModelNames(test.databaseName, test.engineName) + modelNames, err = client.ListModels(test.databaseName, test.engineName) assert.Nil(t, err) assert.False(t, contains(modelNames, "test_model")) - - models, err = client.ListModels(test.databaseName, test.engineName) - assert.Nil(t, err) - model = findModel(models, "test_model") - assert.Nil(t, model) } func findOAuthClient(clients []OAuthClient, id string) *OAuthClient {