diff --git a/component/executor_test.py b/component/executor_test.py index dd56a30..176ecef 100644 --- a/component/executor_test.py +++ b/component/executor_test.py @@ -23,6 +23,7 @@ from six import string_types import tensorflow as tf +from tfx import types import executor from tfx.dsl.io import fileio from tfx.types import artifact_utils @@ -30,9 +31,10 @@ from tfx.types import standard_component_specs from tfx.utils import json_utils import sys + sys.path.append('.') -from schemacomponent.test_data.module_file import module_file +from test_data.module_file import module_file class ExecutorTest(tf.test.TestCase): @@ -43,39 +45,68 @@ def testDo(self): os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) - schema = standard_artifacts.Schema() - schema.uri = os.path.join(self._source_data_dir, 'schema_gen') + self.schema = standard_artifacts.Schema() + self.schema.uri = os.path.join(self._source_data_dir, 'schema_gen') - input_dict = { + self.input_dict = { - standard_component_specs.SCHEMA_KEY: [schema], + standard_component_specs.SCHEMA_KEY: [self.schema], } - schema_output = standard_artifacts.Schema() - schema_output.uri = os.path.join(self._output_data_dir, 'custom_schema') + self.schema_output = standard_artifacts.Schema() + self.schema_output.uri = os.path.join(self._output_data_dir, 'custom_schema') output_dict = { - 'custom_schema': [schema_output], + 'custom_schema': [self.schema_output], } - _module_file = os.path.join(self._source_data_dir, + self._module_file = os.path.join(self._source_data_dir, standard_component_specs.MODULE_FILE_KEY, 'module_file.py') - schema_fn = '%s.%s' % (module_file.schema_fn.__module__, + self.schema_fn = '%s.%s' % (module_file.schema_fn.__module__, module_file.schema_fn.__name__) - print(_module_file) - exec_properties = { - standard_component_specs.MODULE_FILE_KEY : _module_file + print(self._module_file ) + self.exec_properties = { + standard_component_specs.MODULE_FILE_KEY : self._module_file } - schemaCuration_executor = executor.Executor() + self.schemaCuration_executor = executor.Executor() + + self.schemaCuration_executor.Do(self.input_dict, output_dict, self.exec_properties) - schemaCuration_executor.Do(input_dict, output_dict, exec_properties) - self.assertNotEqual(0, len(fileio.listdir(schema_output.uri))) + def _verify_schemaCuration_outputs(self): + self.assertNotEqual(0, len(fileio.listdir(self.schema_output.uri))) + + + def testDoWithModuleFile(self): + self.exec_properties['module_file.py'] = self._module_file + self.schemaCuration_executor.Do(self.input_dict, self.output_dict, + self.exec_properties) + self._verify_schemaCuration_outputs() + + def schemaFn(self): + self._exec_properties['schema_fn'] = self.schema_fn + self.schemaCuration_executor.Do(self.input_dict, self.output_dict, + self._exec_properties) + self._verify_schemaCuration_outputs() + + def testDoWithCache(self): + # First run that creates cache. + output_cache_artifact = types.Artifact('OutputCache') + output_cache_artifact.uri = os.path.join(self._utput_data_dir, 'CACHE/') + + self.output_dict['cache_output_path'] = [output_cache_artifact] + + self.exec_properties['module_file'] = self._module_file + self.schemaCuration_executor.Do(self.input_dict, self.output_dict, + self.exec_properties) + self._verify_schemaCuration_outputs() + self.assertNotEqual(0, + len(tf.io.gfile.listdir(output_cache_artifact.uri))) if __name__ == '__main__':