Skip to content

Commit

Permalink
Task parallelism updated. (flyteorg#348)
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 authored Oct 21, 2021
1 parent 37408ed commit d5ed64e
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 24 deletions.
64 changes: 43 additions & 21 deletions flytepropeller/pkg/controller/nodes/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,14 @@ func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructur
// Optimization!
// If it is start node we directly move it to Queued without needing to run preExecute
if currentPhase == v1alpha1.NodePhaseNotYetStarted && !nCtx.Node().IsStartNode() {
return c.handleNotYetStartedNode(ctx, dag, nCtx, h)
p, err := c.handleNotYetStartedNode(ctx, dag, nCtx, h)
if err != nil {
return p, err
}
if p.NodePhase == executors.NodePhaseQueued {
logger.Infof(ctx, "Node was queued, parallelism is now [%d]", nCtx.ExecutionContext().IncrementParallelism())
}
return p, err
}

if currentPhase == v1alpha1.NodePhaseFailing {
Expand Down Expand Up @@ -794,6 +801,39 @@ func canHandleNode(phase v1alpha1.NodePhase) bool {
phase == v1alpha1.NodePhaseDynamicRunning
}

// IsMaxParallelismAchieved checks if we have already achieved max parallelism. It returns true, if the desired max parallelism
// value is achieved, false otherwise
// MaxParallelism is defined as the maximum number of TaskNodes and LaunchPlans (together) that can be executed concurrently
// by one workflow execution. A setting of `0` indicates that it is disabled.
func IsMaxParallelismAchieved(ctx context.Context, currentNode v1alpha1.ExecutableNode, currentPhase v1alpha1.NodePhase,
execContext executors.ExecutionContext) bool {
maxParallelism := execContext.GetExecutionConfig().MaxParallelism
if maxParallelism == 0 {
logger.Debugf(ctx, "Parallelism control disabled")
return false
}

if currentNode.GetKind() == v1alpha1.NodeKindTask ||
(currentNode.GetKind() == v1alpha1.NodeKindWorkflow && currentNode.GetWorkflowNode() != nil && currentNode.GetWorkflowNode().GetLaunchPlanRefID() != nil) {
// If we are queued, let us see if we can proceed within the node parallelism bounds
if execContext.CurrentParallelism() >= maxParallelism {
logger.Infof(ctx, "Maximum Parallelism for task/launch-plan nodes achieved [%d] >= Max [%d], Round will be short-circuited.", execContext.CurrentParallelism(), maxParallelism)
return true
}
// We know that Propeller goes through each workflow in a single thread, thus every node is really processed
// sequentially. So, we can continue - now that we know we are under the parallelism limits and increment the
// parallelism if the node, enters a running state
logger.Debugf(ctx, "Parallelism criteria not met, Current [%d], Max [%d]", execContext.CurrentParallelism(), maxParallelism)
} else {
logger.Debugf(ctx, "NodeKind: %s in status [%s]. Parallelism control is not applicable. Current Parallelism [%d]",
currentNode.GetKind().String(), currentPhase.String(), execContext.CurrentParallelism())
}
return false
}

// RecursiveNodeHandler This is the entrypoint of executing a node in a workflow. A workflow consists of nodes, that are
// nested within other nodes. The system follows an actor model, where the parent nodes control the execution of nested nodes
// The recursive node-handler uses a modified depth-first type of algorithm to execute non-blocked nodes.
func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext,
dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (
executors.NodeStatus, error) {
Expand Down Expand Up @@ -821,26 +861,8 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext exe
return executors.NodeStatusRunning, nil
}

// Now if the node is of type task, then let us check if we are within the parallelism limit, only if the node
// has been queued already
if currentNode.GetKind() == v1alpha1.NodeKindTask && nodeStatus.GetPhase() == v1alpha1.NodePhaseQueued {
maxParallelism := execContext.GetExecutionConfig().MaxParallelism
if maxParallelism > 0 {
// If we are queued, let us see if we can proceed within the node parallelism bounds
if execContext.CurrentParallelism() >= maxParallelism {
logger.Infof(ctx, "Maximum Parallelism for task nodes achieved [%d] >= Max [%d], Round will be short-circuited.", execContext.CurrentParallelism(), maxParallelism)
return executors.NodeStatusRunning, nil
}
// We know that Propeller goes through each workflow in a single thread, thus every node is really processed
// sequentially. So, we can continue - now that we know we are under the parallelism limits and increment the
// parallelism if the node, enters a running state
logger.Debugf(ctx, "Parallelism criteria not met, Current [%d], Max [%d]", execContext.CurrentParallelism(), maxParallelism)
} else {
logger.Debugf(ctx, "Parallelism control disabled")
}
} else {
logger.Debugf(ctx, "NodeKind: %s in status [%s]. Parallelism control is not applicable. Current Parallelism [%d]",
currentNode.GetKind().String(), nodeStatus.GetPhase().String(), execContext.CurrentParallelism())
if IsMaxParallelismAchieved(ctx, currentNode, nodePhase, execContext) {
return executors.NodeStatusRunning, nil
}

nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl)
Expand Down
75 changes: 74 additions & 1 deletion flytepropeller/pkg/controller/nodes/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) {
hf := &mocks2.HandlerFactory{}
exec.nodeHandlerFactory = hf
h := &nodeHandlerMocks.Node{}
hf.On("GetHandler", v1alpha1.NodeKindEnd).Return(h, nil)
hf.OnGetHandler(v1alpha1.NodeKindEnd).Return(h, nil)

mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.parentNodePhase, 0)
execContext := executors.NewExecutionContext(mockWf, nil, nil, nil, executors.InitializeControlFlow())
Expand Down Expand Up @@ -1286,6 +1286,7 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) {
eCtx.OnGetRawOutputDataConfig().Return(v1alpha1.RawOutputDataConfig{
RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: ""},
})
eCtx.OnIncrementParallelism().Return(0)
eCtx.OnCurrentParallelism().Return(0)
eCtx.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{})

Expand Down Expand Up @@ -1879,6 +1880,17 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) {
assert.Equal(t, s.NodePhase.String(), executors.NodePhaseRunning.String())
})

t.Run("parallelism-met-not-yet-started", func(t *testing.T) {
mockWf, mockNode, _ := createSingleNodeWf(v1alpha1.NodePhaseNotYetStarted, 1)
cf := executors.InitializeControlFlow()
cf.IncrementParallelism()
eCtx := executors.NewExecutionContext(mockWf, mockWf, nil, nil, cf)

s, err := exec.RecursiveNodeHandler(ctx, eCtx, mockWf, mockWf, mockNode)
assert.NoError(t, err)
assert.Equal(t, s.NodePhase.String(), executors.NodePhaseRunning.String())
})

t.Run("parallelism-disabled", func(t *testing.T) {
mockWf, mockNode, _ := createSingleNodeWf(v1alpha1.NodePhaseQueued, 0)
cf := executors.InitializeControlFlow()
Expand Down Expand Up @@ -2299,3 +2311,64 @@ func TestRecover(t *testing.T) {
mockPBStore.AssertNumberOfCalls(t, "ReadProtobuf", 1)
})
}

func TestIsMaxParallelismAchieved(t *testing.T) {

// Creates an execution context for the test
createExecContext := func(maxParallelism, currentParallelism uint32) executors.ExecutionContext {
m := &mocks4.ExecutionContext{}
m.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{
MaxParallelism: maxParallelism,
})
m.OnCurrentParallelism().Return(currentParallelism)
return m
}

createNode := func(kind v1alpha1.NodeKind, lpRef bool) v1alpha1.ExecutableNode {
en := &mocks.ExecutableNode{}
en.OnGetKind().Return(kind)
if kind == v1alpha1.NodeKindWorkflow {
wn := &mocks.ExecutableWorkflowNode{}
var lp *v1alpha1.LaunchPlanRefID
if lpRef {
lp = &v1alpha1.LaunchPlanRefID{}
}
wn.OnGetLaunchPlanRefID().Return(lp)
en.OnGetWorkflowNode().Return(wn)
}
return en
}

type args struct {
currentNode v1alpha1.ExecutableNode
currentPhase v1alpha1.NodePhase
execContext executors.ExecutionContext
}
tests := []struct {
name string
args args
want bool
}{
{"start", args{createNode(v1alpha1.NodeKindStart, false), v1alpha1.NodePhaseQueued, createExecContext(1, 1)}, false},
{"end", args{createNode(v1alpha1.NodeKindEnd, false), v1alpha1.NodePhaseQueued, createExecContext(1, 1)}, false},
{"branch", args{createNode(v1alpha1.NodeKindBranch, false), v1alpha1.NodePhaseQueued, createExecContext(1, 1)}, false},
{"subworkflow", args{createNode(v1alpha1.NodeKindWorkflow, false), v1alpha1.NodePhaseQueued, createExecContext(1, 1)}, false},
{"lp-met", args{createNode(v1alpha1.NodeKindWorkflow, true), v1alpha1.NodePhaseQueued, createExecContext(1, 1)}, true},
{"lp-met-larger", args{createNode(v1alpha1.NodeKindWorkflow, true), v1alpha1.NodePhaseQueued, createExecContext(1, 2)}, true},
{"lp-disabled", args{createNode(v1alpha1.NodeKindWorkflow, true), v1alpha1.NodePhaseQueued, createExecContext(0, 1)}, false},
{"lp-not-met", args{createNode(v1alpha1.NodeKindWorkflow, true), v1alpha1.NodePhaseQueued, createExecContext(4, 1)}, false},
{"lp-not-met-1", args{createNode(v1alpha1.NodeKindWorkflow, true), v1alpha1.NodePhaseQueued, createExecContext(2, 1)}, false},
{"task-met", args{createNode(v1alpha1.NodeKindTask, false), v1alpha1.NodePhaseQueued, createExecContext(1, 1)}, true},
{"task-met-larger", args{createNode(v1alpha1.NodeKindTask, false), v1alpha1.NodePhaseQueued, createExecContext(1, 2)}, true},
{"task-disabled", args{createNode(v1alpha1.NodeKindTask, false), v1alpha1.NodePhaseQueued, createExecContext(0, 1)}, false},
{"task-not-met", args{createNode(v1alpha1.NodeKindTask, false), v1alpha1.NodePhaseQueued, createExecContext(4, 1)}, false},
{"task-not-met-1", args{createNode(v1alpha1.NodeKindTask, false), v1alpha1.NodePhaseQueued, createExecContext(2, 1)}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsMaxParallelismAchieved(context.TODO(), tt.args.currentNode, tt.args.currentPhase, tt.args.execContext); got != tt.want {
t.Errorf("IsMaxParallelismAchieved() = %v, want %v", got, tt.want)
}
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ func createNodeContextWithVersion(phase v1alpha1.WorkflowNodePhase, n v1alpha1.E
ex.OnGetParentInfo().Return(nil)
ex.OnGetName().Return("name")
ex.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{})
ex.OnIncrementParallelism().Return(1)

nCtx.OnExecutionContext().Return(ex)

Expand Down Expand Up @@ -171,6 +172,8 @@ func TestWorkflowNodeHandler_StartNode_Launchplan(t *testing.T) {
s, err := h.Handle(ctx, nCtx)
assert.NoError(t, err)
assert.Equal(t, handler.EPhaseRunning, s.Info().GetPhase())
c := nCtx.ExecutionContext().(*execMocks.ExecutionContext)
c.AssertCalled(t, "IncrementParallelism")
})

t.Run("happy v1", func(t *testing.T) {
Expand All @@ -194,6 +197,8 @@ func TestWorkflowNodeHandler_StartNode_Launchplan(t *testing.T) {
s, err := h.Handle(ctx, nCtx)
assert.NoError(t, err)
assert.Equal(t, handler.EPhaseRunning, s.Info().GetPhase())
c := nCtx.ExecutionContext().(*execMocks.ExecutionContext)
c.AssertCalled(t, "IncrementParallelism")
})
}

Expand Down Expand Up @@ -243,6 +248,8 @@ func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) {
s, err := h.Handle(ctx, nCtx)
assert.NoError(t, err)
assert.Equal(t, handler.EPhaseRunning, s.Info().GetPhase())
c := nCtx.ExecutionContext().(*execMocks.ExecutionContext)
c.AssertCalled(t, "IncrementParallelism")
})
t.Run("stillRunning V1", func(t *testing.T) {

Expand All @@ -262,6 +269,8 @@ func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) {
s, err := h.Handle(ctx, nCtx)
assert.NoError(t, err)
assert.Equal(t, handler.EPhaseRunning, s.Info().GetPhase())
c := nCtx.ExecutionContext().(*execMocks.ExecutionContext)
c.AssertCalled(t, "IncrementParallelism")
})
}

Expand Down
7 changes: 5 additions & 2 deletions flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx handler.No
return handler.UnknownTransition, err
}
} else {
logger.Infof(ctx, "Launched launchplan with ID [%s]", childID.Name)
eCtx := nCtx.ExecutionContext()
logger.Infof(ctx, "Launched launchplan with ID [%s], Parallelism is now set to [%d]", childID.Name, eCtx.IncrementParallelism())
}

return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{
Expand Down Expand Up @@ -134,7 +135,8 @@ func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx hand

if wfStatusClosure == nil {
logger.Info(ctx, "Retrieved Launch Plan status is nil. This might indicate pressure on the admin cache."+
" Consider tweaking its size to allow for more concurrent executions to be cached.")
" Consider tweaking its size to allow for more concurrent executions to be cached."+
" Assuming LP is running, parallelism [%d].", nCtx.ExecutionContext().IncrementParallelism())
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{
WorkflowNodeInfo: &handler.WorkflowNodeInfo{LaunchedWorkflowID: childID},
})), nil
Expand Down Expand Up @@ -184,6 +186,7 @@ func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx hand
OutputInfo: oInfo,
})), nil
}
logger.Infof(ctx, "LaunchPlan running, parallelism is now set to [%d]", nCtx.ExecutionContext().IncrementParallelism())
return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) {
WorkflowExecutionIdentifier: recoveredExecID,
},
})
ectx.OnIncrementParallelism().Return(1)
nCtx.OnExecutionContext().Return(ectx)
nCtx.OnCurrentAttempt().Return(uint32(1))
nCtx.OnNode().Return(mockNode)
Expand Down

0 comments on commit d5ed64e

Please sign in to comment.