diff --git a/plugins/pyproject.toml b/plugins/pyproject.toml index c4c679100..4d80dcf66 100644 --- a/plugins/pyproject.toml +++ b/plugins/pyproject.toml @@ -15,9 +15,6 @@ polyglot_piranha = "*" [tool.poetry.dev-dependencies] pytest = "*" -# [tool.poetry.scripts] -# scala_test = "scala_test.main:main" - [tool.poetry.scripts."scala_test"] main = "scala_test.main:main" diff --git a/plugins/scala_test/README.md b/plugins/scala_test/README.md index 6d98d772a..0858f7c73 100644 --- a/plugins/scala_test/README.md +++ b/plugins/scala_test/README.md @@ -1,4 +1,15 @@ -# `scalatest` Migration Plugin +# `scalatest` Migration Plugin (WIP) + +This piranha plugin updates `scalatest` to a new version. + + +Currently, it updates to [v.3.2.2](https://mvnrepository.com/artifact/org.scalatest/scalatest_2.12/3.2.2) only. The following import statements are updated: +* `org.scalatest.Matchers`-> `org.scalatest.matchers.should.Matchers` +* `org.scalatest.mock.MockitoSugar`-> `org.scalatestplus.mockito.MockitoSugar` +* `org.scalatest.FunSuite`->`org.scalatest.funsuite.AnyFunSuite` +* `org.scalatest.junit.JUnitRunner`->`org.scalatestplus.junit.JUnitRunner` +* `org.scalatest.FlatSpec`-> `org.scalatest.flatspec.AnyFlatSpec` +* `org.scalatest.junit.AssertionsForJUnit`-> `org.scalatestplus.junit.AssertionsForJUnit` ## Usage: @@ -12,7 +23,7 @@ CLI: ``` usage: main.py [-h] --path_to_codebase PATH_TO_CODEBASE -Migrates scala tests!!! +Updates the codebase to use a new version of `scalatest_2.12`. options: -h, --help show this help message and exit diff --git a/plugins/scala_test/main.py b/plugins/scala_test/main.py index 42249aa4b..63f77498f 100644 --- a/plugins/scala_test/main.py +++ b/plugins/scala_test/main.py @@ -2,19 +2,24 @@ from update_imports import update_imports def _parse_args(): - parser = argparse.ArgumentParser(description="Migrates scala tests!!!") + parser = argparse.ArgumentParser(description="Updates the codebase to use a new version of `scalatest_2.12`") parser.add_argument( "--path_to_codebase", required=True, help="Path to the codebase directory.", ) - + parser.add_argument( + "--new_version", + required=True, + default="3.2.2", + help="Version of `scalatest` to update to.", + ) args = parser.parse_args() return args def main(): args = _parse_args() - update_imports(args.path_to_codebase, dry_run=True) + update_imports(args.path_to_codebase, args.new_version, dry_run=True) if __name__ == "__main__": main() diff --git a/plugins/scala_test/recipes.py b/plugins/scala_test/recipes.py index 8d9ecbde7..d6f9ab65e 100644 --- a/plugins/scala_test/recipes.py +++ b/plugins/scala_test/recipes.py @@ -1,42 +1,69 @@ -from polyglot_piranha import Rule, OutgoingEdges, RuleGraph, PiranhaArguments, execute_piranha +from polyglot_piranha import ( + Rule, + OutgoingEdges, + RuleGraph, + PiranhaArguments, + execute_piranha, +) + def replace_imports( - target_new_types: dict[str, str], search_heuristic: str, path_to_codebase: str, - dry_run = False + target_new_types: dict[str, str], + search_heuristic: str, + path_to_codebase: str, + dry_run=False, ): + """This function replaces the imports of the target types with the new types. + The search heuristic is used to find the files that contain the target types. + + Args: + target_new_types (dict[str, str]): A dictionary from target type to new type (fully qualified names) + search_heuristic (str): The search heuristic to find the files that contain the target types + path_to_codebase (str): The path to the codebase + dry_run (bool, optional): True if the changes should not be written to disk. Defaults to False. + + Returns: + _type_: A list of PiranhaOutput objects + """ find_relevant_files = Rule( name="find_relevant_files", - query="((identifier) @x (#eq? @x \"@search_heuristic\"))", + query='((identifier) @x (#eq? @x "@search_heuristic"))', holes={"search_heuristic"}, ) - e1 = OutgoingEdges("find_relevant_files", to=[f"update_import"], scope="File") + find_relevant_files_andThen_update_import = OutgoingEdges( + "find_relevant_files", to=["update_import"], scope="File" + ) rules = [find_relevant_files] - edges = [e1] + edges = [find_relevant_files_andThen_update_import] for target_type, new_type in target_new_types.items(): - rs, es = replace_import_rules_edges(target_type, new_type) + rs, es = replace_import_rules_and_edges(target_type, new_type) rules.extend(rs) edges.extend(es) rule_graph = RuleGraph(rules=rules, edges=edges) - args= PiranhaArguments( + args = PiranhaArguments( language="scala", path_to_codebase=path_to_codebase, rule_graph=rule_graph, substitutions={"search_heuristic": f"{search_heuristic}"}, - dry_run=dry_run + dry_run=dry_run, ) - + return execute_piranha(args) - -def replace_import_rules_edges( +def replace_import_rules_and_edges( target_qualified_type_name: str, new_qualified_type_name: str ) -> (list[Rule], list[OutgoingEdges]): - + """This function generates the rules and edges to replace the imports of the target type with the new type. + It supports both simple and nested imports. While the simple imports are replaced directly, the nested imports are deleted and the new type is imported (as a simple non-nested import). + Assume that the target type is "a.b.c.d" and the new type is "x.y.z". Then the following rules are generated: + import a.b.c.d -> import x.y.z + import a.b.c.{d, e} -> import x.y.z \n import a.b.c.{d} + """ name_components = target_qualified_type_name.split(".") type_name = name_components[-1] diff --git a/plugins/scala_test/tests/test_update_imports.py b/plugins/scala_test/tests/test_update_imports.py index 405b5a352..3db94789b 100644 --- a/plugins/scala_test/tests/test_update_imports.py +++ b/plugins/scala_test/tests/test_update_imports.py @@ -1,14 +1,11 @@ -from logging import debug, error +from logging import error from pathlib import Path - from os.path import join, basename from os import listdir - from update_imports import update_imports -# from update_imports import update_imports def test_update_imports(): - summary = update_imports("plugins/scala_test/tests/resources/input/", dry_run=True) + summary = update_imports("plugins/scala_test/tests/resources/input/", "3.2.2", dry_run=True) assert is_as_expected("plugins/scala_test/tests/resources/", summary) def is_as_expected(path_to_scenario, output_summary): diff --git a/plugins/scala_test/update_imports.py b/plugins/scala_test/update_imports.py index 181a729fb..b6d802146 100644 --- a/plugins/scala_test/update_imports.py +++ b/plugins/scala_test/update_imports.py @@ -11,5 +11,7 @@ "org.scalatest.junit.AssertionsForJUnit": "org.scalatestplus.junit.AssertionsForJUnit", } -def update_imports(path_to_codebase: str, dry_run = False): - return replace_imports(IMPORT_MAPPING, "scalatest", path_to_codebase, dry_run) +def update_imports(path_to_codebase: str, scalatest_version,dry_run = False): + if scalatest_version == "3.2.2": + return replace_imports(IMPORT_MAPPING, "scalatest", path_to_codebase, dry_run) + raise Exception(f"Unsupported version: {scalatest_version}")