Skip to content

Commit

Permalink
start workflow job -> initialize workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource committed Nov 1, 2024
1 parent 78ed565 commit 1331902
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 49 deletions.
10 changes: 5 additions & 5 deletions temporalio/bridge/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,11 @@ async def decode_activation(
await codec.decode_failure(job.resolve_signal_external_workflow.failure)
elif job.HasField("signal_workflow"):
await _decode_payloads(job.signal_workflow.input, codec)
elif job.HasField("start_workflow"):
await _decode_payloads(job.start_workflow.arguments, codec)
if job.start_workflow.HasField("continued_failure"):
await codec.decode_failure(job.start_workflow.continued_failure)
for val in job.start_workflow.memo.fields.values():
elif job.HasField("initialize_workflow"):
await _decode_payloads(job.initialize_workflow.arguments, codec)
if job.initialize_workflow.HasField("continued_failure"):
await codec.decode_failure(job.initialize_workflow.continued_failure)
for val in job.initialize_workflow.memo.fields.values():
# This uses API payload not bridge payload
new_payload = (await codec.decode([val]))[0]
val.metadata.clear()
Expand Down
66 changes: 33 additions & 33 deletions temporalio/worker/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ async def _handle_activation(

# Extract a couple of jobs from the activation
cache_remove_job = None
start_job = None
init_job = None
for job in act.jobs:
if job.HasField("remove_from_cache"):
cache_remove_job = job.remove_from_cache
elif job.HasField("start_workflow"):
start_job = job.start_workflow
elif job.HasField("initialize_workflow"):
init_job = job.initialize_workflow

# Build default success completion (e.g. remove-job-only activations)
completion = (
Expand All @@ -235,16 +235,16 @@ async def _handle_activation(
if not cache_remove_job or not self._disable_safe_eviction:
workflow = self._running_workflows.get(act.run_id)
if not workflow and not cache_remove_job:
# Must have a start job to create instance
if not start_job:
# Must have a initialize job to create instance
if not init_job:
raise RuntimeError(
"Missing start workflow, workflow could have unexpectedly been removed from cache"
"Missing initialize workflow, workflow could have unexpectedly been removed from cache"
)
workflow = self._create_workflow_instance(act, start_job)
workflow = self._create_workflow_instance(act, init_job)
self._running_workflows[act.run_id] = workflow
elif start_job:
elif init_job:
# This should never happen
logger.warn("Cache already exists for activation with start job")
logger.warn("Cache already exists for activation with initialize job")

# Run activation in separate thread so we can check if it's
# deadlocked
Expand Down Expand Up @@ -354,54 +354,54 @@ async def _handle_activation(
def _create_workflow_instance(
self,
act: temporalio.bridge.proto.workflow_activation.WorkflowActivation,
start: temporalio.bridge.proto.workflow_activation.StartWorkflow,
init: temporalio.bridge.proto.workflow_activation.InitializeWorkflow,
) -> WorkflowInstance:
# Get the definition
defn = self._workflows.get(start.workflow_type, self._dynamic_workflow)
defn = self._workflows.get(init.workflow_type, self._dynamic_workflow)
if not defn:
workflow_names = ", ".join(sorted(self._workflows.keys()))
raise temporalio.exceptions.ApplicationError(
f"Workflow class {start.workflow_type} is not registered on this worker, available workflows: {workflow_names}",
f"Workflow class {init.workflow_type} is not registered on this worker, available workflows: {workflow_names}",
type="NotFoundError",
)

# Build info
parent: Optional[temporalio.workflow.ParentInfo] = None
if start.HasField("parent_workflow_info"):
if init.HasField("parent_workflow_info"):
parent = temporalio.workflow.ParentInfo(
namespace=start.parent_workflow_info.namespace,
run_id=start.parent_workflow_info.run_id,
workflow_id=start.parent_workflow_info.workflow_id,
namespace=init.parent_workflow_info.namespace,
run_id=init.parent_workflow_info.run_id,
workflow_id=init.parent_workflow_info.workflow_id,
)
info = temporalio.workflow.Info(
attempt=start.attempt,
continued_run_id=start.continued_from_execution_run_id or None,
cron_schedule=start.cron_schedule or None,
execution_timeout=start.workflow_execution_timeout.ToTimedelta()
if start.HasField("workflow_execution_timeout")
attempt=init.attempt,
continued_run_id=init.continued_from_execution_run_id or None,
cron_schedule=init.cron_schedule or None,
execution_timeout=init.workflow_execution_timeout.ToTimedelta()
if init.HasField("workflow_execution_timeout")
else None,
headers=dict(start.headers),
headers=dict(init.headers),
namespace=self._namespace,
parent=parent,
raw_memo=dict(start.memo.fields),
retry_policy=temporalio.common.RetryPolicy.from_proto(start.retry_policy)
if start.HasField("retry_policy")
raw_memo=dict(init.memo.fields),
retry_policy=temporalio.common.RetryPolicy.from_proto(init.retry_policy)
if init.HasField("retry_policy")
else None,
run_id=act.run_id,
run_timeout=start.workflow_run_timeout.ToTimedelta()
if start.HasField("workflow_run_timeout")
run_timeout=init.workflow_run_timeout.ToTimedelta()
if init.HasField("workflow_run_timeout")
else None,
search_attributes=temporalio.converter.decode_search_attributes(
start.search_attributes
init.search_attributes
),
start_time=act.timestamp.ToDatetime().replace(tzinfo=timezone.utc),
task_queue=self._task_queue,
task_timeout=start.workflow_task_timeout.ToTimedelta(),
task_timeout=init.workflow_task_timeout.ToTimedelta(),
typed_search_attributes=temporalio.converter.decode_typed_search_attributes(
start.search_attributes
init.search_attributes
),
workflow_id=start.workflow_id,
workflow_type=start.workflow_type,
workflow_id=init.workflow_id,
workflow_type=init.workflow_type,
)

# Create instance from details
Expand All @@ -411,7 +411,7 @@ def _create_workflow_instance(
interceptor_classes=self._interceptor_classes,
defn=defn,
info=info,
randomness_seed=start.randomness_seed,
randomness_seed=init.randomness_seed,
extern_functions=self._extern_functions,
disable_eager_activity_execution=self._disable_eager_activity_execution,
worker_level_failure_exception_types=self._workflow_failure_exception_types,
Expand Down
20 changes: 10 additions & 10 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ def activate(
elif job.HasField("signal_workflow") or job.HasField("do_update"):
job_sets[1].append(job)
elif not job.HasField("query_workflow"):
if job.HasField("start_workflow"):
start_job = job.start_workflow
if job.HasField("initialize_workflow"):
start_job = job.initialize_workflow
job_sets[2].append(job)
else:
job_sets[3].append(job)
Expand Down Expand Up @@ -477,8 +477,8 @@ def _apply(
)
elif job.HasField("signal_workflow"):
self._apply_signal_workflow(job.signal_workflow)
elif job.HasField("start_workflow"):
self._apply_start_workflow(job.start_workflow)
elif job.HasField("initialize_workflow"):
self._apply_initialize_workflow(job.initialize_workflow)
elif job.HasField("update_random_seed"):
self._apply_update_random_seed(job.update_random_seed)
else:
Expand Down Expand Up @@ -847,8 +847,8 @@ def _apply_signal_workflow(
return
self._process_signal_job(signal_defn, job)

def _apply_start_workflow(
self, job: temporalio.bridge.proto.workflow_activation.StartWorkflow
def _apply_initialize_workflow(
self, job: temporalio.bridge.proto.workflow_activation.InitializeWorkflow
) -> None:
# Async call to run on the scheduler thread. This will be wrapped in
# another function which applies exception handling.
Expand Down Expand Up @@ -886,14 +886,14 @@ def _apply_update_random_seed(
self._random.seed(job.randomness_seed)

def _make_workflow_input(
self, start_job: temporalio.bridge.proto.workflow_activation.StartWorkflow
self, init_job: temporalio.bridge.proto.workflow_activation.InitializeWorkflow
) -> ExecuteWorkflowInput:
# Set arg types, using raw values for dynamic
arg_types = self._defn.arg_types
if not self._defn.name:
# Dynamic is just the raw value for each input value
arg_types = [temporalio.common.RawValue] * len(start_job.arguments)
args = self._convert_payloads(start_job.arguments, arg_types)
arg_types = [temporalio.common.RawValue] * len(init_job.arguments)
args = self._convert_payloads(init_job.arguments, arg_types)
# Put args in a list if dynamic
if not self._defn.name:
args = [args]
Expand All @@ -903,7 +903,7 @@ def _make_workflow_input(
# TODO(cretz): Remove cast when https://github.com/python/mypy/issues/5485 fixed
run_fn=cast(Callable[..., Awaitable[Any]], self._defn.run_fn),
args=args,
headers=start_job.headers,
headers=init_job.headers,
)

#### _Runtime direct workflow call overrides ####
Expand Down
4 changes: 3 additions & 1 deletion tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,7 +1553,9 @@ async def test_workflow_with_custom_runner(client: Client):
)
assert result == "Hello, Temporal!"
# Confirm first activation and last completion
assert runner._pairs[0][0].jobs[0].start_workflow.workflow_type == "HelloWorkflow"
assert (
runner._pairs[0][0].jobs[0].initialize_workflow.workflow_type == "HelloWorkflow"
)
assert (
runner._pairs[-1][-1]
.successful.commands[0]
Expand Down

0 comments on commit 1331902

Please sign in to comment.