diff --git a/CHANGELOG.md b/CHANGELOG.md index fc541fcea2c9..ba1990279785 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Throttling policy for unauthenticated users () - Added default label color table for mask export (https://github.com/opencv/cvat/pull/1549) +- Added visual identification for unavailable formats (https://github.com/opencv/cvat/pull/1567) ### Changed - Removed information about e-mail from the basic user information () @@ -19,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Fixed -- +- Fixed interpreter crash when trying to import `tensorflow` with no AVX instructions available (https://github.com/opencv/cvat/pull/1567) ### Security - diff --git a/cvat-core/package.json b/cvat-core/package.json index b2f95a0c1619..b347fcb15b19 100644 --- a/cvat-core/package.json +++ b/cvat-core/package.json @@ -1,6 +1,6 @@ { "name": "cvat-core", - "version": "2.0.1", + "version": "2.1.1", "description": "Part of Computer Vision Tool which presents an interface for client-side integration", "main": "babel.config.js", "scripts": { diff --git a/cvat-core/src/annotation-formats.js b/cvat-core/src/annotation-formats.js index 641f7bc06161..a248c9d72bd4 100644 --- a/cvat-core/src/annotation-formats.js +++ b/cvat-core/src/annotation-formats.js @@ -15,6 +15,7 @@ name: initialData.name, format: initialData.ext, version: initialData.version, + enabled: initialData.enabled, }; Object.defineProperties(this, { @@ -48,6 +49,16 @@ */ get: () => data.version, }, + enabled: { + /** + * @name enabled + * @type {string} + * @memberof module:API.cvat.classes.Loader + * @readonly + * @instance + */ + get: () => data.enabled, + }, }); } } @@ -63,6 +74,7 @@ name: initialData.name, format: initialData.ext, version: initialData.version, + enabled: initialData.enabled, }; Object.defineProperties(this, { @@ -96,6 +108,16 @@ */ get: () => data.version, }, + enabled: { + /** + * @name enabled + * @type {string} + * @memberof module:API.cvat.classes.Loader + * @readonly + * @instance + */ + get: () => data.enabled, + }, }); } } diff --git a/cvat-ui/package-lock.json b/cvat-ui/package-lock.json index 6960db8dcb68..a375e42da9d6 100644 --- a/cvat-ui/package-lock.json +++ b/cvat-ui/package-lock.json @@ -1,6 +1,6 @@ { "name": "cvat-ui", - "version": "1.2.0", + "version": "1.2.1", "lockfileVersion": 1, "requires": true, "dependencies": { diff --git a/cvat-ui/package.json b/cvat-ui/package.json index 00831d57a648..d9c4c6e6d538 100644 --- a/cvat-ui/package.json +++ b/cvat-ui/package.json @@ -1,6 +1,6 @@ { "name": "cvat-ui", - "version": "1.2.0", + "version": "1.2.1", "description": "CVAT single-page application", "main": "src/index.tsx", "scripts": { diff --git a/cvat-ui/src/components/actions-menu/actions-menu.tsx b/cvat-ui/src/components/actions-menu/actions-menu.tsx index 7d4f305bd508..5a1070d05a01 100644 --- a/cvat-ui/src/components/actions-menu/actions-menu.tsx +++ b/cvat-ui/src/components/actions-menu/actions-menu.tsx @@ -16,8 +16,8 @@ interface Props { taskMode: string; bugTracker: string; - loaders: string[]; - dumpers: string[]; + loaders: any[]; + dumpers: any[]; loadActivity: string | null; dumpActivities: string[] | null; exportActivities: string[] | null; diff --git a/cvat-ui/src/components/actions-menu/dump-submenu.tsx b/cvat-ui/src/components/actions-menu/dump-submenu.tsx index 23a6ab6c1474..803a5db98927 100644 --- a/cvat-ui/src/components/actions-menu/dump-submenu.tsx +++ b/cvat-ui/src/components/actions-menu/dump-submenu.tsx @@ -15,7 +15,7 @@ function isDefaultFormat(dumperName: string, taskMode: string): boolean { interface Props { taskMode: string; menuKey: string; - dumpers: string[]; + dumpers: any[]; dumpActivities: string[] | null; } @@ -30,17 +30,21 @@ export default function DumpSubmenu(props: Props): JSX.Element { return ( { - dumpers.map((dumper: string): JSX.Element => { - const pending = (dumpActivities || []).includes(dumper); - const isDefault = isDefaultFormat(dumper, taskMode); + dumpers + .sort((a: any, b: any) => a.name.localeCompare(b.name)) + .map((dumper: any): JSX.Element => + { + const pending = (dumpActivities || []).includes(dumper.name); + const disabled = !dumper.enabled || pending; + const isDefault = isDefaultFormat(dumper.name, taskMode); return ( - {dumper} + {dumper.name} {pending && } ); diff --git a/cvat-ui/src/components/actions-menu/export-submenu.tsx b/cvat-ui/src/components/actions-menu/export-submenu.tsx index e5dd53e8388d..045d682f85e0 100644 --- a/cvat-ui/src/components/actions-menu/export-submenu.tsx +++ b/cvat-ui/src/components/actions-menu/export-submenu.tsx @@ -9,7 +9,7 @@ import Text from 'antd/lib/typography/Text'; interface Props { menuKey: string; - exporters: string[]; + exporters: any[]; exportActivities: string[] | null; } @@ -23,16 +23,20 @@ export default function ExportSubmenu(props: Props): JSX.Element { return ( { - exporters.map((exporter: string): JSX.Element => { - const pending = (exportActivities || []).includes(exporter); + exporters + .sort((a: any, b: any) => a.name.localeCompare(b.name)) + .map((exporter: any): JSX.Element => + { + const pending = (exportActivities || []).includes(exporter.name); + const disabled = !exporter.enabled || pending; return ( - {exporter} + {exporter.name} {pending && } ); diff --git a/cvat-ui/src/components/actions-menu/load-submenu.tsx b/cvat-ui/src/components/actions-menu/load-submenu.tsx index 110761a31ee9..2db41d176499 100644 --- a/cvat-ui/src/components/actions-menu/load-submenu.tsx +++ b/cvat-ui/src/components/actions-menu/load-submenu.tsx @@ -11,7 +11,7 @@ import Text from 'antd/lib/typography/Text'; interface Props { menuKey: string; - loaders: string[]; + loaders: any[]; loadActivity: string | null; onFileUpload(file: File): void; } @@ -27,13 +27,20 @@ export default function LoadSubmenu(props: Props): JSX.Element { return ( { - loaders.map((_loader: string): JSX.Element => { - const [loader, accept] = _loader.split('::'); - const pending = loadActivity === loader; + loaders + .sort((a: any, b: any) => a.name.localeCompare(b.name)) + .map((loader: any): JSX.Element => + { + const accept = loader.format + .split(',') + .map((x: string) => '.' + x.trimStart()) + .join(', '); // add '.' to each extension in a list + const pending = loadActivity === loader.name; + const disabled = !loader.enabled || !!loadActivity; return ( - diff --git a/cvat-ui/src/components/annotation-page/top-bar/annotation-menu.tsx b/cvat-ui/src/components/annotation-page/top-bar/annotation-menu.tsx index 51a460ea3781..92da2240d251 100644 --- a/cvat-ui/src/components/annotation-page/top-bar/annotation-menu.tsx +++ b/cvat-ui/src/components/annotation-page/top-bar/annotation-menu.tsx @@ -13,8 +13,8 @@ import ReIDPlugin from './reid-plugin'; interface Props { taskMode: string; - loaders: string[]; - dumpers: string[]; + loaders: any[]; + dumpers: any[]; loadActivity: string | null; dumpActivities: string[] | null; exportActivities: string[] | null; diff --git a/cvat-ui/src/containers/actions-menu/actions-menu.tsx b/cvat-ui/src/containers/actions-menu/actions-menu.tsx index c502eece1482..a92a42bb1fff 100644 --- a/cvat-ui/src/containers/actions-menu/actions-menu.tsx +++ b/cvat-ui/src/containers/actions-menu/actions-menu.tsx @@ -134,7 +134,7 @@ function ActionsMenuContainer(props: OwnProps & StateToProps & DispatchToProps): dumpAnnotations(taskInstance, dumper); } } else if (action === Actions.LOAD_TASK_ANNO) { - const [format] = additionalKey.split('::'); + const format = additionalKey; const [loader] = loaders .filter((_loader: any): boolean => _loader.name === format); if (loader && file) { @@ -166,8 +166,8 @@ function ActionsMenuContainer(props: OwnProps & StateToProps & DispatchToProps): taskID={taskInstance.id} taskMode={taskInstance.mode} bugTracker={taskInstance.bugTracker} - loaders={loaders.map((loader: any): string => `${loader.name}::${loader.format}`)} - dumpers={dumpers.map((dumper: any): string => dumper.name)} + loaders={loaders} + dumpers={dumpers} loadActivity={loadActivity} dumpActivities={dumpActivities} exportActivities={exportActivities} diff --git a/cvat-ui/src/containers/annotation-page/top-bar/annotation-menu.tsx b/cvat-ui/src/containers/annotation-page/top-bar/annotation-menu.tsx index b6633885c1be..139168a8f293 100644 --- a/cvat-ui/src/containers/annotation-page/top-bar/annotation-menu.tsx +++ b/cvat-ui/src/containers/annotation-page/top-bar/annotation-menu.tsx @@ -123,7 +123,7 @@ function AnnotationMenuContainer(props: Props): JSX.Element { dumpAnnotations(jobInstance.task, dumper); } } else if (action === Actions.LOAD_JOB_ANNO) { - const [format] = additionalKey.split('::'); + const format = additionalKey; const [loader] = loaders .filter((_loader: any): boolean => _loader.name === format); if (loader && file) { @@ -150,8 +150,8 @@ function AnnotationMenuContainer(props: Props): JSX.Element { return ( loader.name)} - dumpers={dumpers.map((dumper: any): string => dumper.name)} + loaders={loaders} + dumpers={dumpers} loadActivity={loadActivity} dumpActivities={dumpActivities} exportActivities={exportActivities} diff --git a/cvat/apps/dataset_manager/formats/registry.py b/cvat/apps/dataset_manager/formats/registry.py index 20377dd67030..ed4defc559df 100644 --- a/cvat/apps/dataset_manager/formats/registry.py +++ b/cvat/apps/dataset_manager/formats/registry.py @@ -13,6 +13,7 @@ class _Format: EXT = '' VERSION = '' DISPLAY_NAME = '{NAME} {VERSION}' + ENABLED = True class Exporter(_Format): def __call__(self, dst_file, task_data, **options): @@ -22,7 +23,7 @@ class Importer(_Format): def __call__(self, src_file, task_data, **options): raise NotImplementedError() -def _wrap_format(f_or_cls, klass, name, version, ext, display_name): +def _wrap_format(f_or_cls, klass, name, version, ext, display_name, enabled): import inspect assert inspect.isclass(f_or_cls) or inspect.isfunction(f_or_cls) if inspect.isclass(f_or_cls): @@ -44,14 +45,17 @@ def __call__(self, *args, **kwargs): target.DISPLAY_NAME = (display_name or klass.DISPLAY_NAME).format( NAME=name, VERSION=version, EXT=ext) assert all([target.NAME, target.VERSION, target.EXT, target.DISPLAY_NAME]) + target.ENABLED = enabled + return target EXPORT_FORMATS = {} -def exporter(name, version, ext, display_name=None): +def exporter(name, version, ext, display_name=None, enabled=True): assert name not in EXPORT_FORMATS, "Export format '%s' already registered" % name def wrap_with_params(f_or_cls): t = _wrap_format(f_or_cls, Exporter, - name=name, ext=ext, version=version, display_name=display_name) + name=name, ext=ext, version=version, display_name=display_name, + enabled=enabled) key = t.DISPLAY_NAME assert key not in EXPORT_FORMATS, "Export format '%s' already registered" % name EXPORT_FORMATS[key] = t @@ -59,10 +63,11 @@ def wrap_with_params(f_or_cls): return wrap_with_params IMPORT_FORMATS = {} -def importer(name, version, ext, display_name=None): +def importer(name, version, ext, display_name=None, enabled=True): def wrap_with_params(f_or_cls): t = _wrap_format(f_or_cls, Importer, - name=name, ext=ext, version=version, display_name=display_name) + name=name, ext=ext, version=version, display_name=display_name, + enabled=enabled) key = t.DISPLAY_NAME assert key not in IMPORT_FORMATS, "Import format '%s' already registered" % name IMPORT_FORMATS[key] = t diff --git a/cvat/apps/dataset_manager/formats/tfrecord.py b/cvat/apps/dataset_manager/formats/tfrecord.py index 0e4962fa6c4a..fef95aa710f6 100644 --- a/cvat/apps/dataset_manager/formats/tfrecord.py +++ b/cvat/apps/dataset_manager/formats/tfrecord.py @@ -14,7 +14,15 @@ from .registry import dm_env, exporter, importer -@exporter(name='TFRecord', ext='ZIP', version='1.0') +from datumaro.util.tf_util import import_tf +try: + import_tf() + tf_available = True +except ImportError: + tf_available = False + + +@exporter(name='TFRecord', ext='ZIP', version='1.0', enabled=tf_available) def _export(dst_file, task_data, save_images=False): extractor = CvatTaskDataExtractor(task_data, include_images=save_images) extractor = Dataset.from_extractors(extractor) # apply lazy transforms @@ -25,7 +33,7 @@ def _export(dst_file, task_data, save_images=False): make_zip_archive(temp_dir, dst_file) -@importer(name='TFRecord', ext='ZIP', version='1.0') +@importer(name='TFRecord', ext='ZIP', version='1.0', enabled=tf_available) def _import(src_file, task_data): with TemporaryDirectory() as tmp_dir: Archive(src_file.name).extractall(tmp_dir) diff --git a/cvat/apps/dataset_manager/serializers.py b/cvat/apps/dataset_manager/serializers.py index 51cf71ca8da3..e64c0cb93bcd 100644 --- a/cvat/apps/dataset_manager/serializers.py +++ b/cvat/apps/dataset_manager/serializers.py @@ -9,6 +9,7 @@ class DatasetFormatSerializer(serializers.Serializer): name = serializers.CharField(max_length=64, source='DISPLAY_NAME') ext = serializers.CharField(max_length=64, source='EXT') version = serializers.CharField(max_length=64, source='VERSION') + enabled = serializers.BooleanField(source='ENABLED') class DatasetFormatsSerializer(serializers.Serializer): importers = DatasetFormatSerializer(many=True) diff --git a/cvat/apps/dataset_manager/tests/_test_formats.py b/cvat/apps/dataset_manager/tests/_test_formats.py index bfea13af88e4..d5aa950d0f9f 100644 --- a/cvat/apps/dataset_manager/tests/_test_formats.py +++ b/cvat/apps/dataset_manager/tests/_test_formats.py @@ -335,6 +335,9 @@ def check(file_path): self.assertTrue(len(f.read()) != 0) for f in dm.views.get_export_formats(): + if not f.ENABLED: + self.skipTest("Format is disabled") + format_name = f.DISPLAY_NAME for save_images in { True, False }: with self.subTest(format=format_name, save_images=save_images): @@ -359,6 +362,9 @@ def test_empty_images_are_exported(self): ('YOLO 1.1', 'yolo'), ]: with self.subTest(format=format_name): + if not dm.formats.registry.EXPORT_FORMATS[format_name].ENABLED: + self.skipTest("Format is disabled") + task = self._generate_task() def check(file_path): diff --git a/cvat/apps/engine/tests/_test_rest_api.py b/cvat/apps/engine/tests/_test_rest_api.py index c8bee86c4609..fcd3d59a1d7f 100644 --- a/cvat/apps/engine/tests/_test_rest_api.py +++ b/cvat/apps/engine/tests/_test_rest_api.py @@ -3149,8 +3149,8 @@ def _get_initial_annotation(annotation_format): export_formats = data['exporters'] self.assertTrue(isinstance(import_formats, list) and import_formats) self.assertTrue(isinstance(export_formats, list) and export_formats) - import_formats = { v['name'] for v in import_formats } - export_formats = { v['name'] for v in export_formats } + import_formats = { v['name']: v for v in import_formats } + export_formats = { v['name']: v for v in export_formats } formats = { exp: exp if exp in import_formats else None for exp in export_formats } @@ -3159,12 +3159,12 @@ def _get_initial_annotation(annotation_format): formats['CVAT for video 1.1'] = 'CVAT 1.1' if 'CVAT for images 1.1' in export_formats: formats['CVAT for images 1.1'] = 'CVAT 1.1' - if import_formats ^ export_formats: + if set(import_formats) ^ set(export_formats): # NOTE: this may not be an error, so we should not fail print("The following import formats have no pair:", - import_formats - export_formats) + set(import_formats) - set(export_formats)) print("The following export formats have no pair:", - export_formats - import_formats) + set(export_formats) - set(import_formats)) for export_format, import_format in formats.items(): with self.subTest(export_format=export_format, @@ -3183,7 +3183,12 @@ def _get_initial_annotation(annotation_format): # 3. download annotation response = self._dump_api_v1_tasks_id_annotations(task["id"], annotator, "?format={}".format(export_format)) - self.assertEqual(response.status_code, HTTP_202_ACCEPTED) + if annotator and not export_formats[export_format]['enabled']: + self.assertEqual(response.status_code, + status.HTTP_405_METHOD_NOT_ALLOWED) + continue + else: + self.assertEqual(response.status_code, HTTP_202_ACCEPTED) response = self._dump_api_v1_tasks_id_annotations(task["id"], annotator, "?format={}".format(export_format)) diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 47dfd80036a8..735d77c43370 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -482,7 +482,8 @@ def data(self, request, pk): responses={ '202': openapi.Response(description='Dump of annotations has been started'), '201': openapi.Response(description='Annotations file is ready to download'), - '200': openapi.Response(description='Download of file started') + '200': openapi.Response(description='Download of file started'), + '405': openapi.Response(description='Format is not available'), } ) @swagger_auto_schema(method='put', operation_summary='Method allows to upload task annotations', @@ -494,6 +495,7 @@ def data(self, request, pk): responses={ '202': openapi.Response(description='Uploading has been started'), '201': openapi.Response(description='Uploading has finished'), + '405': openapi.Response(description='Format is not available'), } ) @swagger_auto_schema(method='patch', operation_summary='Method performs a partial update of annotations in a specific task', @@ -619,7 +621,8 @@ def data_info(request, pk): ], responses={'202': openapi.Response(description='Exporting has been started'), '201': openapi.Response(description='Output file is ready for downloading'), - '200': openapi.Response(description='Download of file started') + '200': openapi.Response(description='Download of file started'), + '405': openapi.Response(description='Format is not available'), } ) @action(detail=True, methods=['GET'], serializer_class=None, @@ -799,17 +802,20 @@ def rq_handler(job, exc_type, exc_value, tb): # tags=['tasks']) # @api_view(['PUT']) def _import_annotations(request, rq_id, rq_func, pk, format_name): + format_desc = {f.DISPLAY_NAME: f + for f in dm.views.get_import_formats()}.get(format_name) + if format_desc is None: + raise serializers.ValidationError( + "Unknown input format '{}'".format(format_name)) + elif not format_desc.ENABLED: + return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED) + queue = django_rq.get_queue("default") rq_job = queue.fetch_job(rq_id) if not rq_job: serializer = AnnotationFileSerializer(data=request.data) if serializer.is_valid(raise_exception=True): - if format_name not in \ - [f.DISPLAY_NAME for f in dm.views.get_import_formats()]: - raise serializers.ValidationError( - "Unknown input format '{}'".format(format_name)) - anno_file = serializer.validated_data['annotation_file'] fd, filename = mkstemp(prefix='cvat_{}'.format(pk)) with open(filename, 'wb+') as f: @@ -843,9 +849,13 @@ def _export_annotations(db_task, rq_id, request, format_name, action, callback, raise serializers.ValidationError( "Unexpected action specified for the request") - if format_name not in [f.DISPLAY_NAME for f in dm.views.get_export_formats()]: + format_desc = {f.DISPLAY_NAME: f + for f in dm.views.get_export_formats()}.get(format_name) + if format_desc is None: raise serializers.ValidationError( "Unknown format specified for the request") + elif not format_desc.ENABLED: + return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED) queue = django_rq.get_queue("default") diff --git a/datumaro/datumaro/components/project.py b/datumaro/datumaro/components/project.py index d4468edd9b95..3c590508626f 100644 --- a/datumaro/datumaro/components/project.py +++ b/datumaro/datumaro/components/project.py @@ -235,7 +235,17 @@ def _load_plugins(cls, plugins_dir, types): exports = cls._import_module(module_dir, module_name, types, package) except Exception as e: - log.debug("Failed to import module '%s': %s" % (module_name, e)) + module_search_error = ImportError + try: + module_search_error = ModuleNotFoundError # python 3.6+ + except NameError: + pass + + message = ["Failed to import module '%s': %s", module_name, e] + if isinstance(e, module_search_error): + log.debug(*message) + else: + log.warning(*message) continue log.debug("Imported the following symbols from %s: %s" % \ diff --git a/datumaro/datumaro/plugins/openvino_launcher.py b/datumaro/datumaro/plugins/openvino_launcher.py index 10f12feab614..438a4b3da086 100644 --- a/datumaro/datumaro/plugins/openvino_launcher.py +++ b/datumaro/datumaro/plugins/openvino_launcher.py @@ -10,11 +10,11 @@ import os import os.path as osp import platform -import subprocess from openvino.inference_engine import IENetwork, IEPlugin from datumaro.components.launcher import Launcher +from datumaro.util.os_util import check_instruction_set class InterpreterScript: @@ -45,17 +45,6 @@ class OpenVinoLauncher(Launcher): _DEFAULT_IE_PLUGINS_PATH = "/opt/intel/openvino_2019.1.144/deployment_tools/inference_engine/lib/intel64" _IE_PLUGINS_PATH = os.getenv("IE_PLUGINS_PATH", _DEFAULT_IE_PLUGINS_PATH) - @staticmethod - def _check_instruction_set(instruction): - return instruction == str.strip( - # Let's ignore a warning from bandit about using shell=True. - # In this case it isn't a security issue and we use some - # shell features like pipes. - subprocess.check_output( - 'lscpu | grep -o "{}" | head -1'.format(instruction), - shell=True).decode('utf-8') # nosec - ) - @staticmethod def make_plugin(device='cpu', plugins_path=_IE_PLUGINS_PATH): if plugins_path is None or not osp.isdir(plugins_path): @@ -63,10 +52,10 @@ def make_plugin(device='cpu', plugins_path=_IE_PLUGINS_PATH): (plugins_path)) plugin = IEPlugin(device='CPU', plugin_dirs=[plugins_path]) - if (OpenVinoLauncher._check_instruction_set('avx2')): + if (check_instruction_set('avx2')): plugin.add_cpu_extension(os.path.join(plugins_path, 'libcpu_extension_avx2.so')) - elif (OpenVinoLauncher._check_instruction_set('sse4')): + elif (check_instruction_set('sse4')): plugin.add_cpu_extension(os.path.join(plugins_path, 'libcpu_extension_sse4.so')) elif platform.system() == 'Darwin': diff --git a/datumaro/datumaro/util/os_util.py b/datumaro/datumaro/util/os_util.py new file mode 100644 index 000000000000..b4d05e376db2 --- /dev/null +++ b/datumaro/datumaro/util/os_util.py @@ -0,0 +1,17 @@ + +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import subprocess + + +def check_instruction_set(instruction): + return instruction == str.strip( + # Let's ignore a warning from bandit about using shell=True. + # In this case it isn't a security issue and we use some + # shell features like pipes. + subprocess.check_output( + 'lscpu | grep -o "%s" | head -1' % instruction, + shell=True).decode('utf-8') # nosec + ) \ No newline at end of file diff --git a/datumaro/datumaro/util/tf_util.py b/datumaro/datumaro/util/tf_util.py index 00bf834a0f0d..841fc53faf1f 100644 --- a/datumaro/datumaro/util/tf_util.py +++ b/datumaro/datumaro/util/tf_util.py @@ -3,7 +3,36 @@ # # SPDX-License-Identifier: MIT -def import_tf(): + +def check_import(): + # Workaround for checking import availability: + # Official TF builds include AVX instructions. Once we try to import, + # the program crashes. We raise an exception instead. + + import subprocess + import sys + + from .os_util import check_instruction_set + + result = subprocess.run([sys.executable, '-c', 'import tensorflow'], + timeout=60, + universal_newlines=True, # use text mode for output stream + stdout=subprocess.PIPE, stderr=subprocess.PIPE) # capture output + + if result.returncode != 0: + message = result.stderr + if not message: + message = "Can't import tensorflow. " \ + "Test process exit code: %s." % result.returncode + if not check_instruction_set('avx'): + # The process has probably crashed for AVX unavalability + message += " This is likely because your CPU does not " \ + "support AVX instructions, " \ + "which are required for tensorflow." + + raise ImportError(message) + +def import_tf(check=True): import sys tf = sys.modules.get('tensorflow', None) @@ -14,6 +43,9 @@ def import_tf(): import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + if check: + check_import() + import tensorflow as tf try: diff --git a/datumaro/tests/test_tfrecord_format.py b/datumaro/tests/test_tfrecord_format.py index 737ea6cf5f77..cc55a9fc9a3f 100644 --- a/datumaro/tests/test_tfrecord_format.py +++ b/datumaro/tests/test_tfrecord_format.py @@ -1,17 +1,34 @@ import numpy as np -from unittest import TestCase +from unittest import TestCase, skipIf from datumaro.components.extractor import (Extractor, DatasetItem, AnnotationType, Bbox, Mask, LabelCategories ) -from datumaro.plugins.tf_detection_api_format.importer import TfDetectionApiImporter -from datumaro.plugins.tf_detection_api_format.extractor import TfDetectionApiExtractor -from datumaro.plugins.tf_detection_api_format.converter import TfDetectionApiConverter from datumaro.util.image import Image from datumaro.util.test_utils import TestDir, compare_datasets - - +from datumaro.util.tf_util import check_import + +try: + from datumaro.plugins.tf_detection_api_format.importer import TfDetectionApiImporter + from datumaro.plugins.tf_detection_api_format.extractor import TfDetectionApiExtractor + from datumaro.plugins.tf_detection_api_format.converter import TfDetectionApiConverter + import_failed = False +except ImportError: + import_failed = True + + import importlib + module_found = importlib.util.find_spec('tensorflow') is not None + + @skipIf(not module_found, "Tensorflow package is not found") + class TfImportTest(TestCase): + def test_raises_when_crashes_on_import(self): + # Should fire if import can't be done for any reason except + # module unavailability and import crash + with self.assertRaisesRegex(ImportError, 'Test process exit code'): + check_import() + +@skipIf(import_failed, "Failed to import tensorflow") class TfrecordConverterTest(TestCase): def _test_save_and_load(self, source_dataset, converter, test_dir, target_dataset=None, importer_args=None): @@ -171,6 +188,7 @@ def test_labelmap_parsing(self): self.assertEqual(expected, parsed) +@skipIf(import_failed, "Failed to import tensorflow") class TfrecordImporterTest(TestCase): def test_can_detect(self): class TestExtractor(Extractor):