diff --git a/.gitignore b/.gitignore index 29e3930f6..0694fd045 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ target Cargo.lock tmp_test* env/ +**.egg-info # Dependencies diff --git a/plugins/pyproject.toml b/plugins/pyproject.toml new file mode 100644 index 000000000..4d80dcf66 --- /dev/null +++ b/plugins/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.poetry] +name = "scala_test" +version = "0.0.1" +description = "Rules to migrate 'scaletest'" +# Add any other metadata you need + +[tool.poetry.dependencies] +python = "^3.9" +polyglot_piranha = "*" + +[tool.poetry.dev-dependencies] +pytest = "*" + +[tool.poetry.scripts."scala_test"] +main = "scala_test.main:main" + +[tool.poetry.scripts."pytest"] +main = "pytest" diff --git a/plugins/scala_test/README.md b/plugins/scala_test/README.md new file mode 100644 index 000000000..0858f7c73 --- /dev/null +++ b/plugins/scala_test/README.md @@ -0,0 +1,37 @@ +# `scalatest` Migration Plugin (WIP) + +This piranha plugin updates `scalatest` to a new version. + + +Currently, it updates to [v.3.2.2](https://mvnrepository.com/artifact/org.scalatest/scalatest_2.12/3.2.2) only. The following import statements are updated: +* `org.scalatest.Matchers`-> `org.scalatest.matchers.should.Matchers` +* `org.scalatest.mock.MockitoSugar`-> `org.scalatestplus.mockito.MockitoSugar` +* `org.scalatest.FunSuite`->`org.scalatest.funsuite.AnyFunSuite` +* `org.scalatest.junit.JUnitRunner`->`org.scalatestplus.junit.JUnitRunner` +* `org.scalatest.FlatSpec`-> `org.scalatest.flatspec.AnyFlatSpec` +* `org.scalatest.junit.AssertionsForJUnit`-> `org.scalatestplus.junit.AssertionsForJUnit` + +## Usage: + +Clone the repository - `git clone https://github.com/uber/piranha.git` + +Install the dependencies - `pip3 install -r plugins/scala_test/requirements.txt` + +Run the tool - `python3 plugins/scala_test/main.py -h` + +CLI: +``` +usage: main.py [-h] --path_to_codebase PATH_TO_CODEBASE + +Updates the codebase to use a new version of `scalatest_2.12`. + +options: + -h, --help show this help message and exit + --path_to_codebase PATH_TO_CODEBASE + Path to the codebase directory. +``` + +## Test +``` +pytest plugins/scala_test +``` diff --git a/plugins/scala_test/main.py b/plugins/scala_test/main.py new file mode 100644 index 000000000..63f77498f --- /dev/null +++ b/plugins/scala_test/main.py @@ -0,0 +1,25 @@ +import argparse +from update_imports import update_imports + +def _parse_args(): + parser = argparse.ArgumentParser(description="Updates the codebase to use a new version of `scalatest_2.12`") + parser.add_argument( + "--path_to_codebase", + required=True, + help="Path to the codebase directory.", + ) + parser.add_argument( + "--new_version", + required=True, + default="3.2.2", + help="Version of `scalatest` to update to.", + ) + args = parser.parse_args() + return args + +def main(): + args = _parse_args() + update_imports(args.path_to_codebase, args.new_version, dry_run=True) + +if __name__ == "__main__": + main() diff --git a/plugins/scala_test/recipes.py b/plugins/scala_test/recipes.py new file mode 100644 index 000000000..d6f9ab65e --- /dev/null +++ b/plugins/scala_test/recipes.py @@ -0,0 +1,110 @@ +from polyglot_piranha import ( + Rule, + OutgoingEdges, + RuleGraph, + PiranhaArguments, + execute_piranha, +) + + +def replace_imports( + target_new_types: dict[str, str], + search_heuristic: str, + path_to_codebase: str, + dry_run=False, +): + """This function replaces the imports of the target types with the new types. + The search heuristic is used to find the files that contain the target types. + + Args: + target_new_types (dict[str, str]): A dictionary from target type to new type (fully qualified names) + search_heuristic (str): The search heuristic to find the files that contain the target types + path_to_codebase (str): The path to the codebase + dry_run (bool, optional): True if the changes should not be written to disk. Defaults to False. + + Returns: + _type_: A list of PiranhaOutput objects + """ + find_relevant_files = Rule( + name="find_relevant_files", + query='((identifier) @x (#eq? @x "@search_heuristic"))', + holes={"search_heuristic"}, + ) + find_relevant_files_andThen_update_import = OutgoingEdges( + "find_relevant_files", to=["update_import"], scope="File" + ) + + rules = [find_relevant_files] + edges = [find_relevant_files_andThen_update_import] + + for target_type, new_type in target_new_types.items(): + rs, es = replace_import_rules_and_edges(target_type, new_type) + rules.extend(rs) + edges.extend(es) + + rule_graph = RuleGraph(rules=rules, edges=edges) + + args = PiranhaArguments( + language="scala", + path_to_codebase=path_to_codebase, + rule_graph=rule_graph, + substitutions={"search_heuristic": f"{search_heuristic}"}, + dry_run=dry_run, + ) + + return execute_piranha(args) + + +def replace_import_rules_and_edges( + target_qualified_type_name: str, new_qualified_type_name: str +) -> (list[Rule], list[OutgoingEdges]): + """This function generates the rules and edges to replace the imports of the target type with the new type. + It supports both simple and nested imports. While the simple imports are replaced directly, the nested imports are deleted and the new type is imported (as a simple non-nested import). + Assume that the target type is "a.b.c.d" and the new type is "x.y.z". Then the following rules are generated: + import a.b.c.d -> import x.y.z + import a.b.c.{d, e} -> import x.y.z \n import a.b.c.{d} + """ + name_components = target_qualified_type_name.split(".") + type_name = name_components[-1] + + qualifier_predicate = "\n".join( + [f'(#match? @import_decl "{n}")' for n in name_components[:-1]] + ) + + delete_nested_import = Rule( + name=f"delete_nested_import_{type_name}", + query=f"""( + (import_declaration (namespace_selectors (_) @tn )) @import_decl + (#eq? @tn "{type_name}") + {qualifier_predicate} + )""", + replace_node="tn", + replace="", + is_seed_rule=False, + groups={"update_import"}, + ) + + update_simple_import = Rule( + name=f"update_simple_import_{type_name}", + query=f"cs import {target_qualified_type_name}", + replace_node="*", + replace=f"import {new_qualified_type_name}", + is_seed_rule=False, + groups={"update_import"}, + ) + + insert_import = Rule( + name=f"insert_import_{type_name}", + query="(import_declaration) @import_decl", + replace_node="import_decl", + replace=f"@import_decl\nimport {new_qualified_type_name}\n", + is_seed_rule=False, + ) + + e2 = OutgoingEdges( + f"delete_nested_import_{type_name}", + to=[f"insert_import_{type_name}"], + scope="Parent", + ) + + return [delete_nested_import, update_simple_import, insert_import], [e2] diff --git a/plugins/scala_test/requirements.txt b/plugins/scala_test/requirements.txt new file mode 100644 index 000000000..b3cd0ac6e --- /dev/null +++ b/plugins/scala_test/requirements.txt @@ -0,0 +1,2 @@ +polyglot-piranha +pytest diff --git a/plugins/scala_test/tests/__init__.py b/plugins/scala_test/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/scala_test/tests/resources/expected/sample.scala b/plugins/scala_test/tests/resources/expected/sample.scala new file mode 100644 index 000000000..eabe14b08 --- /dev/null +++ b/plugins/scala_test/tests/resources/expected/sample.scala @@ -0,0 +1,8 @@ +package com.scala.piranha + +import com.uber.michelangelo.AbstractSparkSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} +import org.scalatest.{BeforeAndAfter} +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar diff --git a/plugins/scala_test/tests/resources/input/sample.scala b/plugins/scala_test/tests/resources/input/sample.scala new file mode 100644 index 000000000..189bea734 --- /dev/null +++ b/plugins/scala_test/tests/resources/input/sample.scala @@ -0,0 +1,7 @@ +package com.scala.piranha + +import com.uber.michelangelo.AbstractSparkSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} +import org.scalatest.{BeforeAndAfter, Matchers} +import org.scalatest.mock.MockitoSugar diff --git a/plugins/scala_test/tests/test_update_imports.py b/plugins/scala_test/tests/test_update_imports.py new file mode 100644 index 000000000..3db94789b --- /dev/null +++ b/plugins/scala_test/tests/test_update_imports.py @@ -0,0 +1,37 @@ +from logging import error +from pathlib import Path +from os.path import join, basename +from os import listdir +from update_imports import update_imports + +def test_update_imports(): + summary = update_imports("plugins/scala_test/tests/resources/input/", "3.2.2", dry_run=True) + assert is_as_expected("plugins/scala_test/tests/resources/", summary) + +def is_as_expected(path_to_scenario, output_summary): + expected_output = join(path_to_scenario, "expected") + print("Summary", output_summary) + input_dir = join(path_to_scenario, "input") + for file_name in listdir(expected_output): + with open(join(expected_output, file_name), "r") as f: + file_content = f.read() + expected_content = "".join(file_content.split()) + + # Search for the file in the output summary + updated_content = [ + "".join(o.content.split()) + for o in output_summary + if basename(o.path) == file_name + ] + print(file_name) + # Check if the file was rewritten + if updated_content: + if expected_content != updated_content[0]: + error("----update" + updated_content[0] ) + return False + else: + # The scenario where the file is not expected to be rewritten + original_content= Path(join(input_dir, file_name)).read_text() + if expected_content != "".join(original_content.split()): + return False + return True diff --git a/plugins/scala_test/update_imports.py b/plugins/scala_test/update_imports.py new file mode 100644 index 000000000..b6d802146 --- /dev/null +++ b/plugins/scala_test/update_imports.py @@ -0,0 +1,17 @@ +from recipes import replace_imports + + +IMPORT_MAPPING = { + "org.scalatest.Matchers": "org.scalatest.matchers.should.Matchers", + "org.scalatest.mock.MockitoSugar": "org.scalatestplus.mockito.MockitoSugar", + # Todo write test scenarios for these + "org.scalatest.FunSuite":"org.scalatest.funsuite.AnyFunSuite", + "org.scalatest.junit.JUnitRunner":"org.scalatestplus.junit.JUnitRunner", + "org.scalatest.FlatSpec": "org.scalatest.flatspec.AnyFlatSpec", + "org.scalatest.junit.AssertionsForJUnit": "org.scalatestplus.junit.AssertionsForJUnit", +} + +def update_imports(path_to_codebase: str, scalatest_version,dry_run = False): + if scalatest_version == "3.2.2": + return replace_imports(IMPORT_MAPPING, "scalatest", path_to_codebase, dry_run) + raise Exception(f"Unsupported version: {scalatest_version}")