Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(): Workflow cancellation + gracefully handle non serializable state #10674

Merged
merged 6 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/beige-ligers-yawn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@medusajs/orchestration": patch
"@medusajs/workflows-sdk": patch
---

fix(): Workflow cancellation + gracefully handle non serializable state
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { TransactionStepState, TransactionStepStatus } from "@medusajs/utils"
import { setTimeout } from "timers/promises"
import {
DistributedTransaction,
DistributedTransactionType,
TransactionHandlerType,
TransactionOrchestrator,
Expand All @@ -10,6 +11,7 @@ import {
TransactionStepTimeoutError,
TransactionTimeoutError,
} from "../../transaction"
import { BaseInMemoryDistributedTransactionStorage } from "../../transaction/datastore/base-in-memory-storage"

describe("Transaction Orchestrator", () => {
afterEach(() => {
Expand Down Expand Up @@ -151,6 +153,104 @@ describe("Transaction Orchestrator", () => {
expect(actionOrder).toEqual(["one", "two", "three", "four", "five", "six"])
})

it("Should gracefully handle non serializable error when an async step fails", async () => {
class BaseInMemoryDistributedTransactionStorage_ extends BaseInMemoryDistributedTransactionStorage {
scheduleRetry() {
return Promise.resolve()
}
}
DistributedTransaction.setStorage(
new BaseInMemoryDistributedTransactionStorage_()
)

const actionOrder: string[] = []
async function handler(
actionId: string,
functionHandlerType: TransactionHandlerType,
payload: TransactionPayload
) {
if (functionHandlerType === TransactionHandlerType.INVOKE) {
actionOrder.push(actionId)
}

if (
functionHandlerType === TransactionHandlerType.INVOKE &&
actionId === "three"
) {
const error = new Error("Step 3 failed")

const obj: any = {}
obj.self = obj
;(error as any).metadata = obj
throw error
}
}

const flow: TransactionStepsDefinition = {
next: [
{
action: "one",
},
{
action: "two",
next: {
action: "four",
next: {
action: "six",
},
},
},
{
action: "three",
async: true,
maxRetries: 0,
next: {
action: "five",
},
},
],
}

const strategy = new TransactionOrchestrator({
id: "transaction-name",
definition: flow,
})

const transaction = await strategy.beginTransaction(
"transaction_id_123",
handler
)

await strategy.resume(transaction)

expect(transaction.getErrors()).toHaveLength(2)
expect(transaction.getErrors()).toEqual([
{
action: "three",
error: {
message: "Step 3 failed",
name: "Error",
stack: expect.any(String),
},
handlerType: "invoke",
},
{
action: "three",
error: expect.objectContaining({
message: expect.stringContaining(
"Converting circular structure to JSON"
),
stack: expect.any(String),
}),
handlerType: "invoke",
},
])

DistributedTransaction.setStorage(
new BaseInMemoryDistributedTransactionStorage()
)
})

it("Should not execute next steps when a step fails", async () => {
const actionOrder: string[] = []
async function handler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
TransactionHandlerType,
TransactionState,
} from "./types"
import { NonSerializableCheckPointError } from "./errors"

/**
* @typedef TransactionMetadata
Expand Down Expand Up @@ -204,19 +205,14 @@ class DistributedTransaction extends EventEmitter {
return
}

const data = new TransactionCheckpoint(
this.getFlow(),
this.getContext(),
this.getErrors()
)

const key = TransactionOrchestrator.getKeyName(
DistributedTransaction.keyPrefix,
this.modelId,
this.transactionId
)

const rawData = JSON.parse(JSON.stringify(data))
const rawData = this.#serializeCheckpointData()

await DistributedTransaction.keyValueStore.save(key, rawData, ttl, options)

return rawData
Expand Down Expand Up @@ -320,6 +316,76 @@ class DistributedTransaction extends EventEmitter {
public hasTemporaryData(key: string) {
return this.#temporaryStorage.has(key)
}

/**
* Try to serialize the checkpoint data
* If it fails, it means that the context or the errors are not serializable
* and we should handle it
*
* @internal
* @returns
*/
#serializeCheckpointData() {
const data = new TransactionCheckpoint(
this.getFlow(),
this.getContext(),
this.getErrors()
)

const isSerializable = (obj) => {
try {
JSON.parse(JSON.stringify(obj))
return true
} catch {
return false
}
}

let rawData
try {
rawData = JSON.parse(JSON.stringify(data))
} catch (e) {
if (!isSerializable(this.context)) {
// This is a safe guard, we should never reach this point
// If we do, it means that the context is not serializable
// and we should throw an error
throw new NonSerializableCheckPointError(
"Unable to serialize context object. Please make sure the workflow input and steps response are serializable."
)
}

if (!isSerializable(this.errors)) {
const nonSerializableErrors: TransactionStepError[] = []
for (const error of this.errors) {
if (!isSerializable(error.error)) {
error.error = {
name: error.error.name,
message: error.error.message,
stack: error.error.stack,
}
nonSerializableErrors.push({
...error,
error: e,
})
}
}

if (nonSerializableErrors.length) {
this.errors.push(...nonSerializableErrors)
}
}

const data = new TransactionCheckpoint(
this.getFlow(),
this.getContext(),
this.getErrors()
)

rawData = JSON.parse(JSON.stringify(data))
}

return rawData
}
}

DistributedTransaction.setStorage(
Expand Down
16 changes: 16 additions & 0 deletions packages/core/orchestration/src/transaction/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,19 @@ export class TransactionTimeoutError extends BaseStepErrror {
super("TransactionTimeoutError", message, stepResponse)
}
}

export class NonSerializableCheckPointError extends Error {
static isNonSerializableCheckPointError(
error: Error
): error is NonSerializableCheckPointError {
return (
error instanceof NonSerializableCheckPointError ||
error?.name === "NonSerializableCheckPointError"
)
}

constructor(message?: string) {
super(message)
this.name = "NonSerializableCheckPointError"
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { TransactionState } from "@medusajs/utils"
import { createStep } from "../create-step"
import { createWorkflow } from "../create-workflow"
import { StepResponse } from "../helpers"
Expand Down Expand Up @@ -42,6 +43,44 @@ describe("Workflow composer", () => {
expect(result).toEqual({ result: "hi from outside" })
})

it("should cancel transaction on failed sub workflow call", async function () {
const step1 = createStep("step1", async (_, context) => {
return new StepResponse("step1")
})

const step2 = createStep("step2", async (input: string, context) => {
return new StepResponse({ result: input })
})
const step3 = createStep("step3", async (input: string, context) => {
throw new Error("I have failed")
})

const subWorkflow = createWorkflow(
getNewWorkflowId(),
function (input: WorkflowData<string>) {
step1()
return new WorkflowResponse(step2(input))
}
)

const workflow = createWorkflow(getNewWorkflowId(), function () {
const subWorkflowRes = subWorkflow.runAsStep({
input: "hi from outside",
})
return new WorkflowResponse(step3(subWorkflowRes.result))
})

const { errors, transaction } = await workflow.run({
input: {},
throwOnError: false,
})

expect(errors).toHaveLength(1)
expect(errors[0].error.message).toEqual("I have failed")

expect(transaction.getState()).toEqual(TransactionState.REVERTED)
})

it("should skip step if condition is false", async function () {
const step1 = createStep("step1", async (_, context) => {
return new StepResponse({ result: "step1" })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
OrchestrationUtils,
} from "@medusajs/utils"
import { ulid } from "ulid"
import { exportWorkflow } from "../../helper"
import { exportWorkflow, WorkflowResult } from "../../helper"
import { createStep } from "./create-step"
import { proxify } from "./helpers/proxy"
import { StepResponse } from "./helpers/step-response"
Expand Down Expand Up @@ -201,20 +201,29 @@ export function createWorkflow<TData, TResult, THooks extends any[]>(
},
})

const { result, transaction: flowTransaction } = transaction
const { result } = transaction

if (!context.isAsync || flowTransaction.hasFinished()) {
return new StepResponse(result, transaction)
}

return
return new StepResponse(
result,
context.isAsync ? stepContext.transactionId : transaction
)
},
async (transaction, { container }) => {
async (transaction, stepContext) => {
if (!transaction) {
return
}

await workflow(container).cancel(transaction)
const { container, ...sharedContext } = stepContext

await workflow(container).cancel({
transaction: (transaction as WorkflowResult<any>).transaction,
transactionId: isString(transaction) ? transaction : undefined,
container,
context: {
...sharedContext,
parentStepIdempotencyKey: stepContext.idempotencyKey,
},
})
}
)(input) as ReturnType<StepFunction<TData, TResult>>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ export * from "./workflow_async"
export * from "./workflow_step_timeout"
export * from "./workflow_transaction_timeout"
export * from "./workflow_when"
export * from "./workflow_async_compensate"
Loading
Loading