Skip to content

Commit

Permalink
Use generated mocks in flyteadmin (#6197)
Browse files Browse the repository at this point in the history
* build mock files with mockery v2

Signed-off-by: Alex Wu <[email protected]>

* add more mocks

Signed-off-by: Alex Wu <[email protected]>

* fix go generate command in interface files

Signed-off-by: Alex Wu <[email protected]>

* minor fix

Signed-off-by: Alex Wu <[email protected]>

---------

Signed-off-by: Alex Wu <[email protected]>
  • Loading branch information
popojk authored Jan 31, 2025
1 parent 448aba9 commit 2d457df
Show file tree
Hide file tree
Showing 47 changed files with 3,549 additions and 1,144 deletions.
41 changes: 21 additions & 20 deletions flyteadmin/dataproxy/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
Expand All @@ -33,8 +34,8 @@ func TestNewService(t *testing.T) {
dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
nodeExecutionManager := &mocks.NodeExecutionInterface{}
taskExecutionManager := &mocks.TaskExecutionInterface{}
s, err := NewService(config.DataProxyConfig{
Upload: config.DataProxyUploadConfig{},
}, nodeExecutionManager, dataStore, taskExecutionManager)
Expand All @@ -59,8 +60,8 @@ func Test_createStorageLocation(t *testing.T) {
func TestCreateUploadLocation(t *testing.T) {
dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
nodeExecutionManager := &mocks.NodeExecutionInterface{}
taskExecutionManager := &mocks.TaskExecutionInterface{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)
t.Run("No project/domain", func(t *testing.T) {
Expand Down Expand Up @@ -113,8 +114,8 @@ func TestCreateUploadLocationMore(t *testing.T) {
}

assert.NoError(t, err)
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
nodeExecutionManager := &mocks.NodeExecutionInterface{}
taskExecutionManager := &mocks.TaskExecutionInterface{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, &ds, taskExecutionManager)
assert.NoError(t, err)

Expand Down Expand Up @@ -171,15 +172,15 @@ func (t testMetadata) Exists() bool {

func TestCreateDownloadLink(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
nodeExecutionManager.SetGetNodeExecutionFunc(func(ctx context.Context, request *admin.NodeExecutionGetRequest) (*admin.NodeExecution, error) {
nodeExecutionManager := &mocks.NodeExecutionInterface{}
nodeExecutionManager.EXPECT().GetNodeExecution(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *admin.NodeExecutionGetRequest) (*admin.NodeExecution, error) {
return &admin.NodeExecution{
Closure: &admin.NodeExecutionClosure{
DeckUri: "s3://something/something",
},
}, nil
})
taskExecutionManager := &mocks.MockTaskExecutionManager{}
taskExecutionManager := &mocks.TaskExecutionInterface{}

s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)
Expand Down Expand Up @@ -262,8 +263,8 @@ func TestCreateDownloadLink(t *testing.T) {

func TestCreateDownloadLocation(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
nodeExecutionManager := &mocks.NodeExecutionInterface{}
taskExecutionManager := &mocks.TaskExecutionInterface{}
s, err := NewService(config.DataProxyConfig{Download: config.DataProxyDownloadConfig{MaxExpiresIn: stdlibConfig.Duration{Duration: time.Hour}}}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)

Expand Down Expand Up @@ -300,8 +301,8 @@ func TestCreateDownloadLocation(t *testing.T) {

func TestService_GetData(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
nodeExecutionManager := &mocks.NodeExecutionInterface{}
taskExecutionManager := &mocks.TaskExecutionInterface{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)

Expand Down Expand Up @@ -340,15 +341,15 @@ func TestService_GetData(t *testing.T) {
},
}

nodeExecutionManager.SetGetNodeExecutionDataFunc(
nodeExecutionManager.EXPECT().GetNodeExecutionData(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, request *admin.NodeExecutionGetDataRequest) (*admin.NodeExecutionGetDataResponse, error) {
return &admin.NodeExecutionGetDataResponse{
FullInputs: inputsLM,
FullOutputs: outputsLM,
}, nil
},
)
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
taskExecutionManager.EXPECT().ListTaskExecutions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return &admin.TaskExecutionList{
TaskExecutions: []*admin.TaskExecution{
{
Expand All @@ -374,7 +375,7 @@ func TestService_GetData(t *testing.T) {
},
}, nil
})
taskExecutionManager.SetGetTaskExecutionDataCallback(func(ctx context.Context, request *admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) {
taskExecutionManager.EXPECT().GetTaskExecutionData(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) {
return &admin.TaskExecutionGetDataResponse{
FullInputs: inputsLM,
FullOutputs: outputsLM,
Expand Down Expand Up @@ -441,13 +442,13 @@ func TestService_GetData(t *testing.T) {

func TestService_Error(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
taskExecutionManager := &mocks.MockTaskExecutionManager{}
nodeExecutionManager := &mocks.NodeExecutionInterface{}
taskExecutionManager := &mocks.TaskExecutionInterface{}
s, err := NewService(config.DataProxyConfig{}, nodeExecutionManager, dataStore, taskExecutionManager)
assert.NoError(t, err)

t.Run("get a working set of urls without retry attempt", func(t *testing.T) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
taskExecutionManager.EXPECT().ListTaskExecutions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return nil, errors.NewFlyteAdminErrorf(1, "not found")
})
nodeExecID := &core.NodeExecutionIdentifier{
Expand All @@ -463,7 +464,7 @@ func TestService_Error(t *testing.T) {
})

t.Run("get a working set of urls without retry attempt", func(t *testing.T) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
taskExecutionManager.EXPECT().ListTaskExecutions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return &admin.TaskExecutionList{
TaskExecutions: nil,
Token: "",
Expand Down
25 changes: 13 additions & 12 deletions flyteadmin/pkg/async/schedule/aws/workflow_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/codes"

flyteAdminErrors "github.com/flyteorg/flyte/flyteadmin/pkg/errors"
Expand Down Expand Up @@ -138,8 +139,8 @@ func TestGetActiveLaunchPlanVersion(t *testing.T) {
Version: "foo",
}

launchPlanManager := mocks.NewMockLaunchPlanManager()
launchPlanManager.(*mocks.MockLaunchPlanManager).SetListLaunchPlansCallback(
launchPlanManager := mocks.LaunchPlanInterface{}
launchPlanManager.EXPECT().ListLaunchPlans(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, request *admin.ResourceListRequest) (
*admin.LaunchPlanList, error) {
assert.True(t, proto.Equal(launchPlanNamedIdentifier, request.GetId()))
Expand All @@ -153,7 +154,7 @@ func TestGetActiveLaunchPlanVersion(t *testing.T) {
},
}, nil
})
testExecutor := newWorkflowExecutorForTest(nil, nil, launchPlanManager)
testExecutor := newWorkflowExecutorForTest(nil, nil, &launchPlanManager)
launchPlan, err := testExecutor.getActiveLaunchPlanVersion(launchPlanNamedIdentifier)
assert.Nil(t, err)
assert.True(t, proto.Equal(&launchPlanIdentifier, launchPlan.GetId()))
Expand All @@ -167,13 +168,13 @@ func TestGetActiveLaunchPlanVersion_ManagerError(t *testing.T) {
}

expectedErr := errors.New("expected error")
launchPlanManager := mocks.NewMockLaunchPlanManager()
launchPlanManager.(*mocks.MockLaunchPlanManager).SetListLaunchPlansCallback(
launchPlanManager := mocks.LaunchPlanInterface{}
launchPlanManager.EXPECT().ListLaunchPlans(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, request *admin.ResourceListRequest) (
*admin.LaunchPlanList, error) {
return nil, expectedErr
})
testExecutor := newWorkflowExecutorForTest(nil, nil, launchPlanManager)
testExecutor := newWorkflowExecutorForTest(nil, nil, &launchPlanManager)
_, err := testExecutor.getActiveLaunchPlanVersion(launchPlanIdentifier)
assert.EqualError(t, err, expectedErr.Error())
}
Expand Down Expand Up @@ -229,23 +230,23 @@ func TestRun(t *testing.T) {
testSubscriber := pubsubtest.TestSubscriber{
JSONMessages: messages,
}
testExecutionManager := mocks.MockExecutionManager{}
testExecutionManager := mocks.ExecutionInterface{}
var messagesSeen int
testExecutionManager.SetCreateCallback(func(
testExecutionManager.EXPECT().CreateExecution(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(
ctx context.Context, request *admin.ExecutionCreateRequest, requestedAt time.Time) (
*admin.ExecutionCreateResponse, error) {
assert.Equal(t, "project", request.GetProject())
assert.Equal(t, "domain", request.GetDomain())
assert.Equal(t, "ar8fphnlc5wh9dksjncj", request.GetName())
if messagesSeen == 0 {
assert.Contains(t, request.GetInputs().GetLiterals(), testKickoffTime)
assert.Equal(t, testKickoffTimeProtoLiteral, request.GetInputs().GetLiterals()[testKickoffTime])
assert.True(t, proto.Equal(testKickoffTimeProtoLiteral, request.GetInputs().GetLiterals()[testKickoffTime]))
}
messagesSeen++
return &admin.ExecutionCreateResponse{}, nil
})
launchPlanManager := mocks.NewMockLaunchPlanManager()
launchPlanManager.(*mocks.MockLaunchPlanManager).SetListLaunchPlansCallback(
launchPlanManager := mocks.LaunchPlanInterface{}
launchPlanManager.EXPECT().ListLaunchPlans(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, request *admin.ResourceListRequest) (
*admin.LaunchPlanList, error) {
assert.Equal(t, "project", request.GetId().GetProject())
Expand Down Expand Up @@ -280,7 +281,7 @@ func TestRun(t *testing.T) {
},
}, nil
})
testExecutor := newWorkflowExecutorForTest(&testSubscriber, &testExecutionManager, launchPlanManager)
testExecutor := newWorkflowExecutorForTest(&testSubscriber, &testExecutionManager, &launchPlanManager)
err := testExecutor.run()
assert.Len(t, messages, messagesSeen)
assert.Nil(t, err)
Expand Down
17 changes: 10 additions & 7 deletions flyteadmin/pkg/clusterresource/impl/db_admin_data_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

Expand All @@ -28,9 +29,9 @@ func TestGetClusterResourceAttributes(t *testing.T) {
"K1": "V1",
"K2": "V2",
}
resourceManager := mocks.MockResourceManager{}
t.Run("happy case", func(t *testing.T) {
resourceManager.GetResourceFunc = func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) {
resourceManager := mocks.ResourceInterface{}
resourceManager.EXPECT().GetResource(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) {
return &interfaces.ResourceResponse{
Project: request.Project,
Domain: request.Domain,
Expand All @@ -43,7 +44,7 @@ func TestGetClusterResourceAttributes(t *testing.T) {
},
},
}, nil
}
})
provider := dbAdminProvider{
resourceManager: &resourceManager,
}
Expand All @@ -52,17 +53,19 @@ func TestGetClusterResourceAttributes(t *testing.T) {
assert.EqualValues(t, attrs.GetAttributes(), attributes)
})
t.Run("error", func(t *testing.T) {
resourceManager.GetResourceFunc = func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) {
resourceManager := mocks.ResourceInterface{}
resourceManager.EXPECT().GetResource(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) {
return nil, errFoo
}
})
provider := dbAdminProvider{
resourceManager: &resourceManager,
}
_, err := provider.GetClusterResourceAttributes(context.TODO(), project, domain)
assert.EqualError(t, err, errFoo.Error())
})
t.Run("weird db response", func(t *testing.T) {
resourceManager.GetResourceFunc = func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) {
resourceManager := mocks.ResourceInterface{}
resourceManager.EXPECT().GetResource(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) {
return &interfaces.ResourceResponse{
Project: request.Project,
Domain: request.Domain,
Expand All @@ -75,7 +78,7 @@ func TestGetClusterResourceAttributes(t *testing.T) {
},
},
}, nil
}
})
provider := dbAdminProvider{
resourceManager: &resourceManager,
}
Expand Down
Loading

0 comments on commit 2d457df

Please sign in to comment.