diff --git a/tfx/orchestration/experimental/core/async_pipeline_task_gen.py b/tfx/orchestration/experimental/core/async_pipeline_task_gen.py index 93add1de1b..78c76187b2 100644 --- a/tfx/orchestration/experimental/core/async_pipeline_task_gen.py +++ b/tfx/orchestration/experimental/core/async_pipeline_task_gen.py @@ -71,6 +71,8 @@ def __init__(self, mlmd_handle: metadata.Metadata, self._pipeline = pipeline self._is_task_id_tracked_fn = is_task_id_tracked_fn self._service_job_manager = service_job_manager + # TODO(b/201294315): Remove once the underlying issue is fixed. + self._generate_invoked = False def generate(self) -> List[task_lib.Task]: """Generates tasks for all executable nodes in the async pipeline. @@ -80,7 +82,17 @@ def generate(self) -> List[task_lib.Task]: Returns: A `list` of tasks to execute. + + Raises: + RuntimeError: If `generate` invoked more than once on the same instance. """ + # TODO(b/201294315): Remove this artificial restriction once the underlying + # issue is fixed. + if self._generate_invoked: + raise RuntimeError( + 'Invoking `generate` more than once on the same instance of ' + 'AsyncPipelineTaskGenerator is restricted due to a bug.') + self._generate_invoked = True result = [] for node in [n.pipeline_node for n in self._pipeline.nodes]: node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node) diff --git a/tfx/orchestration/experimental/core/pipeline_ops.py b/tfx/orchestration/experimental/core/pipeline_ops.py index 4a66f942be..017df1f93a 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops.py +++ b/tfx/orchestration/experimental/core/pipeline_ops.py @@ -576,9 +576,10 @@ def _filter_by_state(node_infos: List[_NodeInfo], # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: + # TODO(b/200618482): Remove fail_fast=True. generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, - service_job_manager) + service_job_manager, fail_fast=True) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, diff --git a/tfx/orchestration/experimental/core/pipeline_state.py b/tfx/orchestration/experimental/core/pipeline_state.py index 9ca3c3e700..91e63d17e7 100644 --- a/tfx/orchestration/experimental/core/pipeline_state.py +++ b/tfx/orchestration/experimental/core/pipeline_state.py @@ -77,7 +77,7 @@ class NodeState(json_utils.Jsonable): STARTING = 'starting' # Pending work before state can change to STARTED. STARTED = 'started' # Node is ready for execution. STOPPING = 'stopping' # Pending work before state can change to STOPPED. - STOPPED = 'stopped' # Node execution is stoped. + STOPPED = 'stopped' # Node execution is stopped. RUNNING = 'running' # Node is under active execution (i.e. triggered). COMPLETE = 'complete' # Node execution completed successfully. SKIPPED = 'skipped' # Node execution skipped due to conditional. @@ -121,6 +121,20 @@ def is_stoppable(self) -> bool: return self.state in set( [self.STARTING, self.STARTED, self.RUNNING, self.PAUSED]) + def is_success(self) -> bool: + return is_node_state_success(self.state) + + def is_failure(self) -> bool: + return is_node_state_failure(self.state) + + +def is_node_state_success(state: str) -> bool: + return state in (NodeState.COMPLETE, NodeState.SKIPPED) + + +def is_node_state_failure(state: str) -> bool: + return state == NodeState.FAILED + _NODE_STATE_TO_RUN_STATE_MAP = { NodeState.STARTING: run_state_pb2.RunState.UNKNOWN, @@ -174,7 +188,7 @@ class PipelineState: mlmd_handle: Handle to MLMD db. pipeline: The pipeline proto associated with this `PipelineState` object. TODO(b/201294315): Fix self.pipeline going out of sync with the actual - pipeline proto stored in the underlying MLMD execution in some cases. + pipeline proto stored in the underlying MLMD execution in some cases. execution_id: Id of the underlying execution in MLMD. pipeline_uid: Unique id of the pipeline. """ @@ -437,6 +451,14 @@ def get_node_state(self, node_uid: task_lib.NodeUid) -> NodeState: node_states_dict = _get_node_states_dict(self._execution) return node_states_dict.get(node_uid.node_id, NodeState()) + def get_node_states_dict(self) -> Dict[task_lib.NodeUid, NodeState]: + self._check_context() + result = {} + for node in get_all_pipeline_nodes(self.pipeline): + node_uid = task_lib.NodeUid.from_pipeline_node(self.pipeline, node) + result[node_uid] = self.get_node_state(node_uid) + return result + def get_pipeline_execution_state(self) -> metadata_store_pb2.Execution.State: """Returns state of underlying pipeline execution.""" self._check_context() diff --git a/tfx/orchestration/experimental/core/pipeline_state_test.py b/tfx/orchestration/experimental/core/pipeline_state_test.py index 582c8b8df9..cf49c7780e 100644 --- a/tfx/orchestration/experimental/core/pipeline_state_test.py +++ b/tfx/orchestration/experimental/core/pipeline_state_test.py @@ -291,6 +291,43 @@ def test_initiate_node_start_stop(self): node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STARTED, node_state.state) + def test_get_node_states_dict(self): + with self._mlmd_connection as m: + pipeline = pipeline_pb2.Pipeline() + pipeline.pipeline_info.id = 'pipeline1' + pipeline.execution_mode = pipeline_pb2.Pipeline.SYNC + pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) + pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' + pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' + pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' + pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' + eg_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen') + transform_node_uid = task_lib.NodeUid(pipeline_uid, 'Transform') + trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer') + evaluator_node_uid = task_lib.NodeUid(pipeline_uid, 'Evaluator') + with pstate.PipelineState.new(m, pipeline) as pipeline_state: + with pipeline_state.node_state_update_context( + eg_node_uid) as node_state: + node_state.update(pstate.NodeState.COMPLETE) + with pipeline_state.node_state_update_context( + transform_node_uid) as node_state: + node_state.update(pstate.NodeState.RUNNING) + with pipeline_state.node_state_update_context( + trainer_node_uid) as node_state: + node_state.update(pstate.NodeState.STARTING) + with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: + self.assertEqual( + { + eg_node_uid: + pstate.NodeState(state=pstate.NodeState.COMPLETE), + transform_node_uid: + pstate.NodeState(state=pstate.NodeState.RUNNING), + trainer_node_uid: + pstate.NodeState(state=pstate.NodeState.STARTING), + evaluator_node_uid: + pstate.NodeState(state=pstate.NodeState.STARTED), + }, pipeline_state.get_node_states_dict()) + def test_save_and_remove_property(self): property_key = 'key' property_value = 'value' diff --git a/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py b/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py index 77e23041cf..5c710b93af 100644 --- a/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py +++ b/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py @@ -13,10 +13,11 @@ # limitations under the License. """TaskGenerator implementation for sync pipelines.""" -from typing import Callable, Hashable, List, Optional, Sequence, Set +import collections +import typing +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Set from absl import logging -import cachetools from tfx.orchestration import data_types_utils from tfx.orchestration import metadata from tfx.orchestration.experimental.core import constants @@ -36,9 +37,6 @@ from google.protobuf import any_pb2 from ml_metadata.proto import metadata_store_pb2 -# Caches successful and skipped nodes so we don't have to query MLMD repeatedly. -_successful_nodes_cache = cachetools.LRUCache(maxsize=1024) - class SyncPipelineTaskGenerator(task_gen.TaskGenerator): """Task generator for executing a sync pipeline. @@ -49,10 +47,12 @@ class SyncPipelineTaskGenerator(task_gen.TaskGenerator): where the instances refer to the same MLMD db and the same pipeline IR. """ - def __init__(self, mlmd_handle: metadata.Metadata, + def __init__(self, + mlmd_handle: metadata.Metadata, pipeline_state: pstate.PipelineState, is_task_id_tracked_fn: Callable[[task_lib.TaskId], bool], - service_job_manager: service_jobs.ServiceJobManager): + service_job_manager: service_jobs.ServiceJobManager, + fail_fast: bool = False): """Constructs `SyncPipelineTaskGenerator`. Args: @@ -61,6 +61,9 @@ def __init__(self, mlmd_handle: metadata.Metadata, is_task_id_tracked_fn: A callable that returns `True` if a task_id is tracked by the task queue. service_job_manager: Used for handling service nodes in the pipeline. + fail_fast: If `True`, pipeline run is aborted immediately if any node + fails. If `False`, pipeline run is only aborted when no further + progress can be made due to node failures. """ self._mlmd_handle = mlmd_handle pipeline = pipeline_state.pipeline @@ -76,12 +79,17 @@ def __init__(self, mlmd_handle: metadata.Metadata, 'All sync pipeline nodes should be of type `PipelineNode`; found: ' '`{}`'.format(which_node)) self._pipeline_state = pipeline_state + with self._pipeline_state: + self._node_states_dict = self._pipeline_state.get_node_states_dict() self._pipeline_uid = self._pipeline_state.pipeline_uid self._pipeline = pipeline self._pipeline_run_id = ( pipeline.runtime_spec.pipeline_run_id.field_value.string_value) self._is_task_id_tracked_fn = is_task_id_tracked_fn self._service_job_manager = service_job_manager + self._fail_fast = fail_fast + # TODO(b/201294315): Remove once the underlying issue is fixed. + self._generate_invoked = False def generate(self) -> List[task_lib.Task]: """Generates tasks for executing the next executable nodes in the pipeline. @@ -91,24 +99,50 @@ def generate(self) -> List[task_lib.Task]: Returns: A `list` of tasks to execute. + + Raises: + RuntimeError: If `generate` invoked more than once on the same instance. """ + # TODO(b/201294315): Remove this artificial restriction once the underlying + # issue is fixed. + if self._generate_invoked: + raise RuntimeError( + 'Invoking `generate` more than once on the same instance of ' + 'SyncPipelineTaskGenerator is restricted due to a bug.') + self._generate_invoked = True layers = _topsorted_layers(self._pipeline) terminal_node_ids = _terminal_node_ids(layers) exec_node_tasks = [] update_node_state_tasks = [] successful_node_ids = set() + failed_nodes_dict: Dict[str, status_lib.Status] = {} finalize_pipeline_task = None for layer_nodes in layers: for node in layer_nodes: - tasks = self._generate_tasks_for_node(node, successful_node_ids) + node_id = node.node_info.id + node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node) + node_state = self._node_states_dict[node_uid] + if node_state.is_success(): + successful_node_ids.add(node_id) + continue + if node_state.is_failure(): + failed_nodes_dict[node_id] = node_state.status + continue + if not self._upstream_nodes_successful(node, successful_node_ids): + continue + tasks = self._generate_tasks_for_node(node) for task in tasks: if task_lib.is_update_node_state_task(task): + task = typing.cast(task_lib.UpdateNodeStateTask, task) + if pstate.is_node_state_success(task.state): + successful_node_ids.add(node_id) + elif pstate.is_node_state_failure(task.state): + failed_nodes_dict[node_id] = task.status + if self._fail_fast: + finalize_pipeline_task = self._abort_task(task.status.message) update_node_state_tasks.append(task) elif task_lib.is_exec_node_task(task): exec_node_tasks.append(task) - else: - assert task_lib.is_finalize_pipeline_task(task) - finalize_pipeline_task = task if finalize_pipeline_task: break @@ -116,9 +150,21 @@ def generate(self) -> List[task_lib.Task]: if finalize_pipeline_task: break - layer_node_ids = set(node.node_info.id for node in layer_nodes) - successful_layer_node_ids = layer_node_ids & successful_node_ids - self._update_successful_nodes_cache(successful_layer_node_ids) + if not self._fail_fast and failed_nodes_dict: + assert not finalize_pipeline_task + node_by_id = _node_by_id(self._pipeline) + # Collect nodes that cannot be run because they have a failed ancestor. + unrunnable_node_ids = set() + for node_id in failed_nodes_dict: + unrunnable_node_ids |= _descendants(node_by_id, node_id) + # Nodes that are still runnable have neither succeeded nor failed, and + # don't have a failed ancestor. + runnable_node_ids = node_by_id.keys() - ( + unrunnable_node_ids | successful_node_ids | failed_nodes_dict.keys()) + # If there are no runnable nodes, we can abort the pipeline. + if not runnable_node_ids: + finalize_pipeline_task = self._abort_task( + f'Cannot make progress due to node failures: {failed_nodes_dict}') result = update_node_state_tasks if finalize_pipeline_task: @@ -134,28 +180,19 @@ def generate(self) -> List[task_lib.Task]: return result def _generate_tasks_for_node( - self, node: pipeline_pb2.PipelineNode, - successful_node_ids: Set[str]) -> List[task_lib.Task]: + self, node: pipeline_pb2.PipelineNode) -> List[task_lib.Task]: """Generates list of tasks for the given node.""" node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node) node_id = node.node_info.id result = [] - if self._in_successful_nodes_cache(node_uid): - successful_node_ids.add(node_id) - return result - - if not self._upstream_nodes_successful(node, successful_node_ids): + node_state = self._node_states_dict[node_uid] + if node_state.state in (pstate.NodeState.STOPPING, + pstate.NodeState.STOPPED): + logging.info('Ignoring node in state \'%s\' for task generation: %s', + node_state.state, node_uid) return result - with self._pipeline_state: - node_state = self._pipeline_state.get_node_state(node_uid) - if node_state.state in (pstate.NodeState.STOPPING, - pstate.NodeState.STOPPED): - logging.info('Ignoring node in state \'%s\' for task generation: %s', - node_state.state, node_uid) - return result - # If this is a pure service node, there is no ExecNodeTask to generate # but we ensure node services and check service status. service_status = self._ensure_node_services_if_pure(node_id) @@ -168,13 +205,11 @@ def _generate_tasks_for_node( state=pstate.NodeState.FAILED, status=status_lib.Status( code=status_lib.Code.ABORTED, message=error_msg))) - result.append(self._abort_task(error_msg)) elif service_status == service_jobs.ServiceStatus.SUCCESS: logging.info('Service node successful: %s', node_uid) result.append( task_lib.UpdateNodeStateTask( node_uid=node_uid, state=pstate.NodeState.COMPLETE)) - successful_node_ids.add(node_id) elif service_status == service_jobs.ServiceStatus.RUNNING: result.append( task_lib.UpdateNodeStateTask( @@ -195,7 +230,6 @@ def _generate_tasks_for_node( state=pstate.NodeState.FAILED, status=status_lib.Status( code=status_lib.Code.ABORTED, message=error_msg))) - result.append(self._abort_task(error_msg)) return result node_executions = task_gen_utils.get_executions(self._mlmd_handle, node) @@ -208,7 +242,6 @@ def _generate_tasks_for_node( result.append( task_lib.UpdateNodeStateTask( node_uid=node_uid, state=pstate.NodeState.COMPLETE)) - successful_node_ids.add(node_id) return result # If the latest execution failed or cancelled, the pipeline should be @@ -221,15 +254,13 @@ def _generate_tasks_for_node( constants.EXECUTION_ERROR_MSG_KEY) error_msg = data_types_utils.get_metadata_value( error_msg_value) if error_msg_value else '' + error_msg = f'node failed; node uid: {node_uid}; error: {error_msg}' result.append( task_lib.UpdateNodeStateTask( node_uid=node_uid, state=pstate.NodeState.FAILED, status=status_lib.Status( code=status_lib.Code.ABORTED, message=error_msg))) - result.append( - self._abort_task( - f'node failed; node uid: {node_uid}; error: {error_msg}')) return result exec_node_task = task_gen_utils.generate_task_from_active_execution( @@ -243,14 +274,13 @@ def _generate_tasks_for_node( # Finally, we are ready to generate tasks for the node by resolving inputs. result.extend( - self._resolve_inputs_and_generate_tasks_for_node( - node, node_executions, successful_node_ids)) + self._resolve_inputs_and_generate_tasks_for_node(node, node_executions)) return result def _resolve_inputs_and_generate_tasks_for_node( self, node: pipeline_pb2.PipelineNode, - node_executions: Sequence[metadata_store_pb2.Execution], - successful_node_ids: Set[str]) -> List[task_lib.Task]: + node_executions: Sequence[metadata_store_pb2.Execution] + ) -> List[task_lib.Task]: """Generates tasks for a node by freshly resolving inputs.""" result = [] node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node) @@ -259,9 +289,7 @@ def _resolve_inputs_and_generate_tasks_for_node( if resolved_info is None: result.append( task_lib.UpdateNodeStateTask( - node_uid=node_uid, - state=pstate.NodeState.SKIPPED)) - successful_node_ids.add(node.node_info.id) + node_uid=node_uid, state=pstate.NodeState.SKIPPED)) return result if resolved_info.input_artifacts is None: error_msg = f'failure to resolve inputs; node uid: {node_uid}' @@ -271,7 +299,6 @@ def _resolve_inputs_and_generate_tasks_for_node( state=pstate.NodeState.FAILED, status=status_lib.Status( code=status_lib.Code.ABORTED, message=error_msg))) - result.append(self._abort_task(error_msg)) return result execution = execution_publish_utils.register_execution( @@ -308,7 +335,6 @@ def _resolve_inputs_and_generate_tasks_for_node( contexts=contexts, execution_id=execution.id, output_artifacts=cached_outputs) - successful_node_ids.add(node.node_info.id) pstate.record_state_change_time() result.append( task_lib.UpdateNodeStateTask( @@ -326,7 +352,6 @@ def _resolve_inputs_and_generate_tasks_for_node( state=pstate.NodeState.FAILED, status=status_lib.Status( code=status_lib.Code.ABORTED, message=error_msg))) - result.append(self._abort_task(error_msg)) return result outputs_utils.make_output_dirs(output_artifacts) @@ -381,18 +406,6 @@ def _abort_task(self, error_msg: str) -> task_lib.FinalizePipelineTask: status=status_lib.Status( code=status_lib.Code.ABORTED, message=error_msg)) - def _update_successful_nodes_cache(self, node_ids: Set[str]) -> None: - for node_id in node_ids: - node_uid = task_lib.NodeUid( - pipeline_uid=self._pipeline_uid, node_id=node_id) - _successful_nodes_cache[self._node_cache_key(node_uid)] = True - - def _in_successful_nodes_cache(self, node_uid) -> bool: - return _successful_nodes_cache.get(self._node_cache_key(node_uid), False) - - def _node_cache_key(self, node_uid: task_lib.NodeUid) -> Hashable: - return (self._pipeline_run_id, node_uid) - # TODO(b/182944474): Raise error in _get_executor_spec if executor spec is # missing for a non-system node. @@ -410,10 +423,7 @@ def _get_executor_spec(pipeline: pipeline_pb2.Pipeline, def _topsorted_layers( pipeline: pipeline_pb2.Pipeline) -> List[List[pipeline_pb2.PipelineNode]]: """Returns pipeline nodes in topologically sorted layers.""" - node_by_id = { - node.pipeline_node.node_info.id: node.pipeline_node - for node in pipeline.nodes - } + node_by_id = _node_by_id(pipeline) return topsort.topsorted_layers( [node.pipeline_node for node in pipeline.nodes], get_node_id_fn=lambda node: node.node_info.id, @@ -432,3 +442,25 @@ def _terminal_node_ids( if not node.downstream_nodes: terminal_node_ids.add(node.node_info.id) return terminal_node_ids + + +def _node_by_id( + pipeline: pipeline_pb2.Pipeline) -> Dict[str, pipeline_pb2.PipelineNode]: + return { + node.pipeline_node.node_info.id: node.pipeline_node + for node in pipeline.nodes + } + + +def _descendants(node_by_id: Mapping[str, pipeline_pb2.PipelineNode], + node_id: str) -> Set[str]: + """Returns node_ids of all descendants of the given node_id.""" + queue = collections.deque() + queue.extend(node_by_id[node_id].downstream_nodes) + result = set() + while queue: + q_node_id = queue.popleft() + if q_node_id not in result: + queue.extend(node_by_id[q_node_id].downstream_nodes) + result.add(q_node_id) + return result diff --git a/tfx/orchestration/experimental/core/sync_pipeline_task_gen_test.py b/tfx/orchestration/experimental/core/sync_pipeline_task_gen_test.py index 4de918eadd..f0f74d43a8 100644 --- a/tfx/orchestration/experimental/core/sync_pipeline_task_gen_test.py +++ b/tfx/orchestration/experimental/core/sync_pipeline_task_gen_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for tfx.orchestration.experimental.core.sync_pipeline_task_gen.""" +import itertools import os import uuid @@ -123,7 +124,10 @@ def _finish_node_execution(self, artifact_custom_properties=artifact_custom_properties) self._finish_processing(use_task_queue, exec_node_task) - def _generate(self, use_task_queue, ignore_update_node_state_tasks=False): + def _generate(self, + use_task_queue, + ignore_update_node_state_tasks=False, + fail_fast=False): return test_utils.run_generator( self._mlmd_connection, sptg.SyncPipelineTaskGenerator, @@ -131,13 +135,15 @@ def _generate(self, use_task_queue, ignore_update_node_state_tasks=False): self._task_queue, use_task_queue, self._mock_service_job_manager, - ignore_update_node_state_tasks=ignore_update_node_state_tasks) + ignore_update_node_state_tasks=ignore_update_node_state_tasks, + fail_fast=fail_fast) def _run_next(self, use_task_queue, expect_nodes, finish_nodes=None, - artifact_custom_properties=None): + artifact_custom_properties=None, + fail_fast=False): """Runs a complete cycle of task generation and simulating their completion. Args: @@ -147,8 +153,9 @@ def _run_next(self, `None` (default), all of `expect_nodes` will be finished. artifact_custom_properties: A dict of custom properties to attach to the output artifacts. + fail_fast: If `True`, pipeline is aborted immediately if any node fails. """ - tasks = self._generate(use_task_queue, True) + tasks = self._generate(use_task_queue, True, fail_fast=fail_fast) for task in tasks: self.assertTrue(task_lib.is_exec_node_task(task)) expected_node_ids = [n.node_info.id for n in expect_nodes] @@ -171,7 +178,8 @@ def _generate_and_test(self, num_active_executions, pipeline=None, expected_exec_nodes=None, - ignore_update_node_state_tasks=False): + ignore_update_node_state_tasks=False, + fail_fast=False): """Generates tasks and tests the effects.""" return test_utils.run_generator_and_test( self, @@ -186,7 +194,8 @@ def _generate_and_test(self, num_new_executions=num_new_executions, num_active_executions=num_active_executions, expected_exec_nodes=expected_exec_nodes, - ignore_update_node_state_tasks=ignore_update_node_state_tasks) + ignore_update_node_state_tasks=ignore_update_node_state_tasks, + fail_fast=fail_fast) @parameterized.parameters(False, True) def test_tasks_generated_when_upstream_done(self, use_task_queue): @@ -312,8 +321,9 @@ def test_tasks_generated_when_upstream_done(self, use_task_queue): if use_task_queue: self.assertTrue(self._task_queue.is_empty()) - @parameterized.parameters(False, True) - def test_finalize_pipeline_after_terminal_nodes_success(self, use_task_queue): + @parameterized.parameters(itertools.product((False, True), repeat=2)) + def test_pipeline_succeeds_when_terminal_nodes_succeed( + self, use_task_queue, fail_fast): """Tests that pipeline is finalized only after terminal nodes are successful. Args: @@ -321,6 +331,7 @@ def test_finalize_pipeline_after_terminal_nodes_success(self, use_task_queue): a task with the same task_id does not already exist in the queue. `use_task_queue=False` is useful to test the case of task generation when task queue is empty (for eg: due to orchestrator restart). + fail_fast: If `True`, pipeline is aborted immediately if any node fails. """ # Check the expected terminal nodes. layers = sptg._topsorted_layers(self._pipeline) @@ -343,8 +354,8 @@ def test_finalize_pipeline_after_terminal_nodes_success(self, use_task_queue): self._run_next(use_task_queue, expect_nodes=[self._schema_gen]) # Both example-validator and transform are ready to execute. - [example_validator_task, - transform_task] = self._generate(use_task_queue, True) + [example_validator_task, transform_task] = self._generate( + use_task_queue, True, fail_fast=fail_fast) self.assertEqual(self._example_validator.node_info.id, example_validator_task.node_uid.node_id) self.assertEqual(self._transform.node_info.id, @@ -362,27 +373,32 @@ def test_finalize_pipeline_after_terminal_nodes_success(self, use_task_queue): use_task_queue, expect_nodes=[self._trainer] if use_task_queue else [self._example_validator, self._trainer], - finish_nodes=[self._trainer]) + finish_nodes=[self._trainer], + fail_fast=fail_fast) self._run_next( use_task_queue, expect_nodes=[self._chore_a] if use_task_queue else [self._example_validator, self._chore_a], - finish_nodes=[self._chore_a]) + finish_nodes=[self._chore_a], + fail_fast=fail_fast) self._run_next( use_task_queue, expect_nodes=[self._chore_b] if use_task_queue else [self._example_validator, self._chore_b], - finish_nodes=[self._chore_b]) + finish_nodes=[self._chore_b], + fail_fast=fail_fast) self._run_next( use_task_queue, expect_nodes=[] if use_task_queue else [self._example_validator], - finish_nodes=[]) + finish_nodes=[], + fail_fast=fail_fast) # FinalizePipelineTask is generated only after example-validator finishes. test_utils.fake_execute_node(self._mlmd_connection, example_validator_task) self._finish_processing(use_task_queue, example_validator_task) - [finalize_task] = self._generate(use_task_queue, True) + [finalize_task] = self._generate(use_task_queue, True, fail_fast=fail_fast) self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task)) + self.assertEqual(status_lib.Code.OK, finalize_task.status.code) def test_service_job_running(self): """Tests task generation when example-gen service job is still running.""" @@ -428,7 +444,8 @@ def test_service_job_success(self): self.assertEqual(pstate.NodeState.RUNNING, sg_update_node_state_task.state) self.assertTrue(task_lib.is_exec_node_task(sg_exec_node_task)) - def test_service_job_failed(self): + @parameterized.parameters(False, True) + def test_service_job_failed(self, fail_fast): """Tests task generation when example-gen service job fails.""" def _ensure_node_services(unused_pipeline_state, node_id): @@ -442,7 +459,8 @@ def _ensure_node_services(unused_pipeline_state, node_id): num_initial_executions=0, num_tasks_generated=2, num_new_executions=0, - num_active_executions=0) + num_active_executions=0, + fail_fast=fail_fast) self.assertTrue(task_lib.is_update_node_state_task(update_node_state_task)) self.assertEqual('my_example_gen', update_node_state_task.node_uid.node_id) self.assertEqual(pstate.NodeState.FAILED, update_node_state_task.state) @@ -490,7 +508,8 @@ def test_node_success(self): schema_gen_update_node_state_task.state) self.assertTrue(task_lib.is_exec_node_task(schema_gen_exec_node_task)) - def test_node_failed(self): + @parameterized.parameters(False, True) + def test_node_failed(self, fail_fast): """Tests task generation when a node registers a failed execution.""" test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) @@ -501,7 +520,8 @@ def test_node_failed(self): num_tasks_generated=1, num_new_executions=1, num_active_executions=1, - ignore_update_node_state_tasks=True) + ignore_update_node_state_tasks=True, + fail_fast=fail_fast) self.assertEqual( task_lib.NodeUid.from_pipeline_node(self._pipeline, self._stats_gen), stats_gen_task.node_uid) @@ -520,7 +540,8 @@ def test_node_failed(self): num_initial_executions=2, num_tasks_generated=2, num_new_executions=0, - num_active_executions=0) + num_active_executions=0, + fail_fast=fail_fast) self.assertTrue(task_lib.is_update_node_state_task(update_node_state_task)) self.assertEqual('my_statistics_gen', update_node_state_task.node_uid.node_id) @@ -762,6 +783,52 @@ def test_conditional_execution(self, evaluate): [finalize_task] = self._generate(False, True) self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task)) + @parameterized.parameters(False, True) + def test_pipeline_failure_strategies(self, fail_fast): + """Tests pipeline failure strategies.""" + test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, + 1) + + self._run_next(False, expect_nodes=[self._stats_gen], fail_fast=fail_fast) + self._run_next(False, expect_nodes=[self._schema_gen], fail_fast=fail_fast) + + # Both example-validator and transform are ready to execute. + [example_validator_task, transform_task] = self._generate( + False, True, fail_fast=fail_fast) + self.assertEqual(self._example_validator.node_info.id, + example_validator_task.node_uid.node_id) + self.assertEqual(self._transform.node_info.id, + transform_task.node_uid.node_id) + + # Simulate Transform success. + self._finish_node_execution(False, transform_task) + + # But fail example-validator. + with self._mlmd_connection as m: + with mlmd_state.mlmd_execution_atomic_op( + m, example_validator_task.execution_id) as ev_exec: + # Fail stats-gen execution. + ev_exec.last_known_state = metadata_store_pb2.Execution.FAILED + data_types_utils.set_metadata_value( + ev_exec.custom_properties[constants.EXECUTION_ERROR_MSG_KEY], + 'example-validator error') + + if fail_fast: + # Pipeline run should immediately fail because example-validator failed. + [finalize_task] = self._generate(False, True, fail_fast=fail_fast) + self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task)) + self.assertEqual(status_lib.Code.ABORTED, finalize_task.status.code) + else: + # Trainer and downstream nodes can execute as transform has finished. + # example-validator failure does not impact them as it is not upstream. + # Pipeline run will still fail but when no more progress can be made. + self._run_next(False, expect_nodes=[self._trainer], fail_fast=fail_fast) + self._run_next(False, expect_nodes=[self._chore_a], fail_fast=fail_fast) + self._run_next(False, expect_nodes=[self._chore_b], fail_fast=fail_fast) + [finalize_task] = self._generate(False, True, fail_fast=fail_fast) + self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task)) + self.assertEqual(status_lib.Code.ABORTED, finalize_task.status.code) + if __name__ == '__main__': tf.test.main() diff --git a/tfx/orchestration/experimental/core/test_utils.py b/tfx/orchestration/experimental/core/test_utils.py index 002983350c..e47c4672da 100644 --- a/tfx/orchestration/experimental/core/test_utils.py +++ b/tfx/orchestration/experimental/core/test_utils.py @@ -176,7 +176,8 @@ def run_generator(mlmd_connection, task_queue, use_task_queue, service_job_manager, - ignore_update_node_state_tasks=False): + ignore_update_node_state_tasks=False, + fail_fast=None): """Generates tasks for testing.""" with mlmd_connection as m: pipeline_state = get_or_create_pipeline_state(m, pipeline) @@ -185,12 +186,20 @@ def run_generator(mlmd_connection, pipeline_state=pipeline_state, is_task_id_tracked_fn=task_queue.contains_task_id, service_job_manager=service_job_manager) + if fail_fast is not None: + generator_params['fail_fast'] = fail_fast task_gen = generator_class(**generator_params) tasks = task_gen.generate() if use_task_queue: for task in tasks: if task_lib.is_exec_node_task(task): task_queue.enqueue(task) + for task in tasks: + if task_lib.is_update_node_state_task(task): + with pipeline_state: + with pipeline_state.node_state_update_context( + task.node_uid) as node_state: + node_state.update(task.state, task.status) if ignore_update_node_state_tasks: tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)] return tasks @@ -231,7 +240,8 @@ def run_generator_and_test(test_case, num_new_executions, num_active_executions, expected_exec_nodes=None, - ignore_update_node_state_tasks=False): + ignore_update_node_state_tasks=False, + fail_fast=None): """Runs generator.generate() and tests the effects.""" if service_job_manager is None: service_job_manager = service_jobs.DummyServiceJobManager() @@ -247,7 +257,8 @@ def run_generator_and_test(test_case, task_queue, use_task_queue, service_job_manager, - ignore_update_node_state_tasks=ignore_update_node_state_tasks) + ignore_update_node_state_tasks=ignore_update_node_state_tasks, + fail_fast=fail_fast) with mlmd_connection as m: test_case.assertLen( tasks, num_tasks_generated,