Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing Trainer support for multiple artifact input with LatestArtifactResolver ResolverNode #2174

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions tfx/components/trainer/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

# Standard Imports
import mock
import copy
import tensorflow as tf

from google.protobuf import json_format
Expand All @@ -47,20 +48,27 @@ def setUp(self):
self._testMethodName)

# Create input dict.
examples = standard_artifacts.Examples()
examples.uri = os.path.join(self._source_data_dir,
'transform/transformed_examples')
examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])
e1 = standard_artifacts.Examples()
e1.uri = os.path.join(self._source_data_dir,
'transform/transformed_examples')
e1.split_names = artifact_utils.encode_split_names(['train', 'eval'])

e2 = copy.deepcopy(e1)

self._single_artifact = [e1]
self._multiple_artifacts = [e1, e2]

transform_output = standard_artifacts.TransformGraph()
transform_output.uri = os.path.join(self._source_data_dir,
'transform/transform_graph')

schema = standard_artifacts.Schema()
schema.uri = os.path.join(self._source_data_dir, 'schema_gen')
previous_model = standard_artifacts.Model()
previous_model.uri = os.path.join(self._source_data_dir, 'trainer/previous')

self._input_dict = {
constants.EXAMPLES_KEY: [examples],
constants.EXAMPLES_KEY: self._single_artifact,
constants.TRANSFORM_GRAPH_KEY: [transform_output],
constants.SCHEMA_KEY: [schema],
constants.BASE_MODEL_KEY: [previous_model]
Expand Down Expand Up @@ -186,6 +194,12 @@ def testDoWithHyperParameters(self):
self._verify_model_exports()
self._verify_model_run_exports()

def testMultipleArtifacts(self):
self._input_dict[constants.EXAMPLES_KEY] = self._multiple_artifacts
self._exec_properties['module_file'] = self._module_file
self._do(self._generic_trainer_executor)
self._verify_model_exports()
self._verify_model_run_exports()

if __name__ == '__main__':
tf.test.main()
16 changes: 9 additions & 7 deletions tfx/components/trainer/fn_args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function

import absl

from typing import Any, Dict, List, Text, NamedTuple

from google.protobuf import json_format
Expand Down Expand Up @@ -57,16 +59,16 @@
def get_common_fn_args(input_dict: Dict[Text, List[types.Artifact]],
exec_properties: Dict[Text, Any],
working_dir: Text = None) -> FnArgs:
"""Get common args of training and tuning."""
"""Get common args of training and tuning."""
train_files = [
io_utils.all_files_pattern(
artifact_utils.get_split_uri(input_dict[constants.EXAMPLES_KEY],
'train'))
io_utils.all_files_pattern(uri) for uri in
artifact_utils.get_split_uris(input_dict[constants.EXAMPLES_KEY],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall lgtm, could you print input_dict[constants.EXAMPLES_KEY] here, and paste result here? I want to check the uri and id to see if it matches what examplegen generates

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attached is that specific log:

0717 20:04:34.873566 135202664122112 fn_args_utils.py:63] [Artifact(artifact: id: 1 type_id: 5 uri: "/tmp/iris_resolver_e2e_testfc2pebe5/tmpa0t72p3g/testIrisPipelineResolverWithDependency/tfx/pipelines/resolver_test/CsvExampleGen/examples/1" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "input_fingerprint" value { string_value: "split:single_split,num_files:1,total_bytes:2758,xor_checksum:1595041459,sum_checksum:1595041459" } } custom_properties { key: "name" value { string_value: "examples" } } custom_properties { key: "payload_format" value { string_value: "FORMAT_TF_EXAMPLE" } } custom_properties { key: "pipeline_name" value { string_value: "resolver_test" } } custom_properties { key: "producer_component" value { string_value: "CsvExampleGen" } } custom_properties { key: "span" value { string_value: "0" } } custom_properties { key: "state" value { string_value: "published" } } create_time_since_epoch: 1595041459439 last_update_time_since_epoch: 1595041460679 , artifact_type: id: 5 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } ), Artifact(artifact: id: 4 type_id: 5 uri: "/tmp/iris_resolver_e2e_testfc2pebe5/tmpa0t72p3g/testIrisPipelineResolverWithDependency/tfx/pipelines/resolver_test/CsvExampleGen/examples/4" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "input_fingerprint" value { string_value: "split:single_split,num_files:1,total_bytes:2758,xor_checksum:1595041464,sum_checksum:1595041464" } } custom_properties { key: "name" value { string_value: "examples" } } custom_properties { key: "payload_format" value { string_value: "FORMAT_TF_EXAMPLE" } } custom_properties { key: "pipeline_name" value { string_value: "resolver_test" } } custom_properties { key: "producer_component" value { string_value: "CsvExampleGen" } } custom_properties { key: "span" value { string_value: "1" } } custom_properties { key: "state" value { string_value: "published" } } create_time_since_epoch: 1595041464576 last_update_time_since_epoch: 1595041465895 , artifact_type: id: 5 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } ), Artifact(artifact: id: 7 type_id: 5 uri: "/tmp/iris_resolver_e2e_testfc2pebe5/tmpa0t72p3g/testIrisPipelineResolverWithDependency/tfx/pipelines/resolver_test/CsvExampleGen/examples/7" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "input_fingerprint" value { string_value: "split:single_split,num_files:1,total_bytes:2758,xor_checksum:1595041469,sum_checksum:1595041469" } } custom_properties { key: "name" value { string_value: "examples" } } custom_properties { key: "payload_format" value { string_value: "FORMAT_TF_EXAMPLE" } } custom_properties { key: "pipeline_name" value { string_value: "resolver_test" } } custom_properties { key: "producer_component" value { string_value: "CsvExampleGen" } } custom_properties { key: "span" value { string_value: "2" } } custom_properties { key: "state" value { string_value: "published" } } create_time_since_epoch: 1595041469775 last_update_time_since_epoch: 1595041471049 , artifact_type: id: 5 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } )]

'train')
]
eval_files = [
io_utils.all_files_pattern(
artifact_utils.get_split_uri(input_dict[constants.EXAMPLES_KEY],
'eval'))
io_utils.all_files_pattern(uri) for uri in
artifact_utils.get_split_uris(input_dict[constants.EXAMPLES_KEY],
'eval')
]

if input_dict.get(constants.TRANSFORM_GRAPH_KEY):
Expand Down
297 changes: 297 additions & 0 deletions tfx/examples/iris/iris_resolver_e2e_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
# Lint as: python2, python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing the ResolverNode multiple artifact output with Trainer."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from typing import List, Text, Optional

import tensorflow as tf

from ml_metadata.metadata_store import metadata_store
from ml_metadata.proto import metadata_store_pb2

from tfx.components import CsvExampleGen
from tfx.components import ResolverNode
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.components import Trainer
from tfx.components.base import executor_spec
from tfx.components.trainer.executor import GenericExecutor
from tfx.dsl.experimental import latest_artifacts_resolver
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
from tfx.types.standard_artifacts import Schema
from tfx.types.standard_artifacts import Examples
from tfx.types.channel import Channel
from tfx.proto import example_gen_pb2
from tfx.proto import trainer_pb2

from tfx.utils import path_utils
from tfx.utils import io_utils


def _create_example_pipeline(pipeline_name: Text, pipeline_root: Text,
data_root: Text, metadata_path: Text,
beam_pipeline_args: List[Text]
) -> pipeline.Pipeline:
"""Simple pipeline to ingest data into Examples artifacts."""
# Brings data into the pipeline or otherwise joins/converts training data.
input_config = example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(name='single_split',
pattern='span{SPAN}/*')])
example_gen = CsvExampleGen(input_base=data_root, input_config=input_config)

# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

# Generates schema based on statistics files.
schema_gen = SchemaGen(
statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=[
example_gen,
statistics_gen,
schema_gen
],
enable_cache=False,
metadata_connection_config=metadata.sqlite_metadata_connection_config(
metadata_path),
beam_pipeline_args=beam_pipeline_args)

def _create_trainer_pipeline(pipeline_name: Text, pipeline_root: Text,
module_file: Text, metadata_path: Text,
window_size: int, beam_pipeline_args: List[Text],
) -> pipeline.Pipeline:
"""Trainer pipeline to train based on resolver outputs"""
# Get latest schema for training.
schema_resolver = ResolverNode(
instance_name='schema_resolver',
resolver_class=latest_artifacts_resolver.LatestArtifactsResolver,
schema=Channel(type=Schema))

# Resolve latest two example artifacts into one channel for trainer.
latest_examples_resolver = ResolverNode(
instance_name='latest_examples_resolver',
resolver_class=latest_artifacts_resolver.LatestArtifactsResolver,
resolver_configs={'desired_num_of_artifacts': window_size},
latest_n_examples=Channel(type=Examples))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you verify the if we put examplegen and resolver in the same pipeline, we can make sure resolver depends on examplegen

window size 3:
pipeline1: examplegen -> tdfv
pipeline2:examplegen ->tfdv
pipeline3: examplegen->tfdv ->resolver->trainer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated test to also test case for dependency.

Copy link
Contributor Author

@18jeffreyma 18jeffreyma Jul 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From running this "full pipeline" and printing out the dependency graph, we get the following:

I0720 13:11:47.808361 135087534106432 beam_dag_runner.py:151] Component CsvExampleGen depends on [].
I0720 13:11:47.829361 135087534106432 beam_dag_runner.py:165] Component CsvExampleGen is scheduled.
I0720 13:11:47.836360 135087534106432 beam_dag_runner.py:151] Component ResolverNode.latest_examples_resolver depends on ['Run[CsvExampleGen]'].
I0720 13:11:47.836360 135087534106432 beam_dag_runner.py:165] Component ResolverNode.latest_examples_resolver is scheduled.
I0720 13:11:47.843361 135087534106432 beam_dag_runner.py:151] Component StatisticsGen depends on ['Run[CsvExampleGen]'].
I0720 13:11:47.844361 135087534106432 beam_dag_runner.py:165] Component StatisticsGen is scheduled.
I0720 13:11:47.844361 135087534106432 beam_dag_runner.py:151] Component SchemaGen depends on ['Run[StatisticsGen]'].
I0720 13:11:47.851361 135087534106432 beam_dag_runner.py:165] Component SchemaGen is scheduled.
I0720 13:11:47.858361 135087534106432 beam_dag_runner.py:151] Component Trainer depends on ['Run[SchemaGen]', 'Run[ResolverNode.latest_examples_resolver]'].
I0720 13:11:47.858361 135087534106432 beam_dag_runner.py:165] Component Trainer is scheduled.

Seems like the dependency between CsvExampleGen and ResolverNode.latest_examples_resolver exists and is explicitly created.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!


# Uses user-provided Python function that implements a model using TF-Learn.
trainer = Trainer(
module_file=module_file,
custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
examples=latest_examples_resolver.outputs['latest_n_examples'],
schema=schema_resolver.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=2000),
eval_args=trainer_pb2.EvalArgs(num_steps=5))

return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=[
schema_resolver,
latest_examples_resolver,
trainer
],
enable_cache=True,
metadata_connection_config=metadata.sqlite_metadata_connection_config(
metadata_path),
beam_pipeline_args=beam_pipeline_args)

def _create_full_pipeline(pipeline_name: Text, pipeline_root: Text,
data_root: Text, module_file: Text,
metadata_path: Text, window_size: int,
beam_pipeline_args: List[Text],
) -> pipeline.Pipeline:
"""Full pipeline to train based on, testing Resolver/ExampleGen dependency"""
# Brings data into the pipeline or otherwise joins/converts training data.
input_config = example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(name='single_split',
pattern='span{SPAN}/*')])
example_gen = CsvExampleGen(input_base=data_root, input_config=input_config)

# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

# Generates schema based on statistics files.
schema_gen = SchemaGen(
statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

# Resolve latest two example artifacts into one channel for trainer.
latest_examples_resolver = ResolverNode(
instance_name='latest_examples_resolver',
resolver_class=latest_artifacts_resolver.LatestArtifactsResolver,
resolver_configs={'desired_num_of_artifacts': window_size},
latest_n_examples=example_gen.outputs['examples'])

# Uses user-provided Python function that implements a model using TF-Learn.
trainer = Trainer(
module_file=module_file,
custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
examples=latest_examples_resolver.outputs['latest_n_examples'],
schema=schema_gen.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=2000),
eval_args=trainer_pb2.EvalArgs(num_steps=5))

return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=[
example_gen,
statistics_gen,
schema_gen,
latest_examples_resolver,
trainer
],
enable_cache=True,
metadata_connection_config=metadata.sqlite_metadata_connection_config(
metadata_path),
beam_pipeline_args=beam_pipeline_args)

class IrisResolverEndToEndTest(tf.test.TestCase):

def setUp(self):
super(IrisResolverEndToEndTest, self).setUp()
self._test_dir = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
self._testMethodName)

self._pipeline_name = 'resolver_test'
self._init_data_root = os.path.join(os.path.dirname(__file__), 'data')
self._data_root = os.path.join(self._test_dir, 'data')
self._module_file = os.path.join(os.path.dirname(__file__), 'iris_utils.py')
self._pipeline_root = os.path.join(self._test_dir, 'tfx', 'pipelines',
self._pipeline_name)
self._metadata_path = os.path.join(self._test_dir, 'tfx', 'metadata',
self._pipeline_name, 'metadata.db')
self._window_size = 3

def _testOutputs(self):
# Test Trainer output.
self.assertTrue(tf.io.gfile.exists(self._metadata_path))
trainer_dir = os.path.join(self._pipeline_root, 'Trainer', 'model')
working_dir = io_utils.get_only_uri_in_dir(trainer_dir)
self.assertTrue(
tf.io.gfile.exists(path_utils.serving_model_path(working_dir)))

# Query MLMD to see if trainer and resolver_node worked properly.
connection_config = metadata_store_pb2.ConnectionConfig()
connection_config.sqlite.filename_uri = self._metadata_path
connection_config.sqlite.connection_mode = \
metadata_store_pb2.SqliteMetadataSourceConfig.READWRITE_OPENCREATE
store = metadata_store.MetadataStore(connection_config)

# Get example artifact ids.
example_ids = [e.id for e in store.get_artifacts_by_type('Examples')]

# Get latest example resolver execution information.
all_resolvers = store.get_executions_by_type(
'tfx.components.common_nodes.resolver_node.ResolverNode')
resolve_exec = [e for e in all_resolvers
if e.properties['component_id'] == metadata_store_pb2.Value(
string_value='ResolverNode.latest_examples_resolver')][0]

# Check if window size is exactly equal to number of examples
# appearing in output events from example resolver.
resolve_events = store.get_events_by_execution_ids([resolve_exec.id])
self.assertEqual(self._window_size,
len([e for e in resolve_events if e.artifact_id in
example_ids and
e.type == metadata_store_pb2.Event.Type.OUTPUT]))

# Get trainer component execution information.
trainer_exec = store.get_executions_by_type(
'tfx.components.trainer.component.Trainer')[0]

# Check if window size is exactly equal to number of examples
# appearing in input events to Trainer.
train_events = store.get_events_by_execution_ids([trainer_exec.id])
self.assertEqual(self._window_size,
len([e for e in train_events if e.artifact_id in
example_ids and
e.type == metadata_store_pb2.Event.Type.INPUT]))

def testIrisPipelineResolver(self):
"""Test with ResolverNode having no ExampleGen dependency."""
example_gen_pipeline = _create_example_pipeline(
pipeline_name=self._pipeline_name,
pipeline_root=self._pipeline_root,
data_root=self._data_root,
metadata_path=self._metadata_path,
beam_pipeline_args=[])

trainer_pipeline = _create_trainer_pipeline(
pipeline_name=self._pipeline_name,
pipeline_root=self._pipeline_root,
module_file=self._module_file,
metadata_path=self._metadata_path,
window_size=self._window_size,
beam_pipeline_args=[])

# Generate two example artifacts.
for i in range(self._window_size):
io_utils.copy_file(os.path.join(self._init_data_root, 'iris.csv'),
os.path.join(self._data_root, 'span{}'.format(i),
'iris.csv'))
BeamDagRunner().run(example_gen_pipeline)

# Train on multiple example artifacts, which are pulled using ResolverNode.
BeamDagRunner().run(trainer_pipeline)
self._testOutputs()

def testIrisPipelineResolverWithDependency(self):
"""Test with ResolverNode having ExampleGen dependency."""
example_gen_pipeline = _create_example_pipeline(
pipeline_name=self._pipeline_name,
pipeline_root=self._pipeline_root,
data_root=self._data_root,
metadata_path=self._metadata_path,
beam_pipeline_args=[])

full_pipeline = _create_full_pipeline(
pipeline_name=self._pipeline_name,
pipeline_root=self._pipeline_root,
data_root=self._data_root,
module_file=self._module_file,
metadata_path=self._metadata_path,
window_size=self._window_size,
beam_pipeline_args=[])

# Generate two example artifacts.
for i in range(self._window_size-1):
io_utils.copy_file(os.path.join(self._init_data_root, 'iris.csv'),
os.path.join(self._data_root, 'span{}'.format(i),
'iris.csv'))
BeamDagRunner().run(example_gen_pipeline)

# Train on multiple example artifacts, which are pulled using ResolverNode.
io_utils.copy_file(os.path.join(self._init_data_root, 'iris.csv'),
os.path.join(self._data_root,
'span{}'.format(self._window_size-1),
'iris.csv'))
BeamDagRunner().run(full_pipeline)
self._testOutputs()


if __name__ == '__main__':
tf.test.main()
Loading