From 3590cdd3dfeeea0584e0db235f587a5ec6b8929e Mon Sep 17 00:00:00 2001 From: Michael Chisholm Date: Wed, 13 Mar 2024 13:01:12 -0400 Subject: [PATCH 1/5] feat: Add module for statically analyzing python code to find task plugin function signatures --- src/dioptra/restapi/v1/lib/__init__.py | 16 + .../restapi/v1/lib/signature_analysis.py | 736 ++++++++++++++++++ .../restapi/lib/test_signature_analysis.py | 361 +++++++++ 3 files changed, 1113 insertions(+) create mode 100644 src/dioptra/restapi/v1/lib/__init__.py create mode 100644 src/dioptra/restapi/v1/lib/signature_analysis.py create mode 100644 tests/unit/restapi/lib/test_signature_analysis.py diff --git a/src/dioptra/restapi/v1/lib/__init__.py b/src/dioptra/restapi/v1/lib/__init__.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/src/dioptra/restapi/v1/lib/__init__.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/src/dioptra/restapi/v1/lib/signature_analysis.py b/src/dioptra/restapi/v1/lib/signature_analysis.py new file mode 100644 index 000000000..63bb847d3 --- /dev/null +++ b/src/dioptra/restapi/v1/lib/signature_analysis.py @@ -0,0 +1,736 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +""" +Extract task plugin function signature information from Python source code. +""" +import ast as ast_module # how many variables named "ast" might we have... +import itertools +import re +from pathlib import Path +from typing import Any, Container, Iterator, Optional, Union + +from dioptra.task_engine import type_registry + +_PYTHON_TO_DIOPTRA_TYPE_NAME = { + "str": "string", + "int": "integer", + "float": "number", + "bool": "boolean", + "None": "null", +} + + +def _is_constant(ast: ast_module.AST, value: Any) -> bool: + """ + Determine whether the given AST node represents a constant (literal) of + the given value. + + Args: + ast: An AST node + value: A value to compare to + + Returns: + True if the AST node is a constant of the given value; False if not + """ + return isinstance(ast, ast_module.Constant) and ast.value == value + + +def _is_simple_dotted_name(node: ast_module.AST) -> bool: + """ + Determine whether the given AST node represents a simple name or dotted + name, like "foo", "foo.bar", "foo.bar.baz", etc. + + Args: + node: The AST node + + Returns: + True if the node represents a simple dotted name; False if not + """ + return isinstance(node, ast_module.Name) or ( + isinstance(node, ast_module.Attribute) and _is_simple_dotted_name(node.value) + ) + + +def _update_symbols(symbol_tree: dict[str, Any], name: str) -> dict[str, Any]: + """ + Update/modify the given symbol tree such that it includes the given + name. + + The symbol tree is conceptually roughly a symbol hierarchy. This is how + modules and other types of values are naturally arranged in Python. An + import statement (assuming it is correct, and in the absence of any way or + desire to check, we assume they are all correct) reflects this hierarchy, + and the hierarchy may be inferred from it. + + It is implemented as a nested dict of dicts. The dicts map a symbol name + to other dicts which may have other symbol names, which map to other dicts, + etc. One can look up a symbol to get a "value", but we don't actually have + access to any runtime values. A symbol's "value" in this tree will be + whatever dict it maps to (which may be empty). + + Importantly, aliasing present in import statements ("as" clauses) is + reflected in the symbol tree by referring to the same dict in multiple + places. This means the structure is not technically a tree, since nodes + can have in-degree greater than one. But it makes aliasing trivial to + deal with: you can use the "is" operator to check whether two symbols' + "values" are the same. + + Args: + symbol_tree: A symbol tree structure to update + name: The name to update the tree with + + Returns: + The resulting "value" of the symbol after the tree has been updated + """ + names = name.split(".") + + curr_mod = symbol_tree + for symbol_name in names: + curr_mod = curr_mod.setdefault(symbol_name, {}) + + return curr_mod + + +def _look_up_symbol( + symbol_tree: Optional[dict[str, Any]], name: str +) -> Optional[dict[str, Any]]: + """ + Look up a symbol in the given symbol tree and return its "value". The + symbol tree data structure is comprised of nested dicts, so the value + returned (if the symbol is found) is always a dict. + + Args: + symbol_tree: A symbol tree structure + name: The name to look up, as a string. E.g. "foo", "foo.bar", etc. + + Returns: + The value of the given symbol, or None if it was not found in the + symbol tree + """ + if not name: + # Just in case... + raise ValueError("Symbol name must not be null/empty") + + if not symbol_tree: + result = None + else: + dot_idx = name.find(".") + if dot_idx == -1: + result = symbol_tree.get(name) + else: + result = _look_up_symbol( + symbol_tree.get(name[:dot_idx]), name[dot_idx + 1 :] + ) + + return result + + +def _are_aliases(symbol_tree: dict[str, Any], name1: str, name2: str) -> bool: + """ + Determine whether two symbol names refer to the same value. + + Args: + symbol_tree: A symbol tree structure + name1: A symbol name + name2: A symbol name + + Returns: + True if both symbol names are defined and resolve to the same value; + False if not + """ + name1_value = _look_up_symbol(symbol_tree, name1) + name2_value = _look_up_symbol(symbol_tree, name2) + + return ( + name1_value is not None + and name2_value is not None + and name1_value is name2_value + ) + + +def _process_import(stmt: ast_module.AST, symbol_tree: dict[str, Any]) -> None: + """ + Update the given symbol tree according to the given import statement. This + can add new symbols to the tree, or change what existing symbols refer to. + + Args: + stmt: A stmt AST node. Node types other than Import and ImportFrom + are ignored. + symbol_tree: A symbol tree structure to update. + """ + if isinstance(stmt, ast_module.Import): + # For a normal import, update the hierarchy according to the + # imported name. If aliased, also introduce an alias symbol at + # the top level. + for alias in stmt.names: + value = _update_symbols(symbol_tree, alias.name) + + if alias.asname: + symbol_tree[alias.asname] = value + + elif isinstance(stmt, ast_module.ImportFrom): + # for mypy: how can a "from import ...", import from nothing? + # But module is apparently optional... + assert stmt.module + + # Can't hope to interpret relative imports by themselves, because + # we don't know what they're relative to. So just ignore those. + # E.g. "from ...foo import bar". + if stmt.level == 0: + # Update the symbol hierarchy with the module name + # (from "..."). This identifies a module to import from. + mod_value = _update_symbols(symbol_tree, stmt.module) + + # Each imported symbol is introduced at the sub-module level + # (from ... import "..."), since the statement implies that + # symbol exists there. If the symbol is not aliased, it is + # also introduced at the top level. If it is aliased, only the + # alias is introduced at the top level. + for alias in stmt.names: + value = mod_value.setdefault(alias.name, {}) + if alias.asname: + symbol_tree[alias.asname] = value + else: + symbol_tree[alias.name] = value + + +def _is_register_decorator(decorator_symbol: str, symbol_tree: dict[str, Any]) -> bool: + """ + Try to detect a pyplugs registration decorator symbol. In dioptra, the + "register" symbol is defined in the "dioptra.pyplugs" module. So one could + import the dioptra.pyplugs module and just access the "register" symbol + from there, or import the "register" symbol directly. E.g. + + import dioptra.pyplugs + + @dioptra.pyplugs.register + def foo(): + pass + + Or: + + from dioptra import pyplugs + + @pyplugs.register + def foo(): + pass + + Or: + + from dioptra.pyplugs import register + + @register + def foo(): + pass + + In the first two cases, our symbol tree would contain "dioptra.pyplugs" + but not "register" since the latter was never mentioned in an import + statement. In the last case, the whole "dioptra.pyplugs.register" symbol + would be present. We need to handle both cases. This should also be + transparent to aliasing, e.g. + + from dioptra import pyplugs as bar + + @bar.register + def foo(): + pass + + must also work. + + Args: + decorator_symbol: A decorator symbol used on a function, as a string, + e.g. "foo", "foo.bar", etc + symbol_tree: A data structure representing symbol hierarchy inferred + from import statements + + Returns: + True if the decorator symbol represents a task plugin registration + decorator; False if not + """ + + if _are_aliases(symbol_tree, "dioptra.pyplugs.register", decorator_symbol): + result = True + + elif decorator_symbol.endswith(".register"): + deco_prefix = decorator_symbol[:-9] + result = _are_aliases(symbol_tree, "dioptra.pyplugs", deco_prefix) + + else: + result = False + + return result + + +def _is_task_plugin( + func_def: ast_module.FunctionDef, symbol_tree: dict[str, Any] +) -> bool: + """ + Determine whether the given function definition is defining a task plugin. + + Args: + func_def: A function definition AST node + symbol_tree: A data structure representing symbol hierarchy inferred + from import statements + + Returns: + True if the function definition is for a task plugin; False if not + """ + for decorator_expr in func_def.decorator_list: + + # we will only handle simple decorator expressions: simple dotted + # names, optionally with a function call. + if _is_simple_dotted_name(decorator_expr): + decorator_symbol = ast_module.unparse(decorator_expr) + + elif isinstance(decorator_expr, ast_module.Call) and _is_simple_dotted_name( + decorator_expr.func + ): + decorator_symbol = ast_module.unparse(decorator_expr.func) + + else: + decorator_symbol = None + + if decorator_symbol and _is_register_decorator(decorator_symbol, symbol_tree): + result = True + break + else: + result = False + + return result + + +def _find_plugins(ast: ast_module.Module) -> Iterator[ast_module.FunctionDef]: + """ + Find AST nodes corresponding to task plugin functions. + + Args: + ast: An AST node. Plugin functions will only be found inside Module + nodes + + Yields: + AST nodes corresponding to task plugin function definitions + """ + if isinstance(ast, ast_module.Module): + symbol_tree: dict[str, Any] = {} + for stmt in ast.body: + + if isinstance(stmt, (ast_module.Import, ast_module.ImportFrom)): + _process_import(stmt, symbol_tree) + + elif isinstance(stmt, ast_module.FunctionDef) and _is_task_plugin( + stmt, symbol_tree + ): + yield stmt + + +def _derive_type_name_from_annotation(annotation_ast: ast_module.AST) -> Optional[str]: + """ + Try to derive a suitable Dioptra type name from a type annotation AST. + Annotations can be arbitrarily complex and even nonsensical (not all + kind of errors are caught at parse time), so derivation may fail depending + on the AST. + + Args: + annotation_ast: An AST used as an argument or return type annotation + + Returns: + A type name if one could be derived, or None if one could not be + derived from the given annotation + """ + + # "None" isn't a type, but is used to mean the type of None + if _is_constant(annotation_ast, None): + type_name_suggestion = "null" + + # A name, e.g. int + elif isinstance(annotation_ast, ast_module.Name): + type_name_suggestion = annotation_ast.id + + # A string literal, e.g. "foo". Can be used in Python code to defer + # evaluation of an annotation. + elif isinstance(annotation_ast, ast_module.Constant) and isinstance( + annotation_ast.value, str + ): + type_name_suggestion = annotation_ast.value + + # Frequently used annotation expressions, e.g. list[str] is a "Subscript", + # and str | int is a "BinOp". + elif isinstance( + annotation_ast, (ast_module.Subscript, ast_module.BinOp) + ) or _is_simple_dotted_name(annotation_ast): + type_name_suggestion = ast_module.unparse(annotation_ast) + + else: + type_name_suggestion = None + + # normalize the suggestion, if we were able to derive one + if type_name_suggestion: + type_name_suggestion = type_name_suggestion.strip() + type_name_suggestion = type_name_suggestion.lower() + type_name_suggestion = type_name_suggestion.replace(" ", "") + # Replace non-alphanumerics with underscores + type_name_suggestion = re.sub(r"\W+", "_", type_name_suggestion) + # Condense multiple underscores to one + type_name_suggestion = re.sub("_+", "_", type_name_suggestion) + type_name_suggestion = type_name_suggestion.strip("_") + + # Try to map to a Dioptra builtin type name. + type_name_suggestion = _PYTHON_TO_DIOPTRA_TYPE_NAME.get( + type_name_suggestion, type_name_suggestion + ) + + # After all this, if we wound up with an empty string, we failed. + # If the name doesn't begin with a letter (like all good identifiers + # should), we also failed. + if not type_name_suggestion or not type_name_suggestion[0].isalpha(): + type_name_suggestion = None + + return type_name_suggestion + + +def _make_unique_type_name(existing_types: Container[str]) -> str: + """ + Make a unique type name, i.e. one which doesn't exist in existing_types. + One never knows if a user's type annotation actually resulted in a derived + type name which matches our chosen unique name syntax. So it is not + sufficient to maintain a counter elsewhere which is incremented every time + we need a new unique name. That might result in name collisions. So this + is done conservatively (if inefficiently) by concatenating a base name with + an incrementing integer counter starting at 1, until we obtain a name which + has not previously been seen. + + :param existing_types: A container of existing type names + :return: A new type name which is not in the container + """ + counter = 1 + type_name = f"type{counter}" + while type_name in existing_types: + counter += 1 + type_name = f"type{counter}" + + return type_name + + +def _pos_args_defaults( + args: ast_module.arguments, +) -> Iterator[tuple[ast_module.arg, Optional[ast_module.expr]]]: + """ + Generate the positional argument AST nodes paired with their defined + default AST nodes (if any), contained within the given AST arguments value. + This requires a bit of coding since pos args/defaults aren't stored in a + way you can straightforwardly just zip them up. This includes all + positional-only and "regular" (non-keyword-only) arguments, in the order + they appear in the function signature. + + Args: + args: An AST arguments value + + Yields: + positional arg, arg default pairs. If an arg does not have a default + defined in the signature, it is generated as None. + """ + num_pos_args = len(args.posonlyargs) + len(args.args) + idx_first_defaulted_arg = num_pos_args - len(args.defaults) + + for arg_idx, arg in enumerate(itertools.chain(args.posonlyargs, args.args)): + if arg_idx >= idx_first_defaulted_arg: + arg_default = args.defaults[arg_idx - idx_first_defaulted_arg] + else: + arg_default = None + + yield arg, arg_default + + +def _func_args_defaults( + func: ast_module.FunctionDef, +) -> Iterator[tuple[ast_module.arg, Optional[ast_module.expr]]]: + """ + Generate all argument AST nodes paired with their defined default AST nodes + (if any). This includes positional-only and keyword-only arguments, in the + order they appear in the function signature. + + Args: + func: A FunctionDef AST node representing a function definition + + Yields: + arg, arg default pairs. If an arg does not have a default defined in + the signature, it is generated as None. + """ + yield from _pos_args_defaults(func.args) + yield from zip(func.args.kwonlyargs, func.args.kw_defaults) + + +def _func_args(func: ast_module.FunctionDef) -> Iterator[ast_module.arg]: + """ + Generate all argument AST nodes. This does not include any of their + defaults. They are generated in the order they appear in the function + signature. + + Args: + func: A FunctionDef AST node representing a function definition + + Returns: + An iterator which produces all function argument AST nodes + """ + # Must use same iteration order as _func_args_defaults()! + return itertools.chain(func.args.posonlyargs, func.args.args, func.args.kwonlyargs) + + +def _get_function_signature_via_derivation( + func: ast_module.FunctionDef, +) -> dict[str, Any]: + """ + Create a dict structure which reflects the signature of the given function, + including where possible, argument and return type names suitable for use + with the Dioptra type system. This function tries to derive type names + from argument/return type annotations. This derivation may or may not + produce a suitable type name. Where it is unable to derive a type name, + None is used in the data structure. The end result is a structure which + accounts for all arguments and the return type, although some type names + may be None. + + Args: + func: A FunctionDef AST node representing a function definition + + Returns: + A function signature data structure as a dict + """ + inputs = [] + outputs = [] + suggested_types = [] + used_type_names = set() + + for arg, arg_default in _func_args_defaults(func): + if arg.annotation: + type_name_suggestion = _derive_type_name_from_annotation(arg.annotation) + else: + type_name_suggestion = None + + inputs.append( + { + "name": arg.arg, + "type": type_name_suggestion, # might be None + "required": arg_default is None, + } + ) + + # Add suggestions for non-Dioptra-builtin types only, which we have not + # already created a suggestion for + if ( + type_name_suggestion + and type_name_suggestion not in type_registry.BUILTIN_TYPES + and type_name_suggestion not in used_type_names + ): + # For mypy: we would not have a type name suggestion here if we did + # not have an annotation. + assert arg.annotation + suggested_types.append( + { + "suggestion": type_name_suggestion, + "type_annotation": ast_module.unparse(arg.annotation), + } + ) + + used_type_names.add(type_name_suggestion) + + # Also address any return annotation other than None. If it is None, + # skip the output. None means the plugin produces no output. + if func.returns and not _is_constant(func.returns, None): + type_name_suggestion = _derive_type_name_from_annotation(func.returns) + + outputs.append( + {"name": "output", "type": type_name_suggestion} # might be None + ) + + if ( + type_name_suggestion + and type_name_suggestion not in type_registry.BUILTIN_TYPES + and type_name_suggestion not in used_type_names + ): + suggested_types.append( + { + "suggestion": type_name_suggestion, + "type_annotation": ast_module.unparse(func.returns), + } + ) + + used_type_names.add(type_name_suggestion) + + signature = { + "name": func.name, + "inputs": inputs, + "outputs": outputs, + "suggested_types": suggested_types, + } + + return signature + + +def _complete_function_signature_via_generation( + func: ast_module.FunctionDef, signature: dict[str, Any] +) -> None: + """ + Search through the given signature structure for missing (None) type names, + and use name generation to generate unique names. The signature structure + is updated such that all arguments and return type should have a type name. + + Args: + func: A FunctionDef AST node representing a function definition + signature: A function signature structure to update + """ + + # Gather used types; use this to ensure uniqueness of generated types. + used_type_names = { + input_["type"] for input_ in signature["inputs"] if input_["type"] + } + + used_type_names.update( + output["type"] for output in signature["outputs"] if output["type"] + ) + + # For annotations for which we could not derive a type name, we must + # nevertheless recognize annotation reuse, and reuse the same + # generated unique type name. I don't think AST's have any support + # for equality checks, hashing, etc. The only way I can think of to + # compare one AST to another is via their unparsed Python code (as + # strings). So this mapping maps unparsed Python to a generated unique + # name. + ann_to_unique: dict[str, str] = {} + unparsed_ann: Optional[str] + + for input_, arg in zip(signature["inputs"], _func_args(func)): + if not input_["type"]: + if arg.annotation: + unparsed_ann = ast_module.unparse(arg.annotation) + type_name_suggestion = ann_to_unique.get(unparsed_ann) + else: + unparsed_ann = type_name_suggestion = None + + if not type_name_suggestion: + type_name_suggestion = _make_unique_type_name(used_type_names) + if unparsed_ann: + ann_to_unique[unparsed_ann] = type_name_suggestion + + input_["type"] = type_name_suggestion + + if unparsed_ann and type_name_suggestion not in used_type_names: + signature["suggested_types"].append( + { + "suggestion": type_name_suggestion, + "type_annotation": unparsed_ann, + } + ) + + used_type_names.add(type_name_suggestion) + + # generate a type name for output if necessary + if signature["outputs"]: + output = signature["outputs"][0] + if not output["type"]: + # For mypy: we would not have a defined output if the function did + # not have a return type annotation. + assert func.returns + unparsed_ann = ast_module.unparse(func.returns) + type_name_suggestion = ann_to_unique.get(unparsed_ann) + if not type_name_suggestion: + type_name_suggestion = _make_unique_type_name(used_type_names) + ann_to_unique[unparsed_ann] = type_name_suggestion + + output["type"] = type_name_suggestion + + if type_name_suggestion not in used_type_names: + signature["suggested_types"].append( + { + "suggestion": type_name_suggestion, + "type_annotation": unparsed_ann, + } + ) + + used_type_names.add(type_name_suggestion) + + +def get_plugin_signatures( + python_source: str, filepath: Optional[Union[str, Path]] = None +) -> Iterator[dict[str, Any]]: + """ + Extract plugin signatures and build signature information structures from + all task plugins defined in the given source code. + + Args: + python_source: Some Python source code; should be complete with + supporting import statements to assist in understanding what + symbols mean + filepath: A value representative of where the python source came from. + This is an optional arg passed on to the underlying compile() + function, which documents it as: "The filename argument should + give the file from which the code was read; pass some recognizable + value if it wasn't read from a file ('' is commonly used)." + + Yields: + Function signature information data structures, as dicts + """ + if filepath: + ast = ast_module.parse(python_source, filename=filepath, feature_version=(3, 9)) + else: + ast = ast_module.parse(python_source, feature_version=(3, 9)) + + for plugin_func in _find_plugins(ast): + + # We need to come up with a syntax for unique type names. But no + # matter what syntax we choose, a user's type annotations might collide + # with it. So we can't easily do this in one pass where we generate a + # name whenever we fail to derive one from a type annotation. If a + # subsequent type name derived from a user type annotation collides + # with a unique name we already generated, the user's name must take + # precedence. + # + # A better way is to make two passes: the first pass derives type names + # from type annotations where possible, and determines what the + # user-annotation-derived type names are. The second pass uses unique + # name generation to generate all type names we could not derive in the + # first pass, where the generation can use the names derived in the + # first pass to ensure there are no naming collisions. + + # Pass #1 + signature = _get_function_signature_via_derivation(plugin_func) + + # Pass #2 + _complete_function_signature_via_generation(plugin_func, signature) + + yield signature + + +def get_plugin_signatures_from_file( + filepath: Union[str, Path], encoding: str = "utf-8" +) -> Iterator[dict[str, Any]]: + """ + Extract plugin signatures and build signature information structures from + all task plugins defined in the given Python source file. + + Args: + filepath: A path to a file with Python source code; should be complete + with supporting import statements to assist in understanding what + symbols mean + encoding: A text encoding used to read the given source file + + Returns: + An iterator of function signature information data structures, as dicts + """ + filepath = Path(filepath) + python_source = filepath.read_text(encoding=encoding) + + return get_plugin_signatures(python_source, filepath) diff --git a/tests/unit/restapi/lib/test_signature_analysis.py b/tests/unit/restapi/lib/test_signature_analysis.py new file mode 100644 index 000000000..626a89312 --- /dev/null +++ b/tests/unit/restapi/lib/test_signature_analysis.py @@ -0,0 +1,361 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from dioptra.restapi.v1.lib.signature_analysis import get_plugin_signatures + + +def test_plugin_recognition_1(): + source = """\ +import dioptra.pyplugs + +@dioptra.pyplugs.register +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_2(): + source = """\ +from dioptra import pyplugs + +@pyplugs.register +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_3(): + source = """\ +from dioptra.pyplugs import register + +@register +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_alias_1(): + source = """\ +import dioptra.pyplugs as foo + +@foo.register +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_alias_2(): + source = """\ +from dioptra import pyplugs as foo + +@foo.register +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_alias_3(): + source = """\ +from dioptra.pyplugs import register as foo + +@foo +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_call(): + source = """\ +from dioptra.pyplugs import register + +@register() +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_alias_call(): + source = """\ +from dioptra.pyplugs import register as foo + +@foo() +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_none(): + source = """\ +import dioptra.pyplugs + +# missing the decorator +def not_a_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert not signatures + + +def test_plugin_recognition_complex(): + source = """\ +from dioptra.pyplugs import register +import aaa + +@register() +def test_plugin(): + pass + +@aaa.register +def not_a_plugin(): + pass + +class SomeClass: + pass + +def some_other_func(): + pass + +x = 1 + +@register +def test_plugin2(): + pass + +# re-definition of the "register" symbol +from bbb import ccc as register + +@register +def also_not_a_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 2 + + +def test_dioptra_builtin_types(): + source = """\ +from dioptra.pyplugs import register + +@register +def test_plugin( + arg1: str, + arg2: int, + arg3: float, + arg4: bool, + arg5: None +): + pass +""" + + signatures = list(get_plugin_signatures(source)) + + assert signatures == [ + { + "name": "test_plugin", + "inputs": [ + {"name": "arg1", "required": True, "type": "string"}, + {"name": "arg2", "required": True, "type": "integer"}, + {"name": "arg3", "required": True, "type": "number"}, + {"name": "arg4", "required": True, "type": "boolean"}, + {"name": "arg5", "required": True, "type": "null"}, + ], + "outputs": [], + "suggested_types": [] + } + ] + + +def test_return_none(): + source = """\ +from dioptra.pyplugs import register + +# None is same as not having a return type annotation +@register +def my_plugin() -> None: + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert signatures == [ + { + "name": "my_plugin", + "inputs": [], + "outputs": [], + "suggested_types": [] + } + ] + + +def test_derive_type_simple(): + source = """\ +import dioptra.pyplugs + +@dioptra.pyplugs.register() +def the_plugin(arg1: SomeType) -> SomeType: + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert signatures == [ + { + "name": "the_plugin", + "inputs": [ + {"name": "arg1", "required": True, "type": "sometype"} + ], + "outputs": [ + {"name": "output", "type": "sometype"} + ], + "suggested_types": [ + {"suggestion": "sometype", "type_annotation": "SomeType"} + ] + } + ] + + +def test_derive_type_complex(): + source = """\ +import dioptra.pyplugs + +@dioptra.pyplugs.register() +def the_plugin(arg1: Optional[str]) -> Union[int, bool]: + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert signatures == [ + { + "name": "the_plugin", + "inputs": [ + {"name": "arg1", "required": True, "type": "optional_str"} + ], + "outputs": [ + {"name": "output", "type": "union_int_bool"} + ], + "suggested_types": [ + {"suggestion": "optional_str", "type_annotation": "Optional[str]"}, + {"suggestion": "union_int_bool", "type_annotation": "Union[int, bool]"} + ] + } + ] + + +def test_generate_type(): + source = """\ +import dioptra.pyplugs + +# annotation is a function call; we don't attempt a type derivation for +# that kind of annotation. +@dioptra.pyplugs.register +def plugin_func(arg1: foo(2)) -> foo(2): + pass +""" + signatures = list(get_plugin_signatures(source)) + assert signatures == [ + { + "name": "plugin_func", + "inputs": [ + {"name": "arg1", "required": True, "type": "type1"} + ], + "outputs": [ + {"name": "output", "type": "type1"} + ], + "suggested_types": [ + {"suggestion": "type1", "type_annotation": "foo(2)"} + ] + } + ] + + +def test_generate_type_conflict(): + source = """\ +import dioptra.pyplugs + +# annotation is a function call; we don't attempt a type derivation for +# that kind of annotation. Our first generated type would normally be "type1", +# but we can't use that either because the code author already used that! So +# our generated type will have to be "type2". +@dioptra.pyplugs.register +def plugin_func(arg1: foo(2), arg2: Type1) -> foo(2): + pass +""" + signatures = list(get_plugin_signatures(source)) + assert signatures == [ + { + "name": "plugin_func", + "inputs": [ + {"name": "arg1", "required": True, "type": "type2"}, + {"name": "arg2", "required": True, "type": "type1"} + ], + "outputs": [ + {"name": "output", "type": "type2"} + ], + "suggested_types": [ + {"suggestion": "type1", "type_annotation": "Type1"}, + {"suggestion": "type2", "type_annotation": "foo(2)"} + ] + } + ] + + +def test_optional_arg(): + source = """\ +from dioptra import pyplugs + +@pyplugs.register() +def do_things(arg1: Optional[str], arg2: int = 123): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert signatures == [ + { + "name": "do_things", + "inputs": [ + {"name": "arg1", "required": True, "type": "optional_str"}, + {"name": "arg2", "required": False, "type": "integer"} + ], + "outputs": [], + "suggested_types": [ + {"suggestion": "optional_str", "type_annotation": "Optional[str]"} + ] + } + ] From ab4abd02c647ec00ca3dc6feac65c44dd4b99036 Mon Sep 17 00:00:00 2001 From: jtsextonMITRE <45762017+jtsextonMITRE@users.noreply.github.com> Date: Thu, 9 Jan 2025 11:10:55 -0500 Subject: [PATCH 2/5] feat(restapi): add schema.py --- src/dioptra/restapi/v1/workflows/schema.py | 59 ++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 src/dioptra/restapi/v1/workflows/schema.py diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py new file mode 100644 index 000000000..6652486d8 --- /dev/null +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -0,0 +1,59 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The schemas for serializing/deserializing Workflow resources.""" +from enum import Enum + +from marshmallow import Schema, fields + + +class FileTypes(Enum): + TAR_GZ = "tar_gz" + ZIP = "zip" + + +class JobFilesDownloadQueryParametersSchema(Schema): + """The query parameters for making a jobFilesDownload workflow request.""" + + jobId = fields.String( + attribute="job_id", + metadata=dict(description="A job's unique identifier."), + ) + fileType = fields.Enum( + FileTypes, + attribute="file_type", + metadata=dict( + description="The type of file to download: tar_gz or zip.", + ), + by_value=True, + default=FileTypes.TAR_GZ.value, + ) + +class SignatureAnalysisSchema(Schema): + + fileContents = fields.String( + attribute="file_contents", + metadata=dict( + description="The contents of the file" + ) + ) + + filename = fields.String( + attribute="filename", + metadata=dict( + description="The name of the file" + ) + ) \ No newline at end of file From 3c93e3764465f79919690668988a727f5780a593 Mon Sep 17 00:00:00 2001 From: jtsextonMITRE <45762017+jtsextonMITRE@users.noreply.github.com> Date: Tue, 21 Jan 2025 16:15:49 -0500 Subject: [PATCH 3/5] feat(restapi): added workflow & tests for plugin signature analysis --- src/dioptra/client/workflows.py | 38 ++ .../restapi/v1/workflows/controller.py | 39 +- src/dioptra/restapi/v1/workflows/schema.py | 78 ++++ src/dioptra/restapi/v1/workflows/service.py | 51 ++- .../v1/signature_analysis/test_alias.py | 4 + .../signature_analysis/test_complex_type.py | 4 + .../signature_analysis/test_function_type.py | 4 + .../v1/signature_analysis/test_none_return.py | 4 + .../v1/signature_analysis/test_optional.py | 4 + .../signature_analysis/test_pyplugs_alias.py | 4 + .../v1/signature_analysis/test_real_world.py | 381 ++++++++++++++++ .../signature_analysis/test_redefinition.py | 22 + .../signature_analysis/test_register_alias.py | 4 + .../signature_analysis/test_type_conflict.py | 4 + .../restapi/v1/test_signature_analysis.py | 411 ++++++++++++++++++ 15 files changed, 1048 insertions(+), 4 deletions(-) create mode 100644 tests/unit/restapi/v1/signature_analysis/test_alias.py create mode 100644 tests/unit/restapi/v1/signature_analysis/test_complex_type.py create mode 100644 tests/unit/restapi/v1/signature_analysis/test_function_type.py create mode 100644 tests/unit/restapi/v1/signature_analysis/test_none_return.py create mode 100644 tests/unit/restapi/v1/signature_analysis/test_optional.py create mode 100644 tests/unit/restapi/v1/signature_analysis/test_pyplugs_alias.py create mode 100644 tests/unit/restapi/v1/signature_analysis/test_real_world.py create mode 100644 tests/unit/restapi/v1/signature_analysis/test_redefinition.py create mode 100644 tests/unit/restapi/v1/signature_analysis/test_register_alias.py create mode 100644 tests/unit/restapi/v1/signature_analysis/test_type_conflict.py create mode 100644 tests/unit/restapi/v1/test_signature_analysis.py diff --git a/src/dioptra/client/workflows.py b/src/dioptra/client/workflows.py index 8dfa4f6c6..4fd1cb899 100644 --- a/src/dioptra/client/workflows.py +++ b/src/dioptra/client/workflows.py @@ -22,6 +22,7 @@ T = TypeVar("T") JOB_FILES_DOWNLOAD: Final[str] = "jobFilesDownload" +SIGNATURE_ANALYSIS: Final[str] = "signatureAnalysis" class WorkflowsCollectionClient(CollectionClient[T]): @@ -86,3 +87,40 @@ def download_job_files( return self._session.download( self.url, JOB_FILES_DOWNLOAD, output_path=job_files_path, params=params ) + + def signature_analysis_contents( + self, + fileContents: str, + filename: str ="something.py" + ) -> T: + """ + Requests signature analysis for the functions in an annotated python file. + + Args: + fileContents: The contents of the python file. + filename: The name of the file. + + Returns: + The response from the Dioptra API. + + """ + return self._session.post(self.url, SIGNATURE_ANALYSIS, json_={"filename":filename, "fileContents":fileContents}) + + def signature_analysis_file( + self, + filename: str + ) -> T: + """ + Reads a file, and then requests signature analysis for the + functions in an annotated python file. + + Args: + filename: The name of the file. + + Returns: + The response from the Dioptra API. + + """ + with open(filename, 'r+') as f: + contents = f.read() + return self.signature_analysis_contents(fileContents=contents, filename=filename) diff --git a/src/dioptra/restapi/v1/workflows/controller.py b/src/dioptra/restapi/v1/workflows/controller.py index 428619cdc..6b587aa6c 100644 --- a/src/dioptra/restapi/v1/workflows/controller.py +++ b/src/dioptra/restapi/v1/workflows/controller.py @@ -19,14 +19,14 @@ import structlog from flask import request, send_file -from flask_accepts import accepts +from flask_accepts import accepts, responds from flask_login import login_required from flask_restx import Namespace, Resource from injector import inject from structlog.stdlib import BoundLogger -from .schema import FileTypes, JobFilesDownloadQueryParametersSchema -from .service import JobFilesDownloadService +from .schema import FileTypes, JobFilesDownloadQueryParametersSchema, SignatureAnalysisSchema, SignatureAnalysisOutputSchema +from .service import JobFilesDownloadService, SignatureAnalysisService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -78,3 +78,36 @@ def get(self): mimetype=mimetype[parsed_query_params["file_type"]], download_name=download_name[parsed_query_params["file_type"]], ) + + +@api.route("/signatureAnalysis") +class SignatureAnalysisEndpoint(Resource): + @inject + def __init__( + self, signature_analysis_service: SignatureAnalysisService, *args, **kwargs + ) -> None: + """Initialize the workflow resource. + + All arguments are provided via dependency injection. + + Args: + signature_analysis_service: A SignatureAnalysisService object. + """ + self._signature_analysis_service = signature_analysis_service + super().__init__(*args, **kwargs) + + @login_required + @accepts(schema=SignatureAnalysisSchema, api=api) + @responds(schema=SignatureAnalysisOutputSchema, api=api) + def post(self): + """Download a compressed file archive containing the files needed to execute a submitted job.""" # noqa: B950 + log = LOGGER.new( # noqa: F841 + request_id=str(uuid.uuid4()), + resource="SignatureAnalysis", + request_type="POST", + ) + parsed_obj = request.parsed_obj + return self._signature_analysis_service.post( + filename=parsed_obj["filename"], + fileContents=parsed_obj["file_contents"], + ) diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py index 6c006c287..5be2ff3c2 100644 --- a/src/dioptra/restapi/v1/workflows/schema.py +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -57,3 +57,81 @@ class SignatureAnalysisSchema(Schema): description="The name of the file" ) ) + +class SignatureAnalysisSignatureParamSchema(Schema): + name = fields.String( + attribute="name", + metadata=dict( + description="The name of the parameter" + ) + ) + type = fields.String( + attribute="type", + metadata=dict( + description="The type of the parameter" + ) + ) + +class SignatureAnalysisSignatureInputSchema(SignatureAnalysisSignatureParamSchema): + required = fields.Boolean( + attribute="required", + metadata=dict( + description="Whether this is a required parameter" + ) + ) +class SignatureAnalysisSignatureOutputSchema(SignatureAnalysisSignatureParamSchema): + ''' No additional fields. ''' + +class SignatureAnalysisSuggestedTypes(Schema): + # this should be an integer or a list of integer resource ids on the next iteration + proposed_type = fields.String( + attribute="proposed_type", + metadata=dict( + description="A suggestion for the name of the type" + ) + ) + + missing_type = fields.String( + attribute="missing_type", + metadata=dict( + description="The annotation the suggestion is attempting to represent" + ) + ) + +class SignatureAnalysisSignatureSchema(Schema): + name = fields.String( + attribute="name", + metadata=dict( + description="The name of the function" + ) + ) + inputs = fields.Nested( + SignatureAnalysisSignatureInputSchema, + metadata=dict( + description="A list of objects describing the input parameters." + ), + many=True + ) + outputs = fields.Nested( + SignatureAnalysisSignatureOutputSchema, + metadata=dict( + description="A list of objects describing the output parameters." + ), + many=True + ) + missing_types = fields.Nested( + SignatureAnalysisSuggestedTypes, + metadata=dict( + description="A list of suggested types for non-primitives defined by the file" + ), + many=True + ) + +class SignatureAnalysisOutputSchema(Schema): + plugins = fields.Nested( + SignatureAnalysisSignatureSchema, + metadata=dict( + description="A list of signature analyses for the plugins in the input file" + ), + many=True + ) \ No newline at end of file diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index d5769e274..2d178fdb0 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -15,11 +15,12 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """The server-side functions that perform workflows endpoint operations.""" -from typing import IO, Final +from typing import IO, Any, Final, Iterator, List import structlog from structlog.stdlib import BoundLogger +from dioptra.restapi.v1.lib.signature_analysis import get_plugin_signatures from .lib import views from .lib.package_job_files import package_job_files from .schema import FileTypes @@ -65,3 +66,51 @@ def get(self, job_id: int, file_type: FileTypes, **kwargs) -> IO[bytes]: file_type=file_type, logger=log, ) + +class SignatureAnalysisService(object): + """The service methods for performing signature analysis on a file.""" + + def post(self, filename: str, fileContents: str, **kwargs) -> dict[str, List[dict[str, Any]]]: + """Perform signature analysis on a file. + + Args: + filename: The name of the file. + file_contents: The contents of the file. + + Returns: + A dictionary containing the signature analysis. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Performing signature analysis", filename=filename, python_source=fileContents) + + + signatures = list(get_plugin_signatures( + python_source=fileContents, + filepath=filename, + )) + + print(signatures) + endpoint_analyses = [] + for signature in signatures: + function_name = signature["name"] + function_inputs = signature["inputs"] + function_outputs = signature["outputs"] + inferences = signature["suggested_types"] + endpoint_analysis = {} + + endpoint_analysis['name'] = function_name + endpoint_analysis['inputs'] = function_inputs + endpoint_analysis['outputs'] = function_outputs + + # Compute the suggestions for the unknown types + + missing_types = [] + + for inference in inferences: + suggested_type = inference['suggestion'] # replace this with resource id's for suggestions + original_annotation = inference['type_annotation'] # do a database lookup with this + missing_types += [{ 'missing_type': original_annotation, 'proposed_type': suggested_type}] + + endpoint_analysis['missing_types'] = missing_types + endpoint_analyses += [endpoint_analysis] + return {"plugins": endpoint_analyses} \ No newline at end of file diff --git a/tests/unit/restapi/v1/signature_analysis/test_alias.py b/tests/unit/restapi/v1/signature_analysis/test_alias.py new file mode 100644 index 000000000..66ba8dc2e --- /dev/null +++ b/tests/unit/restapi/v1/signature_analysis/test_alias.py @@ -0,0 +1,4 @@ +import dioptra.pyplugs as foo +@foo.register +def test_plugin(): + pass \ No newline at end of file diff --git a/tests/unit/restapi/v1/signature_analysis/test_complex_type.py b/tests/unit/restapi/v1/signature_analysis/test_complex_type.py new file mode 100644 index 000000000..dc03cb3dd --- /dev/null +++ b/tests/unit/restapi/v1/signature_analysis/test_complex_type.py @@ -0,0 +1,4 @@ +import dioptra.pyplugs +@dioptra.pyplugs.register() +def the_plugin(arg1: Optional[str]) -> Union[int, bool]: + pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_function_type.py b/tests/unit/restapi/v1/signature_analysis/test_function_type.py new file mode 100644 index 000000000..b9e1296dd --- /dev/null +++ b/tests/unit/restapi/v1/signature_analysis/test_function_type.py @@ -0,0 +1,4 @@ +import dioptra.pyplugs +@dioptra.pyplugs.register +def plugin_func(arg1: foo(2)) -> foo(2): + pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_none_return.py b/tests/unit/restapi/v1/signature_analysis/test_none_return.py new file mode 100644 index 000000000..b27c11417 --- /dev/null +++ b/tests/unit/restapi/v1/signature_analysis/test_none_return.py @@ -0,0 +1,4 @@ +from dioptra.pyplugs import register +@register +def my_plugin() -> None: + pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_optional.py b/tests/unit/restapi/v1/signature_analysis/test_optional.py new file mode 100644 index 000000000..59aeae3d8 --- /dev/null +++ b/tests/unit/restapi/v1/signature_analysis/test_optional.py @@ -0,0 +1,4 @@ +from dioptra import pyplugs +@pyplugs.register() +def do_things(arg1: Optional[str], arg2: int = 123): + pass \ No newline at end of file diff --git a/tests/unit/restapi/v1/signature_analysis/test_pyplugs_alias.py b/tests/unit/restapi/v1/signature_analysis/test_pyplugs_alias.py new file mode 100644 index 000000000..1eb2bb716 --- /dev/null +++ b/tests/unit/restapi/v1/signature_analysis/test_pyplugs_alias.py @@ -0,0 +1,4 @@ +from dioptra import pyplugs as foo +@foo.register +def test_plugin(): + pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_real_world.py b/tests/unit/restapi/v1/signature_analysis/test_real_world.py new file mode 100644 index 000000000..5cd1eb6f7 --- /dev/null +++ b/tests/unit/restapi/v1/signature_analysis/test_real_world.py @@ -0,0 +1,381 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from __future__ import annotations + +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple, Union, Any + +import mlflow +import numpy as np +import pandas as pd +import scipy.stats +import structlog +from structlog.stdlib import BoundLogger +from tensorflow.keras.preprocessing.image import ( + DirectoryIterator +) + +from dioptra import pyplugs +from .tensorflow import get_optimizer, get_model_callbacks, get_performance_metrics, evaluate_metrics_tensorflow, predict_tensorflow +from .estimators_keras_classifiers import init_classifier +from .registry_art import load_wrapped_tensorflow_keras_classifier +from .registry_mlflow import load_tensorflow_keras_classifier +from .random_rng import init_rng +from .random_sample import draw_random_integer +from .backend_configs_tensorflow import init_tensorflow +from .tracking_mlflow import log_parameters, log_tensorflow_keras_estimator, log_metrics +from .data_tensorflow import get_n_classes_from_directory_iterator, create_image_dataset, predictions_to_df, df_to_predictions +from .estimators_methods import fit +from .mlflow import add_model_to_registry +from .artifacts_restapi import get_uri_for_model, get_uris_for_job, get_uris_for_artifacts +from .artifacts_utils import make_directories, extract_tarfile +from .metrics_distance import get_distance_metric_list +from .attacks_fgm import fgm +from .artifacts_mlflow import upload_directory_as_tarball_artifact, upload_data_frame_artifact, download_all_artifacts +from .defenses_image_preprocessing import create_defended_dataset +from .attacks_patch import create_adversarial_patches, create_adversarial_patch_dataset +from .metrics_performance import get_performance_metric_list, evaluate_metrics_generic + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +@pyplugs.register +def load_dataset( + ep_seed: int = 10145783023, + training_dir: str = "/dioptra/data/Mnist/training", + testing_dir: str = "/dioptra/data/Mnist/testing", + subsets: List[str] = ['testing'], + image_size: Tuple[int, int, int] = [28, 28, 1], + rescale: float = 1.0 / 255, + validation_split: Optional[float] = 0.2, + batch_size: int = 32, + label_mode: str = "categorical", + shuffle: bool = False +) -> DirectoryIterator: + seed, rng = init_rng(ep_seed) + global_seed = draw_random_integer(rng) + dataset_seed = draw_random_integer(rng) + init_tensorflow(global_seed) + log_parameters( + {'entry_point_seed': ep_seed, + 'tensorflow_global_seed':global_seed, + 'dataset_seed':dataset_seed}) + training_dataset = None if "training" not in subsets else create_image_dataset( + data_dir=training_dir, + subset="training", + image_size=image_size, + seed=dataset_seed, + rescale=rescale, + validation_split=validation_split, + batch_size=batch_size, + label_mode=label_mode, + shuffle=shuffle + ) + + validation_dataset = None if "validation" not in subsets else create_image_dataset( + data_dir=training_dir, + subset="validation", + image_size=image_size, + seed=dataset_seed, + rescale=rescale, + validation_split=validation_split, + batch_size=batch_size, + label_mode=label_mode, + shuffle=shuffle + ) + testing_dataset = None if "testing" not in subsets else create_image_dataset( + data_dir=testing_dir, + subset=None, + image_size=image_size, + seed=dataset_seed, + rescale=rescale, + validation_split=validation_split, + batch_size=batch_size, + label_mode=label_mode, + shuffle=shuffle + ) + return training_dataset, validation_dataset, testing_dataset + +@pyplugs.register +def create_model( + dataset: DirectoryIterator = None, + model_architecture: str = "le_net", + input_shape: Tuple[int, int, int] = [28, 28, 1], + loss: str = "categorical_crossentropy", + learning_rate: float = 0.001, + optimizer: str = "Adam", + metrics_list: List[Dict[str, Any]] = None, +): + n_classes = get_n_classes_from_directory_iterator(dataset) + optim = get_optimizer(optimizer, learning_rate) + perf_metrics = get_performance_metrics(metrics_list) + classifier = init_classifier(model_architecture, optim, perf_metrics, input_shape, n_classes, loss) + return classifier + +@pyplugs.register +def load_model( + model_name: str | None = None, + model_version: int | None = None, + imagenet_preprocessing: bool = False, + art: bool = False, + image_size: Any = None, + classifier_kwargs: Optional[Dict[str, Any]] = None +): + uri = get_uri_for_model(model_name, model_version) + if (art): + classifier = load_wrapped_tensorflow_keras_classifier(uri, imagenet_preprocessing, image_size, classifier_kwargs) + else: + classifier = load_tensorflow_keras_classifier(uri) + return classifier + +@pyplugs.register +def train( + estimator: Any, + x: Any = None, + y: Any = None, + callbacks_list: List[Dict[str, Any]] = None, + fit_kwargs: Optional[Dict[str, Any]] = None +): + fit_kwargs = {} if fit_kwargs is None else fit_kwargs + callbacks = get_model_callbacks(callbacks_list) + fit_kwargs['callbacks'] = callbacks + fit(estimator=estimator, x=x, y=y, fit_kwargs=fit_kwargs) + return estimator + +@pyplugs.register +def save_artifacts_and_models( + artifacts: List[Dict[str, Any]] = None, + models: List[Dict[str, Any]] = None +): + artifacts = [] if artifacts is None else artifacts + models = [] if models is None else models + + for model in models: + log_tensorflow_keras_estimator(model['model'], "model") + add_model_to_registry(model['name'], "model") + for artifact in artifacts: + if (artifact['type'] == 'tarball'): + upload_directory_as_tarball_artifact( + source_dir=artifact['adv_data_dir'], + tarball_filename=artifact['adv_tar_name'] + ) + if (artifact['type'] == 'dataframe'): + upload_data_frame_artifact( + data_frame=artifact['data_frame'], + file_name=artifact['file_name'], + file_format=artifact['file_format'], + file_format_kwargs=artifact['file_format_kwargs'] + ) +@pyplugs.register +def load_artifacts_for_job( + job_id: str, + files: List[str|Path] = None, + extract_files: List[str|Path] = None +): + files = [] if files is None else files + extract_files = [] if extract_files is None else extract_files + files += extract_files # need to download them to be able to extract + + uris = get_uris_for_job(job_id) + paths = download_all_artifacts(uris, files) + for extract in paths: + for ef in extract_files: + if (ef.endswith(str(ef))): + extract_tarfile(extract) + return paths + +@pyplugs.register +def load_artifacts( + artifact_ids: List[int] = None, extract_files: List[str|Path] = None +): + extract_files = [] if extract_files is None else extract_files + artifact_ids = [] if artifact_ids is not None else artifact_ids + uris = get_uris_for_artifacts(artifact_ids) + paths = download_all_artifacts(uris, extract_files) + for extract in paths: + extract_tarfile(extract) + +@pyplugs.register +def attack_fgm( + dataset: Any, + adv_data_dir: Union[str, Path], + classifier: Any, + distance_metrics: List[Dict[str, str]], + batch_size: int = 32, + eps: float = 0.3, + eps_step: float = 0.1, + minimal: bool = False, + norm: Union[int, float, str] = np.inf, +): + '''generate fgm examples''' + make_directories([adv_data_dir]) + distance_metrics_list = get_distance_metric_list(distance_metrics) + fgm_dataset = fgm( + data_flow=dataset, + adv_data_dir=adv_data_dir, + keras_classifier=classifier, + distance_metrics_list=distance_metrics_list, + batch_size=batch_size, + eps=eps, + eps_step=eps_step, + minimal=minimal, + norm=norm + ) + return fgm_dataset + +@pyplugs.register() +def attack_patch( + data_flow: Any, + adv_data_dir: Union[str, Path], + model: Any, + patch_target: int, + num_patch: int, + num_patch_samples: int, + rotation_max: float, + scale_min: float, + scale_max: float, + learning_rate: float, + max_iter: int, + patch_shape: Tuple, +): + '''generate patches''' + make_directories([adv_data_dir]) + create_adversarial_patches( + data_flow=data_flow, + adv_data_dir=adv_data_dir, + keras_classifier=model, + patch_target=patch_target, + num_patch=num_patch, + num_patch_samples=num_patch_samples, + rotation_max=rotation_max, + scale_min=scale_min, + scale_max=scale_max, + learning_rate=learning_rate, + max_iter=max_iter, + patch_shape=patch_shape, + ) + +@pyplugs.register() +def augment_patch( + data_flow: Any, + adv_data_dir: Union[str, Path], + patch_dir: Union[str, Path], + model: Any, + patch_shape: Tuple, + distance_metrics: List[Dict[str, str]], + batch_size: int = 32, + patch_scale: float = 0.4, + rotation_max: float = 22.5, + scale_min: float = 0.1, + scale_max: float = 1.0, +): + '''add patches to a dataset''' + make_directories([adv_data_dir]) + distance_metrics_list = get_distance_metric_list(distance_metrics) + create_adversarial_patch_dataset( + data_flow=data_flow, + adv_data_dir=adv_data_dir, + patch_dir=patch_dir, + keras_classifier=model, + patch_shape=patch_shape, + distance_metrics_list=distance_metrics_list, + batch_size=batch_size, + patch_scale=patch_scale, + rotation_max=rotation_max, + scale_min=scale_min, + scale_max=scale_max + ) +@pyplugs.register +def model_metrics( + classifier: Any, + dataset: Any +): + metrics = evaluate_metrics_tensorflow(classifier, dataset) + log_metrics(metrics) + return metrics + +@pyplugs.register +def prediction_metrics( + y_true: np.ndarray, + y_pred: np.ndarray, + metrics_list: List[Dict[str, str]], + func_kwargs: Dict[str, Dict[str, Any]] = None +): + func_kwargs = {} if func_kwargs is None else func_kwargs + callable_list = get_performance_metric_list(metrics_list) + metrics = evaluate_metrics_generic(y_true, y_pred, callable_list, func_kwargs) + log_metrics(metrics) + return pd.DataFrame(metrics, index=[1]) + +@pyplugs.register +def augment_data( + dataset: Any, + def_data_dir: Union[str, Path], + image_size: Tuple[int, int, int], + distance_metrics: List[Dict[str, str]], + batch_size: int = 50, + def_type: str = "spatial_smoothing", + defense_kwargs: Optional[Dict[str, Any]] = None, +): + make_directories([def_data_dir]) + distance_metrics_list = get_distance_metric_list(distance_metrics) + defended_dataset = create_defended_dataset( + data_flow=dataset, + def_data_dir=def_data_dir, + image_size=image_size, + distance_metrics_list=distance_metrics_list, + batch_size=batch_size, + def_type=def_type, + defense_kwargs=defense_kwargs, + ) + return defended_dataset + +@pyplugs.register +def predict( + classifier: Any, + dataset: Any, + show_actual: bool = False, + show_target: bool = False, +): + predictions = predict_tensorflow(classifier, dataset) + df = predictions_to_df( + predictions, + dataset, + show_actual=show_actual, + show_target=show_target) + return df + +@pyplugs.register +def load_predictions( + paths: List[str], + filename: str, + format: str = 'csv', + dataset: DirectoryIterator = None, + n_classes: int = -1, +): + loc = None + for m in paths: + if m.endswith(filename): + loc = m + if (format == 'csv'): + df = pd.read_csv(loc) + elif (format == 'json'): + df = pd.read_json(loc) + y_true, y_pred = df_to_predictions(df, dataset, n_classes) + return y_true, y_pred + + + diff --git a/tests/unit/restapi/v1/signature_analysis/test_redefinition.py b/tests/unit/restapi/v1/signature_analysis/test_redefinition.py new file mode 100644 index 000000000..a33dc6fb3 --- /dev/null +++ b/tests/unit/restapi/v1/signature_analysis/test_redefinition.py @@ -0,0 +1,22 @@ +from dioptra.pyplugs import register +import aaa +@register() +def test_plugin(): + pass +@aaa.register +def not_a_plugin(): + pass +class SomeClass: + pass + +def some_other_func(): + pass +x = 1 +@register +def test_plugin2(): + pass +# re-definition of the "register" symbol +from bbb import ccc as register +@register +def also_not_a_plugin(): + pass \ No newline at end of file diff --git a/tests/unit/restapi/v1/signature_analysis/test_register_alias.py b/tests/unit/restapi/v1/signature_analysis/test_register_alias.py new file mode 100644 index 000000000..aa63572f8 --- /dev/null +++ b/tests/unit/restapi/v1/signature_analysis/test_register_alias.py @@ -0,0 +1,4 @@ +from dioptra.pyplugs import register as foo +@foo +def test_plugin(): + pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_type_conflict.py b/tests/unit/restapi/v1/signature_analysis/test_type_conflict.py new file mode 100644 index 000000000..b9d75333a --- /dev/null +++ b/tests/unit/restapi/v1/signature_analysis/test_type_conflict.py @@ -0,0 +1,4 @@ +import dioptra.pyplugs +@dioptra.pyplugs.register +def plugin_func(arg1: foo(2), arg2: Type1) -> foo(2): + pass \ No newline at end of file diff --git a/tests/unit/restapi/v1/test_signature_analysis.py b/tests/unit/restapi/v1/test_signature_analysis.py new file mode 100644 index 000000000..6b50e6296 --- /dev/null +++ b/tests/unit/restapi/v1/test_signature_analysis.py @@ -0,0 +1,411 @@ +expected_outputs = {} + +expected_outputs['test_real_world.py'] = [{'name': 'load_dataset', + 'inputs': [{'name': 'ep_seed', 'type': 'integer', 'required': False}, + {'name': 'training_dir', 'type': 'string', 'required': False}, + {'name': 'testing_dir', 'type': 'string', 'required': False}, + {'name': 'subsets', 'type': 'list_str', 'required': False}, + {'name': 'image_size', 'type': 'tuple_int_int_int', 'required': False}, + {'name': 'rescale', 'type': 'number', 'required': False}, + {'name': 'validation_split', 'type': 'optional_float', 'required': False}, + {'name': 'batch_size', 'type': 'integer', 'required': False}, + {'name': 'label_mode', 'type': 'string', 'required': False}, + {'name': 'shuffle', 'type': 'boolean', 'required': False}], + 'outputs': [{'name': 'output', 'type': 'directoryiterator'}], + 'missing_types': [{'proposed_type': 'list_str', 'missing_type': 'List[str]'}, + {'proposed_type': 'tuple_int_int_int', + 'missing_type': 'Tuple[int, int, int]'}, + {'proposed_type': 'optional_float', 'missing_type': 'Optional[float]'}, + {'proposed_type': 'directoryiterator', + 'missing_type': 'DirectoryIterator'}]}, + {'name': 'create_model', + 'inputs': [{'name': 'dataset', + 'type': 'directoryiterator', + 'required': False}, + {'name': 'model_architecture', 'type': 'string', 'required': False}, + {'name': 'input_shape', 'type': 'tuple_int_int_int', 'required': False}, + {'name': 'loss', 'type': 'string', 'required': False}, + {'name': 'learning_rate', 'type': 'number', 'required': False}, + {'name': 'optimizer', 'type': 'string', 'required': False}, + {'name': 'metrics_list', 'type': 'list_dict_str_any', 'required': False}], + 'outputs': [], + 'missing_types': [{'proposed_type': 'directoryiterator', + 'missing_type': 'DirectoryIterator'}, + {'proposed_type': 'tuple_int_int_int', + 'missing_type': 'Tuple[int, int, int]'}, + {'proposed_type': 'list_dict_str_any', + 'missing_type': 'List[Dict[str, Any]]'}]}, + {'name': 'load_model', + 'inputs': [{'name': 'model_name', 'type': 'str_none', 'required': False}, + {'name': 'model_version', 'type': 'int_none', 'required': False}, + {'name': 'imagenet_preprocessing', 'type': 'boolean', 'required': False}, + {'name': 'art', 'type': 'boolean', 'required': False}, + {'name': 'image_size', 'type': 'any', 'required': False}, + {'name': 'classifier_kwargs', + 'type': 'optional_dict_str_any', + 'required': False}], + 'outputs': [], + 'missing_types': [{'proposed_type': 'str_none', + 'missing_type': 'str | None'}, + {'proposed_type': 'int_none', 'missing_type': 'int | None'}, + {'proposed_type': 'optional_dict_str_any', + 'missing_type': 'Optional[Dict[str, Any]]'}]}, + {'name': 'train', + 'inputs': [{'name': 'estimator', 'type': 'any', 'required': True}, + {'name': 'x', 'type': 'any', 'required': False}, + {'name': 'y', 'type': 'any', 'required': False}, + {'name': 'callbacks_list', 'type': 'list_dict_str_any', 'required': False}, + {'name': 'fit_kwargs', 'type': 'optional_dict_str_any', 'required': False}], + 'outputs': [], + 'missing_types': [{'proposed_type': 'list_dict_str_any', + 'missing_type': 'List[Dict[str, Any]]'}, + {'proposed_type': 'optional_dict_str_any', + 'missing_type': 'Optional[Dict[str, Any]]'}]}, + {'name': 'save_artifacts_and_models', + 'inputs': [{'name': 'artifacts', + 'type': 'list_dict_str_any', + 'required': False}, + {'name': 'models', 'type': 'list_dict_str_any', 'required': False}], + 'outputs': [], + 'missing_types': [{'proposed_type': 'list_dict_str_any', + 'missing_type': 'List[Dict[str, Any]]'}]}, + {'name': 'load_artifacts_for_job', + 'inputs': [{'name': 'job_id', 'type': 'string', 'required': True}, + {'name': 'files', 'type': 'list_str_path', 'required': False}, + {'name': 'extract_files', 'type': 'list_str_path', 'required': False}], + 'outputs': [], + 'missing_types': [{'proposed_type': 'list_str_path', + 'missing_type': 'List[str | Path]'}]}, + {'name': 'load_artifacts', + 'inputs': [{'name': 'artifact_ids', 'type': 'list_int', 'required': False}, + {'name': 'extract_files', 'type': 'list_str_path', 'required': False}], + 'outputs': [], + 'missing_types': [{'proposed_type': 'list_int', 'missing_type': 'List[int]'}, + {'proposed_type': 'list_str_path', 'missing_type': 'List[str | Path]'}]}, + {'name': 'attack_fgm', + 'inputs': [{'name': 'dataset', 'type': 'any', 'required': True}, + {'name': 'adv_data_dir', 'type': 'union_str_path', 'required': True}, + {'name': 'classifier', 'type': 'any', 'required': True}, + {'name': 'distance_metrics', 'type': 'list_dict_str_str', 'required': True}, + {'name': 'batch_size', 'type': 'integer', 'required': False}, + {'name': 'eps', 'type': 'number', 'required': False}, + {'name': 'eps_step', 'type': 'number', 'required': False}, + {'name': 'minimal', 'type': 'boolean', 'required': False}, + {'name': 'norm', 'type': 'union_int_float_str', 'required': False}], + 'outputs': [], + 'missing_types': [{'proposed_type': 'union_str_path', + 'missing_type': 'Union[str, Path]'}, + {'proposed_type': 'list_dict_str_str', + 'missing_type': 'List[Dict[str, str]]'}, + {'proposed_type': 'union_int_float_str', + 'missing_type': 'Union[int, float, str]'}]}, + {'name': 'attack_patch', + 'inputs': [{'name': 'data_flow', 'type': 'any', 'required': True}, + {'name': 'adv_data_dir', 'type': 'union_str_path', 'required': True}, + {'name': 'model', 'type': 'any', 'required': True}, + {'name': 'patch_target', 'type': 'integer', 'required': True}, + {'name': 'num_patch', 'type': 'integer', 'required': True}, + {'name': 'num_patch_samples', 'type': 'integer', 'required': True}, + {'name': 'rotation_max', 'type': 'number', 'required': True}, + {'name': 'scale_min', 'type': 'number', 'required': True}, + {'name': 'scale_max', 'type': 'number', 'required': True}, + {'name': 'learning_rate', 'type': 'number', 'required': True}, + {'name': 'max_iter', 'type': 'integer', 'required': True}, + {'name': 'patch_shape', 'type': 'tuple', 'required': True}], + 'outputs': [], + 'missing_types': [{'proposed_type': 'union_str_path', + 'missing_type': 'Union[str, Path]'}, + {'proposed_type': 'tuple', 'missing_type': 'Tuple'}]}, + {'name': 'augment_patch', + 'inputs': [{'name': 'data_flow', 'type': 'any', 'required': True}, + {'name': 'adv_data_dir', 'type': 'union_str_path', 'required': True}, + {'name': 'patch_dir', 'type': 'union_str_path', 'required': True}, + {'name': 'model', 'type': 'any', 'required': True}, + {'name': 'patch_shape', 'type': 'tuple', 'required': True}, + {'name': 'distance_metrics', 'type': 'list_dict_str_str', 'required': True}, + {'name': 'batch_size', 'type': 'integer', 'required': False}, + {'name': 'patch_scale', 'type': 'number', 'required': False}, + {'name': 'rotation_max', 'type': 'number', 'required': False}, + {'name': 'scale_min', 'type': 'number', 'required': False}, + {'name': 'scale_max', 'type': 'number', 'required': False}], + 'outputs': [], + 'missing_types': [{'proposed_type': 'union_str_path', + 'missing_type': 'Union[str, Path]'}, + {'proposed_type': 'tuple', 'missing_type': 'Tuple'}, + {'proposed_type': 'list_dict_str_str', + 'missing_type': 'List[Dict[str, str]]'}]}, + {'name': 'model_metrics', + 'inputs': [{'name': 'classifier', 'type': 'any', 'required': True}, + {'name': 'dataset', 'type': 'any', 'required': True}], + 'outputs': [], + 'missing_types': []}, + {'name': 'prediction_metrics', + 'inputs': [{'name': 'y_true', 'type': 'np_ndarray', 'required': True}, + {'name': 'y_pred', 'type': 'np_ndarray', 'required': True}, + {'name': 'metrics_list', 'type': 'list_dict_str_str', 'required': True}, + {'name': 'func_kwargs', + 'type': 'dict_str_dict_str_any', + 'required': False}], + 'outputs': [], + 'missing_types': [{'proposed_type': 'np_ndarray', + 'missing_type': 'np.ndarray'}, + {'proposed_type': 'list_dict_str_str', + 'missing_type': 'List[Dict[str, str]]'}, + {'proposed_type': 'dict_str_dict_str_any', + 'missing_type': 'Dict[str, Dict[str, Any]]'}]}, + {'name': 'augment_data', + 'inputs': [{'name': 'dataset', 'type': 'any', 'required': True}, + {'name': 'def_data_dir', 'type': 'union_str_path', 'required': True}, + {'name': 'image_size', 'type': 'tuple_int_int_int', 'required': True}, + {'name': 'distance_metrics', 'type': 'list_dict_str_str', 'required': True}, + {'name': 'batch_size', 'type': 'integer', 'required': False}, + {'name': 'def_type', 'type': 'string', 'required': False}, + {'name': 'defense_kwargs', + 'type': 'optional_dict_str_any', + 'required': False}], + 'outputs': [], + 'missing_types': [{'proposed_type': 'union_str_path', + 'missing_type': 'Union[str, Path]'}, + {'proposed_type': 'tuple_int_int_int', + 'missing_type': 'Tuple[int, int, int]'}, + {'proposed_type': 'list_dict_str_str', + 'missing_type': 'List[Dict[str, str]]'}, + {'proposed_type': 'optional_dict_str_any', + 'missing_type': 'Optional[Dict[str, Any]]'}]}, + {'name': 'predict', + 'inputs': [{'name': 'classifier', 'type': 'any', 'required': True}, + {'name': 'dataset', 'type': 'any', 'required': True}, + {'name': 'show_actual', 'type': 'boolean', 'required': False}, + {'name': 'show_target', 'type': 'boolean', 'required': False}], + 'outputs': [], + 'missing_types': []}, + {'name': 'load_predictions', + 'inputs': [{'name': 'paths', 'type': 'list_str', 'required': True}, + {'name': 'filename', 'type': 'string', 'required': True}, + {'name': 'format', 'type': 'string', 'required': False}, + {'name': 'dataset', 'type': 'directoryiterator', 'required': False}, + {'name': 'n_classes', 'type': 'integer', 'required': False}], + 'outputs': [], + 'missing_types': [{'proposed_type': 'list_str', 'missing_type': 'List[str]'}, + {'proposed_type': 'directoryiterator', + 'missing_type': 'DirectoryIterator'}]}] + +expected_outputs['test_alias.py'] = [{ + 'name':'test_plugin', + 'inputs': [], + 'outputs': [], + 'missing_types': [] +}] + +expected_outputs['test_complex_type.py'] = [{ + 'name':'the_plugin', + 'inputs': [ + { + 'name': 'arg1', + 'type': 'optional_str', + 'required': True, + } + ], + 'outputs': [ + { + 'name': 'output', + 'type': 'union_int_bool' + } + ], + 'missing_types': [ + {'proposed_type': 'optional_str', 'missing_type': 'Optional[str]'}, + {'proposed_type': 'union_int_bool', 'missing_type': 'Union[int, bool]'}, + ] +}] + +expected_outputs['test_function_type.py'] = [{ + 'name':'plugin_func', + 'inputs': [ + { + 'name': 'arg1', + 'type': 'type1', + 'required': True, + } + ], + 'outputs': [ + { + 'name': 'output', + 'type': 'type1' + } + ], + 'missing_types': [ + {'proposed_type': 'type1', 'missing_type': 'foo(2)'}, + ] +}] + +expected_outputs['test_none_return.py'] = [{ + 'name':'my_plugin', + 'inputs': [], + 'outputs': [], + 'missing_types': [] +}] + +expected_outputs['test_optional.py'] = [{ + 'name':'do_things', + 'inputs': [ + { + 'name': 'arg1', + 'type': 'optional_str', + 'required': True, + }, + { + 'name': 'arg2', + 'type': 'integer', + 'required': False, + }, + + ], + 'outputs': [], + 'missing_types': [ + {'proposed_type': 'optional_str', 'missing_type': 'Optional[str]'}, + ] +}] + +expected_outputs['test_pyplugs_alias.py'] = [{ + 'name':'test_plugin', + 'inputs': [], + 'outputs': [], + 'missing_types': [] +}] + +expected_outputs['test_redefinition.py'] = [{ + 'name':'test_plugin', + 'inputs': [], + 'outputs': [], + 'missing_types': [] +},{ + 'name':'test_plugin2', + 'inputs': [], + 'outputs': [], + 'missing_types': [] +}] + +expected_outputs['test_register_alias.py'] = [{ + 'name':'test_plugin', + 'inputs': [], + 'outputs': [], + 'missing_types': [] +}] + +expected_outputs['test_type_conflict.py'] = [{ + 'name':'plugin_func', + 'inputs': [ + { + 'name': 'arg1', + 'type': 'type2', + 'required': True, + }, + { + 'name': 'arg2', + 'type': 'type1', + 'required': True, + } + ], + 'outputs': [ + { + 'name': 'output', + 'type': 'type2' + } + ], + 'missing_types': [ + {'proposed_type': 'type2', 'missing_type': 'foo(2)'}, + {'proposed_type': 'type1', 'missing_type': 'Type1'}, + ] +}] + +from pathlib import Path +from typing import Any +from http import HTTPStatus +from flask_sqlalchemy import SQLAlchemy + +from dioptra.client.base import DioptraResponseProtocol +from dioptra.client.client import DioptraClient + +# -- Assertions ------------------------------------------------------------------------ + + +def assert_signature_analysis_response_matches_expectations( + response: dict[str, Any], expected_contents: dict[str, Any] +) -> None: + """Assert that a job response contents is valid. + + Args: + response: The actual response from the API. + expected_contents: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response or if the response contents is not + valid. + """ + # Check expected keys + expected_keys = { + "name", + "missing_types", + "outputs", + "inputs", + } + assert set(response.keys()) == expected_keys + + # Check basic response types + assert isinstance(response["name"], str) + assert isinstance(response["outputs"], list) + assert isinstance(response["missing_types"], list) + assert isinstance(response["inputs"], list) + + def sort_by_name(lst, k='name'): + return sorted(lst, key=lambda x: x[k]) + + assert sort_by_name(response["outputs"]) == sort_by_name(expected_contents["outputs"]) + assert sort_by_name(response["inputs"]) == sort_by_name(expected_contents["inputs"]) + assert sort_by_name(response["missing_types"], k='proposed_type') == sort_by_name(expected_contents["missing_types"], k='proposed_type') + + +def assert_signature_analysis_responses_matches_expectations( + responses: list[dict[str, Any]], expected_contents: list[dict[str, Any]] +) -> None: + assert (len(responses) == len(expected_contents)) + for response in responses: + assert_signature_analysis_response_matches_expectations(response, [a for a in expected_contents if a['name']==response['name']][0]) + + +def assert_signature_analysis_file_load_and_contents( + dioptra_client: DioptraClient[DioptraResponseProtocol], + filename: str, +): + location = Path("tests/unit/restapi/v1/signature_analysis") / filename + file_analysis = dioptra_client.workflows.signature_analysis_file(str(location)) + with location.open('r') as f: + contents = f.read() + contents_analysis = dioptra_client.workflows.signature_analysis_contents(contents, filename) + + assert(file_analysis.status_code == HTTPStatus.OK) + assert(contents_analysis.status_code == HTTPStatus.OK) + + assert_signature_analysis_responses_matches_expectations(file_analysis.json()["plugins"], expected_contents=expected_outputs[filename]) + assert_signature_analysis_responses_matches_expectations(contents_analysis.json()["plugins"], expected_contents=expected_outputs[filename]) + + +# -- Tests ----------------------------------------------------------------------------- + +def test_signature_analysis( + dioptra_client: DioptraClient[DioptraResponseProtocol], + db: SQLAlchemy, + auth_account: dict[str, Any], +) -> None: + """ + Test that signature analysis + Args: + client: The Flask test client. + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + + for fn in expected_outputs: + assert_signature_analysis_file_load_and_contents(dioptra_client=dioptra_client, filename=fn) From ac92dbb18c2c14b45bc3b30d7be3aaef18588dcf Mon Sep 17 00:00:00 2001 From: jtsextonMITRE <45762017+jtsextonMITRE@users.noreply.github.com> Date: Fri, 24 Jan 2025 11:11:03 -0500 Subject: [PATCH 4/5] chore: fix linting errors --- .flake8 | 1 + .../restapi/v1/workflows/controller.py | 7 +- src/dioptra/restapi/v1/workflows/schema.py | 67 +- src/dioptra/restapi/v1/workflows/service.py | 54 +- tests/unit/restapi/lib/actions.py | 3 +- tests/unit/restapi/lib/asserts.py | 4 +- tests/unit/restapi/lib/mock_mlflow.py | 7 +- .../restapi/lib/test_signature_analysis.py | 61 +- tests/unit/restapi/test_depth_limited_repr.py | 19 +- tests/unit/restapi/v1/conftest.py | 2 +- .../v1/signature_analysis/test_alias.py | 20 +- .../signature_analysis/test_complex_type.py | 18 + .../signature_analysis/test_function_type.py | 18 + .../v1/signature_analysis/test_none_return.py | 18 + .../v1/signature_analysis/test_optional.py | 20 +- .../signature_analysis/test_pyplugs_alias.py | 18 + .../v1/signature_analysis/test_real_world.py | 253 +++--- .../signature_analysis/test_redefinition.py | 38 +- .../signature_analysis/test_register_alias.py | 18 + .../signature_analysis/test_type_conflict.py | 20 +- .../restapi/v1/test_signature_analysis.py | 784 ++++++++++-------- 21 files changed, 885 insertions(+), 565 deletions(-) diff --git a/.flake8 b/.flake8 index 80ddeb7e9..2ae752bdd 100644 --- a/.flake8 +++ b/.flake8 @@ -5,6 +5,7 @@ select = B,C,E,F,W,B9 ignore = E203,E302,E501,W503,B905,B907,B909 per-file-ignores = examples/*/src/*.py:E402 + tests/unit/restapi/v1/signature_analysis/test*.py:F821,F401,B006,B950,E402 extend-exclude = .ipynb_checkpoints alembic diff --git a/src/dioptra/restapi/v1/workflows/controller.py b/src/dioptra/restapi/v1/workflows/controller.py index 6b587aa6c..a485fb704 100644 --- a/src/dioptra/restapi/v1/workflows/controller.py +++ b/src/dioptra/restapi/v1/workflows/controller.py @@ -25,7 +25,12 @@ from injector import inject from structlog.stdlib import BoundLogger -from .schema import FileTypes, JobFilesDownloadQueryParametersSchema, SignatureAnalysisSchema, SignatureAnalysisOutputSchema +from .schema import ( + FileTypes, + JobFilesDownloadQueryParametersSchema, + SignatureAnalysisOutputSchema, + SignatureAnalysisSchema, +) from .service import JobFilesDownloadService, SignatureAnalysisService LOGGER: BoundLogger = structlog.stdlib.get_logger() diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py index 5be2ff3c2..b70fad100 100644 --- a/src/dioptra/restapi/v1/workflows/schema.py +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -42,96 +42,83 @@ class JobFilesDownloadQueryParametersSchema(Schema): default=FileTypes.TAR_GZ.value, ) + class SignatureAnalysisSchema(Schema): fileContents = fields.String( - attribute="file_contents", - metadata=dict( - description="The contents of the file" - ) + attribute="file_contents", metadata=dict(description="The contents of the file") ) filename = fields.String( - attribute="filename", - metadata=dict( - description="The name of the file" - ) + attribute="filename", metadata=dict(description="The name of the file") ) + class SignatureAnalysisSignatureParamSchema(Schema): name = fields.String( - attribute="name", - metadata=dict( - description="The name of the parameter" - ) + attribute="name", metadata=dict(description="The name of the parameter") ) type = fields.String( - attribute="type", - metadata=dict( - description="The type of the parameter" - ) + attribute="type", metadata=dict(description="The type of the parameter") ) + class SignatureAnalysisSignatureInputSchema(SignatureAnalysisSignatureParamSchema): required = fields.Boolean( attribute="required", - metadata=dict( - description="Whether this is a required parameter" - ) + metadata=dict(description="Whether this is a required parameter"), ) + + class SignatureAnalysisSignatureOutputSchema(SignatureAnalysisSignatureParamSchema): - ''' No additional fields. ''' + """No additional fields.""" + class SignatureAnalysisSuggestedTypes(Schema): # this should be an integer or a list of integer resource ids on the next iteration - proposed_type = fields.String( + proposed_type = fields.String( attribute="proposed_type", - metadata=dict( - description="A suggestion for the name of the type" - ) + metadata=dict(description="A suggestion for the name of the type"), ) - + missing_type = fields.String( attribute="missing_type", metadata=dict( description="The annotation the suggestion is attempting to represent" - ) + ), ) + class SignatureAnalysisSignatureSchema(Schema): name = fields.String( - attribute="name", - metadata=dict( - description="The name of the function" - ) + attribute="name", metadata=dict(description="The name of the function") ) inputs = fields.Nested( SignatureAnalysisSignatureInputSchema, - metadata=dict( - description="A list of objects describing the input parameters." - ), - many=True + metadata=dict(description="A list of objects describing the input parameters."), + many=True, ) outputs = fields.Nested( SignatureAnalysisSignatureOutputSchema, metadata=dict( description="A list of objects describing the output parameters." ), - many=True + many=True, ) missing_types = fields.Nested( SignatureAnalysisSuggestedTypes, metadata=dict( - description="A list of suggested types for non-primitives defined by the file" + description="A list of missing types for non-primitives defined by the file" ), - many=True + many=True, ) + class SignatureAnalysisOutputSchema(Schema): plugins = fields.Nested( SignatureAnalysisSignatureSchema, metadata=dict( description="A list of signature analyses for the plugins in the input file" - ), - many=True - ) \ No newline at end of file + ), + many=True, + ) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index 2d178fdb0..ac5615d09 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -15,12 +15,13 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """The server-side functions that perform workflows endpoint operations.""" -from typing import IO, Any, Final, Iterator, List +from typing import IO, Any, Final, List import structlog from structlog.stdlib import BoundLogger from dioptra.restapi.v1.lib.signature_analysis import get_plugin_signatures + from .lib import views from .lib.package_job_files import package_job_files from .schema import FileTypes @@ -67,10 +68,13 @@ def get(self, job_id: int, file_type: FileTypes, **kwargs) -> IO[bytes]: logger=log, ) + class SignatureAnalysisService(object): """The service methods for performing signature analysis on a file.""" - def post(self, filename: str, fileContents: str, **kwargs) -> dict[str, List[dict[str, Any]]]: + def post( + self, filename: str, fileContents: str, **kwargs + ) -> dict[str, List[dict[str, Any]]]: """Perform signature analysis on a file. Args: @@ -81,14 +85,19 @@ def post(self, filename: str, fileContents: str, **kwargs) -> dict[str, List[dic A dictionary containing the signature analysis. """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - log.debug("Performing signature analysis", filename=filename, python_source=fileContents) + log.debug( + "Performing signature analysis", + filename=filename, + python_source=fileContents, + ) + signatures = list( + get_plugin_signatures( + python_source=fileContents, + filepath=filename, + ) + ) - signatures = list(get_plugin_signatures( - python_source=fileContents, - filepath=filename, - )) - print(signatures) endpoint_analyses = [] for signature in signatures: @@ -98,19 +107,28 @@ def post(self, filename: str, fileContents: str, **kwargs) -> dict[str, List[dic inferences = signature["suggested_types"] endpoint_analysis = {} - endpoint_analysis['name'] = function_name - endpoint_analysis['inputs'] = function_inputs - endpoint_analysis['outputs'] = function_outputs - + endpoint_analysis["name"] = function_name + endpoint_analysis["inputs"] = function_inputs + endpoint_analysis["outputs"] = function_outputs + # Compute the suggestions for the unknown types missing_types = [] for inference in inferences: - suggested_type = inference['suggestion'] # replace this with resource id's for suggestions - original_annotation = inference['type_annotation'] # do a database lookup with this - missing_types += [{ 'missing_type': original_annotation, 'proposed_type': suggested_type}] - - endpoint_analysis['missing_types'] = missing_types + suggested_type = inference[ + "suggestion" + ] # replace this with resource id's for suggestions + original_annotation = inference[ + "type_annotation" + ] # do a database lookup with this + missing_types += [ + { + "missing_type": original_annotation, + "proposed_type": suggested_type, + } + ] + + endpoint_analysis["missing_types"] = missing_types endpoint_analyses += [endpoint_analysis] - return {"plugins": endpoint_analyses} \ No newline at end of file + return {"plugins": endpoint_analyses} diff --git a/tests/unit/restapi/lib/actions.py b/tests/unit/restapi/lib/actions.py index b7f10c6f7..46e380929 100644 --- a/tests/unit/restapi/lib/actions.py +++ b/tests/unit/restapi/lib/actions.py @@ -769,6 +769,7 @@ def remove_tag( follow_redirects=True, ) + def post_metrics( client: FlaskClient, job_id: int, metric_name: str, metric_value: float ) -> TestResponse: @@ -835,4 +836,4 @@ def post_mlflowruns( ).get_json() responses[key] = mlflowrun_response - return responses \ No newline at end of file + return responses diff --git a/tests/unit/restapi/lib/asserts.py b/tests/unit/restapi/lib/asserts.py index 28839cc1c..631c733a3 100644 --- a/tests/unit/restapi/lib/asserts.py +++ b/tests/unit/restapi/lib/asserts.py @@ -285,9 +285,7 @@ def assert_creating_another_existing_draft_fails( Raises: AssertionError: If the response status code is not 400. """ - response = drafts_client.create( - *resource_ids, **payload - ) + response = drafts_client.create(*resource_ids, **payload) assert response.status_code == HTTPStatus.BAD_REQUEST diff --git a/tests/unit/restapi/lib/mock_mlflow.py b/tests/unit/restapi/lib/mock_mlflow.py index a7fb7a851..a13822867 100644 --- a/tests/unit/restapi/lib/mock_mlflow.py +++ b/tests/unit/restapi/lib/mock_mlflow.py @@ -59,7 +59,12 @@ def get_run(self, id: str) -> MockMlflowRun: return run def log_metric( - self, id: str, key: str, value: float, step: Optional[int] = None, timestamp: Optional[int] = None + self, + id: str, + key: str, + value: float, + step: Optional[int] = None, + timestamp: Optional[int] = None, ): if id not in active_runs: active_runs[id] = [] diff --git a/tests/unit/restapi/lib/test_signature_analysis.py b/tests/unit/restapi/lib/test_signature_analysis.py index 626a89312..43314a6d7 100644 --- a/tests/unit/restapi/lib/test_signature_analysis.py +++ b/tests/unit/restapi/lib/test_signature_analysis.py @@ -137,7 +137,7 @@ def not_a_plugin(): def test_plugin_recognition_complex(): source = """\ from dioptra.pyplugs import register -import aaa +import aaa @register() def test_plugin(): @@ -149,7 +149,7 @@ def not_a_plugin(): class SomeClass: pass - + def some_other_func(): pass @@ -199,7 +199,7 @@ def test_plugin( {"name": "arg5", "required": True, "type": "null"}, ], "outputs": [], - "suggested_types": [] + "suggested_types": [], } ] @@ -216,12 +216,7 @@ def my_plugin() -> None: signatures = list(get_plugin_signatures(source)) assert signatures == [ - { - "name": "my_plugin", - "inputs": [], - "outputs": [], - "suggested_types": [] - } + {"name": "my_plugin", "inputs": [], "outputs": [], "suggested_types": []} ] @@ -238,15 +233,11 @@ def the_plugin(arg1: SomeType) -> SomeType: assert signatures == [ { "name": "the_plugin", - "inputs": [ - {"name": "arg1", "required": True, "type": "sometype"} - ], - "outputs": [ - {"name": "output", "type": "sometype"} - ], + "inputs": [{"name": "arg1", "required": True, "type": "sometype"}], + "outputs": [{"name": "output", "type": "sometype"}], "suggested_types": [ {"suggestion": "sometype", "type_annotation": "SomeType"} - ] + ], } ] @@ -264,16 +255,12 @@ def the_plugin(arg1: Optional[str]) -> Union[int, bool]: assert signatures == [ { "name": "the_plugin", - "inputs": [ - {"name": "arg1", "required": True, "type": "optional_str"} - ], - "outputs": [ - {"name": "output", "type": "union_int_bool"} - ], + "inputs": [{"name": "arg1", "required": True, "type": "optional_str"}], + "outputs": [{"name": "output", "type": "union_int_bool"}], "suggested_types": [ {"suggestion": "optional_str", "type_annotation": "Optional[str]"}, - {"suggestion": "union_int_bool", "type_annotation": "Union[int, bool]"} - ] + {"suggestion": "union_int_bool", "type_annotation": "Union[int, bool]"}, + ], } ] @@ -292,15 +279,9 @@ def plugin_func(arg1: foo(2)) -> foo(2): assert signatures == [ { "name": "plugin_func", - "inputs": [ - {"name": "arg1", "required": True, "type": "type1"} - ], - "outputs": [ - {"name": "output", "type": "type1"} - ], - "suggested_types": [ - {"suggestion": "type1", "type_annotation": "foo(2)"} - ] + "inputs": [{"name": "arg1", "required": True, "type": "type1"}], + "outputs": [{"name": "output", "type": "type1"}], + "suggested_types": [{"suggestion": "type1", "type_annotation": "foo(2)"}], } ] @@ -323,15 +304,13 @@ def plugin_func(arg1: foo(2), arg2: Type1) -> foo(2): "name": "plugin_func", "inputs": [ {"name": "arg1", "required": True, "type": "type2"}, - {"name": "arg2", "required": True, "type": "type1"} - ], - "outputs": [ - {"name": "output", "type": "type2"} + {"name": "arg2", "required": True, "type": "type1"}, ], + "outputs": [{"name": "output", "type": "type2"}], "suggested_types": [ {"suggestion": "type1", "type_annotation": "Type1"}, - {"suggestion": "type2", "type_annotation": "foo(2)"} - ] + {"suggestion": "type2", "type_annotation": "foo(2)"}, + ], } ] @@ -351,11 +330,11 @@ def do_things(arg1: Optional[str], arg2: int = 123): "name": "do_things", "inputs": [ {"name": "arg1", "required": True, "type": "optional_str"}, - {"name": "arg2", "required": False, "type": "integer"} + {"name": "arg2", "required": False, "type": "integer"}, ], "outputs": [], "suggested_types": [ {"suggestion": "optional_str", "type_annotation": "Optional[str]"} - ] + ], } ] diff --git a/tests/unit/restapi/test_depth_limited_repr.py b/tests/unit/restapi/test_depth_limited_repr.py index 1fbe33d23..248476ba1 100644 --- a/tests/unit/restapi/test_depth_limited_repr.py +++ b/tests/unit/restapi/test_depth_limited_repr.py @@ -15,10 +15,11 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode import pytest -from dioptra.restapi.db.models.utils import depth_limited_repr import sqlalchemy as sa import sqlalchemy.orm as sao +from dioptra.restapi.db.models.utils import depth_limited_repr + class TestBase(sao.DeclarativeBase): pass @@ -44,13 +45,14 @@ class C(TestBase): @pytest.mark.parametrize( - "struct, expected_repr", [ + "struct, expected_repr", + [ (1, "1"), ("a", "'a'"), (False, "False"), (None, "None"), (b"abc", "b'abc'"), - (bytearray(b'abc'), "bytearray(b'abc')"), + (bytearray(b"abc"), "bytearray(b'abc')"), ([], "[]"), ([1], "[1]"), ([1, "foo"], "[1, 'foo']"), @@ -62,7 +64,7 @@ class C(TestBase): (range(3), "[0, 1, 2]"), ({1, 2, 3}, "[1, 2, 3]"), ((1, 2, 3), "[1, 2, 3]"), - ] + ], ) def test_simple_repr(struct, expected_repr): @@ -126,14 +128,7 @@ def test_orm_mixed(): b = B(id=2) a = A(id=1, bs=[b]) - value = { - "a": { - "b": [ - 1, - a - ] - } - } + value = {"a": {"b": [1, a]}} # depth does not affect traversal of plain data structures, only the # ORM instances diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index 676725d26..0dd235c63 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -16,12 +16,12 @@ # https://creativecommons.org/licenses/by/4.0/legalcode """Fixtures representing resources needed for test suites""" import textwrap +import uuid from collections.abc import Iterator from http import HTTPStatus from typing import Any, cast import pytest -import uuid from flask import Flask from flask.testing import FlaskClient from flask_sqlalchemy import SQLAlchemy diff --git a/tests/unit/restapi/v1/signature_analysis/test_alias.py b/tests/unit/restapi/v1/signature_analysis/test_alias.py index 66ba8dc2e..904d2cf65 100644 --- a/tests/unit/restapi/v1/signature_analysis/test_alias.py +++ b/tests/unit/restapi/v1/signature_analysis/test_alias.py @@ -1,4 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode import dioptra.pyplugs as foo + + @foo.register def test_plugin(): - pass \ No newline at end of file + pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_complex_type.py b/tests/unit/restapi/v1/signature_analysis/test_complex_type.py index dc03cb3dd..f2833120a 100644 --- a/tests/unit/restapi/v1/signature_analysis/test_complex_type.py +++ b/tests/unit/restapi/v1/signature_analysis/test_complex_type.py @@ -1,4 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode import dioptra.pyplugs + + @dioptra.pyplugs.register() def the_plugin(arg1: Optional[str]) -> Union[int, bool]: pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_function_type.py b/tests/unit/restapi/v1/signature_analysis/test_function_type.py index b9e1296dd..bc3242674 100644 --- a/tests/unit/restapi/v1/signature_analysis/test_function_type.py +++ b/tests/unit/restapi/v1/signature_analysis/test_function_type.py @@ -1,4 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode import dioptra.pyplugs + + @dioptra.pyplugs.register def plugin_func(arg1: foo(2)) -> foo(2): pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_none_return.py b/tests/unit/restapi/v1/signature_analysis/test_none_return.py index b27c11417..0ed95097e 100644 --- a/tests/unit/restapi/v1/signature_analysis/test_none_return.py +++ b/tests/unit/restapi/v1/signature_analysis/test_none_return.py @@ -1,4 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode from dioptra.pyplugs import register + + @register def my_plugin() -> None: pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_optional.py b/tests/unit/restapi/v1/signature_analysis/test_optional.py index 59aeae3d8..ec847c6ea 100644 --- a/tests/unit/restapi/v1/signature_analysis/test_optional.py +++ b/tests/unit/restapi/v1/signature_analysis/test_optional.py @@ -1,4 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode from dioptra import pyplugs + + @pyplugs.register() def do_things(arg1: Optional[str], arg2: int = 123): - pass \ No newline at end of file + pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_pyplugs_alias.py b/tests/unit/restapi/v1/signature_analysis/test_pyplugs_alias.py index 1eb2bb716..73ab9039a 100644 --- a/tests/unit/restapi/v1/signature_analysis/test_pyplugs_alias.py +++ b/tests/unit/restapi/v1/signature_analysis/test_pyplugs_alias.py @@ -1,4 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode from dioptra import pyplugs as foo + + @foo.register def test_plugin(): pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_real_world.py b/tests/unit/restapi/v1/signature_analysis/test_real_world.py index 5cd1eb6f7..79689c7ef 100644 --- a/tests/unit/restapi/v1/signature_analysis/test_real_world.py +++ b/tests/unit/restapi/v1/signature_analysis/test_real_world.py @@ -17,98 +17,133 @@ from __future__ import annotations from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Union, Any +from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import mlflow import numpy as np import pandas as pd import scipy.stats import structlog from structlog.stdlib import BoundLogger -from tensorflow.keras.preprocessing.image import ( - DirectoryIterator -) +from tensorflow.keras.preprocessing.image import DirectoryIterator +import mlflow from dioptra import pyplugs -from .tensorflow import get_optimizer, get_model_callbacks, get_performance_metrics, evaluate_metrics_tensorflow, predict_tensorflow -from .estimators_keras_classifiers import init_classifier -from .registry_art import load_wrapped_tensorflow_keras_classifier -from .registry_mlflow import load_tensorflow_keras_classifier -from .random_rng import init_rng -from .random_sample import draw_random_integer + +from .artifacts_mlflow import ( + download_all_artifacts, + upload_data_frame_artifact, + upload_directory_as_tarball_artifact, +) +from .artifacts_restapi import ( + get_uri_for_model, + get_uris_for_artifacts, + get_uris_for_job, +) +from .artifacts_utils import extract_tarfile, make_directories +from .attacks_fgm import fgm +from .attacks_patch import create_adversarial_patch_dataset, create_adversarial_patches from .backend_configs_tensorflow import init_tensorflow -from .tracking_mlflow import log_parameters, log_tensorflow_keras_estimator, log_metrics -from .data_tensorflow import get_n_classes_from_directory_iterator, create_image_dataset, predictions_to_df, df_to_predictions +from .data_tensorflow import ( + create_image_dataset, + df_to_predictions, + get_n_classes_from_directory_iterator, + predictions_to_df, +) +from .defenses_image_preprocessing import create_defended_dataset +from .estimators_keras_classifiers import init_classifier from .estimators_methods import fit -from .mlflow import add_model_to_registry -from .artifacts_restapi import get_uri_for_model, get_uris_for_job, get_uris_for_artifacts -from .artifacts_utils import make_directories, extract_tarfile from .metrics_distance import get_distance_metric_list -from .attacks_fgm import fgm -from .artifacts_mlflow import upload_directory_as_tarball_artifact, upload_data_frame_artifact, download_all_artifacts -from .defenses_image_preprocessing import create_defended_dataset -from .attacks_patch import create_adversarial_patches, create_adversarial_patch_dataset -from .metrics_performance import get_performance_metric_list, evaluate_metrics_generic +from .metrics_performance import evaluate_metrics_generic, get_performance_metric_list +from .mlflow import add_model_to_registry +from .random_rng import init_rng +from .random_sample import draw_random_integer +from .registry_art import load_wrapped_tensorflow_keras_classifier +from .registry_mlflow import load_tensorflow_keras_classifier +from .tensorflow import ( + evaluate_metrics_tensorflow, + get_model_callbacks, + get_optimizer, + get_performance_metrics, + predict_tensorflow, +) +from .tracking_mlflow import log_metrics, log_parameters, log_tensorflow_keras_estimator LOGGER: BoundLogger = structlog.stdlib.get_logger() + @pyplugs.register def load_dataset( ep_seed: int = 10145783023, training_dir: str = "/dioptra/data/Mnist/training", testing_dir: str = "/dioptra/data/Mnist/testing", - subsets: List[str] = ['testing'], + subsets: List[str] = ["testing"], image_size: Tuple[int, int, int] = [28, 28, 1], rescale: float = 1.0 / 255, validation_split: Optional[float] = 0.2, batch_size: int = 32, label_mode: str = "categorical", - shuffle: bool = False + shuffle: bool = False, ) -> DirectoryIterator: seed, rng = init_rng(ep_seed) global_seed = draw_random_integer(rng) dataset_seed = draw_random_integer(rng) init_tensorflow(global_seed) log_parameters( - {'entry_point_seed': ep_seed, - 'tensorflow_global_seed':global_seed, - 'dataset_seed':dataset_seed}) - training_dataset = None if "training" not in subsets else create_image_dataset( - data_dir=training_dir, - subset="training", - image_size=image_size, - seed=dataset_seed, - rescale=rescale, - validation_split=validation_split, - batch_size=batch_size, - label_mode=label_mode, - shuffle=shuffle + { + "entry_point_seed": ep_seed, + "tensorflow_global_seed": global_seed, + "dataset_seed": dataset_seed, + } ) - - validation_dataset = None if "validation" not in subsets else create_image_dataset( - data_dir=training_dir, - subset="validation", - image_size=image_size, - seed=dataset_seed, - rescale=rescale, - validation_split=validation_split, - batch_size=batch_size, - label_mode=label_mode, - shuffle=shuffle + training_dataset = ( + None + if "training" not in subsets + else create_image_dataset( + data_dir=training_dir, + subset="training", + image_size=image_size, + seed=dataset_seed, + rescale=rescale, + validation_split=validation_split, + batch_size=batch_size, + label_mode=label_mode, + shuffle=shuffle, + ) ) - testing_dataset = None if "testing" not in subsets else create_image_dataset( - data_dir=testing_dir, - subset=None, - image_size=image_size, - seed=dataset_seed, - rescale=rescale, - validation_split=validation_split, - batch_size=batch_size, - label_mode=label_mode, - shuffle=shuffle + + validation_dataset = ( + None + if "validation" not in subsets + else create_image_dataset( + data_dir=training_dir, + subset="validation", + image_size=image_size, + seed=dataset_seed, + rescale=rescale, + validation_split=validation_split, + batch_size=batch_size, + label_mode=label_mode, + shuffle=shuffle, + ) + ) + testing_dataset = ( + None + if "testing" not in subsets + else create_image_dataset( + data_dir=testing_dir, + subset=None, + image_size=image_size, + seed=dataset_seed, + rescale=rescale, + validation_split=validation_split, + batch_size=batch_size, + label_mode=label_mode, + shuffle=shuffle, + ) ) return training_dataset, validation_dataset, testing_dataset + @pyplugs.register def create_model( dataset: DirectoryIterator = None, @@ -116,15 +151,18 @@ def create_model( input_shape: Tuple[int, int, int] = [28, 28, 1], loss: str = "categorical_crossentropy", learning_rate: float = 0.001, - optimizer: str = "Adam", + optimizer: str = "Adam", metrics_list: List[Dict[str, Any]] = None, ): n_classes = get_n_classes_from_directory_iterator(dataset) optim = get_optimizer(optimizer, learning_rate) perf_metrics = get_performance_metrics(metrics_list) - classifier = init_classifier(model_architecture, optim, perf_metrics, input_shape, n_classes, loss) + classifier = init_classifier( + model_architecture, optim, perf_metrics, input_shape, n_classes, loss + ) return classifier + @pyplugs.register def load_model( model_name: str | None = None, @@ -132,74 +170,78 @@ def load_model( imagenet_preprocessing: bool = False, art: bool = False, image_size: Any = None, - classifier_kwargs: Optional[Dict[str, Any]] = None + classifier_kwargs: Optional[Dict[str, Any]] = None, ): uri = get_uri_for_model(model_name, model_version) - if (art): - classifier = load_wrapped_tensorflow_keras_classifier(uri, imagenet_preprocessing, image_size, classifier_kwargs) + if art: + classifier = load_wrapped_tensorflow_keras_classifier( + uri, imagenet_preprocessing, image_size, classifier_kwargs + ) else: classifier = load_tensorflow_keras_classifier(uri) return classifier + @pyplugs.register def train( estimator: Any, x: Any = None, y: Any = None, callbacks_list: List[Dict[str, Any]] = None, - fit_kwargs: Optional[Dict[str, Any]] = None + fit_kwargs: Optional[Dict[str, Any]] = None, ): fit_kwargs = {} if fit_kwargs is None else fit_kwargs callbacks = get_model_callbacks(callbacks_list) - fit_kwargs['callbacks'] = callbacks + fit_kwargs["callbacks"] = callbacks fit(estimator=estimator, x=x, y=y, fit_kwargs=fit_kwargs) return estimator -@pyplugs.register + +@pyplugs.register def save_artifacts_and_models( - artifacts: List[Dict[str, Any]] = None, - models: List[Dict[str, Any]] = None + artifacts: List[Dict[str, Any]] = None, models: List[Dict[str, Any]] = None ): artifacts = [] if artifacts is None else artifacts models = [] if models is None else models for model in models: - log_tensorflow_keras_estimator(model['model'], "model") - add_model_to_registry(model['name'], "model") + log_tensorflow_keras_estimator(model["model"], "model") + add_model_to_registry(model["name"], "model") for artifact in artifacts: - if (artifact['type'] == 'tarball'): + if artifact["type"] == "tarball": upload_directory_as_tarball_artifact( - source_dir=artifact['adv_data_dir'], - tarball_filename=artifact['adv_tar_name'] + source_dir=artifact["adv_data_dir"], + tarball_filename=artifact["adv_tar_name"], ) - if (artifact['type'] == 'dataframe'): + if artifact["type"] == "dataframe": upload_data_frame_artifact( - data_frame=artifact['data_frame'], - file_name=artifact['file_name'], - file_format=artifact['file_format'], - file_format_kwargs=artifact['file_format_kwargs'] + data_frame=artifact["data_frame"], + file_name=artifact["file_name"], + file_format=artifact["file_format"], + file_format_kwargs=artifact["file_format_kwargs"], ) + + @pyplugs.register def load_artifacts_for_job( - job_id: str, - files: List[str|Path] = None, - extract_files: List[str|Path] = None + job_id: str, files: List[str | Path] = None, extract_files: List[str | Path] = None ): files = [] if files is None else files extract_files = [] if extract_files is None else extract_files - files += extract_files # need to download them to be able to extract + files += extract_files # need to download them to be able to extract uris = get_uris_for_job(job_id) paths = download_all_artifacts(uris, files) for extract in paths: for ef in extract_files: - if (ef.endswith(str(ef))): + if ef.endswith(str(ef)): extract_tarfile(extract) return paths + @pyplugs.register def load_artifacts( - artifact_ids: List[int] = None, extract_files: List[str|Path] = None + artifact_ids: List[int] = None, extract_files: List[str | Path] = None ): extract_files = [] if extract_files is None else extract_files artifact_ids = [] if artifact_ids is not None else artifact_ids @@ -208,6 +250,7 @@ def load_artifacts( for extract in paths: extract_tarfile(extract) + @pyplugs.register def attack_fgm( dataset: Any, @@ -220,7 +263,7 @@ def attack_fgm( minimal: bool = False, norm: Union[int, float, str] = np.inf, ): - '''generate fgm examples''' + """generate fgm examples""" make_directories([adv_data_dir]) distance_metrics_list = get_distance_metric_list(distance_metrics) fgm_dataset = fgm( @@ -232,10 +275,11 @@ def attack_fgm( eps=eps, eps_step=eps_step, minimal=minimal, - norm=norm + norm=norm, ) return fgm_dataset + @pyplugs.register() def attack_patch( data_flow: Any, @@ -251,9 +295,9 @@ def attack_patch( max_iter: int, patch_shape: Tuple, ): - '''generate patches''' + """generate patches""" make_directories([adv_data_dir]) - create_adversarial_patches( + create_adversarial_patches( data_flow=data_flow, adv_data_dir=adv_data_dir, keras_classifier=model, @@ -268,6 +312,7 @@ def attack_patch( patch_shape=patch_shape, ) + @pyplugs.register() def augment_patch( data_flow: Any, @@ -282,7 +327,7 @@ def augment_patch( scale_min: float = 0.1, scale_max: float = 1.0, ): - '''add patches to a dataset''' + """add patches to a dataset""" make_directories([adv_data_dir]) distance_metrics_list = get_distance_metric_list(distance_metrics) create_adversarial_patch_dataset( @@ -296,23 +341,23 @@ def augment_patch( patch_scale=patch_scale, rotation_max=rotation_max, scale_min=scale_min, - scale_max=scale_max + scale_max=scale_max, ) + + @pyplugs.register -def model_metrics( - classifier: Any, - dataset: Any -): +def model_metrics(classifier: Any, dataset: Any): metrics = evaluate_metrics_tensorflow(classifier, dataset) log_metrics(metrics) return metrics + @pyplugs.register def prediction_metrics( y_true: np.ndarray, y_pred: np.ndarray, metrics_list: List[Dict[str, str]], - func_kwargs: Dict[str, Dict[str, Any]] = None + func_kwargs: Dict[str, Dict[str, Any]] = None, ): func_kwargs = {} if func_kwargs is None else func_kwargs callable_list = get_performance_metric_list(metrics_list) @@ -320,6 +365,7 @@ def prediction_metrics( log_metrics(metrics) return pd.DataFrame(metrics, index=[1]) + @pyplugs.register def augment_data( dataset: Any, @@ -343,6 +389,7 @@ def augment_data( ) return defended_dataset + @pyplugs.register def predict( classifier: Any, @@ -351,18 +398,17 @@ def predict( show_target: bool = False, ): predictions = predict_tensorflow(classifier, dataset) - df = predictions_to_df( - predictions, - dataset, - show_actual=show_actual, - show_target=show_target) + df = predictions_to_df( + predictions, dataset, show_actual=show_actual, show_target=show_target + ) return df + @pyplugs.register def load_predictions( paths: List[str], filename: str, - format: str = 'csv', + format: str = "csv", dataset: DirectoryIterator = None, n_classes: int = -1, ): @@ -370,12 +416,9 @@ def load_predictions( for m in paths: if m.endswith(filename): loc = m - if (format == 'csv'): + if format == "csv": df = pd.read_csv(loc) - elif (format == 'json'): + elif format == "json": df = pd.read_json(loc) y_true, y_pred = df_to_predictions(df, dataset, n_classes) return y_true, y_pred - - - diff --git a/tests/unit/restapi/v1/signature_analysis/test_redefinition.py b/tests/unit/restapi/v1/signature_analysis/test_redefinition.py index a33dc6fb3..8978be0a0 100644 --- a/tests/unit/restapi/v1/signature_analysis/test_redefinition.py +++ b/tests/unit/restapi/v1/signature_analysis/test_redefinition.py @@ -1,22 +1,54 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import aaa + from dioptra.pyplugs import register -import aaa + + @register() def test_plugin(): pass + + @aaa.register def not_a_plugin(): pass + + class SomeClass: pass - + + def some_other_func(): pass + + x = 1 + + @register def test_plugin2(): pass + + # re-definition of the "register" symbol from bbb import ccc as register + + @register def also_not_a_plugin(): - pass \ No newline at end of file + pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_register_alias.py b/tests/unit/restapi/v1/signature_analysis/test_register_alias.py index aa63572f8..b5ab0d362 100644 --- a/tests/unit/restapi/v1/signature_analysis/test_register_alias.py +++ b/tests/unit/restapi/v1/signature_analysis/test_register_alias.py @@ -1,4 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode from dioptra.pyplugs import register as foo + + @foo def test_plugin(): pass diff --git a/tests/unit/restapi/v1/signature_analysis/test_type_conflict.py b/tests/unit/restapi/v1/signature_analysis/test_type_conflict.py index b9d75333a..0282d7703 100644 --- a/tests/unit/restapi/v1/signature_analysis/test_type_conflict.py +++ b/tests/unit/restapi/v1/signature_analysis/test_type_conflict.py @@ -1,4 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode import dioptra.pyplugs + + @dioptra.pyplugs.register def plugin_func(arg1: foo(2), arg2: Type1) -> foo(2): - pass \ No newline at end of file + pass diff --git a/tests/unit/restapi/v1/test_signature_analysis.py b/tests/unit/restapi/v1/test_signature_analysis.py index 6b50e6296..714f11803 100644 --- a/tests/unit/restapi/v1/test_signature_analysis.py +++ b/tests/unit/restapi/v1/test_signature_analysis.py @@ -1,331 +1,427 @@ -expected_outputs = {} - -expected_outputs['test_real_world.py'] = [{'name': 'load_dataset', - 'inputs': [{'name': 'ep_seed', 'type': 'integer', 'required': False}, - {'name': 'training_dir', 'type': 'string', 'required': False}, - {'name': 'testing_dir', 'type': 'string', 'required': False}, - {'name': 'subsets', 'type': 'list_str', 'required': False}, - {'name': 'image_size', 'type': 'tuple_int_int_int', 'required': False}, - {'name': 'rescale', 'type': 'number', 'required': False}, - {'name': 'validation_split', 'type': 'optional_float', 'required': False}, - {'name': 'batch_size', 'type': 'integer', 'required': False}, - {'name': 'label_mode', 'type': 'string', 'required': False}, - {'name': 'shuffle', 'type': 'boolean', 'required': False}], - 'outputs': [{'name': 'output', 'type': 'directoryiterator'}], - 'missing_types': [{'proposed_type': 'list_str', 'missing_type': 'List[str]'}, - {'proposed_type': 'tuple_int_int_int', - 'missing_type': 'Tuple[int, int, int]'}, - {'proposed_type': 'optional_float', 'missing_type': 'Optional[float]'}, - {'proposed_type': 'directoryiterator', - 'missing_type': 'DirectoryIterator'}]}, - {'name': 'create_model', - 'inputs': [{'name': 'dataset', - 'type': 'directoryiterator', - 'required': False}, - {'name': 'model_architecture', 'type': 'string', 'required': False}, - {'name': 'input_shape', 'type': 'tuple_int_int_int', 'required': False}, - {'name': 'loss', 'type': 'string', 'required': False}, - {'name': 'learning_rate', 'type': 'number', 'required': False}, - {'name': 'optimizer', 'type': 'string', 'required': False}, - {'name': 'metrics_list', 'type': 'list_dict_str_any', 'required': False}], - 'outputs': [], - 'missing_types': [{'proposed_type': 'directoryiterator', - 'missing_type': 'DirectoryIterator'}, - {'proposed_type': 'tuple_int_int_int', - 'missing_type': 'Tuple[int, int, int]'}, - {'proposed_type': 'list_dict_str_any', - 'missing_type': 'List[Dict[str, Any]]'}]}, - {'name': 'load_model', - 'inputs': [{'name': 'model_name', 'type': 'str_none', 'required': False}, - {'name': 'model_version', 'type': 'int_none', 'required': False}, - {'name': 'imagenet_preprocessing', 'type': 'boolean', 'required': False}, - {'name': 'art', 'type': 'boolean', 'required': False}, - {'name': 'image_size', 'type': 'any', 'required': False}, - {'name': 'classifier_kwargs', - 'type': 'optional_dict_str_any', - 'required': False}], - 'outputs': [], - 'missing_types': [{'proposed_type': 'str_none', - 'missing_type': 'str | None'}, - {'proposed_type': 'int_none', 'missing_type': 'int | None'}, - {'proposed_type': 'optional_dict_str_any', - 'missing_type': 'Optional[Dict[str, Any]]'}]}, - {'name': 'train', - 'inputs': [{'name': 'estimator', 'type': 'any', 'required': True}, - {'name': 'x', 'type': 'any', 'required': False}, - {'name': 'y', 'type': 'any', 'required': False}, - {'name': 'callbacks_list', 'type': 'list_dict_str_any', 'required': False}, - {'name': 'fit_kwargs', 'type': 'optional_dict_str_any', 'required': False}], - 'outputs': [], - 'missing_types': [{'proposed_type': 'list_dict_str_any', - 'missing_type': 'List[Dict[str, Any]]'}, - {'proposed_type': 'optional_dict_str_any', - 'missing_type': 'Optional[Dict[str, Any]]'}]}, - {'name': 'save_artifacts_and_models', - 'inputs': [{'name': 'artifacts', - 'type': 'list_dict_str_any', - 'required': False}, - {'name': 'models', 'type': 'list_dict_str_any', 'required': False}], - 'outputs': [], - 'missing_types': [{'proposed_type': 'list_dict_str_any', - 'missing_type': 'List[Dict[str, Any]]'}]}, - {'name': 'load_artifacts_for_job', - 'inputs': [{'name': 'job_id', 'type': 'string', 'required': True}, - {'name': 'files', 'type': 'list_str_path', 'required': False}, - {'name': 'extract_files', 'type': 'list_str_path', 'required': False}], - 'outputs': [], - 'missing_types': [{'proposed_type': 'list_str_path', - 'missing_type': 'List[str | Path]'}]}, - {'name': 'load_artifacts', - 'inputs': [{'name': 'artifact_ids', 'type': 'list_int', 'required': False}, - {'name': 'extract_files', 'type': 'list_str_path', 'required': False}], - 'outputs': [], - 'missing_types': [{'proposed_type': 'list_int', 'missing_type': 'List[int]'}, - {'proposed_type': 'list_str_path', 'missing_type': 'List[str | Path]'}]}, - {'name': 'attack_fgm', - 'inputs': [{'name': 'dataset', 'type': 'any', 'required': True}, - {'name': 'adv_data_dir', 'type': 'union_str_path', 'required': True}, - {'name': 'classifier', 'type': 'any', 'required': True}, - {'name': 'distance_metrics', 'type': 'list_dict_str_str', 'required': True}, - {'name': 'batch_size', 'type': 'integer', 'required': False}, - {'name': 'eps', 'type': 'number', 'required': False}, - {'name': 'eps_step', 'type': 'number', 'required': False}, - {'name': 'minimal', 'type': 'boolean', 'required': False}, - {'name': 'norm', 'type': 'union_int_float_str', 'required': False}], - 'outputs': [], - 'missing_types': [{'proposed_type': 'union_str_path', - 'missing_type': 'Union[str, Path]'}, - {'proposed_type': 'list_dict_str_str', - 'missing_type': 'List[Dict[str, str]]'}, - {'proposed_type': 'union_int_float_str', - 'missing_type': 'Union[int, float, str]'}]}, - {'name': 'attack_patch', - 'inputs': [{'name': 'data_flow', 'type': 'any', 'required': True}, - {'name': 'adv_data_dir', 'type': 'union_str_path', 'required': True}, - {'name': 'model', 'type': 'any', 'required': True}, - {'name': 'patch_target', 'type': 'integer', 'required': True}, - {'name': 'num_patch', 'type': 'integer', 'required': True}, - {'name': 'num_patch_samples', 'type': 'integer', 'required': True}, - {'name': 'rotation_max', 'type': 'number', 'required': True}, - {'name': 'scale_min', 'type': 'number', 'required': True}, - {'name': 'scale_max', 'type': 'number', 'required': True}, - {'name': 'learning_rate', 'type': 'number', 'required': True}, - {'name': 'max_iter', 'type': 'integer', 'required': True}, - {'name': 'patch_shape', 'type': 'tuple', 'required': True}], - 'outputs': [], - 'missing_types': [{'proposed_type': 'union_str_path', - 'missing_type': 'Union[str, Path]'}, - {'proposed_type': 'tuple', 'missing_type': 'Tuple'}]}, - {'name': 'augment_patch', - 'inputs': [{'name': 'data_flow', 'type': 'any', 'required': True}, - {'name': 'adv_data_dir', 'type': 'union_str_path', 'required': True}, - {'name': 'patch_dir', 'type': 'union_str_path', 'required': True}, - {'name': 'model', 'type': 'any', 'required': True}, - {'name': 'patch_shape', 'type': 'tuple', 'required': True}, - {'name': 'distance_metrics', 'type': 'list_dict_str_str', 'required': True}, - {'name': 'batch_size', 'type': 'integer', 'required': False}, - {'name': 'patch_scale', 'type': 'number', 'required': False}, - {'name': 'rotation_max', 'type': 'number', 'required': False}, - {'name': 'scale_min', 'type': 'number', 'required': False}, - {'name': 'scale_max', 'type': 'number', 'required': False}], - 'outputs': [], - 'missing_types': [{'proposed_type': 'union_str_path', - 'missing_type': 'Union[str, Path]'}, - {'proposed_type': 'tuple', 'missing_type': 'Tuple'}, - {'proposed_type': 'list_dict_str_str', - 'missing_type': 'List[Dict[str, str]]'}]}, - {'name': 'model_metrics', - 'inputs': [{'name': 'classifier', 'type': 'any', 'required': True}, - {'name': 'dataset', 'type': 'any', 'required': True}], - 'outputs': [], - 'missing_types': []}, - {'name': 'prediction_metrics', - 'inputs': [{'name': 'y_true', 'type': 'np_ndarray', 'required': True}, - {'name': 'y_pred', 'type': 'np_ndarray', 'required': True}, - {'name': 'metrics_list', 'type': 'list_dict_str_str', 'required': True}, - {'name': 'func_kwargs', - 'type': 'dict_str_dict_str_any', - 'required': False}], - 'outputs': [], - 'missing_types': [{'proposed_type': 'np_ndarray', - 'missing_type': 'np.ndarray'}, - {'proposed_type': 'list_dict_str_str', - 'missing_type': 'List[Dict[str, str]]'}, - {'proposed_type': 'dict_str_dict_str_any', - 'missing_type': 'Dict[str, Dict[str, Any]]'}]}, - {'name': 'augment_data', - 'inputs': [{'name': 'dataset', 'type': 'any', 'required': True}, - {'name': 'def_data_dir', 'type': 'union_str_path', 'required': True}, - {'name': 'image_size', 'type': 'tuple_int_int_int', 'required': True}, - {'name': 'distance_metrics', 'type': 'list_dict_str_str', 'required': True}, - {'name': 'batch_size', 'type': 'integer', 'required': False}, - {'name': 'def_type', 'type': 'string', 'required': False}, - {'name': 'defense_kwargs', - 'type': 'optional_dict_str_any', - 'required': False}], - 'outputs': [], - 'missing_types': [{'proposed_type': 'union_str_path', - 'missing_type': 'Union[str, Path]'}, - {'proposed_type': 'tuple_int_int_int', - 'missing_type': 'Tuple[int, int, int]'}, - {'proposed_type': 'list_dict_str_str', - 'missing_type': 'List[Dict[str, str]]'}, - {'proposed_type': 'optional_dict_str_any', - 'missing_type': 'Optional[Dict[str, Any]]'}]}, - {'name': 'predict', - 'inputs': [{'name': 'classifier', 'type': 'any', 'required': True}, - {'name': 'dataset', 'type': 'any', 'required': True}, - {'name': 'show_actual', 'type': 'boolean', 'required': False}, - {'name': 'show_target', 'type': 'boolean', 'required': False}], - 'outputs': [], - 'missing_types': []}, - {'name': 'load_predictions', - 'inputs': [{'name': 'paths', 'type': 'list_str', 'required': True}, - {'name': 'filename', 'type': 'string', 'required': True}, - {'name': 'format', 'type': 'string', 'required': False}, - {'name': 'dataset', 'type': 'directoryiterator', 'required': False}, - {'name': 'n_classes', 'type': 'integer', 'required': False}], - 'outputs': [], - 'missing_types': [{'proposed_type': 'list_str', 'missing_type': 'List[str]'}, - {'proposed_type': 'directoryiterator', - 'missing_type': 'DirectoryIterator'}]}] - -expected_outputs['test_alias.py'] = [{ - 'name':'test_plugin', - 'inputs': [], - 'outputs': [], - 'missing_types': [] -}] - -expected_outputs['test_complex_type.py'] = [{ - 'name':'the_plugin', - 'inputs': [ - { - 'name': 'arg1', - 'type': 'optional_str', - 'required': True, - } - ], - 'outputs': [ - { - 'name': 'output', - 'type': 'union_int_bool' - } - ], - 'missing_types': [ - {'proposed_type': 'optional_str', 'missing_type': 'Optional[str]'}, - {'proposed_type': 'union_int_bool', 'missing_type': 'Union[int, bool]'}, - ] -}] - -expected_outputs['test_function_type.py'] = [{ - 'name':'plugin_func', - 'inputs': [ - { - 'name': 'arg1', - 'type': 'type1', - 'required': True, - } - ], - 'outputs': [ - { - 'name': 'output', - 'type': 'type1' - } - ], - 'missing_types': [ - {'proposed_type': 'type1', 'missing_type': 'foo(2)'}, - ] -}] - -expected_outputs['test_none_return.py'] = [{ - 'name':'my_plugin', - 'inputs': [], - 'outputs': [], - 'missing_types': [] -}] - -expected_outputs['test_optional.py'] = [{ - 'name':'do_things', - 'inputs': [ - { - 'name': 'arg1', - 'type': 'optional_str', - 'required': True, - }, - { - 'name': 'arg2', - 'type': 'integer', - 'required': False, - }, - - ], - 'outputs': [], - 'missing_types': [ - {'proposed_type': 'optional_str', 'missing_type': 'Optional[str]'}, - ] -}] - -expected_outputs['test_pyplugs_alias.py'] = [{ - 'name':'test_plugin', - 'inputs': [], - 'outputs': [], - 'missing_types': [] -}] - -expected_outputs['test_redefinition.py'] = [{ - 'name':'test_plugin', - 'inputs': [], - 'outputs': [], - 'missing_types': [] -},{ - 'name':'test_plugin2', - 'inputs': [], - 'outputs': [], - 'missing_types': [] -}] - -expected_outputs['test_register_alias.py'] = [{ - 'name':'test_plugin', - 'inputs': [], - 'outputs': [], - 'missing_types': [] -}] - -expected_outputs['test_type_conflict.py'] = [{ - 'name':'plugin_func', - 'inputs': [ - { - 'name': 'arg1', - 'type': 'type2', - 'required': True, - }, - { - 'name': 'arg2', - 'type': 'type1', - 'required': True, - } - ], - 'outputs': [ - { - 'name': 'output', - 'type': 'type2' - } - ], - 'missing_types': [ - {'proposed_type': 'type2', 'missing_type': 'foo(2)'}, - {'proposed_type': 'type1', 'missing_type': 'Type1'}, - ] -}] - +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from http import HTTPStatus from pathlib import Path from typing import Any -from http import HTTPStatus + from flask_sqlalchemy import SQLAlchemy from dioptra.client.base import DioptraResponseProtocol from dioptra.client.client import DioptraClient +expected_outputs = {} + +expected_outputs["test_real_world.py"] = [ + { + "name": "load_dataset", + "inputs": [ + {"name": "ep_seed", "type": "integer", "required": False}, + {"name": "training_dir", "type": "string", "required": False}, + {"name": "testing_dir", "type": "string", "required": False}, + {"name": "subsets", "type": "list_str", "required": False}, + {"name": "image_size", "type": "tuple_int_int_int", "required": False}, + {"name": "rescale", "type": "number", "required": False}, + {"name": "validation_split", "type": "optional_float", "required": False}, + {"name": "batch_size", "type": "integer", "required": False}, + {"name": "label_mode", "type": "string", "required": False}, + {"name": "shuffle", "type": "boolean", "required": False}, + ], + "outputs": [{"name": "output", "type": "directoryiterator"}], + "missing_types": [ + {"proposed_type": "list_str", "missing_type": "List[str]"}, + { + "proposed_type": "tuple_int_int_int", + "missing_type": "Tuple[int, int, int]", + }, + {"proposed_type": "optional_float", "missing_type": "Optional[float]"}, + {"proposed_type": "directoryiterator", "missing_type": "DirectoryIterator"}, + ], + }, + { + "name": "create_model", + "inputs": [ + {"name": "dataset", "type": "directoryiterator", "required": False}, + {"name": "model_architecture", "type": "string", "required": False}, + {"name": "input_shape", "type": "tuple_int_int_int", "required": False}, + {"name": "loss", "type": "string", "required": False}, + {"name": "learning_rate", "type": "number", "required": False}, + {"name": "optimizer", "type": "string", "required": False}, + {"name": "metrics_list", "type": "list_dict_str_any", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"proposed_type": "directoryiterator", "missing_type": "DirectoryIterator"}, + { + "proposed_type": "tuple_int_int_int", + "missing_type": "Tuple[int, int, int]", + }, + { + "proposed_type": "list_dict_str_any", + "missing_type": "List[Dict[str, Any]]", + }, + ], + }, + { + "name": "load_model", + "inputs": [ + {"name": "model_name", "type": "str_none", "required": False}, + {"name": "model_version", "type": "int_none", "required": False}, + {"name": "imagenet_preprocessing", "type": "boolean", "required": False}, + {"name": "art", "type": "boolean", "required": False}, + {"name": "image_size", "type": "any", "required": False}, + { + "name": "classifier_kwargs", + "type": "optional_dict_str_any", + "required": False, + }, + ], + "outputs": [], + "missing_types": [ + {"proposed_type": "str_none", "missing_type": "str | None"}, + {"proposed_type": "int_none", "missing_type": "int | None"}, + { + "proposed_type": "optional_dict_str_any", + "missing_type": "Optional[Dict[str, Any]]", + }, + ], + }, + { + "name": "train", + "inputs": [ + {"name": "estimator", "type": "any", "required": True}, + {"name": "x", "type": "any", "required": False}, + {"name": "y", "type": "any", "required": False}, + {"name": "callbacks_list", "type": "list_dict_str_any", "required": False}, + {"name": "fit_kwargs", "type": "optional_dict_str_any", "required": False}, + ], + "outputs": [], + "missing_types": [ + { + "proposed_type": "list_dict_str_any", + "missing_type": "List[Dict[str, Any]]", + }, + { + "proposed_type": "optional_dict_str_any", + "missing_type": "Optional[Dict[str, Any]]", + }, + ], + }, + { + "name": "save_artifacts_and_models", + "inputs": [ + {"name": "artifacts", "type": "list_dict_str_any", "required": False}, + {"name": "models", "type": "list_dict_str_any", "required": False}, + ], + "outputs": [], + "missing_types": [ + { + "proposed_type": "list_dict_str_any", + "missing_type": "List[Dict[str, Any]]", + } + ], + }, + { + "name": "load_artifacts_for_job", + "inputs": [ + {"name": "job_id", "type": "string", "required": True}, + {"name": "files", "type": "list_str_path", "required": False}, + {"name": "extract_files", "type": "list_str_path", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"proposed_type": "list_str_path", "missing_type": "List[str | Path]"} + ], + }, + { + "name": "load_artifacts", + "inputs": [ + {"name": "artifact_ids", "type": "list_int", "required": False}, + {"name": "extract_files", "type": "list_str_path", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"proposed_type": "list_int", "missing_type": "List[int]"}, + {"proposed_type": "list_str_path", "missing_type": "List[str | Path]"}, + ], + }, + { + "name": "attack_fgm", + "inputs": [ + {"name": "dataset", "type": "any", "required": True}, + {"name": "adv_data_dir", "type": "union_str_path", "required": True}, + {"name": "classifier", "type": "any", "required": True}, + {"name": "distance_metrics", "type": "list_dict_str_str", "required": True}, + {"name": "batch_size", "type": "integer", "required": False}, + {"name": "eps", "type": "number", "required": False}, + {"name": "eps_step", "type": "number", "required": False}, + {"name": "minimal", "type": "boolean", "required": False}, + {"name": "norm", "type": "union_int_float_str", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"proposed_type": "union_str_path", "missing_type": "Union[str, Path]"}, + { + "proposed_type": "list_dict_str_str", + "missing_type": "List[Dict[str, str]]", + }, + { + "proposed_type": "union_int_float_str", + "missing_type": "Union[int, float, str]", + }, + ], + }, + { + "name": "attack_patch", + "inputs": [ + {"name": "data_flow", "type": "any", "required": True}, + {"name": "adv_data_dir", "type": "union_str_path", "required": True}, + {"name": "model", "type": "any", "required": True}, + {"name": "patch_target", "type": "integer", "required": True}, + {"name": "num_patch", "type": "integer", "required": True}, + {"name": "num_patch_samples", "type": "integer", "required": True}, + {"name": "rotation_max", "type": "number", "required": True}, + {"name": "scale_min", "type": "number", "required": True}, + {"name": "scale_max", "type": "number", "required": True}, + {"name": "learning_rate", "type": "number", "required": True}, + {"name": "max_iter", "type": "integer", "required": True}, + {"name": "patch_shape", "type": "tuple", "required": True}, + ], + "outputs": [], + "missing_types": [ + {"proposed_type": "union_str_path", "missing_type": "Union[str, Path]"}, + {"proposed_type": "tuple", "missing_type": "Tuple"}, + ], + }, + { + "name": "augment_patch", + "inputs": [ + {"name": "data_flow", "type": "any", "required": True}, + {"name": "adv_data_dir", "type": "union_str_path", "required": True}, + {"name": "patch_dir", "type": "union_str_path", "required": True}, + {"name": "model", "type": "any", "required": True}, + {"name": "patch_shape", "type": "tuple", "required": True}, + {"name": "distance_metrics", "type": "list_dict_str_str", "required": True}, + {"name": "batch_size", "type": "integer", "required": False}, + {"name": "patch_scale", "type": "number", "required": False}, + {"name": "rotation_max", "type": "number", "required": False}, + {"name": "scale_min", "type": "number", "required": False}, + {"name": "scale_max", "type": "number", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"proposed_type": "union_str_path", "missing_type": "Union[str, Path]"}, + {"proposed_type": "tuple", "missing_type": "Tuple"}, + { + "proposed_type": "list_dict_str_str", + "missing_type": "List[Dict[str, str]]", + }, + ], + }, + { + "name": "model_metrics", + "inputs": [ + {"name": "classifier", "type": "any", "required": True}, + {"name": "dataset", "type": "any", "required": True}, + ], + "outputs": [], + "missing_types": [], + }, + { + "name": "prediction_metrics", + "inputs": [ + {"name": "y_true", "type": "np_ndarray", "required": True}, + {"name": "y_pred", "type": "np_ndarray", "required": True}, + {"name": "metrics_list", "type": "list_dict_str_str", "required": True}, + {"name": "func_kwargs", "type": "dict_str_dict_str_any", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"proposed_type": "np_ndarray", "missing_type": "np.ndarray"}, + { + "proposed_type": "list_dict_str_str", + "missing_type": "List[Dict[str, str]]", + }, + { + "proposed_type": "dict_str_dict_str_any", + "missing_type": "Dict[str, Dict[str, Any]]", + }, + ], + }, + { + "name": "augment_data", + "inputs": [ + {"name": "dataset", "type": "any", "required": True}, + {"name": "def_data_dir", "type": "union_str_path", "required": True}, + {"name": "image_size", "type": "tuple_int_int_int", "required": True}, + {"name": "distance_metrics", "type": "list_dict_str_str", "required": True}, + {"name": "batch_size", "type": "integer", "required": False}, + {"name": "def_type", "type": "string", "required": False}, + { + "name": "defense_kwargs", + "type": "optional_dict_str_any", + "required": False, + }, + ], + "outputs": [], + "missing_types": [ + {"proposed_type": "union_str_path", "missing_type": "Union[str, Path]"}, + { + "proposed_type": "tuple_int_int_int", + "missing_type": "Tuple[int, int, int]", + }, + { + "proposed_type": "list_dict_str_str", + "missing_type": "List[Dict[str, str]]", + }, + { + "proposed_type": "optional_dict_str_any", + "missing_type": "Optional[Dict[str, Any]]", + }, + ], + }, + { + "name": "predict", + "inputs": [ + {"name": "classifier", "type": "any", "required": True}, + {"name": "dataset", "type": "any", "required": True}, + {"name": "show_actual", "type": "boolean", "required": False}, + {"name": "show_target", "type": "boolean", "required": False}, + ], + "outputs": [], + "missing_types": [], + }, + { + "name": "load_predictions", + "inputs": [ + {"name": "paths", "type": "list_str", "required": True}, + {"name": "filename", "type": "string", "required": True}, + {"name": "format", "type": "string", "required": False}, + {"name": "dataset", "type": "directoryiterator", "required": False}, + {"name": "n_classes", "type": "integer", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"proposed_type": "list_str", "missing_type": "List[str]"}, + {"proposed_type": "directoryiterator", "missing_type": "DirectoryIterator"}, + ], + }, +] + +expected_outputs["test_alias.py"] = [ + {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []} +] + +expected_outputs["test_complex_type.py"] = [ + { + "name": "the_plugin", + "inputs": [ + { + "name": "arg1", + "type": "optional_str", + "required": True, + } + ], + "outputs": [{"name": "output", "type": "union_int_bool"}], + "missing_types": [ + {"proposed_type": "optional_str", "missing_type": "Optional[str]"}, + {"proposed_type": "union_int_bool", "missing_type": "Union[int, bool]"}, + ], + } +] + +expected_outputs["test_function_type.py"] = [ + { + "name": "plugin_func", + "inputs": [ + { + "name": "arg1", + "type": "type1", + "required": True, + } + ], + "outputs": [{"name": "output", "type": "type1"}], + "missing_types": [ + {"proposed_type": "type1", "missing_type": "foo(2)"}, + ], + } +] + +expected_outputs["test_none_return.py"] = [ + {"name": "my_plugin", "inputs": [], "outputs": [], "missing_types": []} +] + +expected_outputs["test_optional.py"] = [ + { + "name": "do_things", + "inputs": [ + { + "name": "arg1", + "type": "optional_str", + "required": True, + }, + { + "name": "arg2", + "type": "integer", + "required": False, + }, + ], + "outputs": [], + "missing_types": [ + {"proposed_type": "optional_str", "missing_type": "Optional[str]"}, + ], + } +] + +expected_outputs["test_pyplugs_alias.py"] = [ + {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []} +] + +expected_outputs["test_redefinition.py"] = [ + {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []}, + {"name": "test_plugin2", "inputs": [], "outputs": [], "missing_types": []}, +] + +expected_outputs["test_register_alias.py"] = [ + {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []} +] + +expected_outputs["test_type_conflict.py"] = [ + { + "name": "plugin_func", + "inputs": [ + { + "name": "arg1", + "type": "type2", + "required": True, + }, + { + "name": "arg2", + "type": "type1", + "required": True, + }, + ], + "outputs": [{"name": "output", "type": "type2"}], + "missing_types": [ + {"proposed_type": "type2", "missing_type": "foo(2)"}, + {"proposed_type": "type1", "missing_type": "Type1"}, + ], + } +] + # -- Assertions ------------------------------------------------------------------------ @@ -358,20 +454,26 @@ def assert_signature_analysis_response_matches_expectations( assert isinstance(response["missing_types"], list) assert isinstance(response["inputs"], list) - def sort_by_name(lst, k='name'): + def sort_by_name(lst, k="name"): return sorted(lst, key=lambda x: x[k]) - assert sort_by_name(response["outputs"]) == sort_by_name(expected_contents["outputs"]) + assert sort_by_name(response["outputs"]) == sort_by_name( + expected_contents["outputs"] + ) assert sort_by_name(response["inputs"]) == sort_by_name(expected_contents["inputs"]) - assert sort_by_name(response["missing_types"], k='proposed_type') == sort_by_name(expected_contents["missing_types"], k='proposed_type') - + assert sort_by_name(response["missing_types"], k="proposed_type") == sort_by_name( + expected_contents["missing_types"], k="proposed_type" + ) + def assert_signature_analysis_responses_matches_expectations( responses: list[dict[str, Any]], expected_contents: list[dict[str, Any]] ) -> None: - assert (len(responses) == len(expected_contents)) + assert len(responses) == len(expected_contents) for response in responses: - assert_signature_analysis_response_matches_expectations(response, [a for a in expected_contents if a['name']==response['name']][0]) + assert_signature_analysis_response_matches_expectations( + response, [a for a in expected_contents if a["name"] == response["name"]][0] + ) def assert_signature_analysis_file_load_and_contents( @@ -380,19 +482,27 @@ def assert_signature_analysis_file_load_and_contents( ): location = Path("tests/unit/restapi/v1/signature_analysis") / filename file_analysis = dioptra_client.workflows.signature_analysis_file(str(location)) - with location.open('r') as f: - contents = f.read() - contents_analysis = dioptra_client.workflows.signature_analysis_contents(contents, filename) + with location.open("r") as f: + contents = f.read() + contents_analysis = dioptra_client.workflows.signature_analysis_contents( + contents, filename + ) + + assert file_analysis.status_code == HTTPStatus.OK + assert contents_analysis.status_code == HTTPStatus.OK - assert(file_analysis.status_code == HTTPStatus.OK) - assert(contents_analysis.status_code == HTTPStatus.OK) - - assert_signature_analysis_responses_matches_expectations(file_analysis.json()["plugins"], expected_contents=expected_outputs[filename]) - assert_signature_analysis_responses_matches_expectations(contents_analysis.json()["plugins"], expected_contents=expected_outputs[filename]) + assert_signature_analysis_responses_matches_expectations( + file_analysis.json()["plugins"], expected_contents=expected_outputs[filename] + ) + assert_signature_analysis_responses_matches_expectations( + contents_analysis.json()["plugins"], + expected_contents=expected_outputs[filename], + ) # -- Tests ----------------------------------------------------------------------------- + def test_signature_analysis( dioptra_client: DioptraClient[DioptraResponseProtocol], db: SQLAlchemy, @@ -408,4 +518,6 @@ def test_signature_analysis( """ for fn in expected_outputs: - assert_signature_analysis_file_load_and_contents(dioptra_client=dioptra_client, filename=fn) + assert_signature_analysis_file_load_and_contents( + dioptra_client=dioptra_client, filename=fn + ) From 96b5d685c1d4795bd0317c9d48a36567a82b0a75 Mon Sep 17 00:00:00 2001 From: jtsextonMITRE <45762017+jtsextonMITRE@users.noreply.github.com> Date: Fri, 24 Jan 2025 11:29:10 -0500 Subject: [PATCH 5/5] chore: fix linting --- src/dioptra/client/workflows.py | 23 ++++++++++--------- .../{test_alias.py => sample_test_alias.py} | 0 ...ex_type.py => sample_test_complex_type.py} | 0 ...n_type.py => sample_test_function_type.py} | 0 ...e_return.py => sample_test_none_return.py} | 0 ...st_optional.py => sample_test_optional.py} | 0 ..._alias.py => sample_test_pyplugs_alias.py} | 0 ...eal_world.py => sample_test_real_world.py} | 0 ...inition.py => sample_test_redefinition.py} | 0 ...alias.py => sample_test_register_alias.py} | 0 ...nflict.py => sample_test_type_conflict.py} | 0 .../restapi/v1/test_signature_analysis.py | 20 ++++++++-------- 12 files changed, 22 insertions(+), 21 deletions(-) rename tests/unit/restapi/v1/signature_analysis/{test_alias.py => sample_test_alias.py} (100%) rename tests/unit/restapi/v1/signature_analysis/{test_complex_type.py => sample_test_complex_type.py} (100%) rename tests/unit/restapi/v1/signature_analysis/{test_function_type.py => sample_test_function_type.py} (100%) rename tests/unit/restapi/v1/signature_analysis/{test_none_return.py => sample_test_none_return.py} (100%) rename tests/unit/restapi/v1/signature_analysis/{test_optional.py => sample_test_optional.py} (100%) rename tests/unit/restapi/v1/signature_analysis/{test_pyplugs_alias.py => sample_test_pyplugs_alias.py} (100%) rename tests/unit/restapi/v1/signature_analysis/{test_real_world.py => sample_test_real_world.py} (100%) rename tests/unit/restapi/v1/signature_analysis/{test_redefinition.py => sample_test_redefinition.py} (100%) rename tests/unit/restapi/v1/signature_analysis/{test_register_alias.py => sample_test_register_alias.py} (100%) rename tests/unit/restapi/v1/signature_analysis/{test_type_conflict.py => sample_test_type_conflict.py} (100%) diff --git a/src/dioptra/client/workflows.py b/src/dioptra/client/workflows.py index 4fd1cb899..2c993b98b 100644 --- a/src/dioptra/client/workflows.py +++ b/src/dioptra/client/workflows.py @@ -89,9 +89,7 @@ def download_job_files( ) def signature_analysis_contents( - self, - fileContents: str, - filename: str ="something.py" + self, fileContents: str, filename: str = "something.py" ) -> T: """ Requests signature analysis for the functions in an annotated python file. @@ -104,12 +102,13 @@ def signature_analysis_contents( The response from the Dioptra API. """ - return self._session.post(self.url, SIGNATURE_ANALYSIS, json_={"filename":filename, "fileContents":fileContents}) - - def signature_analysis_file( - self, - filename: str - ) -> T: + return self._session.post( + self.url, + SIGNATURE_ANALYSIS, + json_={"filename": filename, "fileContents": fileContents}, + ) + + def signature_analysis_file(self, filename: str) -> T: """ Reads a file, and then requests signature analysis for the functions in an annotated python file. @@ -121,6 +120,8 @@ def signature_analysis_file( The response from the Dioptra API. """ - with open(filename, 'r+') as f: + with open(filename, "r+") as f: contents = f.read() - return self.signature_analysis_contents(fileContents=contents, filename=filename) + return self.signature_analysis_contents( + fileContents=contents, filename=filename + ) diff --git a/tests/unit/restapi/v1/signature_analysis/test_alias.py b/tests/unit/restapi/v1/signature_analysis/sample_test_alias.py similarity index 100% rename from tests/unit/restapi/v1/signature_analysis/test_alias.py rename to tests/unit/restapi/v1/signature_analysis/sample_test_alias.py diff --git a/tests/unit/restapi/v1/signature_analysis/test_complex_type.py b/tests/unit/restapi/v1/signature_analysis/sample_test_complex_type.py similarity index 100% rename from tests/unit/restapi/v1/signature_analysis/test_complex_type.py rename to tests/unit/restapi/v1/signature_analysis/sample_test_complex_type.py diff --git a/tests/unit/restapi/v1/signature_analysis/test_function_type.py b/tests/unit/restapi/v1/signature_analysis/sample_test_function_type.py similarity index 100% rename from tests/unit/restapi/v1/signature_analysis/test_function_type.py rename to tests/unit/restapi/v1/signature_analysis/sample_test_function_type.py diff --git a/tests/unit/restapi/v1/signature_analysis/test_none_return.py b/tests/unit/restapi/v1/signature_analysis/sample_test_none_return.py similarity index 100% rename from tests/unit/restapi/v1/signature_analysis/test_none_return.py rename to tests/unit/restapi/v1/signature_analysis/sample_test_none_return.py diff --git a/tests/unit/restapi/v1/signature_analysis/test_optional.py b/tests/unit/restapi/v1/signature_analysis/sample_test_optional.py similarity index 100% rename from tests/unit/restapi/v1/signature_analysis/test_optional.py rename to tests/unit/restapi/v1/signature_analysis/sample_test_optional.py diff --git a/tests/unit/restapi/v1/signature_analysis/test_pyplugs_alias.py b/tests/unit/restapi/v1/signature_analysis/sample_test_pyplugs_alias.py similarity index 100% rename from tests/unit/restapi/v1/signature_analysis/test_pyplugs_alias.py rename to tests/unit/restapi/v1/signature_analysis/sample_test_pyplugs_alias.py diff --git a/tests/unit/restapi/v1/signature_analysis/test_real_world.py b/tests/unit/restapi/v1/signature_analysis/sample_test_real_world.py similarity index 100% rename from tests/unit/restapi/v1/signature_analysis/test_real_world.py rename to tests/unit/restapi/v1/signature_analysis/sample_test_real_world.py diff --git a/tests/unit/restapi/v1/signature_analysis/test_redefinition.py b/tests/unit/restapi/v1/signature_analysis/sample_test_redefinition.py similarity index 100% rename from tests/unit/restapi/v1/signature_analysis/test_redefinition.py rename to tests/unit/restapi/v1/signature_analysis/sample_test_redefinition.py diff --git a/tests/unit/restapi/v1/signature_analysis/test_register_alias.py b/tests/unit/restapi/v1/signature_analysis/sample_test_register_alias.py similarity index 100% rename from tests/unit/restapi/v1/signature_analysis/test_register_alias.py rename to tests/unit/restapi/v1/signature_analysis/sample_test_register_alias.py diff --git a/tests/unit/restapi/v1/signature_analysis/test_type_conflict.py b/tests/unit/restapi/v1/signature_analysis/sample_test_type_conflict.py similarity index 100% rename from tests/unit/restapi/v1/signature_analysis/test_type_conflict.py rename to tests/unit/restapi/v1/signature_analysis/sample_test_type_conflict.py diff --git a/tests/unit/restapi/v1/test_signature_analysis.py b/tests/unit/restapi/v1/test_signature_analysis.py index 714f11803..2a5818f25 100644 --- a/tests/unit/restapi/v1/test_signature_analysis.py +++ b/tests/unit/restapi/v1/test_signature_analysis.py @@ -25,7 +25,7 @@ expected_outputs = {} -expected_outputs["test_real_world.py"] = [ +expected_outputs["sample_test_real_world.py"] = [ { "name": "load_dataset", "inputs": [ @@ -321,11 +321,11 @@ }, ] -expected_outputs["test_alias.py"] = [ +expected_outputs["sample_test_alias.py"] = [ {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []} ] -expected_outputs["test_complex_type.py"] = [ +expected_outputs["sample_test_complex_type.py"] = [ { "name": "the_plugin", "inputs": [ @@ -343,7 +343,7 @@ } ] -expected_outputs["test_function_type.py"] = [ +expected_outputs["sample_test_function_type.py"] = [ { "name": "plugin_func", "inputs": [ @@ -360,11 +360,11 @@ } ] -expected_outputs["test_none_return.py"] = [ +expected_outputs["sample_test_none_return.py"] = [ {"name": "my_plugin", "inputs": [], "outputs": [], "missing_types": []} ] -expected_outputs["test_optional.py"] = [ +expected_outputs["sample_test_optional.py"] = [ { "name": "do_things", "inputs": [ @@ -386,20 +386,20 @@ } ] -expected_outputs["test_pyplugs_alias.py"] = [ +expected_outputs["sample_test_pyplugs_alias.py"] = [ {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []} ] -expected_outputs["test_redefinition.py"] = [ +expected_outputs["sample_test_redefinition.py"] = [ {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []}, {"name": "test_plugin2", "inputs": [], "outputs": [], "missing_types": []}, ] -expected_outputs["test_register_alias.py"] = [ +expected_outputs["sample_test_register_alias.py"] = [ {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []} ] -expected_outputs["test_type_conflict.py"] = [ +expected_outputs["sample_test_type_conflict.py"] = [ { "name": "plugin_func", "inputs": [