Skip to content

Commit 9ef81af

Browse files
restructure workflow to handle an entire class at a time instead of a single type param
1 parent 545b912 commit 9ef81af

File tree

9 files changed

+139
-58
lines changed

9 files changed

+139
-58
lines changed

src/main/java/io/github/bldl/Main.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class Main {
1212
public static void main(String[] args) {
1313
AstManipulator manip = new AstManipulator(new StdoutMessager(), "example");
14-
manip.eraseTypesAndInsertCasts("Herd.java", "", "T");
14+
manip.eraseTypesAndInsertCasts("Herd.java", "", null);
1515
manip.applyChanges();
1616
}
1717
}

src/main/java/io/github/bldl/annotationProcessing/VarianceProcessor.java

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
import io.github.bldl.astParsing.util.TypeHandler;
1313
import io.github.bldl.astParsing.visitors.ParameterTypeCollector;
1414
import io.github.bldl.astParsing.visitors.ReturnTypeCollector;
15-
import io.github.bldl.graph.ClassHierarchyGraph;
1615
import io.leangen.geantyref.AnnotationFormatException;
1716
import io.leangen.geantyref.TypeFactory;
17+
18+
import java.util.HashMap;
1819
import java.util.HashSet;
1920
import java.util.Map;
2021
import java.util.Set;
@@ -42,18 +43,21 @@
4243
public class VarianceProcessor extends AbstractProcessor {
4344
private Messager messager;
4445
private AstManipulator astManipulator;
46+
private Map<String, Map<String, MyVariance>> classes = new HashMap<>();
47+
private Map<String, String> packages = new HashMap<>();
4548
private final ImmutableList<Class<? extends Annotation>> supportedAnnotations = ImmutableList.of(MyVariance.class,
4649
Covariant.class, Contravariant.class);
4750

4851
@Override
4952
public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {
53+
if (roundEnv.getElementsAnnotatedWithAny(Set.of(MyVariance.class,
54+
Covariant.class, Contravariant.class)).isEmpty())
55+
return false;
56+
boolean workHasBeenDone = false;
5057
messager = processingEnv.getMessager();
5158
astManipulator = new AstManipulator(messager,
5259
System.getProperty("user.dir") + "/src/main/java");
53-
ClassHierarchyGraph<String> classHierarchy = astManipulator.computeClassHierarchy();
54-
messager.printMessage(Kind.NOTE, classHierarchy.toString());
5560
messager.printMessage(Kind.NOTE, "Processing annotations:\n");
56-
boolean workHasBeenDone = false;
5761
for (Class<? extends Annotation> annotationType : supportedAnnotations) {
5862
for (Element e : roundEnv.getElementsAnnotatedWith(annotationType)) {
5963
workHasBeenDone = true;
@@ -67,16 +71,24 @@ else if (annotationType.equals(Contravariant.class))
6771
Map.of("variance", VarianceType.CONTRAVARIANT, "strict", true));
6872

6973
} catch (AnnotationFormatException ex) {
70-
// catch this later
7174
}
7275
if (annotation != null)
7376
processElement(annotation, e);
7477
else
7578
messager.printMessage(Kind.WARNING, "Could not parse annotation for element: " + e);
7679
}
7780
}
78-
if (workHasBeenDone)
79-
astManipulator.applyChanges();
81+
if (!workHasBeenDone) {
82+
messager.printMessage(Kind.NOTE, "No changes made. Not saving.");
83+
return false;
84+
}
85+
86+
for (String className : classes.keySet()) {
87+
astManipulator.eraseTypesAndInsertCasts(className + ".java", packages.get(className),
88+
classes.get(className));
89+
}
90+
91+
astManipulator.applyChanges();
8092
return true;
8193
}
8294

@@ -99,9 +111,12 @@ private void processElement(MyVariance annotation, Element e) {
99111
className));
100112
}
101113

114+
packages.putIfAbsent(className, packageName);
115+
classes.putIfAbsent(className, new HashMap<>());
116+
classes.get(className).put(tE.getSimpleName().toString(), annotation);
102117
checkVariance(className, annotation, packageName, tE.getSimpleName().toString());
103-
astManipulator.eraseTypesAndInsertCasts(className + ".java", packageName,
104-
tE.getSimpleName().toString());
118+
// astManipulator.eraseTypesAndInsertCasts(className + ".java", packageName,
119+
// tE.getSimpleName().toString(), annotation);
105120
}
106121

107122
private void checkVariance(String className, MyVariance annotation, String packageName, String typeOfInterest) {
@@ -110,7 +125,7 @@ private void checkVariance(String className, MyVariance annotation, String packa
110125
+ ".java");
111126
if (annotation.variance() == VarianceType.CONTRAVARIANT)
112127
cu.accept(new ReturnTypeCollector(), types);
113-
else
128+
else if (annotation.variance() == VarianceType.COVARIANT)
114129
cu.accept(new ParameterTypeCollector(), types);
115130

116131
for (Type type : types) {
@@ -122,7 +137,6 @@ private void checkVariance(String className, MyVariance annotation, String packa
122137
className,
123138
annotation.variance(),
124139
annotation.variance() == VarianceType.COVARIANT ? "IN" : "OUT"));
125-
break;
126140
}
127141
}
128142
}

src/main/java/io/github/bldl/astParsing/AstManipulator.java

Lines changed: 77 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.github.bldl.astParsing;
22

33
import com.github.javaparser.ast.CompilationUnit;
4+
import com.github.javaparser.ast.Node;
45
import com.github.javaparser.ast.NodeList;
56
import com.github.javaparser.ast.PackageDeclaration;
67
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
@@ -10,13 +11,18 @@
1011
import com.github.javaparser.ast.expr.MethodCallExpr;
1112
import com.github.javaparser.ast.expr.Name;
1213
import com.github.javaparser.ast.expr.NameExpr;
14+
import com.github.javaparser.ast.nodeTypes.NodeWithAnnotations;
1315
import com.github.javaparser.ast.type.ClassOrInterfaceType;
1416
import com.github.javaparser.ast.type.Type;
1517
import com.github.javaparser.ast.type.TypeParameter;
18+
import com.github.javaparser.ast.visitor.ModifierVisitor;
19+
import com.github.javaparser.ast.visitor.Visitable;
1620
import com.github.javaparser.utils.CodeGenerationUtils;
1721
import com.github.javaparser.utils.SourceRoot;
22+
import io.github.bldl.annotationProcessing.annotations.MyVariance;
1823
import io.github.bldl.astParsing.util.ClassData;
1924
import io.github.bldl.astParsing.util.MethodData;
25+
import io.github.bldl.astParsing.util.ParamData;
2026
import io.github.bldl.astParsing.visitors.CastInsertionVisitor;
2127
import io.github.bldl.astParsing.visitors.MethodCollector;
2228
import io.github.bldl.astParsing.visitors.TypeEraserVisitor;
@@ -27,7 +33,6 @@
2733

2834
import java.io.File;
2935
import java.nio.file.Paths;
30-
import java.util.Arrays;
3136
import java.util.HashMap;
3237
import java.util.HashSet;
3338
import java.util.Map;
@@ -51,9 +56,9 @@ public AstManipulator(Messager messager, String sourceFolder) {
5156

5257
public void applyChanges() {
5358
this.sourceRoot.getCompilationUnits().forEach(cu -> {
54-
// messager.printMessage(Kind.NOTE, "Saving cu: " + cu.toString());
5559
changePackageDeclaration(cu);
5660
});
61+
messager.printMessage(Kind.NOTE, "Saving modified AST's to output directory");
5762
this.sourceRoot.saveAll(
5863
CodeGenerationUtils.mavenModuleRoot(AstManipulator.class).resolve(Paths.get(sourceFolder + "/output")));
5964
}
@@ -62,45 +67,62 @@ public SourceRoot getSourceRoot() {
6267
return sourceRoot;
6368
}
6469

65-
public void eraseTypesAndInsertCasts(String cls, String packageName, String typeOfInterest) {
70+
public void eraseTypesAndInsertCasts(String cls, String packageName, Map<String, MyVariance> mp) {
6671
messager.printMessage(Kind.NOTE,
67-
String.format("Now parsing AST's for class %s and type param %s", cls, typeOfInterest));
72+
String.format("Now parsing AST's for class %s", cls));
6873
File dir = Paths.get(sourceFolder).toFile();
6974
assert dir.exists();
7075
assert dir.isDirectory();
71-
72-
ClassData classData = computeClassData(cls, packageName, typeOfInterest);
76+
eraseAnnotations(cls, packageName);
77+
ClassData classData = computeClassData(cls, packageName, mp);
7378
messager.printMessage(Kind.NOTE, "Collected class data:\n" + classData);
7479
Map<String, MethodData> methodMap = new HashMap<>();
75-
76-
sourceRoot.parse(packageName, cls).accept(new MethodCollector(Arrays.asList(typeOfInterest)),
80+
sourceRoot.parse(packageName, cls).accept(new MethodCollector(mp.keySet()),
7781
methodMap);
7882

7983
messager.printMessage(Kind.NOTE, "Collected methods:\n" + methodMap.toString());
8084
changeAST(dir, classData, methodMap, "");
8185
}
8286

87+
public void eraseAnnotations(String cls, String packageName) {
88+
Set<String> annotations = Set.of("MyVariance", "Covariant", "Contravariant");
89+
CompilationUnit cu = sourceRoot.parse(packageName, cls);
90+
cu.accept(new ModifierVisitor<Void>() {
91+
@Override
92+
public Visitable visit(Parameter n, Void arg) {
93+
n.getAnnotations().removeIf(annotation -> annotations.contains(annotation.getNameAsString()));
94+
return super.visit(n, arg);
95+
}
96+
97+
public Visitable visit(TypeParameter n, Void arg) {
98+
n.getAnnotations().removeIf(annotation -> annotations.contains(annotation.getNameAsString()));
99+
return super.visit(n, arg);
100+
}
101+
}, null);
102+
}
103+
83104
public ClassHierarchyGraph<String> computeClassHierarchy() {
84105
ClassHierarchyGraph<String> g = new ClassHierarchyGraph<>();
85106
g.addVertex("Object");
86107
computeClassHierarchyRec(g, Paths.get(sourceFolder).toFile(), "");
87108
return g;
88109
}
89110

90-
private ClassData computeClassData(String cls, String packageName, String typeOfInterest) {
111+
private ClassData computeClassData(String cls, String packageName, Map<String, MyVariance> mp) {
91112
CompilationUnit cu = sourceRoot.parse(packageName, cls);
113+
Map<String, ParamData> indexAndBound = new HashMap<>();
92114
var a = cu.findAll(ClassOrInterfaceDeclaration.class).get(0).getTypeParameters();
93115
for (int i = 0; i < a.size(); ++i) {
94116
TypeParameter type = a.get(i);
95117
NodeList<ClassOrInterfaceType> boundList = type.getTypeBound();
96118
String leftMostBound = boundList == null || boundList.size() == 0 ? "Object" : boundList.get(0).asString();
97-
if (type.getNameAsString().equals(typeOfInterest)) {
98-
a.get(i);
99-
return new ClassData(cls.replaceFirst("\\.java$", ""), leftMostBound, i);
119+
if (mp.keySet().contains(type.getNameAsString())) {
120+
indexAndBound.put(type.getNameAsString(),
121+
new ParamData(i, leftMostBound, mp.get(type.getNameAsString())));
100122
}
101123

102124
}
103-
return null;
125+
return new ClassData(cls.replaceFirst("\\.java$", ""), indexAndBound);
104126
}
105127

106128
private void changeAST(File dir, ClassData classData, Map<String, MethodData> methodMap,
@@ -117,12 +139,11 @@ private void changeAST(File dir, ClassData classData, Map<String, MethodData> me
117139

118140
CompilationUnit cu = sourceRoot.parse(packageName, fileName);
119141

120-
Set<Pair<String, String>> varsToWatch = new HashSet<>();
142+
Set<Pair<String, ClassOrInterfaceType>> varsToWatch = new HashSet<>();
121143
cu.accept(new VariableCollector(classData), varsToWatch);
122-
messager.printMessage(Kind.NOTE, "Collected variables to watch:\n" + varsToWatch);
123-
performSubtypingChecks(cu, classData, methodMap, varsToWatch);
144+
// performSubtypingChecks(cu, classData, methodMap, varsToWatch);
124145
cu.accept(new TypeEraserVisitor(classData), null);
125-
for (Pair<String, String> var : varsToWatch) {
146+
for (Pair<String, ClassOrInterfaceType> var : varsToWatch) {
126147
CastInsertionVisitor castInsertionVisitor = new CastInsertionVisitor(var, methodMap);
127148
cu.accept(castInsertionVisitor, null);
128149
}
@@ -177,6 +198,10 @@ private void performSubtypingChecks(CompilationUnit cu, ClassData classData,
177198
Map<String, MethodData> methodMap,
178199
Set<Pair<String, String>> varsToWatch) {
179200
Map<String, Map<Integer, Type>> methodParams = collectMethodParams(cu, classData);
201+
Map<String, String> varsToWatchMap = new HashMap<>();
202+
varsToWatch.forEach(p -> {
203+
varsToWatchMap.put(p.first, p.second);
204+
});
180205
cu.findAll(MethodCallExpr.class).forEach(methodCall -> {
181206
if (!methodParams.containsKey(methodCall.getNameAsString()))
182207
return;
@@ -189,20 +214,28 @@ private void performSubtypingChecks(CompilationUnit cu, ClassData classData,
189214
String name = ((NameExpr) e).getNameAsString();
190215
varsToWatch.forEach(p -> {
191216
if (p.first.equals(name)) {
192-
// check subtyping
217+
// boolean valid = isValidSubtype(name, name, annotation);
218+
// if (!valid)
219+
messager.printMessage(Kind.ERROR,
220+
String.format("Invalid subtype for method call: ", methodCall.toString()));
193221
}
194222
});
195223
}
196224

197225
});
198-
cu.findAll(AssignExpr.class).forEach(assignExpr -> {
226+
// cu.findAll(AssignExpr.class).forEach(assignExpr -> {
227+
// if (!(assignExpr.getTarget() instanceof NameExpr))
228+
// return;
229+
// NameExpr name = (NameExpr) assignExpr.getTarget();
230+
// if (!varsToWatchMap.containsKey(name.toString()))
231+
// return;
199232

200-
messager.printMessage(Kind.NOTE, assignExpr.toString());
201-
messager.printMessage(Kind.NOTE, assignExpr.getTarget().getClass().toString());
202-
messager.printMessage(Kind.NOTE, assignExpr.getValue().getClass().toString());
203-
});
233+
// });
204234
// cu.findAll(ForEachStmt.class).forEach(stmt -> {
205235

236+
// });
237+
// cu.findAll(VariableDeclarationExpr.class).forEach(stmt -> {
238+
206239
// });
207240
}
208241

@@ -218,15 +251,33 @@ private Map<String, Map<Integer, Type>> collectMethodParams(CompilationUnit cu,
218251
String methodName = dec.getNameAsString();
219252
if (type.getNameAsString().equals(classData.className())) {
220253
mp.putIfAbsent(methodName, new HashMap<>());
221-
mp.get(methodName).put(i, type.getTypeArguments().get().get(classData.indexOfParam()));
254+
// mp.get(methodName).put(i,
255+
// type.getTypeArguments().get().get(classData.indexOfParam()));
222256
}
223257
}
224258
});
225259
return mp;
226260
}
227261

228-
private String resolveType() {
229-
return null;
262+
private boolean isValidSubtype(String assigneeType, String assignedType, MyVariance annotation) {
263+
if (!classHierarchy.containsVertex(assigneeType)) {
264+
messager.printMessage(Kind.WARNING,
265+
String.format("%s is not a user defined type, so no subtyping checks can be made", assigneeType));
266+
return true;
267+
}
268+
if (!classHierarchy.containsVertex(assignedType)) {
269+
messager.printMessage(Kind.WARNING,
270+
String.format("%s is not a user defined type, so no subtyping checks can be made", assignedType));
271+
return true;
272+
}
273+
switch (annotation.variance()) {
274+
case COVARIANT:
275+
return classHierarchy.isDescendant(assignedType, assigneeType);
276+
case CONTRAVARIANT:
277+
return classHierarchy.isDescendant(assigneeType, assignedType);
278+
default:
279+
return false;
280+
}
230281
}
231282

232283
}
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package io.github.bldl.astParsing.util;
22

3-
public record ClassData(String className, String leftmostBound, int indexOfParam) {
3+
import java.util.Map;
4+
5+
public record ClassData(String className, Map<String, ParamData> params) {
46

57
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package io.github.bldl.astParsing.util;
2+
3+
import io.github.bldl.annotationProcessing.annotations.MyVariance;
4+
5+
public record ParamData(int index, String leftmostBound, MyVariance variance) {
6+
}

src/main/java/io/github/bldl/astParsing/visitors/CastInsertionVisitor.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,24 @@
33
import java.util.Map;
44
import java.util.Optional;
55
import com.github.javaparser.ast.expr.MethodCallExpr;
6+
import com.github.javaparser.ast.NodeList;
67
import com.github.javaparser.ast.expr.CastExpr;
78
import com.github.javaparser.ast.expr.EnclosedExpr;
89
import com.github.javaparser.ast.expr.Expression;
910
import com.github.javaparser.ast.expr.NameExpr;
1011
import com.github.javaparser.ast.type.ClassOrInterfaceType;
1112
import com.github.javaparser.ast.visitor.ModifierVisitor;
1213
import com.github.javaparser.ast.visitor.Visitable;
14+
import com.github.javaparser.ast.type.Type;
1315

1416
import io.github.bldl.astParsing.util.MethodData;
1517
import io.github.bldl.util.Pair;
1618

1719
public class CastInsertionVisitor extends ModifierVisitor<Void> {
18-
private final Pair<String, String> ref;
20+
private final Pair<String, ClassOrInterfaceType> ref;
1921
private final Map<String, MethodData> methodMap;
2022

21-
public CastInsertionVisitor(Pair<String, String> ref, Map<String, MethodData> methodMap) {
23+
public CastInsertionVisitor(Pair<String, ClassOrInterfaceType> ref, Map<String, MethodData> methodMap) {
2224
this.ref = ref;
2325
this.methodMap = methodMap;
2426
}
@@ -32,7 +34,10 @@ public Visitable visit(MethodCallExpr n, Void arg) {
3234
if (expr.getNameAsString().equals(ref.first)) {
3335
MethodData data = methodMap.get(n.getNameAsString());
3436
if (data != null && data.shouldCast()) {
35-
String castString = data.castString().replace("*", ref.second);
37+
NodeList<Type> arguments = ref.second.getTypeArguments().get();
38+
String castString = data.castString();
39+
for (int i = 0; i < arguments.size(); ++i)
40+
castString = data.castString().replace(Integer.toString(i), arguments.get(i).asString());
3641
ClassOrInterfaceType castType = new ClassOrInterfaceType(null, castString);
3742
CastExpr cast = new CastExpr(castType, n);
3843
EnclosedExpr enclosedCast = new EnclosedExpr(cast);

0 commit comments

Comments
 (0)