From 4b9864fbadc61425f83d2be6520cb5b84f49428c Mon Sep 17 00:00:00 2001 From: Brendan Burns Date: Wed, 23 Nov 2016 10:50:00 -0800 Subject: [PATCH] interpret the '@' syntax if something higher hasn't already done that. (#1423) --- .../azure/cli/core/application.py | 17 +++++++++- .../azure/cli/core/tests/test_application.py | 31 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/azure-cli-core/azure/cli/core/application.py b/src/azure-cli-core/azure/cli/core/application.py index 0bc115f746c..caab463a1c2 100644 --- a/src/azure-cli-core/azure/cli/core/application.py +++ b/src/azure-cli-core/azure/cli/core/application.py @@ -6,6 +6,7 @@ from collections import defaultdict import sys import os +import re import uuid import argparse from azure.cli.core.parser import AzCliCommandParser, enable_autocomplete @@ -196,10 +197,24 @@ def _register_builtin_arguments(**kwargs): global_group.add_argument('--debug', dest='_log_verbosity_debug', action='store_true', help='Increase logging verbosity to show all debug logs.') + @staticmethod + def _maybe_load_file(arg): + ix = arg.find('@') + if ix == -1: + return arg + + if ix == 0: + return Application._load_file(arg[1:]) + + res = re.match('(\\-\\-?[a-zA-Z0-9]+[\\-a-zA-Z0-9]*\\=)\\"?@([^\\"]*)\\"?', arg) + if not res: + return arg + return res.group(1) + Application._load_file(res.group(2)) + @staticmethod def _expand_file_prefixed_files(argv): return list( - [Application._load_file(arg[1:]) if arg.startswith('@') else arg for arg in argv] + [Application._maybe_load_file(arg) for arg in argv] ) @staticmethod diff --git a/src/azure-cli-core/azure/cli/core/tests/test_application.py b/src/azure-cli-core/azure/cli/core/tests/test_application.py index aef4796681e..252fd558afa 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_application.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_application.py @@ -5,10 +5,14 @@ import unittest +import os +import tempfile + from six import StringIO from azure.cli.core.application import Application, Configuration, IterateAction from azure.cli.core.commands import CliCommand +from azure.cli.core._util import CLIError class TestApplication(unittest.TestCase): @@ -80,5 +84,32 @@ def handler(args): self.assertEqual(hellos[1]['hello'], 'sir') self.assertEqual(hellos[1]['something'], 'else') + def test_expand_file_prefixed_files(self): + f = tempfile.NamedTemporaryFile(delete=False) + f.close() + + with open(f.name, 'w+') as stream: + stream.write('foo') + + cases = [ + [['--bar=baz'], ['--bar=baz']], + [['--bar', 'baz'], ['--bar', 'baz']], + [['--bar=@{}'.format(f.name)], ['--bar=foo']], + [['--bar', '@{}'.format(f.name)], ['--bar', 'foo']], + [['--bar', f.name], ['--bar', f.name]], + [['--bar="@{}"'.format(f.name)], ['--bar=foo']], + [['--bar=name@company.com'], ['--bar=name@company.com']], + [['--bar', 'name@company.com'], ['--bar', 'name@company.com']], + ] + + for test_case in cases: + try: + args = Application._expand_file_prefixed_files(test_case[0]) #pylint: disable=protected-access + self.assertEqual(args, test_case[1], 'Failed for: {}'.format(test_case[0])) + except CLIError as ex: + self.fail('Unexpected error for {} ({}): {}'.format(test_case[0], args, ex)) + + os.remove(f.name) + if __name__ == '__main__': unittest.main()