Skip to content

Commit

Permalink
Add first draft of Data Access Layer, with test (kubeflow#34)
Browse files Browse the repository at this point in the history
* Add first draft of Data Access Layer, with test

* implement code review feedback

* move code into file structure as requested

* Move DB fn into db_context.go

* Reuse enum from db/type.go

* Use :memory: sqlite as requested

---------

Co-authored-by: Andrea Lamparelli <[email protected]>
  • Loading branch information
tarilabs and lampajr authored Oct 4, 2023
1 parent d8277f5 commit 57395c7
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 14 deletions.
26 changes: 13 additions & 13 deletions internal/server/grpc/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
"github.com/opendatahub-io/model-registry/internal/model/db"
"github.com/opendatahub-io/model-registry/internal/server"
"github.com/opendatahub-io/model-registry/internal/service"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"
Expand All @@ -27,7 +27,7 @@ func NewGrpcServer(dbConnection *gorm.DB) proto.MetadataStoreServiceServer {
var REQUIRED_TYPE_FIELDS = []string{"name"}

func (g grpcServer) PutArtifactType(ctx context.Context, request *proto.PutArtifactTypeRequest) (resp *proto.PutArtifactTypeResponse, err error) {
ctx, _ = server.Begin(ctx, g.dbConnection)
ctx, _ = service.Begin(ctx, g.dbConnection)
defer handleTransaction(ctx, &err)

artifactType := request.GetArtifactType()
Expand Down Expand Up @@ -56,7 +56,7 @@ func (g grpcServer) PutArtifactType(ctx context.Context, request *proto.PutArtif
func (g grpcServer) createOrUpdateType(ctx context.Context, value *db.Type,
properties map[string]proto.PropertyType) error {
// TODO handle CanAdd, CanOmit properties from type request
dbConn, _ := server.FromContext(ctx)
dbConn, _ := service.FromContext(ctx)

if err := dbConn.Where("name = ?", value.Name).Assign(value).FirstOrCreate(value).Error; err != nil {
err = fmt.Errorf("error creating type %s: %v", value.Name, err)
Expand All @@ -70,7 +70,7 @@ func (g grpcServer) createOrUpdateType(ctx context.Context, value *db.Type,
}

func (g grpcServer) PutExecutionType(ctx context.Context, request *proto.PutExecutionTypeRequest) (resp *proto.PutExecutionTypeResponse, err error) {
ctx, _ = server.Begin(ctx, g.dbConnection)
ctx, _ = service.Begin(ctx, g.dbConnection)
defer handleTransaction(ctx, &err)

executionType := request.GetExecutionType()
Expand All @@ -96,7 +96,7 @@ func (g grpcServer) PutExecutionType(ctx context.Context, request *proto.PutExec
}

func (g grpcServer) PutContextType(ctx context.Context, request *proto.PutContextTypeRequest) (resp *proto.PutContextTypeResponse, err error) {
ctx, _ = server.Begin(ctx, g.dbConnection)
ctx, _ = service.Begin(ctx, g.dbConnection)
defer handleTransaction(ctx, &err)

contextType := request.GetContextType()
Expand All @@ -122,7 +122,7 @@ func (g grpcServer) PutContextType(ctx context.Context, request *proto.PutContex
}

func (g grpcServer) PutTypes(ctx context.Context, request *proto.PutTypesRequest) (resp *proto.PutTypesResponse, err error) {
ctx, _ = server.Begin(ctx, g.dbConnection)
ctx, _ = service.Begin(ctx, g.dbConnection)
defer handleTransaction(ctx, &err)

response := &proto.PutTypesResponse{}
Expand Down Expand Up @@ -172,7 +172,7 @@ func (g grpcServer) PutTypes(ctx context.Context, request *proto.PutTypesRequest
var REQUIRED_ARTIFACT_FIELDS = []string{"type_id", "uri"}

func (g grpcServer) PutArtifacts(ctx context.Context, request *proto.PutArtifactsRequest) (resp *proto.PutArtifactsResponse, err error) {
ctx, dbConn := server.Begin(ctx, g.dbConnection)
ctx, dbConn := service.Begin(ctx, g.dbConnection)
defer handleTransaction(ctx, &err)

var artifactIds []int64
Expand Down Expand Up @@ -250,7 +250,7 @@ func (g grpcServer) PutParentContexts(ctx context.Context, request *proto.PutPar
}

func (g grpcServer) GetArtifactType(ctx context.Context, request *proto.GetArtifactTypeRequest) (resp *proto.GetArtifactTypeResponse, err error) {
ctx, dbConn := server.Begin(ctx, g.dbConnection)
ctx, dbConn := service.Begin(ctx, g.dbConnection)
defer handleTransaction(ctx, &err)

err = requiredFields(REQUIRED_TYPE_FIELDS, request.TypeName)
Expand Down Expand Up @@ -484,7 +484,7 @@ func (g grpcServer) mustEmbedUnimplementedMetadataStoreServiceServer() {
}

func (g grpcServer) createTypeProperties(ctx context.Context, properties map[string]proto.PropertyType, typeId int64) (err error) {
ctx, dbConn := server.Begin(ctx, g.dbConnection)
ctx, dbConn := service.Begin(ctx, g.dbConnection)
defer handleTransaction(ctx, &err)

for propName, prop := range properties {
Expand All @@ -504,7 +504,7 @@ func (g grpcServer) createTypeProperties(ctx context.Context, properties map[str
}

func (g grpcServer) createArtifactProperties(ctx context.Context, artifactId int64, properties map[string]*proto.Value, isCustomProperty bool) (err error) {
ctx, dbConn := server.Begin(ctx, g.dbConnection)
ctx, dbConn := service.Begin(ctx, g.dbConnection)
defer handleTransaction(ctx, &err)

for propName, prop := range properties {
Expand Down Expand Up @@ -576,14 +576,14 @@ func nilSafeCopy[D int32 | int64 | *int64 | string, S int64 | proto.Artifact_Sta
func handleTransaction(ctx context.Context, err *error) {
// handle panic
if perr := recover(); perr != nil {
_ = server.Rollback(ctx)
_ = service.Rollback(ctx)
*err = status.Errorf(codes.Internal, "server panic: %v", perr)
return
}
if err == nil || *err == nil {
*err = server.Commit(ctx)
*err = service.Commit(ctx)
} else {
_ = server.Rollback(ctx)
_ = service.Rollback(ctx)
if _, ok := status.FromError(*err); !ok {
*err = status.Errorf(codes.Internal, "internal error: %v", *err)
}
Expand Down
9 changes: 9 additions & 0 deletions internal/service/artifact_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package service

import (
"gorm.io/gorm"
)

type artifactHandler struct {
db *gorm.DB
}
22 changes: 21 additions & 1 deletion internal/server/db_context.go → internal/service/db_context.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package server
package service

import (
"context"
"fmt"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"
)

Expand Down Expand Up @@ -63,3 +66,20 @@ func Rollback(ctx context.Context) error {
// rollback in unwrapped parent context
return nil
}

func handleTransaction(ctx context.Context, err *error) {
// handle panic
if perr := recover(); perr != nil {
_ = Rollback(ctx)
*err = status.Errorf(codes.Internal, "server panic: %v", perr)
return
}
if err == nil || *err == nil {
*err = Commit(ctx)
} else {
_ = Rollback(ctx)
if _, ok := status.FromError(*err); !ok {
*err = status.Errorf(codes.Internal, "internal error: %v", *err)
}
}
}
39 changes: 39 additions & 0 deletions internal/service/db_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package service

import (
"gorm.io/gorm"

"github.com/opendatahub-io/model-registry/internal/model/db"
)

var _ DBService = dbServiceHandler{}
var _ DBService = (*dbServiceHandler)(nil)

func NewDBService(db *gorm.DB) DBService {
return &dbServiceHandler{
typeHandler: &typeHandler{db: db},
artifactHandler: &artifactHandler{db: db},
}
}

type DBService interface {
InsertType(db.Type) (*db.Type, error)
UpsertType(db.Type) (*db.Type, error)
ReadType(db.Type) (*db.Type, error)
// Get-like function to use a signature similar to the gorm `Where` func
ReadAllType(query interface{}, args ...interface{}) ([]*db.Type, error)
UpdateType(db.Type) (*db.Type, error)
DeleteType(db.Type) (*db.Type, error)

// InsertEEE(db.EEE) (*db.EEE, error)
// UpsertEEE(db.EEE) (*db.EEE, error)
// ReadEEE(db.EEE) (*db.EEE, error)
// ReadAllEEE(query interface{}, args ...interface{}) ([]*db.EEE, error)
// UpdateEEE(db.EEE) (*db.EEE, error)
// DeleteEEE(db.EEE) (*db.EEE, error)
}

type dbServiceHandler struct {
*typeHandler
*artifactHandler
}
146 changes: 146 additions & 0 deletions internal/service/db_service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package service

import (
"fmt"
"testing"

"github.com/opendatahub-io/model-registry/internal/model/db"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)

func migrateDatabase(dbConn *gorm.DB) error {
// using only needed RDBMS type for the scope under test
err := dbConn.AutoMigrate(
db.Type{},
db.TypeProperty{},
// TODO: add as needed.
)
if err != nil {
return fmt.Errorf("db migration failed: %w", err)
}
return nil
}

func setup() (*gorm.DB, error) {
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
if err != nil {
return nil, err
}
err = migrateDatabase(db)
if err != nil {
return nil, err
}
return db, nil
}

// Bare minimal test of PutArtifactType with a given Name, and Get.
func TestInsertTypeThenReadAllType(t *testing.T) {
dbc, err := setup()
if err != nil {
t.Errorf("Should expect DB connection: %v", err)
}
defer func() {
dbi, err := dbc.DB()
if err != nil {
t.Errorf("Test need to clear sqlite DB for the next one, but errored: %v", err)
}
dbi.Close()
}()
dal := NewDBService(dbc)

artifactName := "John Doe"
newType := db.Type{
Name: artifactName,
TypeKind: int8(db.ARTIFACT_TYPE),
}

at, err := dal.InsertType(newType)
if err != nil {
t.Errorf("Should create ArtifactType: %v", err)
}
if at.ID < 0 {
t.Errorf("Should have ID for ArtifactType: %v", at.ID)
}
if at.Name != artifactName {
t.Errorf("Should have Name for ArtifactType per constant: %v", at.Name)
}

ats, err2 := dal.ReadAllType(newType)
if err2 != nil {
t.Errorf("Should get ArtifactType: %v", err2)
}
if len(ats) != 1 { // TODO if temp file is okay, this is superfluos
t.Errorf("The test is running under different assumption")
}
at0 := ats[0]
t.Logf("at0: %v", at0)
if at0.ID != at.ID {
t.Errorf("Should have same ID")
}
if at0.Name != at.Name {
t.Errorf("Should have same Name")
}

}

func TestReadAllType(t *testing.T) {
dbc, err := setup()
if err != nil {
t.Errorf("Should expect DB connection: %v", err)
}
defer func() {
dbi, err := dbc.DB()
if err != nil {
t.Errorf("Test need to clear sqlite DB for the next one, but errored: %v", err)
}
dbi.Close()
}()
dal := NewDBService(dbc)

fixVersion := "version"

if _, err := dal.InsertType(db.Type{Name: "at0", Version: &fixVersion, TypeKind: int8(db.ARTIFACT_TYPE)}); err != nil {
t.Errorf("Should create ArtifactType: %v", err)
}
if _, err := dal.InsertType(db.Type{Name: "at1", Version: &fixVersion, TypeKind: int8(db.ARTIFACT_TYPE)}); err != nil {
t.Errorf("Should create ArtifactType: %v", err)
}

results, err := dal.ReadAllType(db.Type{Version: &fixVersion})
t.Logf("results: %v", results)
if err != nil {
t.Errorf("Should get ArtifactTypes: %v", err)
}
if len(results) != 2 {
t.Errorf("Should have retrieved 2 artifactTypes")
}
}

func TestUpsertType(t *testing.T) {
dbc, err := setup()
if err != nil {
t.Errorf("Should expect DB connection: %v", err)
}
defer func() {
dbi, err := dbc.DB()
if err != nil {
t.Errorf("Test need to clear sqlite DB for the next one, but errored: %v", err)
}
dbi.Close()
}()
dal := NewDBService(dbc)

artifactName := "John Doe"
v0 := "v0"
v1 := "v1"
if _, err := dal.InsertType(db.Type{Name: artifactName, Version: &v0, TypeKind: int8(db.ARTIFACT_TYPE)}); err != nil {
t.Errorf("Should Insert ArtifactType: %v", err)
}
if res, err := dal.InsertType(db.Type{Name: artifactName, Version: &v0, TypeKind: int8(db.ARTIFACT_TYPE)}); err == nil {
t.Errorf("Subsequent Insert must have failed: %v", res)
}
if _, err := dal.UpsertType(db.Type{Name: artifactName, Version: &v1, TypeKind: int8(db.ARTIFACT_TYPE)}); err != nil {
t.Errorf("Should Upsert ArtifactType: %v", err)
}
}
Loading

0 comments on commit 57395c7

Please sign in to comment.