Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
FatimahAdwan authored Jul 9, 2021
1 parent 77df9cd commit 749d4de
Showing 1 changed file with 47 additions and 16 deletions.
63 changes: 47 additions & 16 deletions component/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@
from six import string_types

import tensorflow as tf
from tfx import types
import executor
from tfx.dsl.io import fileio
from tfx.types import artifact_utils
from tfx.types import standard_artifacts
from tfx.types import standard_component_specs
from tfx.utils import json_utils
import sys

sys.path.append('.')

from schemacomponent.test_data.module_file import module_file
from test_data.module_file import module_file

class ExecutorTest(tf.test.TestCase):

Expand All @@ -43,39 +45,68 @@ def testDo(self):
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
self._testMethodName)

schema = standard_artifacts.Schema()
schema.uri = os.path.join(self._source_data_dir, 'schema_gen')
self.schema = standard_artifacts.Schema()
self.schema.uri = os.path.join(self._source_data_dir, 'schema_gen')

input_dict = {
self.input_dict = {

standard_component_specs.SCHEMA_KEY: [schema],
standard_component_specs.SCHEMA_KEY: [self.schema],
}

schema_output = standard_artifacts.Schema()
schema_output.uri = os.path.join(self._output_data_dir, 'custom_schema')
self.schema_output = standard_artifacts.Schema()
self.schema_output.uri = os.path.join(self._output_data_dir, 'custom_schema')

output_dict = {
'custom_schema': [schema_output],
'custom_schema': [self.schema_output],
}


_module_file = os.path.join(self._source_data_dir,
self._module_file = os.path.join(self._source_data_dir,
standard_component_specs.MODULE_FILE_KEY,
'module_file.py')
schema_fn = '%s.%s' % (module_file.schema_fn.__module__,
self.schema_fn = '%s.%s' % (module_file.schema_fn.__module__,
module_file.schema_fn.__name__)

print(_module_file)
exec_properties = {
standard_component_specs.MODULE_FILE_KEY : _module_file
print(self._module_file )
self.exec_properties = {
standard_component_specs.MODULE_FILE_KEY : self._module_file

}

schemaCuration_executor = executor.Executor()
self.schemaCuration_executor = executor.Executor()

self.schemaCuration_executor.Do(self.input_dict, output_dict, self.exec_properties)

schemaCuration_executor.Do(input_dict, output_dict, exec_properties)

self.assertNotEqual(0, len(fileio.listdir(schema_output.uri)))
def _verify_schemaCuration_outputs(self):
self.assertNotEqual(0, len(fileio.listdir(self.schema_output.uri)))


def testDoWithModuleFile(self):
self.exec_properties['module_file.py'] = self._module_file
self.schemaCuration_executor.Do(self.input_dict, self.output_dict,
self.exec_properties)
self._verify_schemaCuration_outputs()

def schemaFn(self):
self._exec_properties['schema_fn'] = self.schema_fn
self.schemaCuration_executor.Do(self.input_dict, self.output_dict,
self._exec_properties)
self._verify_schemaCuration_outputs()

def testDoWithCache(self):
# First run that creates cache.
output_cache_artifact = types.Artifact('OutputCache')
output_cache_artifact.uri = os.path.join(self._utput_data_dir, 'CACHE/')

self.output_dict['cache_output_path'] = [output_cache_artifact]

self.exec_properties['module_file'] = self._module_file
self.schemaCuration_executor.Do(self.input_dict, self.output_dict,
self.exec_properties)
self._verify_schemaCuration_outputs()
self.assertNotEqual(0,
len(tf.io.gfile.listdir(output_cache_artifact.uri)))


if __name__ == '__main__':
Expand Down

0 comments on commit 749d4de

Please sign in to comment.