diff --git a/lookout/style/format/__main__.py b/lookout/style/format/__main__.py index ef852a277..7d42be4f7 100644 --- a/lookout/style/format/__main__.py +++ b/lookout/style/format/__main__.py @@ -8,7 +8,7 @@ def create_parser(): parser = argparse.ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatterNoNone) - subparsers = parser.add_subparsers(help="Commands", dest="command") + subparsers = parser.add_subparsers(help="Commands") def add_parser(name, help): return subparsers.add_parser( @@ -17,7 +17,7 @@ def add_parser(name, help): # Evaluation eval_parser = add_parser("eval", "Evaluate trained model on given dataset.") eval_parser.set_defaults(handler=quality_report) - eval_parser.add_argument("-i", "--input", required=True, type=str, + eval_parser.add_argument("-i", "--input-pattern", required=True, type=str, help="Path to folder with source code - " "should be in a format compatible with glob (ends with**/* " "and surrounded by quotes. Ex: `path/**/*`).") @@ -33,8 +33,8 @@ def add_parser(name, help): # Visualization vis_parser = add_parser("vis", "Visualize mispredictions of the model on the given file.") vis_parser.set_defaults(handler=visualize) - vis_parser.add_argument("-i", "--input", required=True, help="Path to folder with source " - "code.") + vis_parser.add_argument("-i", "--input-filename", required=True, + help="Path to file to analyze.") vis_parser.add_argument("--bblfsh", default="0.0.0.0:9432", help="Babelfish server's address.") vis_parser.add_argument("-l", "--language", default="javascript", @@ -48,12 +48,13 @@ def main(): args = parser.parse_args() try: handler = args.handler + delattr(args, "handler") except AttributeError: def print_usage(_): parser.print_usage() handler = print_usage - return handler(args) + return handler(**vars(args)) if __name__ == "__main__": diff --git a/lookout/style/format/quality_report.py b/lookout/style/format/quality_report.py index 49fa92c1a..5bd81e99d 100644 --- a/lookout/style/format/quality_report.py +++ b/lookout/style/format/quality_report.py @@ -2,7 +2,7 @@ import glob import os -import bblfsh +from bblfsh import BblfshClient from bblfsh.client import NonUTF8ContentException import numpy from sklearn.metrics import classification_report,confusion_matrix @@ -36,16 +36,17 @@ def prepare_files(folder, client, language): return files -def quality_report(args): - client = bblfsh.BblfshClient(args.bblfsh) - files = prepare_files(args.input, client, args.language) +def quality_report(input_pattern: str, bblfsh: str, language: str, n_files: int, + model: str) -> None: + client = BblfshClient(bblfsh) + files = prepare_files(input_pattern, client, language) print("Number of files: %s" % (len(files))) - fe = FeatureExtractor(language=args.language) + fe = FeatureExtractor(language=language) X, y, nodes = fe.extract_features(files) - analyzer = FormatModel().load(args.model) - rules = analyzer._rules_by_lang[args.language] + analyzer = FormatModel().load(model) + rules = analyzer._rules_by_lang[language] y_pred = rules.predict(X) target_names = [CLASSES[cls_ind] for cls_ind in numpy.unique(y)] @@ -60,7 +61,7 @@ def quality_report(args): file_stat = Counter(file_mispred) to_show = file_stat.most_common() - if args.n_files > 0: - to_show = to_show[:args.n_files] + if n_files > 0: + to_show = to_show[:n_files] print("Files with most errors:\n" + "\n".join(map(str, to_show))) diff --git a/lookout/style/format/visualization.py b/lookout/style/format/visualization.py index dd7b3e5b5..091b7a39d 100644 --- a/lookout/style/format/visualization.py +++ b/lookout/style/format/visualization.py @@ -1,7 +1,7 @@ from collections import namedtuple import os -import bblfsh +from bblfsh import BblfshClient from lookout.core.api.service_data_pb2 import File from lookout.style.format.features import FeatureExtractor, CLASSES @@ -28,15 +28,15 @@ def prepare_file(filename, client, language): return File(content=content, uast=res.uast, path=filename) -def visualize(args): - client = bblfsh.BblfshClient(args.bblfsh) - file = prepare_file(args.input, client, args.language) +def visualize(input_filename: str, bblfsh: str, language: str, model: str) -> None: + client = BblfshClient(bblfsh) + file = prepare_file(input_filename, client, language) - fe = FeatureExtractor(language=args.language) + fe = FeatureExtractor(language=language) X, y, nodes = fe.extract_features([file]) - analyzer = FormatModel().load(args.model) - rules = analyzer._rules_by_lang[args.language] + analyzer = FormatModel().load(model) + rules = analyzer._rules_by_lang[language] y_pred = rules.predict(X) mispred = []