Skip to content

Commit

Permalink
unit test for excutor
Browse files Browse the repository at this point in the history
  • Loading branch information
FatimahAdwan committed Jul 1, 2021
1 parent 69c82db commit 934be2e
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 69 deletions.
194 changes: 125 additions & 69 deletions component/executor_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Lint as: python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,93 +11,150 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TFX Schema Curation Custom Executor."""
"""Tests for tfx.components.custom.executor."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tempfile
import executor
import tensorflow as tf
from tfx import types
from tfx.utils import path_utils
from tfx.dsl.io import fileio
import executor
from tfx.types import standard_artifacts

from tfx.dsl.io import fileio
from tfx.types import artifact_utils
from tfx.utils import io_utils
from tfx.components.util import udf_utils
from tfx.types import standard_component_specs
from tfx.components.testdata.module_file import transform_module
import tensorflow_transform as tft



class ExecutorTest(tf.test.TestCase):
def _get_output_data_dir(self, sub_dir=None):
test_dir = self._testMethodName
if sub_dir is not None:
test_dir = os.path.join(test_dir, sub_dir)
return os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
test_dir)

def _make_base_do_params(self, source_data_dir, output_data_dir):
# Create input dict.
schema_artifact = standard_artifacts.Schema()
schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen/')

self._input_dict = {
'schema': [schema_artifact],
}

# Create output dict.
self._output = standard_artifacts.TransformGraph()
self._output.uri = os.path.join(output_data_dir,
'custom_output')
temp_path_output = types.Artifact
temp_path_output.uri = tempfile.mkdtemp()

self._output_dict = {
'custom_output': [self._output],
'temp_path': [temp_path_output],
}

# Create exec properties skeleton.
self._exec_properties = {}


def setUp(self):
super(ExecutorTest, self).setUp()

self._source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata')
self._output_data_dir = self._get_output_data_dir()

self._make_base_do_params(self._source_data_dir, self._output_data_dir)

# Create exec properties skeleton.
self._module_file = os.path.join(self._source_data_dir,
'module_file/custom_module.py')
self._preprocessing_fn = '%s.%s' % (
transform_module.preprocessing_fn.__module__,
transform_module.preprocessing_fn.__name__)

# Executor for test.
self._custom_executor = executor.Executor()

def _verify_outputs(self):
path_to_saved_model = os.path.join(
self._output.uri, tft.TFTransformOutput.TRANSFORM_FN_DIR,
tf.saved_model.SAVED_MODEL_FILENAME_PB)
self.assertTrue(tf.io.gfile.exists(path_to_saved_model))

def testDoWithModuleFile(self):
self._exec_properties['module_file'] = self._module_file
self._custom_executor.Do(self._input_dict, self._output_dict,
self._exec_properties)
self._verify_outputs()

def testDoWithPreprocessingFn(self):
self._exec_properties['preprocessing_fn'] = self._preprocessing_fn
self._custom_executor.Do(self._input_dict, self._output_dict,
self._exec_properties)
self._verify_outputs()

def testDoWithNoPreprocessingFn(self):
with self.assertRaises(ValueError):
self._custom_executor.Do(self._input_dict, self._output_dict,
self._exec_properties)

def testDoWithDuplicatePreprocessingFn(self):
self._exec_properties['module_file'] = self._module_file
self._exec_properties['preprocessing_fn'] = self._preprocessing_fn
with self.assertRaises(ValueError):
self._custom_executor.Do(self._input_dict, self._output_dict,
self._exec_properties)

def testDoWithCache(self):
# First run that creates cache.
output_cache_artifact = types.Artifact('OutputCache')
output_cache_artifact.uri = os.path.join(self._output_data_dir, 'CACHE/')

self._output_dict['cache_output_path'] = [output_cache_artifact]

self._exec_properties['module_file'] = self._module_file
self._custom_executor.Do(self._input_dict, self._output_dict,
self._exec_properties)
self._verify_outputs()
self.assertNotEqual(0,
len(tf.io.gfile.listdir(output_cache_artifact.uri)))

# Second run from cache.
self._output_data_dir = self._get_output_data_dir('2nd_run')
input_cache_artifact = types.Artifact('InputCache')
input_cache_artifact.uri = output_cache_artifact.uri

output_cache_artifact = types.Artifact('OutputCache')
output_cache_artifact.uri = os.path.join(self._output_data_dir, 'CACHE/')

def _get_output_data_dir(self, sub_dir=None):
test_dir = self._testMethodName
if sub_dir is not None:
test_dir = os.path.join(test_dir, sub_dir)
return os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
test_dir)

def _make_base_do_params(self, source_data_dir, output_data_dir):
# Create input dict.
train_artifact = standard_artifacts.Examples(split='train')
train_artifact.uri = os.path.join(source_data_dir, 'csv_example_gen/train/')
eval_artifact = standard_artifacts.Examples(split='eval')
eval_artifact.uri = os.path.join(source_data_dir, 'csv_example_gen/eval/')
schema_artifact = standard_artifacts.Schema()
schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen/')

self._input_dict = {
'input_data': [train_artifact, eval_artifact],
'schema': [schema_artifact],
}
self._make_base_do_params(self._source_data_dir, self._output_data_dir)

self.output_data_train = standard_artifacts.Examples(
split='train')
self._input_dict['cache_input_path'] = [input_cache_artifact]
self._output_dict['cache_output_path'] = [output_cache_artifact]

self.output_data_train.uri = os.path.join(output_data_dir,
'train')
self._exec_properties['module_file'] = self._module_file
self._custom_executor.Do(self._input_dict, self._output_dict,
self._exec_properties)

self.output_data_eval = standard_artifacts.Examples(
split='eval')
self._verify_outputs()
self.assertNotEqual(0,
len(tf.io.gfile.listdir(output_cache_artifact.uri)))

self._output_data_eval.uri = os.path.join(output_data_dir, 'eval')

temp_path_output = types.Artifact('TempPath')
temp_path_output.uri = tempfile.mkdtemp()

self.output_dict = {
'output_data': [
self.output_data_train, self.output_data_eval
],
'temp_path': [temp_path_output],
}

self._exec_properties = {}

self._executor = executor.Executor()

def _verify_model_exports(self):
self.assertTrue(

fileio.exists(path_utils.eval_model_dir(self.output_data_eval_uri)))
self.assertTrue(
fileio.exists(path_utils.serving_model_dir(self.output_data_train.uri )))

def _verify_no_eval_model_exports(self):
self.assertFalse(
fileio.exists(path_utils.eval_model_dir(self.output_data_eval.uri )))

def _verify_model_run_exports(self):
self.assertTrue(fileio.exists(os.path.dirname(self.tput_data_train.uri )))

def _do(self, test_executor):
test_executor.Do(
input_dict=self._input_dict,
output_dict=self._output_dict,
exec_properties=self._exec_properties)




if __name__ == '__main__':
tf.test.main()


File renamed without changes.

0 comments on commit 934be2e

Please sign in to comment.