Skip to content

Commit

Permalink
areusch review: As per the API, options shouldn't always be required
Browse files Browse the repository at this point in the history
As an API-level concept, it shouldn't always be required to specify an
option. This commit changes TVMC behavior to only make options mandatory
to be specified when one or more options returned by the API are
required.

This also fixes some nits in some docstrings.
  • Loading branch information
gromero committed Nov 10, 2021
1 parent 6e73746 commit 93bb565
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 24 deletions.
38 changes: 36 additions & 2 deletions python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def get_options(options):
Returns
-------
opts: dict
dict indexed by option name of option name and associated value.
dict indexed by option names and associated values.
"""

opts = {}
Expand Down Expand Up @@ -714,6 +714,40 @@ def check_options_choices(options, valid_options):
if option in valid_options_choices:
if options[option] not in valid_options_choices[option]:
raise TVMCException(
f"Choice '{options[option]} for option '{option}' is invalid. "
f"Choice '{options[option]}' for option '{option}' is invalid. "
"Use --list-options to see all available choices for that option."
)

def get_and_check_options(passed_options, valid_options):
"""Get options and check if they are valid. If choices exist for them, check values against it.
Parameters
----------
passed_options: list of str
list of strings in the "key=value" form as captured by argparse.
valid_option: list
list with all options available for a given API method / project as returned by
get_project_options().
Returns
-------
opts: dict
dict indexed by option names and associated values.
Or None if passed_options is None.
"""

if passed_options is None:
# No options to check
return None

# From a list of k=v strings, make a dict options[k]=v
opts = get_options(passed_options)
# Check if passed options are valid
check_options(opts, valid_options)
# Check (when a list of choices exists) if the passed values are valid
check_options_choices(opts, valid_options)

return opts
19 changes: 5 additions & 14 deletions python/tvm/driver/tvmc/micro.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
from .common import (
TVMCException,
get_project_options,
get_options,
check_options,
check_options_choices,
get_and_check_options,
)


Expand Down Expand Up @@ -204,9 +202,8 @@ def _add_parser(parser, platform):
help_text_by_option = [opt["help_text"] for opt in options_by_method[method]]
help_text = "\n\n".join(help_text_by_option) + "\n\n"

# TODO(gromero): Experiment with required=required below
parser_by_subcmd_n_platform.add_argument(
"--project-option", required=True, metavar="OPTION=VALUE", nargs=nargs, help=help_text
"--project-option", required=required, metavar="OPTION=VALUE", nargs=nargs, help=help_text
)

parser_by_subcmd_n_platform.add_argument(
Expand Down Expand Up @@ -239,9 +236,7 @@ def create_project_handler(args):
if not os.path.exists(mlf_path):
raise TVMCException(f"MLF file {mlf_path} does not exist!")

options = get_options(args.project_option)
check_options(options, args.valid_options)
check_options_choices(options, args.valid_options)
options = get_and_check_options(args.project_option, args.valid_options)

try:
project.generate_project_from_mlf(template_dir, project_dir, mlf_path, options)
Expand All @@ -267,9 +262,7 @@ def build_handler(args):

project_dir = args.project_dir

options = get_options(args.project_option)
check_options(options, args.valid_options)
check_options_choices(options, args.valid_options)
options = get_and_check_options(args.project_option, args.valid_options)

try:
prj = project.GeneratedProject.from_directory(project_dir, options=options)
Expand All @@ -286,9 +279,7 @@ def flash_handler(args):

project_dir = args.project_dir

options = get_options(args.project_option)
check_options(options, args.valid_options)
check_options_choices(options, args.valid_options)
options = get_and_check_options(args.project_option, args.valid_options)

try:
prj = project.GeneratedProject.from_directory(project_dir, options=options)
Expand Down
11 changes: 3 additions & 8 deletions python/tvm/driver/tvmc/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@
TVMCException,
TVMCSilentArgumentParser,
get_project_options,
get_options,
check_options,
check_options_choices,
get_and_check_options,
)
from .main import register_parser
from .model import TVMCPackage, TVMCResult
Expand Down Expand Up @@ -153,7 +151,7 @@ def add_run_parser(subparsers, main_parser):
help_text_by_option = [opt["help_text"] for opt in options_by_method["open_transport"]]
help_text = "\n\n".join(help_text_by_option) + "\n\n"
parser.add_argument(
"--project-option", required=True, metavar="OPTION=VALUE", nargs=nargs, help=help_text
"--project-option", required=required, metavar="OPTION=VALUE", nargs=nargs, help=help_text
)


Expand Down Expand Up @@ -192,10 +190,7 @@ def drive_run(args):
raise TVMCException("--profile is not supported for micro targets.")

# Get and check options for micro targets.

options = get_options(args.project_option)
check_options(options, args.valid_options)
check_options_choices(options, args.valid_options)
options = get_and_check_options(args.project_option, args.valid_options)

try:
tvmc_package = TVMCPackage(package_path=path)
Expand Down

0 comments on commit 93bb565

Please sign in to comment.