diff --git a/changelogs/fragments/45-api-split.yml b/changelogs/fragments/45-api-split.yml new file mode 100644 index 00000000..2d525cb4 --- /dev/null +++ b/changelogs/fragments/45-api-split.yml @@ -0,0 +1,2 @@ +breaking_changes: + - "api - splitting commands no longer uses a naive split by whitespace, but a more RouterOS CLI compatible splitting algorithm (https://github.com/ansible-collections/community.routeros/pull/45)." diff --git a/plugins/modules/api.py b/plugins/modules/api.py index 6288834d..93ea2500 100644 --- a/plugins/modules/api.py +++ b/plugins/modules/api.py @@ -258,7 +258,7 @@ from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import missing_required_lib -from ansible.module_utils.common.text.converters import to_native +from ansible.module_utils.common.text.converters import to_native, to_bytes import ssl import traceback @@ -274,6 +274,95 @@ LIB_IMP_ERR = traceback.format_exc() +class ParseError(Exception): + pass + + +ESCAPE_SEQUENCES = { + b'"': b'"', + b'\\': b'\\', + b'?': b'?', + b'$': b'$', + b'_': b'_', + b'a': b'\a', + b'b': b'\b', + b'f': b'\xFF', + b'n': b'\n', + b'r': b'\r', + b't': b'\t', + b'v': b'\v', +} + +ESCAPE_DIGITS = b'0123456789ABCDEF' + + +def split_routeros(line): + line = to_bytes(line) + result = [] + current = [] + index = 0 + length = len(line) + # States: + # 0 = outside param + # 1 = param before '=' + # 2 = param after '=' without quote + # 3 = param after '=' with quote + state = 0 + while index < length: + ch = line[index:index + 1] + index += 1 + if state == 0 and ch == b' ': + pass + elif state in (1, 2) and ch == b' ': + state = 0 + result.append(b''.join(current)) + current = [] + elif ch == b'=' and state == 1: + state = 2 + current.append(ch) + if index + 1 < length and line[index:index + 1] == b'"': + state = 3 + index += 1 + elif ch == b'"': + if state == 3: + state = 0 + result.append(b''.join(current)) + current = [] + if index + 1 < length and line[index:index + 1] != b' ': + raise ParseError('Ending \'"\' must be followed by space or end of string') + else: + raise ParseError('\'"\' must follow \'=\'') + elif ch == b'\\': + if index + 1 == length: + raise ParseError('\'\\\' must not be at the end of the line') + ch = line[index:index + 1] + index += 1 + if ch in ESCAPE_SEQUENCES: + current.append(ch) + else: + d1 = ESCAPE_DIGITS.find(ch) + if d1 < 0: + raise ParseError('Invalid escape sequence \'\\{0}\''.format(ch)) + if index + 1 == length: + raise ParseError('Hex escape sequence cut off at end of line') + ch2 = line[index:index + 1] + d2 = ESCAPE_DIGITS.find(ch2) + index += 1 + if d2 < 0: + raise ParseError('Invalid hex escape sequence \'\\{0}{1}\''.format(ch, ch2)) + result.append(chr(d1 * 16 + d2)) + else: + current.append(ch) + if state == 0: + state = 1 + if state in (1, 2): + if current: + result.append(b''.join(current)) + elif state == 3: + raise ParseError('Unexpected end of string during escaped parameter') + return [to_native(part) for part in result] + + class ROS_api_module: def __init__(self): module_args = dict( @@ -312,7 +401,7 @@ def __init__(self): self.module.params['ca_path'], ) - self.path = self.list_remove_empty(self.module.params['path'].split(' ')) + self.path = self.module.params['path'].split() self.add = self.module.params['add'] self.remove = self.module.params['remove'] self.update = self.module.params['update'] @@ -321,7 +410,7 @@ def __init__(self): self.where = None self.query = self.module.params['query'] if self.query: - self.query = self.list_remove_empty(self.query.split(' ')) + self.query = self.list_remove_empty(self.split_params(self.query)) try: idx = self.query.index('WHERE') self.where = self.query[idx + 1:] @@ -365,6 +454,14 @@ def list_to_dic(self, ldict): dict[p[0]] = p[1] return dict + def split_params(self, params): + if not isinstance(params, str): + raise AssertionError('Parameters can only be a string, received %s' % type(params)) + try: + return split_routeros(params) + except ParseError as e: + self.module.fail_json(msg=to_native(e)) + def api_add_path(self, api, path): api_path = api.path() for p in path: @@ -380,7 +477,7 @@ def api_get_all(self): self.errors(e) def api_add(self): - param = self.list_to_dic(self.add.split(' ')) + param = self.list_to_dic(self.split_params(self.add)) try: self.result['message'].append("added: .id= %s" % self.api_path.add(**param)) @@ -397,7 +494,7 @@ def api_remove(self): self.errors(e) def api_update(self): - param = self.list_to_dic(self.update.split(' ')) + param = self.list_to_dic(self.split_params(self.update)) if '.id' not in param.keys(): self.errors("missing '.id' for %s" % param) try: @@ -448,7 +545,7 @@ def api_query(self): def api_arbitrary(self): param = {} - self.arbitrary = self.arbitrary.split(' ') + self.arbitrary = self.split_params(self.arbitrary) arb_cmd = self.arbitrary[0] if len(self.arbitrary) > 1: param = self.list_to_dic(self.arbitrary[1:]) diff --git a/tests/unit/plugins/modules/test_api.py b/tests/unit/plugins/modules/test_api.py index cabb9183..0eabc049 100644 --- a/tests/unit/plugins/modules/test_api.py +++ b/tests/unit/plugins/modules/test_api.py @@ -288,3 +288,33 @@ def test_api_query_and_WHERE_no_cond(self): result = exc.exception.args[0] self.assertEqual(result['changed'], False) + + +TEST_SPLIT_ROUTEROS = [ + ('', []), + (' ', []), + (r'a b c', ['a', 'b', 'c']), + (r'a=b c d=e', ['a=b', 'c', 'd=e']), + (r'a="b f" c d=e', ['a=b f', 'c', 'd=e']), + (r'a="b\"f" c\FF d=\"e', ['a=b"f', '\xff', 'c', 'd="e']), +] + + +@pytest.mark.parametrize("command, result", TEST_SPLIT_ROUTEROS) +def test_split_routeros(command, result): + result_ = api.split_routeros(command) + print(result_, result) + assert result_ == result + + +TEST_SPLIT_ROUTEROS_ERRORS = [ + (r'a="b\"f" c\FF d="e', 'Unexpected end of string during escaped parameter'), +] + + +@pytest.mark.parametrize("command, message", TEST_SPLIT_ROUTEROS_ERRORS) +def test_split_routeros_errors(command, message): + with pytest.raises(api.ParseError) as exc: + api.split_routeros(command) + print(exc.value.args[0], message) + assert exc.value.args[0] == message