Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fixed memoization is unchecked after mutex synchronization. Fixes #11219 #11456

Merged
merged 1 commit into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 82 additions & 46 deletions workflow/controller/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -1804,52 +1804,6 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string,
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}

// If memoization is on, check if node output exists in cache
if node == nil && processedTmpl.Memoize != nil {
memoizationCache := woc.controller.cacheFactory.GetCache(controllercache.ConfigMapCache, processedTmpl.Memoize.Cache.ConfigMap.Name)
if memoizationCache == nil {
err := fmt.Errorf("cache could not be found or created")
woc.log.WithFields(log.Fields{"cacheName": processedTmpl.Memoize.Cache.ConfigMap.Name}).WithError(err)
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}

entry, err := memoizationCache.Load(ctx, processedTmpl.Memoize.Key)
if err != nil {
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}

hit := entry.Hit()
var outputs *wfv1.Outputs
if processedTmpl.Memoize.MaxAge != "" {
maxAge, err := time.ParseDuration(processedTmpl.Memoize.MaxAge)
if err != nil {
err := fmt.Errorf("invalid maxAge: %s", err)
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}
maxAgeOutputs, ok := entry.GetOutputsWithMaxAge(maxAge)
if !ok {
// The outputs are expired, so this cache entry is not hit
hit = false
}
outputs = maxAgeOutputs
} else {
outputs = entry.GetOutputs()
}

memoizationStatus := &wfv1.MemoizationStatus{
Hit: hit,
Key: processedTmpl.Memoize.Key,
CacheName: processedTmpl.Memoize.Cache.ConfigMap.Name,
}
if hit {
node = woc.initializeCacheHitNode(nodeName, processedTmpl, templateScope, orgTmpl, opts.boundaryID, outputs, memoizationStatus)
} else {
node = woc.initializeCacheNode(nodeName, processedTmpl, templateScope, orgTmpl, opts.boundaryID, memoizationStatus)
}
woc.wf.Status.Nodes.Set(node.ID, *node)
woc.updated = true
}

if node != nil {
if node.Fulfilled() {
woc.controller.syncManager.Release(woc.wf, node.ID, processedTmpl.Synchronization)
Expand Down Expand Up @@ -1897,6 +1851,8 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string,
return node, err
}

unlockedNode := false

if processedTmpl.Synchronization != nil {
lockAcquired, wfUpdated, msg, err := woc.controller.syncManager.TryAcquire(woc.wf, woc.wf.NodeID(nodeName), processedTmpl.Synchronization)
if err != nil {
Expand All @@ -1912,6 +1868,7 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string,
// unexpected behavior and is a bug.
panic("bug: GetLockName should not return an error after a call to TryAcquire")
}
woc.log.Infof("Could not acquire lock named: %s", lockName)
return woc.markNodeWaitingForLock(node.Name, lockName.EncodeName())
} else {
woc.log.Infof("Node %s acquired synchronization lock", nodeName)
Expand All @@ -1922,10 +1879,71 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string,
return nil, err
}
}
// Set this value to check that this node is using synchronization, and has acquired the lock
unlockedNode = true
}

woc.updated = woc.updated || wfUpdated
}

// Check memoization cache if the node is about to be created, or was created in the past but is only now allowed to run due to acquiring a lock
if processedTmpl.Memoize != nil {
if node == nil || unlockedNode {
memoizationCache := woc.controller.cacheFactory.GetCache(controllercache.ConfigMapCache, processedTmpl.Memoize.Cache.ConfigMap.Name)
if memoizationCache == nil {
err := fmt.Errorf("cache could not be found or created")
woc.log.WithFields(log.Fields{"cacheName": processedTmpl.Memoize.Cache.ConfigMap.Name}).WithError(err)
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}

entry, err := memoizationCache.Load(ctx, processedTmpl.Memoize.Key)
if err != nil {
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}

hit := entry.Hit()
var outputs *wfv1.Outputs
if processedTmpl.Memoize.MaxAge != "" {
maxAge, err := time.ParseDuration(processedTmpl.Memoize.MaxAge)
if err != nil {
err := fmt.Errorf("invalid maxAge: %s", err)
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}
maxAgeOutputs, ok := entry.GetOutputsWithMaxAge(maxAge)
if !ok {
// The outputs are expired, so this cache entry is not hit
hit = false
}
outputs = maxAgeOutputs
} else {
outputs = entry.GetOutputs()
}

memoizationStatus := &wfv1.MemoizationStatus{
Hit: hit,
Key: processedTmpl.Memoize.Key,
CacheName: processedTmpl.Memoize.Cache.ConfigMap.Name,
}
if hit {
if node == nil {
node = woc.initializeCacheHitNode(nodeName, processedTmpl, templateScope, orgTmpl, opts.boundaryID, outputs, memoizationStatus)
} else {
woc.log.Infof("Node %s is using mutex with memoize. Cache is hit.", nodeName)
woc.updateAsCacheHitNode(node, outputs, memoizationStatus)
}
} else {
if node == nil {
node = woc.initializeCacheNode(nodeName, processedTmpl, templateScope, orgTmpl, opts.boundaryID, memoizationStatus)
} else {
woc.log.Infof("Node %s is using mutex with memoize. Cache is NOT hit", nodeName)
woc.updateAsCacheNode(node, memoizationStatus)
}
}
woc.wf.Status.Nodes.Set(node.ID, *node)
woc.updated = true
}
}

// If the user has specified retries, node becomes a special retry node.
// This node acts as a parent of all retries that will be done for
// the container. The status of this node should be "Success" if any
Expand Down Expand Up @@ -2387,6 +2405,24 @@ func (woc *wfOperationCtx) initializeNode(nodeName string, nodeType wfv1.NodeTyp
return &node
}

// Update a node status with cache status
func (woc *wfOperationCtx) updateAsCacheNode(node *wfv1.NodeStatus, memStat *wfv1.MemoizationStatus) {
node.MemoizationStatus = memStat

woc.wf.Status.Nodes.Set(node.ID, *node)
woc.updated = true
}

// Update a node status that has been cached and marked as finished
func (woc *wfOperationCtx) updateAsCacheHitNode(node *wfv1.NodeStatus, outputs *wfv1.Outputs, memStat *wfv1.MemoizationStatus, message ...string) {
node.Phase = wfv1.NodeSucceeded
node.Outputs = outputs
node.FinishedAt = metav1.Time{Time: time.Now().UTC()}

woc.updateAsCacheNode(node, memStat)
woc.log.Infof("%s node %v updated %s%s", node.Type, node.ID, node.Phase, message)
}

// markNodePhase marks a node with the given phase, creating the node if necessary and handles timestamps
func (woc *wfOperationCtx) markNodePhase(nodeName string, phase wfv1.NodePhase, message ...string) *wfv1.NodeStatus {
node, err := woc.wf.GetNodeByName(nodeName)
Expand Down
91 changes: 91 additions & 0 deletions workflow/controller/operator_concurrency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package controller
import (
"context"
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
"testing"

"github.com/stretchr/testify/assert"
apiv1 "k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
Expand Down Expand Up @@ -1046,3 +1048,92 @@ func TestSynchronizationForPendingShuttingdownWfs(t *testing.T) {
assert.Equal(t, wfv1.WorkflowRunning, wocTwo.execWf.Status.Phase)
})
}

func TestWorkflowMemoizationWithMutex(t *testing.T) {
wf := wfv1.MustUnmarshalWorkflow(`apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: example-steps-simple
namespace: default
spec:
entrypoint: main
templates:
- name: main
steps:
- - name: job-1
template: sleep
arguments:
parameters:
- name: sleep_duration
value: 10
- name: job-2
template: sleep
arguments:
parameters:
- name: sleep_duration
value: 5

- name: sleep
synchronization:
mutex:
name: mutex-example-steps-simple
inputs:
parameters:
- name: sleep_duration
script:
image: alpine:latest
command: [/bin/sh]
source: |
echo "Sleeping for {{ inputs.parameters.sleep_duration }}"
sleep {{ inputs.parameters.sleep_duration }}
memoize:
key: "memo-key-1"
cache:
configMap:
name: cache-example-steps-simple
`)
cancel, controller := newController(wf)
defer cancel()

ctx := context.Background()

woc := newWorkflowOperationCtx(wf, controller)
woc.operate(ctx)

holdingJobs := make(map[string]string)
for _, node := range woc.wf.Status.Nodes {
holdingJobs[node.ID] = node.DisplayName
}

// Check initial status: job-1 acquired the lock
job1AcquiredLock := false
if woc.wf.Status.Synchronization != nil && woc.wf.Status.Synchronization.Mutex != nil {
for _, holding := range woc.wf.Status.Synchronization.Mutex.Holding {
if holdingJobs[holding.Holder] == "job-1" {
fmt.Println("acquired: ", holding.Holder)
job1AcquiredLock = true
}
}
}
assert.True(t, job1AcquiredLock)

// Make job-1's pod succeed
makePodsPhase(ctx, woc, apiv1.PodSucceeded, func(pod *apiv1.Pod) {
if pod.ObjectMeta.Name == "job-1" {
pod.Status.Phase = apiv1.PodSucceeded
}
})
woc.operate(ctx)

// Check final status: both job-1 and job-2 succeeded, job-2 simply hit the cache
for _, node := range woc.wf.Status.Nodes {
switch node.DisplayName {
case "job-1":
assert.Equal(t, wfv1.NodeSucceeded, node.Phase)
assert.False(t, node.MemoizationStatus.Hit)
case "job-2":
assert.Equal(t, wfv1.NodeSucceeded, node.Phase)
assert.True(t, node.MemoizationStatus.Hit)
}
}
}