diff --git a/azext_iot/common/utility.py b/azext_iot/common/utility.py index b60d58725..dafd84391 100644 --- a/azext_iot/common/utility.py +++ b/azext_iot/common/utility.py @@ -11,6 +11,7 @@ import ast import base64 +import isodate import json import os import sys @@ -34,13 +35,13 @@ def parse_entity(entity, filter_none=False): result (dict): a dictionary of attributes from the function input. """ result = {} - attributes = [attr for attr in dir(entity) if not attr.startswith('_')] + attributes = [attr for attr in dir(entity) if not attr.startswith("_")] for attribute in attributes: value = getattr(entity, attribute, None) if filter_none and not value: continue value_behavior = dir(value) - if '__call__' not in value_behavior: + if "__call__" not in value_behavior: result[attribute] = value return result @@ -73,19 +74,25 @@ def verify_transform(subject, mapping): verifies that subject[k] is of type mapping[k] """ import jmespath + for k in mapping.keys(): result = jmespath.search(k, subject) if result is None: raise AttributeError('The property "{}" is required'.format(k)) if not isinstance(result, mapping[k]): - supplemental_info = '' + supplemental_info = "" if mapping[k] == dict: - wiki_link = 'https://github.com/Azure/azure-iot-cli-extension/wiki/Tips' - supplemental_info = 'Review inline JSON examples here --> {}'.format(wiki_link) - - raise TypeError('The property "{}" must be of {} but is {}. Input: {}. {}'.format( - k, str(mapping[k]), str(type(result)), result, supplemental_info)) + wiki_link = "https://github.com/Azure/azure-iot-cli-extension/wiki/Tips" + supplemental_info = "Review inline JSON examples here --> {}".format( + wiki_link + ) + + raise TypeError( + 'The property "{}" must be of {} but is {}. Input: {}. {}'.format( + k, str(mapping[k]), str(type(result)), result, supplemental_info + ) + ) def validate_key_value_pairs(string): @@ -99,8 +106,8 @@ def validate_key_value_pairs(string): """ result = None if string: - kv_list = [x for x in string.split(';') if '=' in x] # key-value pairs - result = dict(x.split('=', 1) for x in kv_list) + kv_list = [x for x in string.split(";") if "=" in x] # key-value pairs + result = dict(x.split("=", 1) for x in kv_list) return result @@ -139,6 +146,7 @@ def shell_safe_json_parse(json_or_dict_string, preserve_order=False): if not preserve_order: return json.loads(json_or_dict_string) from collections import OrderedDict + return json.loads(json_or_dict_string, object_pairs_hook=OrderedDict) except ValueError as json_ex: try: @@ -146,14 +154,19 @@ def shell_safe_json_parse(json_or_dict_string, preserve_order=False): except SyntaxError: raise CLIError(json_ex) except ValueError as ex: - logger.debug(ex) # log the exception which could be a python dict parsing error. - raise CLIError(json_ex) # raise json_ex error which is more readable and likely. + logger.debug( + ex + ) # log the exception which could be a python dict parsing error. + raise CLIError( + json_ex + ) # raise json_ex error which is more readable and likely. def read_file_content(file_path, allow_binary=False): from codecs import open as codecs_open + # Note, always put 'utf-8-sig' first, so that BOM in WinOS won't cause trouble. - for encoding in ['utf-8-sig', 'utf-8', 'utf-16', 'utf-16le', 'utf-16be']: + for encoding in ["utf-8-sig", "utf-8", "utf-16", "utf-16le", "utf-16be"]: try: with codecs_open(file_path, encoding=encoding) as f: logger.debug("attempting to read file %s as %s", file_path, encoding) @@ -163,19 +176,19 @@ def read_file_content(file_path, allow_binary=False): if allow_binary: try: - with open(file_path, 'rb') as input_file: + with open(file_path, "rb") as input_file: logger.debug("attempting to read file %s as binary", file_path) return base64.b64encode(input_file.read()).decode("utf-8") except Exception: # pylint: disable=broad-except pass - raise CLIError('Failed to decode file {} - unknown decoding'.format(file_path)) + raise CLIError("Failed to decode file {} - unknown decoding".format(file_path)) def trim_from_start(s, substring): """ Trims a substring from the target string (if it exists) returning the trimmed string. Otherwise returns original target string. """ if s.startswith(substring): - s = s[len(substring):] + s = s[len(substring) :] return s @@ -186,12 +199,17 @@ def validate_min_python_version(major, minor, error_msg=None, exit_on_fail=True) if version.major > major: return True if major == version.major: - result = (version.minor >= minor) + result = version.minor >= minor if not result: if exit_on_fail: - msg = error_msg if error_msg else 'Python version {}.{} or higher required for this functionality.'.format( - major, minor) + msg = ( + error_msg + if error_msg + else "Python version {}.{} or higher required for this functionality.".format( + major, minor + ) + ) sys.exit(msg) return result @@ -205,10 +223,10 @@ def unicode_binary_map(target): for k in target: key = k if isinstance(k, bytes): - key = str(k, 'utf8') + key = str(k, "utf8") if isinstance(target[k], bytes): - result[key] = str(target[k], 'utf8') + result[key] = str(target[k], "utf8") else: result[key] = target[k] @@ -232,11 +250,11 @@ def execute_onthread(**kwargs): Event(), Thread(): Event object to set the cancellation flag, Executing Thread object """ - interval = kwargs.get('interval') - method = kwargs.get('method') - method_args = kwargs.get('args') - max_runs = kwargs.get('max_runs') - handle = kwargs.get('return_handle') + interval = kwargs.get("interval") + method = kwargs.get("method") + method_args = kwargs.get("args") + max_runs = kwargs.get("max_runs") + handle = kwargs.get("return_handle") if not interval: interval = 2 @@ -292,6 +310,7 @@ def url_encode_str(s, plus=False): def test_import(package): """ Used to determine if a dependency is loading correctly """ import importlib + try: importlib.import_module(package) except ImportError: @@ -302,10 +321,10 @@ def test_import(package): def unpack_pnp_http_error(e): error = unpack_msrest_error(e) if isinstance(error, dict): - if error.get('error'): - error = error['error'] - if error.get('stackTrace'): - error.pop('stackTrace') + if error.get("error"): + error = error["error"] + if error.get("stackTrace"): + error.pop("stackTrace") return error @@ -340,13 +359,13 @@ def init_monitoring(cmd, timeout, properties, enqueued_time, repair, yes): validate_min_python_version(3, 5) if timeout < 0: - raise CLIError('Monitoring timeout must be 0 (inf) or greater.') - timeout = (timeout * 1000) + raise CLIError("Monitoring timeout must be 0 (inf) or greater.") + timeout = timeout * 1000 config = cmd.cli_ctx.config output = cmd.cli_ctx.invocation.data.get("output", None) if not output: - output = 'json' + output = "json" ensure_uamqp(config, yes, repair) if not properties: @@ -359,14 +378,19 @@ def init_monitoring(cmd, timeout, properties, enqueued_time, repair, yes): def get_sas_token(target): - from azext_iot.common.digitaltwin_sas_token_auth import DigitalTwinSasTokenAuthentication - token = '' - if target.get('repository_id'): - token = DigitalTwinSasTokenAuthentication(target["repository_id"], - target["entity"], - target["policy"], - target["primarykey"]).generate_sas_token() - return {'Authorization': '{}'.format(token)} + from azext_iot.common.digitaltwin_sas_token_auth import ( + DigitalTwinSasTokenAuthentication, + ) + + token = "" + if target.get("repository_id"): + token = DigitalTwinSasTokenAuthentication( + target["repository_id"], + target["entity"], + target["policy"], + target["primarykey"], + ).generate_sas_token() + return {"Authorization": "{}".format(token)} def dict_clean(d): @@ -396,7 +420,7 @@ def looks_like_file(element): ".java", ".ts", ".js", - ".cs" + ".cs", ) ) @@ -412,3 +436,29 @@ def ensure_pkg_resources_entries(): pkg_resources.working_set.add_entry(extension_path) return + + +class ISO8601Validator: + def is_iso8601_date(self, to_validate) -> bool: + try: + return bool(isodate.parse_date(to_validate)) + except Exception: + return False + + def is_iso8601_datetime(self, to_validate: str) -> bool: + try: + return bool(isodate.parse_datetime(to_validate)) + except Exception: + return False + + def is_iso8601_duration(self, to_validate: str) -> bool: + try: + return bool(isodate.parse_duration(to_validate)) + except Exception: + return False + + def is_iso8601_time(self, to_validate: str) -> bool: + try: + return bool(isodate.parse_time(to_validate)) + except Exception: + return False diff --git a/azext_iot/operations/events3/_parser.py b/azext_iot/operations/events3/_parser.py index 22d5a12e5..c01d845b0 100644 --- a/azext_iot/operations/events3/_parser.py +++ b/azext_iot/operations/events3/_parser.py @@ -4,13 +4,14 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +import json import random import re -import json from knack.log import get_logger from uamqp.message import Message -from azext_iot.common.utility import parse_entity, unicode_binary_map + +from azext_iot.common.utility import parse_entity, unicode_binary_map, ISO8601Validator from azext_iot.central.providers import CentralDeviceProvider SUPPORTED_ENCODINGS = ["utf-8"] @@ -18,6 +19,8 @@ INTERFACE_NAME_IDENTIFIER = b"iothub-interface-name" random.seed(0) +ios_validator = ISO8601Validator() + class Event3Parser(object): _logger = get_logger(__name__) @@ -97,14 +100,16 @@ def parse_message( message, origin_device_id, content_type, create_payload_error ) - self._validate_payload_against_dcm( - origin_device_id, - payload, - central_device_provider, - create_payload_name_error, + self._perform_static_validations( + origin_device_id=origin_device_id, payload=payload ) - self._validate_field_names(origin_device_id, payload) + self._perform_dynamic_validations( + origin_device_id=origin_device_id, + payload=payload, + central_device_provider=central_device_provider, + create_payload_name_error=create_payload_name_error, + ) event["payload"] = payload @@ -163,9 +168,7 @@ def _parse_system_properties(self, message: Message): return unicode_binary_map(parse_entity(message.properties, True)) except Exception: self._errors.append( - "Failed to parse system_properties for message {message}.".format( - message - ) + "Failed to parse system_properties for message {}.".format(message) ) return {} @@ -270,83 +273,165 @@ def _parse_payload( return payload - def _validate_payload_against_dcm( + # Static validations should only need information present in the payload + # i.e. there should be no need for network calls + def _perform_static_validations(self, origin_device_id: str, payload: dict): + # if its not a dictionary, something else went wrong with parsing + if not isinstance(payload, dict): + return + + self._validate_field_names(origin_device_id=origin_device_id, payload=payload) + + def _validate_field_names(self, origin_device_id: str, payload: dict): + # source: + # https://github.com/Azure/IoTPlugandPlay/tree/master/DTDL + regex = "^[a-zA-Z_][a-zA-Z0-9_]*$" + + # if a field name does not match the above regex, it is an invalid field name + invalid_field_names = [ + field_name + for field_name in payload.keys() + if not re.search(regex, field_name) + ] + if invalid_field_names: + self._errors.append( + "The following field names are not allowed: '{}'. " + "Payload: '{}'. " + "Message origin: '{}'.".format( + invalid_field_names, payload, origin_device_id + ) + ) + + # Dynamic validations should need data external to the payload + # e.g. device template + def _perform_dynamic_validations( self, origin_device_id: str, - payload: str, + payload: dict, central_device_provider: CentralDeviceProvider, create_payload_name_error=False, ): - if not central_device_provider: + # if the payload is not a dictionary some other parsing error occurred + if not isinstance(payload, dict): + return + + # device provider was not passed in, no way to get the device template + if not isinstance(central_device_provider, CentralDeviceProvider): + return + + template = self._get_device_template( + origin_device_id=origin_device_id, + central_device_provider=central_device_provider, + ) + + # _get_device_template should log error if there was an issue + if not template: return - if not hasattr(payload, "keys"): - # some error happend while parsing - # should be captured by _parse_payload method above + template_schemas = self._extract_template_schemas_from_template( + origin_device_id=origin_device_id, template=template + ) + + # _extract_template_schemas_from_template should log error if there was an issue + if not isinstance(template_schemas, dict): return + self._validate_payload_against_schema( + origin_device_id=origin_device_id, + payload=payload, + template_schemas=template_schemas, + ) + + def _get_device_template( + self, origin_device_id: str, central_device_provider: CentralDeviceProvider, + ): try: - template = central_device_provider.get_device_template_by_device_id( + return central_device_provider.get_device_template_by_device_id( origin_device_id ) except Exception as e: self._errors.append( - "Unable to get DCM for device: {}." + "Unable to retrieve template for device: {}." "Inner exception: {}".format(origin_device_id, e) ) - return + def _extract_template_schemas_from_template( + self, origin_device_id: str, template: dict + ): try: - all_schema = self._extract_schema_from_template(template) - all_names = [schema["name"] for schema in all_schema] + schemas = [] + dcm = template["capabilityModel"] + implements = dcm["implements"] + for implementation in implements: + contents = implementation["schema"]["contents"] + schemas.extend(contents) + return {schema["name"]: schema for schema in schemas} except Exception: self._errors.append( "Unable to extract device schema for device: {}." "Template: {}".format(origin_device_id, template) ) - return - for telemetry_name in payload.keys(): - if create_payload_name_error or telemetry_name not in all_names: + # currently validates: + # 1) primitive types match (e.g. boolean is indeed bool etc) + # 2) names match (i.e. Humidity vs humidity etc) + def _validate_payload_against_schema( + self, origin_device_id: str, payload: dict, template_schemas: dict, + ): + template_schema_names = template_schemas.keys() + for name, value in payload.items(): + schema = template_schemas.get(name) + if not schema: self._errors.append( - "Telemetry item '{}' is not present in DCM. " + "Telemetry item '{}' is not present in capability model. " "Device ID: {}. " "List of allowed telemetry values for this type of device: {}. " "NOTE: telemetry names are CASE-SENSITIVE".format( - telemetry_name, origin_device_id, all_names + name, origin_device_id, template_schema_names ) ) - def _extract_schema_from_template(self, template): - all_schema = [] - dcm = template["capabilityModel"] - implements = dcm["implements"] - for implementation in implements: - contents = implementation["schema"]["contents"] - all_schema.extend(contents) - - return all_schema - - def _validate_field_names(self, origin_device_id: str, payload: dict): - # if its not a dictionary, something else went wrong with parsing - if not isinstance(payload, dict): - return - - # source: - # https://github.com/Azure/IoTPlugandPlay/tree/master/DTDL - regex = "^[a-zA-Z_][a-zA-Z0-9_]*$" - - # if a field name does not match the above regex, it is an invalid field name - invalid_field_names = [ - field_name - for field_name in payload.keys() - if not re.search(regex, field_name) - ] - if invalid_field_names: - self._errors.append( - "The following field names are not allowed: '{}'. " - "Payload: '{}'. " - "Message origin: '{}'.".format( - invalid_field_names, payload, origin_device_id + is_dict = isinstance(schema, dict) + if is_dict and not self._validate_types_match(value, schema): + expected_type = str(schema.get("schema")) + self._errors.append( + "Type mismatch. " + "Expected type: '{}'. " + "Value received: '{}'. " + "Telemetry identifier: {}. " + "Device ID: {}. " + "All dates/times/datetimes/durations must be ISO 8601 compliant.".format( + expected_type, value, name, origin_device_id + ) ) - ) + + def _validate_types_match(self, value, schema: dict) -> bool: + # suppress error if there is no "schema" in schema + # means something else went wrong + schema_type = schema.get("schema") + if not schema_type: + return True + + if schema_type == "boolean": + return isinstance(value, bool) + elif schema_type == "double": + return isinstance(value, (float, int)) + elif schema_type == "float": + return isinstance(value, (float, int)) + elif schema_type == "integer": + return isinstance(value, int) + elif schema_type == "long": + return isinstance(value, (float, int)) + elif schema_type == "string": + return isinstance(value, str) + elif schema_type == "time": + return ios_validator.is_iso8601_time(value) + elif schema_type == "date": + return ios_validator.is_iso8601_date(value) + elif schema_type == "dateTime": + return ios_validator.is_iso8601_datetime(value) + elif schema_type == "duration": + return ios_validator.is_iso8601_duration(value) + + # if the schema_type is not found above, suppress error + return True diff --git a/azext_iot/tests/test_iot_utility_unit.py b/azext_iot/tests/test_iot_utility_unit.py index 5333b6447..46f2d4f4c 100644 --- a/azext_iot/tests/test_iot_utility_unit.py +++ b/azext_iot/tests/test_iot_utility_unit.py @@ -17,6 +17,7 @@ process_json_arg, read_file_content, logger, + ISO8601Validator, ) from azext_iot.common.deps import ensure_uamqp from azext_iot.constants import EVENT_LIB, EXTENSION_NAME @@ -284,6 +285,73 @@ def test_file_json_fail_invalidcontent(self, content, argname, set_cwd, mocker): assert mocked_util_logger.call_count == 0 +# none of these are valid anything in ISO 8601 +BAD_ARRAY = ["asd", "", 123.4, 123, True, False] + + +class TestISO8601Validator: + validator = ISO8601Validator() + + # Success suite + @pytest.mark.parametrize( + "to_validate", ["20200101", "20200101Z", "2020-01-01", "2020-01-01Z"] + ) + def test_is_iso8601_date_pass(self, to_validate): + result = self.validator.is_iso8601_date(to_validate) + assert result + + @pytest.mark.parametrize( + "to_validate", + [ + "20200101T00:00:00", + "20200101T000000", + "2020-01-01T00:00:00", + "2020-01-01T00:00:00Z", + "2020-01-01T00:00:00.00", + "2020-01-01T00:00:00.00Z", + "2020-01-01T00:00:00.00+08:30", + ], + ) + def test_is_iso8601_datetime_pass(self, to_validate): + result = self.validator.is_iso8601_datetime(to_validate) + assert result + + @pytest.mark.parametrize("to_validate", ["P32DT7.592380349524318S", "P32DT7S"]) + def test_is_iso8601_duration_pass(self, to_validate): + result = self.validator.is_iso8601_duration(to_validate) + assert result + + @pytest.mark.parametrize( + "to_validate", ["00:00:00+08:30", "00:00:00Z", "00:00:00.123Z"] + ) + def test_is_iso8601_time_pass(self, to_validate): + result = self.validator.is_iso8601_time(to_validate) + assert result + + # Failure suite + @pytest.mark.parametrize( + "to_validate", ["2020-13-35", *BAD_ARRAY], + ) + def test_is_iso8601_date_fail(self, to_validate): + result = self.validator.is_iso8601_date(to_validate) + assert not result + + @pytest.mark.parametrize("to_validate", ["2020-13-35", "2020-00-00T", *BAD_ARRAY]) + def test_is_iso8601_datetime_fail(self, to_validate): + result = self.validator.is_iso8601_datetime(to_validate) + assert not result + + @pytest.mark.parametrize("to_validate", ["2020-01", *BAD_ARRAY]) + def test_is_iso8601_duration_fail(self, to_validate): + result = self.validator.is_iso8601_duration(to_validate) + assert not result + + @pytest.mark.parametrize("to_validate", [*BAD_ARRAY]) + def test_is_iso8601_time_fail(self, to_validate): + result = self.validator.is_iso8601_time(to_validate) + assert not result + + class TestEvents3Parser: device_id = "some-device-id" payload = {"String": "someValue"} @@ -296,6 +364,7 @@ class TestEvents3Parser: bad_content_type = "bad-content-type" bad_dcm_payload = {"temperature": "someValue"} + type_mismatch_payload = {"Bool": "someValue"} def test_parse_message_should_succeed(self): # setup @@ -614,7 +683,7 @@ def test_validate_against_template_should_fail(self): assert len(parser._info) == 0 actual_error = parser._errors[0] - expected_error = "Telemetry item '{}' is not present in DCM.".format( + expected_error = "Telemetry item '{}' is not present in capability model.".format( list(self.bad_dcm_payload)[0] ) assert expected_error in actual_error @@ -660,3 +729,55 @@ def test_validate_against_bad_template_should_not_throw(self): actual_error = parser._errors[0] assert "Unable to extract device schema for device" in actual_error + + def test_type_mismatch_should_error(self): + # setup + app_prop_type = "some_property" + app_prop_value = "some_value" + properties = MessageProperties( + content_encoding=self.encoding, content_type=self.content_type + ) + message = Message( + body=json.dumps(self.type_mismatch_payload).encode(), + properties=properties, + annotations={_parser.DEVICE_ID_IDENTIFIER: self.device_id.encode()}, + application_properties={app_prop_type.encode(): app_prop_value.encode()}, + ) + parser = _parser.Event3Parser() + + provider = CentralDeviceProvider(cmd=None, app_id=None) + device_template = load_json(FileNames.central_device_template_file) + provider.get_device_template_by_device_id = mock.MagicMock( + return_value=device_template + ) + + # act + parsed_msg = parser.parse_message( + message=message, + pnp_context=False, + interface_name=None, + properties={"all"}, + content_type_hint=None, + simulate_errors=False, + central_device_provider=provider, + ) + + # verify + assert parsed_msg["event"]["payload"] == self.type_mismatch_payload + assert parsed_msg["event"]["origin"] == self.device_id + + assert len(parser._errors) == 1 + assert len(parser._warnings) == 0 + assert len(parser._info) == 0 + + actual_error = parser._errors[0] + assert "Type mismatch" in actual_error + assert "Type mismatch" in actual_error + assert "Value received" in actual_error + assert "Device ID" in actual_error + assert ( + "All dates/times/datetimes/durations must be ISO 8601 compliant." + in actual_error + ) + assert list(self.type_mismatch_payload.values())[0] in actual_error + assert list(self.type_mismatch_payload.keys())[0] in actual_error