From bf9b4d31a9eb12dfbb5f5461743d947a1db264bb Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Fri, 11 Aug 2023 16:02:21 -0700 Subject: [PATCH] Adding plugin to migrate scalatest --- plugins/scala_test/main.py | 23 +++++ plugins/scala_test/recipes.py | 83 +++++++++++++++++++ plugins/scala_test/tests/__init__.py | 0 .../tests/resources/expected/sample.scala | 8 ++ .../tests/resources/input/sample.scala | 7 ++ .../scala_test/tests/test_update_imports.py | 40 +++++++++ plugins/scala_test/update_imports.py | 15 ++++ plugins/setup.py | 39 +++++++++ 8 files changed, 215 insertions(+) create mode 100644 plugins/scala_test/main.py create mode 100644 plugins/scala_test/recipes.py create mode 100644 plugins/scala_test/tests/__init__.py create mode 100644 plugins/scala_test/tests/resources/expected/sample.scala create mode 100644 plugins/scala_test/tests/resources/input/sample.scala create mode 100644 plugins/scala_test/tests/test_update_imports.py create mode 100644 plugins/scala_test/update_imports.py create mode 100644 plugins/setup.py diff --git a/plugins/scala_test/main.py b/plugins/scala_test/main.py new file mode 100644 index 0000000000..6f7eb58f76 --- /dev/null +++ b/plugins/scala_test/main.py @@ -0,0 +1,23 @@ +import argparse +from update_imports import update_imports + + + + +def _parse_args(): + parser = argparse.ArgumentParser(description="Migrates scala tests!!!") + parser.add_argument( + "--path_to_codebase", + required=True, + help="Path to the codebase directory.", + ) + + args = parser.parse_args() + return args + +def main(): + args = _parse_args() + update_imports(args.path_to_codebase, dry_run=True) + +if __name__ == "__main__": + main() diff --git a/plugins/scala_test/recipes.py b/plugins/scala_test/recipes.py new file mode 100644 index 0000000000..8d9ecbde7f --- /dev/null +++ b/plugins/scala_test/recipes.py @@ -0,0 +1,83 @@ +from polyglot_piranha import Rule, OutgoingEdges, RuleGraph, PiranhaArguments, execute_piranha + +def replace_imports( + target_new_types: dict[str, str], search_heuristic: str, path_to_codebase: str, + dry_run = False +): + find_relevant_files = Rule( + name="find_relevant_files", + query="((identifier) @x (#eq? @x \"@search_heuristic\"))", + holes={"search_heuristic"}, + ) + e1 = OutgoingEdges("find_relevant_files", to=[f"update_import"], scope="File") + + rules = [find_relevant_files] + edges = [e1] + + for target_type, new_type in target_new_types.items(): + rs, es = replace_import_rules_edges(target_type, new_type) + rules.extend(rs) + edges.extend(es) + + rule_graph = RuleGraph(rules=rules, edges=edges) + + args= PiranhaArguments( + language="scala", + path_to_codebase=path_to_codebase, + rule_graph=rule_graph, + substitutions={"search_heuristic": f"{search_heuristic}"}, + dry_run=dry_run + ) + + return execute_piranha(args) + + + +def replace_import_rules_edges( + target_qualified_type_name: str, new_qualified_type_name: str +) -> (list[Rule], list[OutgoingEdges]): + + name_components = target_qualified_type_name.split(".") + type_name = name_components[-1] + + qualifier_predicate = "\n".join( + [f'(#match? @import_decl "{n}")' for n in name_components[:-1]] + ) + + delete_nested_import = Rule( + name=f"delete_nested_import_{type_name}", + query=f"""( + (import_declaration (namespace_selectors (_) @tn )) @import_decl + (#eq? @tn "{type_name}") + {qualifier_predicate} + )""", + replace_node="tn", + replace="", + is_seed_rule=False, + groups={"update_import"}, + ) + + update_simple_import = Rule( + name=f"update_simple_import_{type_name}", + query=f"cs import {target_qualified_type_name}", + replace_node="*", + replace=f"import {new_qualified_type_name}", + is_seed_rule=False, + groups={"update_import"}, + ) + + insert_import = Rule( + name=f"insert_import_{type_name}", + query="(import_declaration) @import_decl", + replace_node="import_decl", + replace=f"@import_decl\nimport {new_qualified_type_name}\n", + is_seed_rule=False, + ) + + e2 = OutgoingEdges( + f"delete_nested_import_{type_name}", + to=[f"insert_import_{type_name}"], + scope="Parent", + ) + + return [delete_nested_import, update_simple_import, insert_import], [e2] diff --git a/plugins/scala_test/tests/__init__.py b/plugins/scala_test/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/scala_test/tests/resources/expected/sample.scala b/plugins/scala_test/tests/resources/expected/sample.scala new file mode 100644 index 0000000000..eabe14b082 --- /dev/null +++ b/plugins/scala_test/tests/resources/expected/sample.scala @@ -0,0 +1,8 @@ +package com.scala.piranha + +import com.uber.michelangelo.AbstractSparkSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} +import org.scalatest.{BeforeAndAfter} +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar diff --git a/plugins/scala_test/tests/resources/input/sample.scala b/plugins/scala_test/tests/resources/input/sample.scala new file mode 100644 index 0000000000..189bea734c --- /dev/null +++ b/plugins/scala_test/tests/resources/input/sample.scala @@ -0,0 +1,7 @@ +package com.scala.piranha + +import com.uber.michelangelo.AbstractSparkSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} +import org.scalatest.{BeforeAndAfter, Matchers} +import org.scalatest.mock.MockitoSugar diff --git a/plugins/scala_test/tests/test_update_imports.py b/plugins/scala_test/tests/test_update_imports.py new file mode 100644 index 0000000000..405b5a352a --- /dev/null +++ b/plugins/scala_test/tests/test_update_imports.py @@ -0,0 +1,40 @@ +from logging import debug, error +from pathlib import Path + +from os.path import join, basename +from os import listdir + +from update_imports import update_imports +# from update_imports import update_imports + +def test_update_imports(): + summary = update_imports("plugins/scala_test/tests/resources/input/", dry_run=True) + assert is_as_expected("plugins/scala_test/tests/resources/", summary) + +def is_as_expected(path_to_scenario, output_summary): + expected_output = join(path_to_scenario, "expected") + print("Summary", output_summary) + input_dir = join(path_to_scenario, "input") + for file_name in listdir(expected_output): + with open(join(expected_output, file_name), "r") as f: + file_content = f.read() + expected_content = "".join(file_content.split()) + + # Search for the file in the output summary + updated_content = [ + "".join(o.content.split()) + for o in output_summary + if basename(o.path) == file_name + ] + print(file_name) + # Check if the file was rewritten + if updated_content: + if expected_content != updated_content[0]: + error("----update" + updated_content[0] ) + return False + else: + # The scenario where the file is not expected to be rewritten + original_content= Path(join(input_dir, file_name)).read_text() + if expected_content != "".join(original_content.split()): + return False + return True diff --git a/plugins/scala_test/update_imports.py b/plugins/scala_test/update_imports.py new file mode 100644 index 0000000000..181a729fb4 --- /dev/null +++ b/plugins/scala_test/update_imports.py @@ -0,0 +1,15 @@ +from recipes import replace_imports + + +IMPORT_MAPPING = { + "org.scalatest.Matchers": "org.scalatest.matchers.should.Matchers", + "org.scalatest.mock.MockitoSugar": "org.scalatestplus.mockito.MockitoSugar", + # Todo write test scenarios for these + "org.scalatest.FunSuite":"org.scalatest.funsuite.AnyFunSuite", + "org.scalatest.junit.JUnitRunner":"org.scalatestplus.junit.JUnitRunner", + "org.scalatest.FlatSpec": "org.scalatest.flatspec.AnyFlatSpec", + "org.scalatest.junit.AssertionsForJUnit": "org.scalatestplus.junit.AssertionsForJUnit", +} + +def update_imports(path_to_codebase: str, dry_run = False): + return replace_imports(IMPORT_MAPPING, "scalatest", path_to_codebase, dry_run) diff --git a/plugins/setup.py b/plugins/setup.py new file mode 100644 index 0000000000..5f72c7ac42 --- /dev/null +++ b/plugins/setup.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 Uber Technologies, Inc. +# +#

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +# except in compliance with the License. You may obtain a copy of the License at +#

http://www.apache.org/licenses/LICENSE-2.0 +# +#

Unless required by applicable law or agreed to in writing, software distributed under the +# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing permissions and +# limitations under the License. + +from setuptools import find_packages, setup + +setup( + name="scala_test", + version="0.0.1", + description="Rules to migrate `scaletest`", + # long_description=open("README.md").read(), + # long_description_content_type="text/markdown", + # url="https://github.com/uber/piranha", + packages=find_packages(), + include_package_data=True, + install_requires=[ + # "polyglot-piranha", + "pytest", + ], + entry_points={ + "console_scripts": ["scala_test = scala_test.main:main"] + }, + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + ], + python_requires=">=3.9", + tests_require=["pytest"], + # Define the test suite + test_suite="tests", +)