Skip to content

Commit

Permalink
Fix map test with empty input list (flyteorg#246)
Browse files Browse the repository at this point in the history
* fix map task with no inputs

Signed-off-by: Daniel Rammer <[email protected]>

* fixed lint issues

Signed-off-by: Daniel Rammer <[email protected]>

* added unit test

Signed-off-by: Daniel Rammer <[email protected]>

* updated flytestdlib version to merged fix

Signed-off-by: Daniel Rammer <[email protected]>

* updated idl for tests

Signed-off-by: Daniel Rammer <[email protected]>

* flyteidl dependency fix

Signed-off-by: Daniel Rammer <[email protected]>
  • Loading branch information
hamersaw authored Mar 10, 2022
1 parent 26d0bc7 commit cf25a2c
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 52 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/athena v1.0.0
github.com/coocood/freecache v1.1.1
github.com/flyteorg/flyteidl v0.23.0
github.com/flyteorg/flytestdlib v0.4.7
github.com/flyteorg/flytestdlib v0.4.13
github.com/go-logr/zapr v0.4.0 // indirect
github.com/go-test/deep v1.0.7
github.com/golang/protobuf v1.4.3
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ github.com/flyteorg/flyteidl v0.21.23/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/
github.com/flyteorg/flyteidl v0.23.0 h1:Pjl9Tq1pJfIK0au5PiqPVpl25xTYosN6BruZl+PgWAk=
github.com/flyteorg/flyteidl v0.23.0/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220=
github.com/flyteorg/flytestdlib v0.4.7 h1:SMPPXI3j/MjP7D2fqaR+lPQkTrqYS7xZbwsgJI2F8SU=
github.com/flyteorg/flytestdlib v0.4.7/go.mod h1:fv1ar34LJLMTaf0tbfetisLykUlARi7rP+NQTUn6QQs=
github.com/flyteorg/flytestdlib v0.4.13 h1:TzgqhECRGfOHYH1A7rUwcKEEH2rTtPxGy+oYcif7iBw=
github.com/flyteorg/flytestdlib v0.4.13/go.mod h1:fv1ar34LJLMTaf0tbfetisLykUlARi7rP+NQTUn6QQs=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4=
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/array/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
return state, errors.Errorf(errors.MetadataAccessFailed, "Could not read inputs and therefore failed to determine array job size")
}

size := 0
size := -1
var literalCollection *idlCore.LiteralCollection
var discoveredInputName string
for inputName, literal := range inputs.Literals {
Expand All @@ -89,7 +89,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
}
}

if size == 0 {
if size < 0 {
// Something is wrong, we should have inferred the array size when it is not specified by the size of the
// input collection (for any input value). Non-collection type inputs are not currently supported for
// taskTypeVersion > 0.
Expand Down
9 changes: 8 additions & 1 deletion go/tasks/plugins/array/core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,14 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl
fallthrough

case PhaseWriteToDiscovery:
version := GetPhaseVersionOffset(p, state.GetOriginalArraySize()) + version
// If the array task has 0 inputs we need to ensure the phaseVersion changes so that the
// task can progess. Therefore we default to task length 1 to ensure phase updates.
length := int64(1)
if state.GetOriginalArraySize() != 0 {
length = state.GetOriginalArraySize()
}

version := GetPhaseVersionOffset(p, length) + version
phaseInfo = core.PhaseInfoRunning(version, nowTaskInfo)

case PhaseSuccess:
Expand Down
13 changes: 13 additions & 0 deletions go/tasks/plugins/array/outputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ func (w assembleOutputsWorker) Process(ctx context.Context, workItem workqueue.W
finalOutputs := &core.LiteralMap{
Literals: map[string]*core.Literal{},
}

// Initialize the final output literal with empty output variable collections. Otherwise, if a
// task has no input values they will never be written.
for _, varName := range i.varNames {
finalOutputs.Literals[varName] = &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: make([]*core.Literal, 0),
},
},
}
}

for idx, subTaskPhaseIdx := range i.finalPhases.GetItems() {
existingPhase := pluginCore.Phases[subTaskPhaseIdx]
if existingPhase.IsSuccess() {
Expand Down
140 changes: 94 additions & 46 deletions go/tasks/plugins/array/outputs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,62 +71,110 @@ func init() {
func Test_assembleOutputsWorker_Process(t *testing.T) {
ctx := context.Background()

memStore, err := storage.NewDataStore(&storage.Config{
Type: storage.TypeMemory,
}, promutils.NewTestScope())
assert.NoError(t, err)
t.Run("EmptyInputs", func(t *testing.T) {
memStore, err := storage.NewDataStore(&storage.Config{
Type: storage.TypeMemory,
}, promutils.NewTestScope())
assert.NoError(t, err)

// Write data to 1st and 3rd tasks only. Simulate a failed 2nd and 4th tasks.
l := coreutils.MustMakeLiteral(map[string]interface{}{
"var1": 5,
"var2": "hello world",
// Setup the expected data to be written to outputWriter.
ow := &mocks2.OutputWriter{}
ow.OnGetOutputPrefixPath().Return("/bucket/prefix")
ow.OnGetOutputPath().Return("/bucket/prefix/outputs.pb")
ow.OnGetRawOutputPrefix().Return("/bucket/sandbox/")

// Setup the input phases that inform outputs worker about which tasks failed/succeeded.
phases := arrayCore.NewPhasesCompactArray(0)

item := &outputAssembleItem{
outputPaths: ow,
varNames: []string{"var1", "var2"},
finalPhases: phases,
dataStore: memStore,
isAwsSingleJob: false,
}

w := assembleOutputsWorker{}
actual, err := w.Process(ctx, item)
assert.NoError(t, err)
assert.Equal(t, workqueue.WorkStatusSucceeded, actual)

actualOutputs := &core.LiteralMap{}
assert.NoError(t, memStore.ReadProtobuf(ctx, "/bucket/prefix/outputs.pb", actualOutputs))
expected := coreutils.MustMakeLiteral(map[string]interface{}{
"var1": []interface{}{},
"var2": []interface{}{},
}).GetMap()

expectedBytes, err := json.Marshal(expected)
assert.NoError(t, err)

actualBytes, err := json.Marshal(actualOutputs)
assert.NoError(t, err)

if diff := deep.Equal(string(actualBytes), string(expectedBytes)); diff != nil {
assert.FailNow(t, "Should be equal.", "Diff: %v", diff)
}
})
assert.NoError(t, memStore.WriteProtobuf(ctx, "/bucket/prefix/0/outputs.pb", storage.Options{}, l.GetMap()))
assert.NoError(t, memStore.WriteProtobuf(ctx, "/bucket/prefix/2/outputs.pb", storage.Options{}, l.GetMap()))

// Setup the expected data to be written to outputWriter.
ow := &mocks2.OutputWriter{}
ow.OnGetOutputPrefixPath().Return("/bucket/prefix")
ow.OnGetOutputPath().Return("/bucket/prefix/outputs.pb")
ow.OnGetRawOutputPrefix().Return("/bucket/sandbox/")
t.Run("MissingTasks", func(t *testing.T) {
memStore, err := storage.NewDataStore(&storage.Config{
Type: storage.TypeMemory,
}, promutils.NewTestScope())
assert.NoError(t, err)

// Setup the input phases that inform outputs worker about which tasks failed/succeeded.
phases := arrayCore.NewPhasesCompactArray(4)
phases.SetItem(0, bitarray.Item(pluginCore.PhaseSuccess))
phases.SetItem(1, bitarray.Item(pluginCore.PhasePermanentFailure))
phases.SetItem(2, bitarray.Item(pluginCore.PhaseSuccess))
phases.SetItem(3, bitarray.Item(pluginCore.PhasePermanentFailure))
// Write data to 1st and 3rd tasks only. Simulate a failed 2nd and 4th tasks.
l := coreutils.MustMakeLiteral(map[string]interface{}{
"var1": 5,
"var2": "hello world",
})
assert.NoError(t, memStore.WriteProtobuf(ctx, "/bucket/prefix/0/outputs.pb", storage.Options{}, l.GetMap()))
assert.NoError(t, memStore.WriteProtobuf(ctx, "/bucket/prefix/2/outputs.pb", storage.Options{}, l.GetMap()))

item := &outputAssembleItem{
outputPaths: ow,
varNames: []string{"var1", "var2"},
finalPhases: phases,
dataStore: memStore,
isAwsSingleJob: false,
}
// Setup the expected data to be written to outputWriter.
ow := &mocks2.OutputWriter{}
ow.OnGetOutputPrefixPath().Return("/bucket/prefix")
ow.OnGetOutputPath().Return("/bucket/prefix/outputs.pb")
ow.OnGetRawOutputPrefix().Return("/bucket/sandbox/")

// Setup the input phases that inform outputs worker about which tasks failed/succeeded.
phases := arrayCore.NewPhasesCompactArray(4)
phases.SetItem(0, bitarray.Item(pluginCore.PhaseSuccess))
phases.SetItem(1, bitarray.Item(pluginCore.PhasePermanentFailure))
phases.SetItem(2, bitarray.Item(pluginCore.PhaseSuccess))
phases.SetItem(3, bitarray.Item(pluginCore.PhasePermanentFailure))

item := &outputAssembleItem{
outputPaths: ow,
varNames: []string{"var1", "var2"},
finalPhases: phases,
dataStore: memStore,
isAwsSingleJob: false,
}

w := assembleOutputsWorker{}
actual, err := w.Process(ctx, item)
assert.NoError(t, err)
assert.Equal(t, workqueue.WorkStatusSucceeded, actual)
w := assembleOutputsWorker{}
actual, err := w.Process(ctx, item)
assert.NoError(t, err)
assert.Equal(t, workqueue.WorkStatusSucceeded, actual)

actualOutputs := &core.LiteralMap{}
assert.NoError(t, memStore.ReadProtobuf(ctx, "/bucket/prefix/outputs.pb", actualOutputs))
// Since 2nd and 4th tasks failed, there should be nil literals in their expected places.
expected := coreutils.MustMakeLiteral(map[string]interface{}{
"var1": []interface{}{5, nil, 5, nil},
"var2": []interface{}{"hello world", nil, "hello world", nil},
}).GetMap()
actualOutputs := &core.LiteralMap{}
assert.NoError(t, memStore.ReadProtobuf(ctx, "/bucket/prefix/outputs.pb", actualOutputs))
// Since 2nd and 4th tasks failed, there should be nil literals in their expected places.
expected := coreutils.MustMakeLiteral(map[string]interface{}{
"var1": []interface{}{5, nil, 5, nil},
"var2": []interface{}{"hello world", nil, "hello world", nil},
}).GetMap()

expectedBytes, err := json.Marshal(expected)
assert.NoError(t, err)
expectedBytes, err := json.Marshal(expected)
assert.NoError(t, err)

actualBytes, err := json.Marshal(actualOutputs)
assert.NoError(t, err)
actualBytes, err := json.Marshal(actualOutputs)
assert.NoError(t, err)

if diff := deep.Equal(string(actualBytes), string(expectedBytes)); diff != nil {
assert.FailNow(t, "Should be equal.", "Diff: %v", diff)
}
if diff := deep.Equal(string(actualBytes), string(expectedBytes)); diff != nil {
assert.FailNow(t, "Should be equal.", "Diff: %v", diff)
}
})
}

func Test_appendSubTaskOutput(t *testing.T) {
Expand Down

0 comments on commit cf25a2c

Please sign in to comment.