Skip to content

Commit

Permalink
feat(metastore): add transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
lianxmfor committed Nov 16, 2021
1 parent 47f1d6e commit ea88174
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 45 deletions.
35 changes: 35 additions & 0 deletions internal/database/metadata/postgres/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@ import (
)

var _ metadata.Store = &DB{}
var _ metadata.Store = &Tx{}

type DB struct {
*sqlx.DB
*informer.Informer
}

type Tx struct {
*sqlx.Tx
*informer.Informer
}

func Open(ctx context.Context, option *types.PostgresOpt) (*DB, error) {
db, err := OpenDB(ctx, option.Host, option.Port, option.User, option.Password, option.Database)
if err != nil {
Expand Down Expand Up @@ -87,3 +93,32 @@ func list(ctx context.Context, db *sqlx.DB) (*informer.Cache, error) {
})
return cache, err
}

func (db *DB) WithTransaction(ctx context.Context, fn func(context.Context, metadata.Store) error) (err error) {
tx, err := db.BeginTxx(ctx, nil)
if err != nil {
return
}

txStore := &Tx{Tx: tx}

defer func() {
if p := recover(); p != nil {
// a panic occurred, rollback and repanic
_ = tx.Rollback()
panic(p)
} else if err != nil {
// something went wrong, rollback
_ = tx.Rollback()
} else {
// all good, commit
err = tx.Commit()
}
}()

return fn(ctx, txStore)
}

func (tx *Tx) WithTransaction(ctx context.Context, fn func(context.Context, metadata.Store) error) (err error) {
return fn(ctx, tx)
}
48 changes: 48 additions & 0 deletions internal/database/metadata/postgres/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package postgres

import (
"context"

"github.com/oom-ai/oomstore/internal/database/metadata"
)

func (db *DB) CreateEntity(ctx context.Context, opt metadata.CreateEntityOpt) (int16, error) {
return createEntity(ctx, db, opt)
}

func (db *DB) UpdateEntity(ctx context.Context, opt metadata.UpdateEntityOpt) error {
return updateEntity(ctx, db, opt)
}

func (db *DB) CreateFeatureGroup(ctx context.Context, opt metadata.CreateFeatureGroupOpt) (int16, error) {
return createFeatureGroup(ctx, db, opt)
}

func (db *DB) UpdateFeatureGroup(ctx context.Context, opt metadata.UpdateFeatureGroupOpt) error {
return updateFeatureGroup(ctx, db, opt)
}

func (db *DB) CreateFeature(ctx context.Context, opt metadata.CreateFeatureOpt) (int16, error) {
return createFeature(ctx, db, opt)
}

func (db *DB) UpdateFeature(ctx context.Context, opt metadata.UpdateFeatureOpt) error {
return updateFeature(ctx, db, opt)
}

func (db *DB) CreateRevision(ctx context.Context, opt metadata.CreateRevisionOpt) (int32, string, error) {
var (
revisionId int32
dataTable string
err error
)
err = db.WithTransaction(ctx, func(c context.Context, s metadata.Store) error {
revisionId, dataTable, err = createRevision(ctx, db, opt)
return err
})
return revisionId, dataTable, err
}

func (db *DB) UpdateRevision(ctx context.Context, opt metadata.UpdateRevisionOpt) error {
return updateRevision(ctx, db, opt)
}
8 changes: 4 additions & 4 deletions internal/database/metadata/postgres/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import (
"github.com/oom-ai/oomstore/internal/database/metadata"
)

func (db *DB) CreateEntity(ctx context.Context, opt metadata.CreateEntityOpt) (int16, error) {
func createEntity(ctx context.Context, exec metadata.ExecContext, opt metadata.CreateEntityOpt) (int16, error) {
var entityId int16
query := "insert into feature_entity(name, length, description) values($1, $2, $3) returning id"
err := db.GetContext(ctx, &entityId, query, opt.Name, opt.Length, opt.Description)
err := exec.GetContext(ctx, &entityId, query, opt.Name, opt.Length, opt.Description)
if er, ok := err.(*pq.Error); ok {
if er.Code == pgerrcode.UniqueViolation {
return 0, fmt.Errorf("entity %s already exists", opt.Name)
Expand All @@ -21,9 +21,9 @@ func (db *DB) CreateEntity(ctx context.Context, opt metadata.CreateEntityOpt) (i
return entityId, err
}

func (db *DB) UpdateEntity(ctx context.Context, opt metadata.UpdateEntityOpt) error {
func updateEntity(ctx context.Context, exec metadata.ExecContext, opt metadata.UpdateEntityOpt) error {
query := "UPDATE feature_entity SET description = $1 WHERE id = $2"
result, err := db.ExecContext(ctx, query, opt.NewDescription, opt.EntityID)
result, err := exec.ExecContext(ctx, query, opt.NewDescription, opt.EntityID)
if err != nil {
return err
}
Expand Down
14 changes: 7 additions & 7 deletions internal/database/metadata/postgres/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ import (
"github.com/oom-ai/oomstore/internal/database/metadata"
)

func (db *DB) CreateFeature(ctx context.Context, opt metadata.CreateFeatureOpt) (int16, error) {
if err := db.validateDataType(ctx, opt.DBValueType); err != nil {
func createFeature(ctx context.Context, exec metadata.ExecContext, opt metadata.CreateFeatureOpt) (int16, error) {
if err := validateDataType(ctx, exec, opt.DBValueType); err != nil {
return 0, fmt.Errorf("err when validating value_type input, details: %s", err.Error())
}
var featureId int16
query := "INSERT INTO feature(name, group_id, db_value_type, value_type, description) VALUES ($1, $2, $3, $4, $5) RETURNING id"
err := db.GetContext(ctx, &featureId, query, opt.Name, opt.GroupID, opt.DBValueType, opt.ValueType, opt.Description)
err := exec.GetContext(ctx, &featureId, query, opt.Name, opt.GroupID, opt.DBValueType, opt.ValueType, opt.Description)
if err != nil {
if e2, ok := err.(*pq.Error); ok {
if e2.Code == pgerrcode.UniqueViolation {
Expand All @@ -27,9 +27,9 @@ func (db *DB) CreateFeature(ctx context.Context, opt metadata.CreateFeatureOpt)
return featureId, err
}

func (db *DB) UpdateFeature(ctx context.Context, opt metadata.UpdateFeatureOpt) error {
func updateFeature(ctx context.Context, exec metadata.ExecContext, opt metadata.UpdateFeatureOpt) error {
query := "UPDATE feature SET description = $1 WHERE id = $2"
result, err := db.ExecContext(ctx, query, opt.NewDescription, opt.FeatureID)
result, err := exec.ExecContext(ctx, query, opt.NewDescription, opt.FeatureID)
if err != nil {
return err
}
Expand All @@ -43,9 +43,9 @@ func (db *DB) UpdateFeature(ctx context.Context, opt metadata.UpdateFeatureOpt)
return nil
}

func (db *DB) validateDataType(ctx context.Context, dataType string) error {
func validateDataType(ctx context.Context, exec metadata.ExecContext, dataType string) error {
tmpTable := dbutil.TempTable("validate_data_type")
stmt := fmt.Sprintf("CREATE TEMPORARY TABLE %s (a %s) ON COMMIT DROP", tmpTable, dataType)
_, err := db.ExecContext(ctx, stmt)
_, err := exec.ExecContext(ctx, stmt)
return err
}
8 changes: 4 additions & 4 deletions internal/database/metadata/postgres/feature_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ import (
"github.com/oom-ai/oomstore/pkg/oomstore/types"
)

func (db *DB) CreateFeatureGroup(ctx context.Context, opt metadata.CreateFeatureGroupOpt) (int16, error) {
func createFeatureGroup(ctx context.Context, exec metadata.ExecContext, opt metadata.CreateFeatureGroupOpt) (int16, error) {
if opt.Category != types.BatchFeatureCategory && opt.Category != types.StreamFeatureCategory {
return 0, fmt.Errorf("illegal category '%s', should be either 'stream' or 'batch'", opt.Category)
}
var featureGroupId int16
query := "insert into feature_group(name, entity_id, category, description) values($1, $2, $3, $4) returning id"
err := db.GetContext(ctx, &featureGroupId, query, opt.Name, opt.EntityID, opt.Category, opt.Description)
err := exec.GetContext(ctx, &featureGroupId, query, opt.Name, opt.EntityID, opt.Category, opt.Description)
if err != nil {
if e2, ok := err.(*pq.Error); ok {
if e2.Code == pgerrcode.UniqueViolation {
Expand All @@ -29,7 +29,7 @@ func (db *DB) CreateFeatureGroup(ctx context.Context, opt metadata.CreateFeature
return featureGroupId, err
}

func (db *DB) UpdateFeatureGroup(ctx context.Context, opt metadata.UpdateFeatureGroupOpt) error {
func updateFeatureGroup(ctx context.Context, exec metadata.ExecContext, opt metadata.UpdateFeatureGroupOpt) error {
and := make(map[string]interface{})
if opt.NewDescription != nil {
and["description"] = *opt.NewDescription
Expand All @@ -48,7 +48,7 @@ func (db *DB) UpdateFeatureGroup(ctx context.Context, opt metadata.UpdateFeature
}

query := fmt.Sprintf("UPDATE feature_group SET %s WHERE id = ?", strings.Join(cond, ","))
result, err := db.ExecContext(ctx, db.Rebind(query), args...)
result, err := exec.ExecContext(ctx, exec.Rebind(query), args...)
if err != nil {
return err
}
Expand Down
55 changes: 25 additions & 30 deletions internal/database/metadata/postgres/revision.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,49 @@ import (
"strings"

"github.com/jackc/pgerrcode"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/oom-ai/oomstore/internal/database/dbutil"
"github.com/oom-ai/oomstore/internal/database/metadata"
)

func (db *DB) CreateRevision(ctx context.Context, opt metadata.CreateRevisionOpt) (int32, string, error) {
func createRevision(ctx context.Context, exec metadata.ExecContext, opt metadata.CreateRevisionOpt) (int32, string, error) {
var dataTable string
if opt.DataTable != nil {
dataTable = *opt.DataTable
}

var revisionId int32
err := dbutil.WithTransaction(db.DB, ctx, func(ctx context.Context, tx *sqlx.Tx) error {
insertQuery := "INSERT INTO feature_group_revision(group_id, revision, data_table, anchored, description) VALUES ($1, $2, $3, $4, $5) RETURNING id"
if err := tx.GetContext(ctx, &revisionId, insertQuery, opt.GroupID, opt.Revision, dataTable, opt.Anchored, opt.Description); err != nil {
if e2, ok := err.(*pq.Error); ok {
if e2.Code == pgerrcode.UniqueViolation {
return fmt.Errorf("revision already exists: groupId=%d, revision=%d", opt.GroupID, opt.Revision)
}
insertQuery := "INSERT INTO feature_group_revision(group_id, revision, data_table, anchored, description) VALUES ($1, $2, $3, $4, $5) RETURNING id"
if err := exec.GetContext(ctx, &revisionId, insertQuery, opt.GroupID, opt.Revision, dataTable, opt.Anchored, opt.Description); err != nil {
if e2, ok := err.(*pq.Error); ok {
if e2.Code == pgerrcode.UniqueViolation {
return 0, "", fmt.Errorf("revision already exists: groupId=%d, revision=%d", opt.GroupID, opt.Revision)
}
return err
}
if opt.DataTable == nil {
updateQuery := "UPDATE feature_group_revision SET data_table = $1 WHERE id = $2"
dataTable = fmt.Sprintf("data_%d_%d", opt.GroupID, revisionId)
result, err := tx.ExecContext(ctx, updateQuery, dataTable, revisionId)
if err != nil {
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected != 1 {
return fmt.Errorf("failed to update revision %d: revision not found", revisionId)
}
return 0, "", err
}
if opt.DataTable == nil {
updateQuery := "UPDATE feature_group_revision SET data_table = $1 WHERE id = $2"
dataTable = fmt.Sprintf("data_%d_%d", opt.GroupID, revisionId)
result, err := exec.ExecContext(ctx, updateQuery, dataTable, revisionId)
if err != nil {
return 0, "", err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, "", err
}
if rowsAffected != 1 {
return 0, "", fmt.Errorf("failed to update revision %d: revision not found", revisionId)
}
}

return nil
})

return revisionId, dataTable, err
return revisionId, dataTable, nil
}

// UpdateRevision = MustUpdateRevision
// If fail to update any row or update more than one row, return error
func (db *DB) UpdateRevision(ctx context.Context, opt metadata.UpdateRevisionOpt) error {
func updateRevision(ctx context.Context, exec metadata.ExecContext, opt metadata.UpdateRevisionOpt) error {
and := make(map[string]interface{})
if opt.NewRevision != nil {
and["revision"] = *opt.NewRevision
Expand All @@ -71,7 +66,7 @@ func (db *DB) UpdateRevision(ctx context.Context, opt metadata.UpdateRevisionOpt
args = append(args, opt.RevisionID)

query := fmt.Sprintf("UPDATE feature_group_revision SET %s WHERE id = ?", strings.Join(cond, ","))
result, err := db.ExecContext(ctx, db.Rebind(query), args...)
result, err := exec.ExecContext(ctx, exec.Rebind(query), args...)
if err != nil {
return err
}
Expand Down
39 changes: 39 additions & 0 deletions internal/database/metadata/postgres/tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package postgres

import (
"context"

"github.com/oom-ai/oomstore/internal/database/metadata"
)

func (tx *Tx) CreateEntity(ctx context.Context, opt metadata.CreateEntityOpt) (int16, error) {
return createEntity(ctx, tx, opt)
}

func (tx *Tx) UpdateEntity(ctx context.Context, opt metadata.UpdateEntityOpt) error {
return updateEntity(ctx, tx, opt)
}

func (tx *Tx) CreateFeatureGroup(ctx context.Context, opt metadata.CreateFeatureGroupOpt) (int16, error) {
return createFeatureGroup(ctx, tx, opt)
}

func (tx *Tx) UpdateFeatureGroup(ctx context.Context, opt metadata.UpdateFeatureGroupOpt) error {
return updateFeatureGroup(ctx, tx, opt)
}

func (tx *Tx) CreateFeature(ctx context.Context, opt metadata.CreateFeatureOpt) (int16, error) {
return createFeature(ctx, tx, opt)
}

func (tx *Tx) UpdateFeature(ctx context.Context, opt metadata.UpdateFeatureOpt) error {
return updateFeature(ctx, tx, opt)
}

func (tx *Tx) CreateRevision(ctx context.Context, opt metadata.CreateRevisionOpt) (int32, string, error) {
return createRevision(ctx, tx, opt)
}

func (tx *Tx) UpdateRevision(ctx context.Context, opt metadata.UpdateRevisionOpt) error {
return updateRevision(ctx, tx, opt)
}
13 changes: 13 additions & 0 deletions internal/database/metadata/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package metadata

import (
"context"
"database/sql"
"io"

"github.com/oom-ai/oomstore/pkg/oomstore/types"
Expand Down Expand Up @@ -36,7 +37,19 @@ type Store interface {
GetRevisionBy(ctx context.Context, groupID int16, revision int64) (*types.Revision, error)
ListRevision(ctx context.Context, opt ListRevisionOpt) types.RevisionList

// transaction
WithTransaction(ctx context.Context, fn func(context.Context, Store) error) error

// refresh
Refresh() error
io.Closer
}

type ExecContext interface {
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)

DriverName() string
Rebind(string) string
}

0 comments on commit ea88174

Please sign in to comment.