From 5b18385634c30681b5343b07e8247faff305debf Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Fri, 11 Aug 2023 16:02:21 -0700 Subject: [PATCH 1/4] Adding plugin to migrate scalatest --- plugins/scala_test/main.py | 23 +++++ plugins/scala_test/recipes.py | 83 +++++++++++++++++++ plugins/scala_test/tests/__init__.py | 0 .../tests/resources/expected/sample.scala | 8 ++ .../tests/resources/input/sample.scala | 7 ++ .../scala_test/tests/test_update_imports.py | 40 +++++++++ plugins/scala_test/update_imports.py | 15 ++++ plugins/setup.py | 39 +++++++++ 8 files changed, 215 insertions(+) create mode 100644 plugins/scala_test/main.py create mode 100644 plugins/scala_test/recipes.py create mode 100644 plugins/scala_test/tests/__init__.py create mode 100644 plugins/scala_test/tests/resources/expected/sample.scala create mode 100644 plugins/scala_test/tests/resources/input/sample.scala create mode 100644 plugins/scala_test/tests/test_update_imports.py create mode 100644 plugins/scala_test/update_imports.py create mode 100644 plugins/setup.py diff --git a/plugins/scala_test/main.py b/plugins/scala_test/main.py new file mode 100644 index 000000000..6f7eb58f7 --- /dev/null +++ b/plugins/scala_test/main.py @@ -0,0 +1,23 @@ +import argparse +from update_imports import update_imports + + + + +def _parse_args(): + parser = argparse.ArgumentParser(description="Migrates scala tests!!!") + parser.add_argument( + "--path_to_codebase", + required=True, + help="Path to the codebase directory.", + ) + + args = parser.parse_args() + return args + +def main(): + args = _parse_args() + update_imports(args.path_to_codebase, dry_run=True) + +if __name__ == "__main__": + main() diff --git a/plugins/scala_test/recipes.py b/plugins/scala_test/recipes.py new file mode 100644 index 000000000..8d9ecbde7 --- /dev/null +++ b/plugins/scala_test/recipes.py @@ -0,0 +1,83 @@ +from polyglot_piranha import Rule, OutgoingEdges, RuleGraph, PiranhaArguments, execute_piranha + +def replace_imports( + target_new_types: dict[str, str], search_heuristic: str, path_to_codebase: str, + dry_run = False +): + find_relevant_files = Rule( + name="find_relevant_files", + query="((identifier) @x (#eq? @x \"@search_heuristic\"))", + holes={"search_heuristic"}, + ) + e1 = OutgoingEdges("find_relevant_files", to=[f"update_import"], scope="File") + + rules = [find_relevant_files] + edges = [e1] + + for target_type, new_type in target_new_types.items(): + rs, es = replace_import_rules_edges(target_type, new_type) + rules.extend(rs) + edges.extend(es) + + rule_graph = RuleGraph(rules=rules, edges=edges) + + args= PiranhaArguments( + language="scala", + path_to_codebase=path_to_codebase, + rule_graph=rule_graph, + substitutions={"search_heuristic": f"{search_heuristic}"}, + dry_run=dry_run + ) + + return execute_piranha(args) + + + +def replace_import_rules_edges( + target_qualified_type_name: str, new_qualified_type_name: str +) -> (list[Rule], list[OutgoingEdges]): + + name_components = target_qualified_type_name.split(".") + type_name = name_components[-1] + + qualifier_predicate = "\n".join( + [f'(#match? @import_decl "{n}")' for n in name_components[:-1]] + ) + + delete_nested_import = Rule( + name=f"delete_nested_import_{type_name}", + query=f"""( + (import_declaration (namespace_selectors (_) @tn )) @import_decl + (#eq? @tn "{type_name}") + {qualifier_predicate} + )""", + replace_node="tn", + replace="", + is_seed_rule=False, + groups={"update_import"}, + ) + + update_simple_import = Rule( + name=f"update_simple_import_{type_name}", + query=f"cs import {target_qualified_type_name}", + replace_node="*", + replace=f"import {new_qualified_type_name}", + is_seed_rule=False, + groups={"update_import"}, + ) + + insert_import = Rule( + name=f"insert_import_{type_name}", + query="(import_declaration) @import_decl", + replace_node="import_decl", + replace=f"@import_decl\nimport {new_qualified_type_name}\n", + is_seed_rule=False, + ) + + e2 = OutgoingEdges( + f"delete_nested_import_{type_name}", + to=[f"insert_import_{type_name}"], + scope="Parent", + ) + + return [delete_nested_import, update_simple_import, insert_import], [e2] diff --git a/plugins/scala_test/tests/__init__.py b/plugins/scala_test/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plugins/scala_test/tests/resources/expected/sample.scala b/plugins/scala_test/tests/resources/expected/sample.scala new file mode 100644 index 000000000..eabe14b08 --- /dev/null +++ b/plugins/scala_test/tests/resources/expected/sample.scala @@ -0,0 +1,8 @@ +package com.scala.piranha + +import com.uber.michelangelo.AbstractSparkSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} +import org.scalatest.{BeforeAndAfter} +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar diff --git a/plugins/scala_test/tests/resources/input/sample.scala b/plugins/scala_test/tests/resources/input/sample.scala new file mode 100644 index 000000000..189bea734 --- /dev/null +++ b/plugins/scala_test/tests/resources/input/sample.scala @@ -0,0 +1,7 @@ +package com.scala.piranha + +import com.uber.michelangelo.AbstractSparkSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} +import org.scalatest.{BeforeAndAfter, Matchers} +import org.scalatest.mock.MockitoSugar diff --git a/plugins/scala_test/tests/test_update_imports.py b/plugins/scala_test/tests/test_update_imports.py new file mode 100644 index 000000000..405b5a352 --- /dev/null +++ b/plugins/scala_test/tests/test_update_imports.py @@ -0,0 +1,40 @@ +from logging import debug, error +from pathlib import Path + +from os.path import join, basename +from os import listdir + +from update_imports import update_imports +# from update_imports import update_imports + +def test_update_imports(): + summary = update_imports("plugins/scala_test/tests/resources/input/", dry_run=True) + assert is_as_expected("plugins/scala_test/tests/resources/", summary) + +def is_as_expected(path_to_scenario, output_summary): + expected_output = join(path_to_scenario, "expected") + print("Summary", output_summary) + input_dir = join(path_to_scenario, "input") + for file_name in listdir(expected_output): + with open(join(expected_output, file_name), "r") as f: + file_content = f.read() + expected_content = "".join(file_content.split()) + + # Search for the file in the output summary + updated_content = [ + "".join(o.content.split()) + for o in output_summary + if basename(o.path) == file_name + ] + print(file_name) + # Check if the file was rewritten + if updated_content: + if expected_content != updated_content[0]: + error("----update" + updated_content[0] ) + return False + else: + # The scenario where the file is not expected to be rewritten + original_content= Path(join(input_dir, file_name)).read_text() + if expected_content != "".join(original_content.split()): + return False + return True diff --git a/plugins/scala_test/update_imports.py b/plugins/scala_test/update_imports.py new file mode 100644 index 000000000..181a729fb --- /dev/null +++ b/plugins/scala_test/update_imports.py @@ -0,0 +1,15 @@ +from recipes import replace_imports + + +IMPORT_MAPPING = { + "org.scalatest.Matchers": "org.scalatest.matchers.should.Matchers", + "org.scalatest.mock.MockitoSugar": "org.scalatestplus.mockito.MockitoSugar", + # Todo write test scenarios for these + "org.scalatest.FunSuite":"org.scalatest.funsuite.AnyFunSuite", + "org.scalatest.junit.JUnitRunner":"org.scalatestplus.junit.JUnitRunner", + "org.scalatest.FlatSpec": "org.scalatest.flatspec.AnyFlatSpec", + "org.scalatest.junit.AssertionsForJUnit": "org.scalatestplus.junit.AssertionsForJUnit", +} + +def update_imports(path_to_codebase: str, dry_run = False): + return replace_imports(IMPORT_MAPPING, "scalatest", path_to_codebase, dry_run) diff --git a/plugins/setup.py b/plugins/setup.py new file mode 100644 index 000000000..5f72c7ac4 --- /dev/null +++ b/plugins/setup.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 Uber Technologies, Inc. +# +#
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +# except in compliance with the License. You may obtain a copy of the License at +#
http://www.apache.org/licenses/LICENSE-2.0 +# +#
Unless required by applicable law or agreed to in writing, software distributed under the +# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing permissions and +# limitations under the License. + +from setuptools import find_packages, setup + +setup( + name="scala_test", + version="0.0.1", + description="Rules to migrate `scaletest`", + # long_description=open("README.md").read(), + # long_description_content_type="text/markdown", + # url="https://github.com/uber/piranha", + packages=find_packages(), + include_package_data=True, + install_requires=[ + # "polyglot-piranha", + "pytest", + ], + entry_points={ + "console_scripts": ["scala_test = scala_test.main:main"] + }, + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + ], + python_requires=">=3.9", + tests_require=["pytest"], + # Define the test suite + test_suite="tests", +) From f075ebc062af856138a3deabc781d9139544c9b3 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Fri, 11 Aug 2023 16:43:18 -0700 Subject: [PATCH 2/4] Update the plugin --- .gitignore | 1 + plugins/pyproject.toml | 25 +++++++++++++++++++++++ plugins/scala_test/README.md | 19 ++++++++++++++++++ plugins/scala_test/main.py | 3 --- plugins/setup.py | 39 ------------------------------------ 5 files changed, 45 insertions(+), 42 deletions(-) create mode 100644 plugins/pyproject.toml create mode 100644 plugins/scala_test/README.md delete mode 100644 plugins/setup.py diff --git a/.gitignore b/.gitignore index 29e3930f6..0694fd045 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ target Cargo.lock tmp_test* env/ +**.egg-info # Dependencies diff --git a/plugins/pyproject.toml b/plugins/pyproject.toml new file mode 100644 index 000000000..c4c679100 --- /dev/null +++ b/plugins/pyproject.toml @@ -0,0 +1,25 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.poetry] +name = "scala_test" +version = "0.0.1" +description = "Rules to migrate 'scaletest'" +# Add any other metadata you need + +[tool.poetry.dependencies] +python = "^3.9" +polyglot_piranha = "*" + +[tool.poetry.dev-dependencies] +pytest = "*" + +# [tool.poetry.scripts] +# scala_test = "scala_test.main:main" + +[tool.poetry.scripts."scala_test"] +main = "scala_test.main:main" + +[tool.poetry.scripts."pytest"] +main = "pytest" diff --git a/plugins/scala_test/README.md b/plugins/scala_test/README.md new file mode 100644 index 000000000..2d205730c --- /dev/null +++ b/plugins/scala_test/README.md @@ -0,0 +1,19 @@ +# `scalatest` Migration Plugin + +## Usage: +``` +python3 plugins/scala_test/main.py -h +usage: main.py [-h] --path_to_codebase PATH_TO_CODEBASE + +Migrates scala tests!!! + +options: + -h, --help show this help message and exit + --path_to_codebase PATH_TO_CODEBASE + Path to the codebase directory. +``` + +## Test +``` +pytest plugins/scala_test +``` diff --git a/plugins/scala_test/main.py b/plugins/scala_test/main.py index 6f7eb58f7..42249aa4b 100644 --- a/plugins/scala_test/main.py +++ b/plugins/scala_test/main.py @@ -1,9 +1,6 @@ import argparse from update_imports import update_imports - - - def _parse_args(): parser = argparse.ArgumentParser(description="Migrates scala tests!!!") parser.add_argument( diff --git a/plugins/setup.py b/plugins/setup.py deleted file mode 100644 index 5f72c7ac4..000000000 --- a/plugins/setup.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2023 Uber Technologies, Inc. -# -#
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file -# except in compliance with the License. You may obtain a copy of the License at -#
http://www.apache.org/licenses/LICENSE-2.0 -# -#
Unless required by applicable law or agreed to in writing, software distributed under the -# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing permissions and -# limitations under the License. - -from setuptools import find_packages, setup - -setup( - name="scala_test", - version="0.0.1", - description="Rules to migrate `scaletest`", - # long_description=open("README.md").read(), - # long_description_content_type="text/markdown", - # url="https://github.com/uber/piranha", - packages=find_packages(), - include_package_data=True, - install_requires=[ - # "polyglot-piranha", - "pytest", - ], - entry_points={ - "console_scripts": ["scala_test = scala_test.main:main"] - }, - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - ], - python_requires=">=3.9", - tests_require=["pytest"], - # Define the test suite - test_suite="tests", -) From 5fbdb7bc315087cb76a75a53698ee395034ecd37 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Mon, 14 Aug 2023 08:52:14 -0700 Subject: [PATCH 3/4] Update the plugin --- plugins/scala_test/README.md | 9 ++++++++- plugins/scala_test/requirements.txt | 2 ++ 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 plugins/scala_test/requirements.txt diff --git a/plugins/scala_test/README.md b/plugins/scala_test/README.md index 2d205730c..6d98d772a 100644 --- a/plugins/scala_test/README.md +++ b/plugins/scala_test/README.md @@ -1,8 +1,15 @@ # `scalatest` Migration Plugin ## Usage: + +Clone the repository - `git clone https://github.com/uber/piranha.git` + +Install the dependencies - `pip3 install -r plugins/scala_test/requirements.txt` + +Run the tool - `python3 plugins/scala_test/main.py -h` + +CLI: ``` -python3 plugins/scala_test/main.py -h usage: main.py [-h] --path_to_codebase PATH_TO_CODEBASE Migrates scala tests!!! diff --git a/plugins/scala_test/requirements.txt b/plugins/scala_test/requirements.txt new file mode 100644 index 000000000..b3cd0ac6e --- /dev/null +++ b/plugins/scala_test/requirements.txt @@ -0,0 +1,2 @@ +polyglot-piranha +pytest From 43a7628b168de1313e816a6fc2d827549eef1335 Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Tue, 15 Aug 2023 09:45:05 -0700 Subject: [PATCH 4/4] Update the plugin --- plugins/pyproject.toml | 3 -- plugins/scala_test/README.md | 15 +++++- plugins/scala_test/main.py | 11 ++-- plugins/scala_test/recipes.py | 53 ++++++++++++++----- .../scala_test/tests/test_update_imports.py | 7 +-- plugins/scala_test/update_imports.py | 6 ++- 6 files changed, 67 insertions(+), 28 deletions(-) diff --git a/plugins/pyproject.toml b/plugins/pyproject.toml index c4c679100..4d80dcf66 100644 --- a/plugins/pyproject.toml +++ b/plugins/pyproject.toml @@ -15,9 +15,6 @@ polyglot_piranha = "*" [tool.poetry.dev-dependencies] pytest = "*" -# [tool.poetry.scripts] -# scala_test = "scala_test.main:main" - [tool.poetry.scripts."scala_test"] main = "scala_test.main:main" diff --git a/plugins/scala_test/README.md b/plugins/scala_test/README.md index 6d98d772a..0858f7c73 100644 --- a/plugins/scala_test/README.md +++ b/plugins/scala_test/README.md @@ -1,4 +1,15 @@ -# `scalatest` Migration Plugin +# `scalatest` Migration Plugin (WIP) + +This piranha plugin updates `scalatest` to a new version. + + +Currently, it updates to [v.3.2.2](https://mvnrepository.com/artifact/org.scalatest/scalatest_2.12/3.2.2) only. The following import statements are updated: +* `org.scalatest.Matchers`-> `org.scalatest.matchers.should.Matchers` +* `org.scalatest.mock.MockitoSugar`-> `org.scalatestplus.mockito.MockitoSugar` +* `org.scalatest.FunSuite`->`org.scalatest.funsuite.AnyFunSuite` +* `org.scalatest.junit.JUnitRunner`->`org.scalatestplus.junit.JUnitRunner` +* `org.scalatest.FlatSpec`-> `org.scalatest.flatspec.AnyFlatSpec` +* `org.scalatest.junit.AssertionsForJUnit`-> `org.scalatestplus.junit.AssertionsForJUnit` ## Usage: @@ -12,7 +23,7 @@ CLI: ``` usage: main.py [-h] --path_to_codebase PATH_TO_CODEBASE -Migrates scala tests!!! +Updates the codebase to use a new version of `scalatest_2.12`. options: -h, --help show this help message and exit diff --git a/plugins/scala_test/main.py b/plugins/scala_test/main.py index 42249aa4b..63f77498f 100644 --- a/plugins/scala_test/main.py +++ b/plugins/scala_test/main.py @@ -2,19 +2,24 @@ from update_imports import update_imports def _parse_args(): - parser = argparse.ArgumentParser(description="Migrates scala tests!!!") + parser = argparse.ArgumentParser(description="Updates the codebase to use a new version of `scalatest_2.12`") parser.add_argument( "--path_to_codebase", required=True, help="Path to the codebase directory.", ) - + parser.add_argument( + "--new_version", + required=True, + default="3.2.2", + help="Version of `scalatest` to update to.", + ) args = parser.parse_args() return args def main(): args = _parse_args() - update_imports(args.path_to_codebase, dry_run=True) + update_imports(args.path_to_codebase, args.new_version, dry_run=True) if __name__ == "__main__": main() diff --git a/plugins/scala_test/recipes.py b/plugins/scala_test/recipes.py index 8d9ecbde7..d6f9ab65e 100644 --- a/plugins/scala_test/recipes.py +++ b/plugins/scala_test/recipes.py @@ -1,42 +1,69 @@ -from polyglot_piranha import Rule, OutgoingEdges, RuleGraph, PiranhaArguments, execute_piranha +from polyglot_piranha import ( + Rule, + OutgoingEdges, + RuleGraph, + PiranhaArguments, + execute_piranha, +) + def replace_imports( - target_new_types: dict[str, str], search_heuristic: str, path_to_codebase: str, - dry_run = False + target_new_types: dict[str, str], + search_heuristic: str, + path_to_codebase: str, + dry_run=False, ): + """This function replaces the imports of the target types with the new types. + The search heuristic is used to find the files that contain the target types. + + Args: + target_new_types (dict[str, str]): A dictionary from target type to new type (fully qualified names) + search_heuristic (str): The search heuristic to find the files that contain the target types + path_to_codebase (str): The path to the codebase + dry_run (bool, optional): True if the changes should not be written to disk. Defaults to False. + + Returns: + _type_: A list of PiranhaOutput objects + """ find_relevant_files = Rule( name="find_relevant_files", - query="((identifier) @x (#eq? @x \"@search_heuristic\"))", + query='((identifier) @x (#eq? @x "@search_heuristic"))', holes={"search_heuristic"}, ) - e1 = OutgoingEdges("find_relevant_files", to=[f"update_import"], scope="File") + find_relevant_files_andThen_update_import = OutgoingEdges( + "find_relevant_files", to=["update_import"], scope="File" + ) rules = [find_relevant_files] - edges = [e1] + edges = [find_relevant_files_andThen_update_import] for target_type, new_type in target_new_types.items(): - rs, es = replace_import_rules_edges(target_type, new_type) + rs, es = replace_import_rules_and_edges(target_type, new_type) rules.extend(rs) edges.extend(es) rule_graph = RuleGraph(rules=rules, edges=edges) - args= PiranhaArguments( + args = PiranhaArguments( language="scala", path_to_codebase=path_to_codebase, rule_graph=rule_graph, substitutions={"search_heuristic": f"{search_heuristic}"}, - dry_run=dry_run + dry_run=dry_run, ) - + return execute_piranha(args) - -def replace_import_rules_edges( +def replace_import_rules_and_edges( target_qualified_type_name: str, new_qualified_type_name: str ) -> (list[Rule], list[OutgoingEdges]): - + """This function generates the rules and edges to replace the imports of the target type with the new type. + It supports both simple and nested imports. While the simple imports are replaced directly, the nested imports are deleted and the new type is imported (as a simple non-nested import). + Assume that the target type is "a.b.c.d" and the new type is "x.y.z". Then the following rules are generated: + import a.b.c.d -> import x.y.z + import a.b.c.{d, e} -> import x.y.z \n import a.b.c.{d} + """ name_components = target_qualified_type_name.split(".") type_name = name_components[-1] diff --git a/plugins/scala_test/tests/test_update_imports.py b/plugins/scala_test/tests/test_update_imports.py index 405b5a352..3db94789b 100644 --- a/plugins/scala_test/tests/test_update_imports.py +++ b/plugins/scala_test/tests/test_update_imports.py @@ -1,14 +1,11 @@ -from logging import debug, error +from logging import error from pathlib import Path - from os.path import join, basename from os import listdir - from update_imports import update_imports -# from update_imports import update_imports def test_update_imports(): - summary = update_imports("plugins/scala_test/tests/resources/input/", dry_run=True) + summary = update_imports("plugins/scala_test/tests/resources/input/", "3.2.2", dry_run=True) assert is_as_expected("plugins/scala_test/tests/resources/", summary) def is_as_expected(path_to_scenario, output_summary): diff --git a/plugins/scala_test/update_imports.py b/plugins/scala_test/update_imports.py index 181a729fb..b6d802146 100644 --- a/plugins/scala_test/update_imports.py +++ b/plugins/scala_test/update_imports.py @@ -11,5 +11,7 @@ "org.scalatest.junit.AssertionsForJUnit": "org.scalatestplus.junit.AssertionsForJUnit", } -def update_imports(path_to_codebase: str, dry_run = False): - return replace_imports(IMPORT_MAPPING, "scalatest", path_to_codebase, dry_run) +def update_imports(path_to_codebase: str, scalatest_version,dry_run = False): + if scalatest_version == "3.2.2": + return replace_imports(IMPORT_MAPPING, "scalatest", path_to_codebase, dry_run) + raise Exception(f"Unsupported version: {scalatest_version}")