Skip to content

Commit

Permalink
restructure workflow to handle an entire class at a time instead of a…
Browse files Browse the repository at this point in the history
… single type param
  • Loading branch information
FelixAnthonisen committed Nov 19, 2024
1 parent 545b912 commit 9ef81af
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 58 deletions.
2 changes: 1 addition & 1 deletion src/main/java/io/github/bldl/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class Main {
public static void main(String[] args) {
AstManipulator manip = new AstManipulator(new StdoutMessager(), "example");
manip.eraseTypesAndInsertCasts("Herd.java", "", "T");
manip.eraseTypesAndInsertCasts("Herd.java", "", null);
manip.applyChanges();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import io.github.bldl.astParsing.util.TypeHandler;
import io.github.bldl.astParsing.visitors.ParameterTypeCollector;
import io.github.bldl.astParsing.visitors.ReturnTypeCollector;
import io.github.bldl.graph.ClassHierarchyGraph;
import io.leangen.geantyref.AnnotationFormatException;
import io.leangen.geantyref.TypeFactory;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -42,18 +43,21 @@
public class VarianceProcessor extends AbstractProcessor {
private Messager messager;
private AstManipulator astManipulator;
private Map<String, Map<String, MyVariance>> classes = new HashMap<>();
private Map<String, String> packages = new HashMap<>();
private final ImmutableList<Class<? extends Annotation>> supportedAnnotations = ImmutableList.of(MyVariance.class,
Covariant.class, Contravariant.class);

@Override
public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {
if (roundEnv.getElementsAnnotatedWithAny(Set.of(MyVariance.class,
Covariant.class, Contravariant.class)).isEmpty())
return false;
boolean workHasBeenDone = false;
messager = processingEnv.getMessager();
astManipulator = new AstManipulator(messager,
System.getProperty("user.dir") + "/src/main/java");
ClassHierarchyGraph<String> classHierarchy = astManipulator.computeClassHierarchy();
messager.printMessage(Kind.NOTE, classHierarchy.toString());
messager.printMessage(Kind.NOTE, "Processing annotations:\n");
boolean workHasBeenDone = false;
for (Class<? extends Annotation> annotationType : supportedAnnotations) {
for (Element e : roundEnv.getElementsAnnotatedWith(annotationType)) {
workHasBeenDone = true;
Expand All @@ -67,16 +71,24 @@ else if (annotationType.equals(Contravariant.class))
Map.of("variance", VarianceType.CONTRAVARIANT, "strict", true));

} catch (AnnotationFormatException ex) {
// catch this later
}
if (annotation != null)
processElement(annotation, e);
else
messager.printMessage(Kind.WARNING, "Could not parse annotation for element: " + e);
}
}
if (workHasBeenDone)
astManipulator.applyChanges();
if (!workHasBeenDone) {
messager.printMessage(Kind.NOTE, "No changes made. Not saving.");
return false;
}

for (String className : classes.keySet()) {
astManipulator.eraseTypesAndInsertCasts(className + ".java", packages.get(className),
classes.get(className));
}

astManipulator.applyChanges();
return true;
}

Expand All @@ -99,9 +111,12 @@ private void processElement(MyVariance annotation, Element e) {
className));
}

packages.putIfAbsent(className, packageName);
classes.putIfAbsent(className, new HashMap<>());
classes.get(className).put(tE.getSimpleName().toString(), annotation);
checkVariance(className, annotation, packageName, tE.getSimpleName().toString());
astManipulator.eraseTypesAndInsertCasts(className + ".java", packageName,
tE.getSimpleName().toString());
// astManipulator.eraseTypesAndInsertCasts(className + ".java", packageName,
// tE.getSimpleName().toString(), annotation);
}

private void checkVariance(String className, MyVariance annotation, String packageName, String typeOfInterest) {
Expand All @@ -110,7 +125,7 @@ private void checkVariance(String className, MyVariance annotation, String packa
+ ".java");
if (annotation.variance() == VarianceType.CONTRAVARIANT)
cu.accept(new ReturnTypeCollector(), types);
else
else if (annotation.variance() == VarianceType.COVARIANT)
cu.accept(new ParameterTypeCollector(), types);

for (Type type : types) {
Expand All @@ -122,7 +137,6 @@ private void checkVariance(String className, MyVariance annotation, String packa
className,
annotation.variance(),
annotation.variance() == VarianceType.COVARIANT ? "IN" : "OUT"));
break;
}
}
}
Expand Down
103 changes: 77 additions & 26 deletions src/main/java/io/github/bldl/astParsing/AstManipulator.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.github.bldl.astParsing;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.PackageDeclaration;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
Expand All @@ -10,13 +11,18 @@
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.Name;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.nodeTypes.NodeWithAnnotations;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
import com.github.javaparser.ast.type.Type;
import com.github.javaparser.ast.type.TypeParameter;
import com.github.javaparser.ast.visitor.ModifierVisitor;
import com.github.javaparser.ast.visitor.Visitable;
import com.github.javaparser.utils.CodeGenerationUtils;
import com.github.javaparser.utils.SourceRoot;
import io.github.bldl.annotationProcessing.annotations.MyVariance;
import io.github.bldl.astParsing.util.ClassData;
import io.github.bldl.astParsing.util.MethodData;
import io.github.bldl.astParsing.util.ParamData;
import io.github.bldl.astParsing.visitors.CastInsertionVisitor;
import io.github.bldl.astParsing.visitors.MethodCollector;
import io.github.bldl.astParsing.visitors.TypeEraserVisitor;
Expand All @@ -27,7 +33,6 @@

import java.io.File;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
Expand All @@ -51,9 +56,9 @@ public AstManipulator(Messager messager, String sourceFolder) {

public void applyChanges() {
this.sourceRoot.getCompilationUnits().forEach(cu -> {
// messager.printMessage(Kind.NOTE, "Saving cu: " + cu.toString());
changePackageDeclaration(cu);
});
messager.printMessage(Kind.NOTE, "Saving modified AST's to output directory");
this.sourceRoot.saveAll(
CodeGenerationUtils.mavenModuleRoot(AstManipulator.class).resolve(Paths.get(sourceFolder + "/output")));
}
Expand All @@ -62,45 +67,62 @@ public SourceRoot getSourceRoot() {
return sourceRoot;
}

public void eraseTypesAndInsertCasts(String cls, String packageName, String typeOfInterest) {
public void eraseTypesAndInsertCasts(String cls, String packageName, Map<String, MyVariance> mp) {
messager.printMessage(Kind.NOTE,
String.format("Now parsing AST's for class %s and type param %s", cls, typeOfInterest));
String.format("Now parsing AST's for class %s", cls));
File dir = Paths.get(sourceFolder).toFile();
assert dir.exists();
assert dir.isDirectory();

ClassData classData = computeClassData(cls, packageName, typeOfInterest);
eraseAnnotations(cls, packageName);
ClassData classData = computeClassData(cls, packageName, mp);
messager.printMessage(Kind.NOTE, "Collected class data:\n" + classData);
Map<String, MethodData> methodMap = new HashMap<>();

sourceRoot.parse(packageName, cls).accept(new MethodCollector(Arrays.asList(typeOfInterest)),
sourceRoot.parse(packageName, cls).accept(new MethodCollector(mp.keySet()),
methodMap);

messager.printMessage(Kind.NOTE, "Collected methods:\n" + methodMap.toString());
changeAST(dir, classData, methodMap, "");
}

public void eraseAnnotations(String cls, String packageName) {
Set<String> annotations = Set.of("MyVariance", "Covariant", "Contravariant");
CompilationUnit cu = sourceRoot.parse(packageName, cls);
cu.accept(new ModifierVisitor<Void>() {
@Override
public Visitable visit(Parameter n, Void arg) {
n.getAnnotations().removeIf(annotation -> annotations.contains(annotation.getNameAsString()));
return super.visit(n, arg);
}

public Visitable visit(TypeParameter n, Void arg) {
n.getAnnotations().removeIf(annotation -> annotations.contains(annotation.getNameAsString()));
return super.visit(n, arg);
}
}, null);
}

public ClassHierarchyGraph<String> computeClassHierarchy() {
ClassHierarchyGraph<String> g = new ClassHierarchyGraph<>();
g.addVertex("Object");
computeClassHierarchyRec(g, Paths.get(sourceFolder).toFile(), "");
return g;
}

private ClassData computeClassData(String cls, String packageName, String typeOfInterest) {
private ClassData computeClassData(String cls, String packageName, Map<String, MyVariance> mp) {
CompilationUnit cu = sourceRoot.parse(packageName, cls);
Map<String, ParamData> indexAndBound = new HashMap<>();
var a = cu.findAll(ClassOrInterfaceDeclaration.class).get(0).getTypeParameters();
for (int i = 0; i < a.size(); ++i) {
TypeParameter type = a.get(i);
NodeList<ClassOrInterfaceType> boundList = type.getTypeBound();
String leftMostBound = boundList == null || boundList.size() == 0 ? "Object" : boundList.get(0).asString();
if (type.getNameAsString().equals(typeOfInterest)) {
a.get(i);
return new ClassData(cls.replaceFirst("\\.java$", ""), leftMostBound, i);
if (mp.keySet().contains(type.getNameAsString())) {
indexAndBound.put(type.getNameAsString(),
new ParamData(i, leftMostBound, mp.get(type.getNameAsString())));
}

}
return null;
return new ClassData(cls.replaceFirst("\\.java$", ""), indexAndBound);
}

private void changeAST(File dir, ClassData classData, Map<String, MethodData> methodMap,
Expand All @@ -117,12 +139,11 @@ private void changeAST(File dir, ClassData classData, Map<String, MethodData> me

CompilationUnit cu = sourceRoot.parse(packageName, fileName);

Set<Pair<String, String>> varsToWatch = new HashSet<>();
Set<Pair<String, ClassOrInterfaceType>> varsToWatch = new HashSet<>();
cu.accept(new VariableCollector(classData), varsToWatch);
messager.printMessage(Kind.NOTE, "Collected variables to watch:\n" + varsToWatch);
performSubtypingChecks(cu, classData, methodMap, varsToWatch);
// performSubtypingChecks(cu, classData, methodMap, varsToWatch);
cu.accept(new TypeEraserVisitor(classData), null);
for (Pair<String, String> var : varsToWatch) {
for (Pair<String, ClassOrInterfaceType> var : varsToWatch) {
CastInsertionVisitor castInsertionVisitor = new CastInsertionVisitor(var, methodMap);
cu.accept(castInsertionVisitor, null);
}
Expand Down Expand Up @@ -177,6 +198,10 @@ private void performSubtypingChecks(CompilationUnit cu, ClassData classData,
Map<String, MethodData> methodMap,
Set<Pair<String, String>> varsToWatch) {
Map<String, Map<Integer, Type>> methodParams = collectMethodParams(cu, classData);
Map<String, String> varsToWatchMap = new HashMap<>();
varsToWatch.forEach(p -> {
varsToWatchMap.put(p.first, p.second);
});
cu.findAll(MethodCallExpr.class).forEach(methodCall -> {
if (!methodParams.containsKey(methodCall.getNameAsString()))
return;
Expand All @@ -189,20 +214,28 @@ private void performSubtypingChecks(CompilationUnit cu, ClassData classData,
String name = ((NameExpr) e).getNameAsString();
varsToWatch.forEach(p -> {
if (p.first.equals(name)) {
// check subtyping
// boolean valid = isValidSubtype(name, name, annotation);
// if (!valid)
messager.printMessage(Kind.ERROR,
String.format("Invalid subtype for method call: ", methodCall.toString()));
}
});
}

});
cu.findAll(AssignExpr.class).forEach(assignExpr -> {
// cu.findAll(AssignExpr.class).forEach(assignExpr -> {
// if (!(assignExpr.getTarget() instanceof NameExpr))
// return;
// NameExpr name = (NameExpr) assignExpr.getTarget();
// if (!varsToWatchMap.containsKey(name.toString()))
// return;

messager.printMessage(Kind.NOTE, assignExpr.toString());
messager.printMessage(Kind.NOTE, assignExpr.getTarget().getClass().toString());
messager.printMessage(Kind.NOTE, assignExpr.getValue().getClass().toString());
});
// });
// cu.findAll(ForEachStmt.class).forEach(stmt -> {

// });
// cu.findAll(VariableDeclarationExpr.class).forEach(stmt -> {

// });
}

Expand All @@ -218,15 +251,33 @@ private Map<String, Map<Integer, Type>> collectMethodParams(CompilationUnit cu,
String methodName = dec.getNameAsString();
if (type.getNameAsString().equals(classData.className())) {
mp.putIfAbsent(methodName, new HashMap<>());
mp.get(methodName).put(i, type.getTypeArguments().get().get(classData.indexOfParam()));
// mp.get(methodName).put(i,
// type.getTypeArguments().get().get(classData.indexOfParam()));
}
}
});
return mp;
}

private String resolveType() {
return null;
private boolean isValidSubtype(String assigneeType, String assignedType, MyVariance annotation) {
if (!classHierarchy.containsVertex(assigneeType)) {
messager.printMessage(Kind.WARNING,
String.format("%s is not a user defined type, so no subtyping checks can be made", assigneeType));
return true;
}
if (!classHierarchy.containsVertex(assignedType)) {
messager.printMessage(Kind.WARNING,
String.format("%s is not a user defined type, so no subtyping checks can be made", assignedType));
return true;
}
switch (annotation.variance()) {
case COVARIANT:
return classHierarchy.isDescendant(assignedType, assigneeType);
case CONTRAVARIANT:
return classHierarchy.isDescendant(assigneeType, assignedType);
default:
return false;
}
}

}
4 changes: 3 additions & 1 deletion src/main/java/io/github/bldl/astParsing/util/ClassData.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.github.bldl.astParsing.util;

public record ClassData(String className, String leftmostBound, int indexOfParam) {
import java.util.Map;

public record ClassData(String className, Map<String, ParamData> params) {

}
6 changes: 6 additions & 0 deletions src/main/java/io/github/bldl/astParsing/util/ParamData.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package io.github.bldl.astParsing.util;

import io.github.bldl.annotationProcessing.annotations.MyVariance;

public record ParamData(int index, String leftmostBound, MyVariance variance) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@
import java.util.Map;
import java.util.Optional;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.expr.CastExpr;
import com.github.javaparser.ast.expr.EnclosedExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
import com.github.javaparser.ast.visitor.ModifierVisitor;
import com.github.javaparser.ast.visitor.Visitable;
import com.github.javaparser.ast.type.Type;

import io.github.bldl.astParsing.util.MethodData;
import io.github.bldl.util.Pair;

public class CastInsertionVisitor extends ModifierVisitor<Void> {
private final Pair<String, String> ref;
private final Pair<String, ClassOrInterfaceType> ref;
private final Map<String, MethodData> methodMap;

public CastInsertionVisitor(Pair<String, String> ref, Map<String, MethodData> methodMap) {
public CastInsertionVisitor(Pair<String, ClassOrInterfaceType> ref, Map<String, MethodData> methodMap) {
this.ref = ref;
this.methodMap = methodMap;
}
Expand All @@ -32,7 +34,10 @@ public Visitable visit(MethodCallExpr n, Void arg) {
if (expr.getNameAsString().equals(ref.first)) {
MethodData data = methodMap.get(n.getNameAsString());
if (data != null && data.shouldCast()) {
String castString = data.castString().replace("*", ref.second);
NodeList<Type> arguments = ref.second.getTypeArguments().get();
String castString = data.castString();
for (int i = 0; i < arguments.size(); ++i)
castString = data.castString().replace(Integer.toString(i), arguments.get(i).asString());
ClassOrInterfaceType castType = new ClassOrInterfaceType(null, castString);
CastExpr cast = new CastExpr(castType, n);
EnclosedExpr enclosedCast = new EnclosedExpr(cast);
Expand Down
Loading

0 comments on commit 9ef81af

Please sign in to comment.