Skip to content

Commit

Permalink
Prevent non-public classes of method params from breaking interceptor…
Browse files Browse the repository at this point in the history
… handling

Fixes: quarkusio#18477
  • Loading branch information
geoand committed Jul 9, 2021
1 parent d82ba05 commit 0ed51da
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,10 @@ private Map<MethodInfo, DecorationInfo> initDecoratedMethods() {
Collections.sort(bound, Comparator.comparingInt(DecoratorInfo::getPriority).thenComparing(DecoratorInfo::getBeanClass));

Map<MethodKey, DecorationInfo> candidates = new HashMap<>();
addDecoratedMethods(candidates, target.get().asClass(), bound,
new SubclassSkipPredicate(beanDeployment.getAssignabilityCheck()::isAssignableFrom));
ClassInfo classInfo = target.get().asClass();
addDecoratedMethods(candidates, classInfo, classInfo, bound,
new SubclassSkipPredicate(beanDeployment.getAssignabilityCheck()::isAssignableFrom,
beanDeployment.getBeanArchiveIndex()));

Map<MethodInfo, DecorationInfo> decoratedMethods = new HashMap<>(candidates.size());
for (Entry<MethodKey, DecorationInfo> entry : candidates.entrySet()) {
Expand All @@ -545,8 +547,8 @@ private Map<MethodInfo, DecorationInfo> initDecoratedMethods() {
}

private void addDecoratedMethods(Map<MethodKey, DecorationInfo> decoratedMethods, ClassInfo classInfo,
List<DecoratorInfo> boundDecorators, SubclassSkipPredicate skipPredicate) {
skipPredicate.startProcessing(classInfo);
ClassInfo originalClassInfo, List<DecoratorInfo> boundDecorators, SubclassSkipPredicate skipPredicate) {
skipPredicate.startProcessing(classInfo, originalClassInfo);
for (MethodInfo method : classInfo.methods()) {
if (skipPredicate.test(method)) {
continue;
Expand All @@ -560,7 +562,7 @@ private void addDecoratedMethods(Map<MethodKey, DecorationInfo> decoratedMethods
if (!classInfo.superName().equals(DotNames.OBJECT)) {
ClassInfo superClassInfo = getClassByName(beanDeployment.getBeanArchiveIndex(), classInfo.superName());
if (superClassInfo != null) {
addDecoratedMethods(decoratedMethods, superClassInfo, boundDecorators, skipPredicate);
addDecoratedMethods(decoratedMethods, superClassInfo, originalClassInfo, boundDecorators, skipPredicate);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,22 @@ static Set<MethodInfo> addInterceptedMethodCandidates(BeanDeployment beanDeploym
Map<MethodKey, Set<AnnotationInstance>> candidates,
List<AnnotationInstance> classLevelBindings, Consumer<BytecodeTransformer> bytecodeTransformerConsumer,
boolean transformUnproxyableClasses) {
return addInterceptedMethodCandidates(beanDeployment, classInfo, candidates, classLevelBindings,
return addInterceptedMethodCandidates(beanDeployment, classInfo, classInfo, candidates, classLevelBindings,
bytecodeTransformerConsumer, transformUnproxyableClasses,
new SubclassSkipPredicate(beanDeployment.getAssignabilityCheck()::isAssignableFrom), false);
new SubclassSkipPredicate(beanDeployment.getAssignabilityCheck()::isAssignableFrom,
beanDeployment.getBeanArchiveIndex()),
false);
}

static Set<MethodInfo> addInterceptedMethodCandidates(BeanDeployment beanDeployment, ClassInfo classInfo,
ClassInfo originalClassInfo,
Map<MethodKey, Set<AnnotationInstance>> candidates,
List<AnnotationInstance> classLevelBindings, Consumer<BytecodeTransformer> bytecodeTransformerConsumer,
boolean transformUnproxyableClasses, SubclassSkipPredicate skipPredicate, boolean ignoreMethodLevelBindings) {

Set<NameAndDescriptor> methodsFromWhichToRemoveFinal = new HashSet<>();
Set<MethodInfo> finalMethodsFoundAndNotChanged = new HashSet<>();
skipPredicate.startProcessing(classInfo);
skipPredicate.startProcessing(classInfo, originalClassInfo);

for (MethodInfo method : classInfo.methods()) {
if (skipPredicate.test(method)) {
Expand Down Expand Up @@ -220,7 +223,7 @@ static Set<MethodInfo> addInterceptedMethodCandidates(BeanDeployment beanDeploym
ClassInfo superClassInfo = getClassByName(beanDeployment.getBeanArchiveIndex(), classInfo.superName());
if (superClassInfo != null) {
finalMethodsFoundAndNotChanged
.addAll(addInterceptedMethodCandidates(beanDeployment, superClassInfo, candidates,
.addAll(addInterceptedMethodCandidates(beanDeployment, superClassInfo, classInfo, candidates,
classLevelBindings, bytecodeTransformerConsumer, transformUnproxyableClasses, skipPredicate,
ignoreMethodLevelBindings));
}
Expand All @@ -230,7 +233,7 @@ static Set<MethodInfo> addInterceptedMethodCandidates(BeanDeployment beanDeploym
ClassInfo interfaceInfo = getClassByName(beanDeployment.getBeanArchiveIndex(), i);
if (interfaceInfo != null) {
//interfaces can't have final methods
addInterceptedMethodCandidates(beanDeployment, interfaceInfo, candidates,
addInterceptedMethodCandidates(beanDeployment, interfaceInfo, originalClassInfo, candidates,
classLevelBindings, bytecodeTransformerConsumer, transformUnproxyableClasses,
skipPredicate, true);
}
Expand Down Expand Up @@ -448,22 +451,27 @@ public MethodVisitor visitMethod(int access, String name, String descriptor, Str
/**
* This stateful predicate can be used to skip methods that should not be added to the generated subclass.
* <p>
* Don't forget to call {@link SubclassSkipPredicate#startProcessing(ClassInfo)} before the methods are processed and
* Don't forget to call {@link SubclassSkipPredicate#startProcessing(ClassInfo, ClassInfo)} before the methods are processed
* and
* {@link SubclassSkipPredicate#methodsProcessed()} afterwards.
*/
static class SubclassSkipPredicate implements Predicate<MethodInfo> {

private final BiFunction<Type, Type, Boolean> assignableFromFun;
private final IndexView beanArchiveIndex;
private ClassInfo clazz;
private ClassInfo originalClazz;
private List<MethodInfo> regularMethods;
private Set<MethodInfo> bridgeMethods = new HashSet<>();

public SubclassSkipPredicate(BiFunction<Type, Type, Boolean> assignableFromFun) {
public SubclassSkipPredicate(BiFunction<Type, Type, Boolean> assignableFromFun, IndexView beanArchiveIndex) {
this.assignableFromFun = assignableFromFun;
this.beanArchiveIndex = beanArchiveIndex;
}

void startProcessing(ClassInfo clazz) {
void startProcessing(ClassInfo clazz, ClassInfo originalClazz) {
this.clazz = clazz;
this.originalClazz = originalClazz;
this.regularMethods = new ArrayList<>();
for (MethodInfo method : clazz.methods()) {
if (!Modifier.isAbstract(method.flags()) && !method.isSynthetic() && !isBridge(method)) {
Expand Down Expand Up @@ -505,10 +513,40 @@ public boolean test(MethodInfo method) {
// Do not skip default methods - public non-abstract instance methods declared in an interface
return false;
}

List<Type> parameters = method.parameters();
if (!parameters.isEmpty() && (beanArchiveIndex != null)) {
String originalClassPackage = determinePackage(originalClazz.name());
for (Type type : parameters) {
ClassInfo parameterClassInfo = beanArchiveIndex.getClassByName(type.name());
if (parameterClassInfo == null) {
continue; // hope for the best
}
if (Modifier.isPrivate(parameterClassInfo.flags())) {
return true; // parameters whose class is private can not be loaded, as we would end up with IllegalAccessError when trying to access the use the load the class
}
if (!Modifier.isPublic(parameterClassInfo.flags())) {
if (determinePackage(parameterClassInfo.name()).equals(originalClassPackage)) {
return false;
}
// parameters whose class is package-private and the package is not the same as the package of the method for which we are checking can not be loaded,
// as we would end up with IllegalAccessError when trying to access the use the load the class
return true;
}
}
}

// Note that we intentionally do not skip final methods here - these are handled later
return false;
}

private String determinePackage(DotName dotName) {
if (dotName.isInner()) {
dotName = dotName.prefix();
}
return dotName.prefix() == null ? "" : dotName.prefix().toString();
}

private boolean hasImplementation(MethodInfo bridge) {
for (MethodInfo declaredMethod : regularMethods) {
if (bridge.name().equals(declaredMethod.name())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ public class SubclassSkipPredicateTest {
public void testPredicate() throws IOException {
IndexView index = Basics.index(Base.class, Submarine.class, Long.class, Number.class);
AssignabilityCheck assignabilityCheck = new AssignabilityCheck(index, null);
SubclassSkipPredicate predicate = new SubclassSkipPredicate(assignabilityCheck::isAssignableFrom);
SubclassSkipPredicate predicate = new SubclassSkipPredicate(assignabilityCheck::isAssignableFrom, null);

ClassInfo submarineClass = index.getClassByName(DotName.createSimple(Submarine.class.getName()));
predicate.startProcessing(submarineClass);
predicate.startProcessing(submarineClass, submarineClass);

List<MethodInfo> echos = submarineClass.methods().stream().filter(m -> m.name().equals("echo"))
.collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package io.quarkus.arc.test.interceptors.methodargs;

import io.quarkus.arc.test.interceptors.methodargs.base.BaseExecutor;
import javax.inject.Singleton;

@Singleton
@Simple
public class CustomExecutor extends BaseExecutor {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkus.arc.test.interceptors.methodargs;

import java.util.List;
import javax.inject.Singleton;

@Singleton
@Simple
public class ExampleResource {

public String create(List<String> strings) {
return String.join(",", strings);
}

String otherCreate(PackagePrivate packagePrivate) {
return packagePrivate.toString();
}

static class PackagePrivate {

@Override
public String toString() {
return "PackagePrivate{}";
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package io.quarkus.arc.test.interceptors.methodargs;

import static org.junit.jupiter.api.Assertions.assertEquals;

import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.test.ArcTestContainer;
import io.quarkus.arc.test.interceptors.Counter;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

public class MethodArgsInterceptionTest {

@RegisterExtension
public ArcTestContainer container = new ArcTestContainer(Simple.class, SimpleInterceptor.class, ExampleResource.class,
CustomExecutor.class, Counter.class);

@Test
public void testInterception() {
ArcContainer container = Arc.container();
Counter counter = container.instance(Counter.class).get();

counter.reset();
ExampleResource exampleResource = container.instance(ExampleResource.class).get();

assertEquals("first,second", exampleResource.create(List.of("first", "second")));
assertEquals(1, counter.get());

assertEquals("PackagePrivate{}", exampleResource.otherCreate(new ExampleResource.PackagePrivate()));
assertEquals(2, counter.get());

CustomExecutor customThreadExecutor = container.instance(CustomExecutor.class).get();
assertEquals("run", customThreadExecutor.run());
assertEquals(3, counter.get());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package io.quarkus.arc.test.interceptors.methodargs;

import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.ElementType.TYPE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

import java.lang.annotation.Documented;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import javax.interceptor.InterceptorBinding;

@Target({ TYPE, METHOD })
@Retention(RUNTIME)
@Documented
@InterceptorBinding
public @interface Simple {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.quarkus.arc.test.interceptors.methodargs;

import io.quarkus.arc.test.interceptors.Counter;
import javax.annotation.Priority;
import javax.inject.Inject;
import javax.interceptor.AroundInvoke;
import javax.interceptor.Interceptor;
import javax.interceptor.InvocationContext;

@Simple
@Priority(1)
@Interceptor
public class SimpleInterceptor {

@Inject
Counter counter;

@AroundInvoke
Object mySuperCoolAroundInvoke(InvocationContext ctx) throws Exception {
counter.incrementAndGet();
return ctx.proceed();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkus.arc.test.interceptors.methodargs.base;

public class BaseExecutor {

public String run() {
return "run";
}

protected void runWorker(Worker run) {

}

static class Worker {

}
}

0 comments on commit 0ed51da

Please sign in to comment.