Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move models actions to v2 protocol #35

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion examples/get_model/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func run(opts *Options) error {
if err != nil {
return err
}
rsp, err := client.ListModels(opts.Database, opts.Engine)
rsp, err := client.GetModel(opts.Database, opts.Engine, opts.Model)
if err != nil {
return err
}
Expand Down
135 changes: 77 additions & 58 deletions rai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -885,108 +885,127 @@ func (c *Client) ListOAuthClients() ([]OAuthClient, error) {

func (c *Client) DeleteModel(
database, engine, name string,
) (*TransactionResult, error) {
) (*TransactionAsyncResult, error) {
return c.DeleteModels(database, engine, []string{name})
}

func (c *Client) DeleteModelAsync(
database, engine, name string,
) (*TransactionAsyncResult, error) {
return c.DeleteModelsAsync(database, engine, []string{name})
}

func (c *Client) DeleteModels(
database, engine string, models []string,
) (*TransactionResult, error) {
var result TransactionResult
tx := Transaction{
Region: c.Region,
Database: database,
Engine: engine,
Mode: "OPEN",
Readonly: false}
data := tx.Payload(makeDeleteModelsAction(models))
err := c.Post(PathTransaction, tx.QueryArgs(), data, &result)
if err != nil {
return nil, err
) (*TransactionAsyncResult, error) {
queries := make([]string, 0)
for _, model := range models {
queries = append(queries, fmt.Sprintf("def delete:rel:catalog:model[\"%s\"] = rel:catalog:model[\"%s\"]", model, model))
}
return &result, err

return c.Execute(database, engine, strings.Join(queries, "\n"), nil, false)
}

func (c *Client) DeleteModelsAsync(
database, engine string, models []string,
) (*TransactionAsyncResult, error) {
queries := make([]string, 0)
for _, model := range models {
queries = append(queries, fmt.Sprintf("def delete:rel:catalog:model[\"%s\"] = rel:catalog:model[\"%s\"]", model, model))
}

return c.ExecuteAsync(database, engine, strings.Join(queries, "\n"), nil, false)
}

func (c *Client) GetModel(database, engine, model string) (*Model, error) {
var result listModelsResponse
tx := NewTransaction(c.Region, database, engine, "OPEN")
data := tx.Payload(makeListModelsAction())
err := c.Post(PathTransaction, tx.QueryArgs(), data, &result)
resp, err := c.Execute(database, engine, fmt.Sprintf("def output:__model__ = rel:catalog:model[\"%s\"]", model), nil, true)
if err != nil {
return nil, err
}
// assert len(result.Actions) == 1
for _, item := range result.Actions[0].Result.Models {
if item.Name == model {
return &item, nil

for _, res := range resp.Results {
if strings.Contains(res.RelationID, "/:output/:__model__") {
return &Model{model, fmt.Sprintf("%v", res.Table[0][0])}, nil
}
}

return nil, ErrNotFound
}

func (c *Client) LoadModel(
database, engine, name string, r io.Reader,
) (*TransactionResult, error) {
) (*TransactionAsyncResult, error) {
return c.LoadModels(database, engine, map[string]io.Reader{name: r})
}

func (c *Client) LoadModels(
database, engine string, models map[string]io.Reader,
) (*TransactionResult, error) {
var result TransactionResult
tx := Transaction{
Region: c.Region,
Database: database,
Engine: engine,
Mode: "OPEN",
Readonly: false}
actions := []DbAction{}
) (*TransactionAsyncResult, error) {
queries := make([]string, 0)
for name, r := range models {
model, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
action := makeLoadModelAction(name, string(model))
actions = append(actions, action)

queries = append(queries, fmt.Sprintf("def insert:rel:catalog:model[\"%s\"] = \"\"\"%s\"\"\"", name, model))
}
data := tx.Payload(actions...)
err := c.Post(PathTransaction, tx.QueryArgs(), data, &result)
if err != nil {
return nil, err

return c.Execute(database, engine, strings.Join(queries, "\n"), nil, false)
}

func (c *Client) LoadModelsAsync(
database, engine string, models map[string]io.Reader,
) (*TransactionAsyncResult, error) {
queries := make([]string, 0)
for name, r := range models {
model, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}

queries = append(queries, fmt.Sprintf("def insert:rel:catalog:model[\"%s\"] = \"\"\"%s\"\"\"", name, model))
}
return &result, nil

return c.ExecuteAsync(database, engine, strings.Join(queries, "\n"), nil, false)
}

// Returns a list of model names for the given database.
func (c *Client) ListModelNames(database, engine string) ([]string, error) {
var models listModelsResponse
tx := NewTransaction(c.Region, database, engine, "OPEN")
data := tx.Payload(makeListModelsAction())
err := c.Post(PathTransaction, tx.QueryArgs(), data, &models)
modelNames := make([]string, 0)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar issue here. We shouldn't query out the models and all their source code just to return the names. We can write a query to only return the names.

resp, err := c.Execute(database, engine, "def output:__models__[name] = rel:catalog:model(name, _)", nil, true)
if err != nil {
return nil, err
return modelNames, err
}
actions := models.Actions
// assert len(actions) == 1
result := []string{}
for _, model := range actions[0].Result.Models {
result = append(result, model.Name)

for _, res := range resp.Results {
if strings.Contains(res.RelationID, "/:output/:__models__") {
Copy link

@larf311 larf311 Sep 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we using weird relation names like __models__?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will update this to /:output/:models

for i := 0; i < len(res.Table[0]); i++ {
modelNames = append(modelNames, fmt.Sprintf("%v", res.Table[0][i]))
}
}
}
return result, nil

return modelNames, nil
}

// Returns the names of models installed in the given database.
func (c *Client) ListModels(database, engine string) ([]Model, error) {
var models listModelsResponse
tx := NewTransaction(c.Region, database, engine, "OPEN")
data := tx.Payload(makeListModelsAction())
err := c.Post(PathTransaction, tx.QueryArgs(), data, &models)
models := make([]Model, 0)
resp, err := c.Execute(database, engine, "def output:__models__ = rel:catalog:model", nil, true)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, why are we qualifying the output relation here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh this is an old commit, we no longer query for all models (including name and source)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we qualify output relation to prevent accidental collisions between output results

if err != nil {
return nil, err
return models, err
}

for _, res := range resp.Results {
if strings.Contains(res.RelationID, "/:output/:__models__") {
for i := 0; i < len(res.Table[0]); i++ {
models = append(models, Model{fmt.Sprintf("%v", res.Table[0][i]), fmt.Sprintf("%v", res.Table[1][i])})
}
}
}
actions := models.Actions
// assert len(actions) == 1
return actions[0].Result.Models, nil

return models, nil
}

//
Expand Down
10 changes: 4 additions & 6 deletions rai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,8 +672,7 @@ func TestModels(t *testing.T) {
r := strings.NewReader(testModel)
rsp, err := client.LoadModel(databaseName, engineName, "test_model", r)
assert.Nil(t, err)
assert.Equal(t, false, rsp.Aborted)
assert.Equal(t, 0, len(rsp.Output))
assert.Equal(t, "COMPLETED", rsp.Transaction.State)
assert.Equal(t, 0, len(rsp.Problems))

model, err := client.GetModel(databaseName, engineName, "test_model")
Expand All @@ -689,10 +688,9 @@ func TestModels(t *testing.T) {
model = findModel(models, "test_model")
assert.NotNil(t, model)

rsp, err = client.DeleteModel(databaseName, engineName, "test_model")
assert.Equal(t, false, rsp.Aborted)
assert.Equal(t, 0, len(rsp.Output))
assert.Equal(t, 0, len(rsp.Problems))
deleteResp, err := client.DeleteModel(databaseName, engineName, "test_model")
assert.Equal(t, "COMPLETED", deleteResp.Transaction.State)
assert.Equal(t, 0, len(deleteResp.Problems))

_, err = client.GetModel(databaseName, engineName, "test_model")
assert.True(t, isErrNotFound(err))
Expand Down