Skip to content

Commit

Permalink
Add null checks to generated GRPC API (#2455)
Browse files Browse the repository at this point in the history
Motivation:

At the moment the generated GRPC code does not perform input validation
in all necessary methods and might pass invalid/null inputs
down into the GRPC layer where it is harder to track down what
went wrong in the first place.

Modifications:

This changeset adds explicit requireNonNull to all generated methods
which accept user input and adds a test case to verify that the
requireNonNull is performed at the earliest point in time and
not deeper in the call stack.

Result:

Early argument null checks to improve debugability in case of
invalid argument values.
  • Loading branch information
daschl authored Dec 9, 2022
1 parent d5918cc commit e83c645
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright © 2022 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.grpc.netty;

import io.servicetalk.grpc.api.GrpcClientMetadata;

import io.grpc.examples.helloworld.Greeter;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable;

import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

class GrpcInputValidationTest {

@Test
void verifyClientEarlyNonNullArgumentsCheck() throws Exception {
try (Greeter.BlockingGreeterClient client = GrpcClients
.forAddress("localhost", 0)
.buildBlocking(new Greeter.ClientFactory())) {

assertEarlyRequireNonNull(() -> client.sayHello(null));
assertEarlyRequireNonNull(() -> client.sayHello((GrpcClientMetadata) null, null));
}
}

@Test
void verifyServiceEarlyNonNullArgumentsCheck() {
assertEarlyRequireNonNull(() -> GrpcServers
.forAddress(localAddress(0))
.listenAndAwait(new Greeter.ServiceFactory.Builder().sayHello(null).build()));

assertEarlyRequireNonNull(() -> GrpcServers
.forAddress(localAddress(0))
.listenAndAwait(new Greeter.ServiceFactory.Builder().sayHello(null, null).build()));
}

private static void assertEarlyRequireNonNull(final Executable executable) {
NullPointerException ex = assertThrows(NullPointerException.class, executable);
assertEquals("requireNonNull", ex.getStackTrace()[0].getMethodName());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ private TypeSpec.Builder addServiceFactory(final State state, final TypeSpec.Bui
.addModifiers(PUBLIC)
.addParameter(rpcInterface.className, rpc, FINAL)
.returns(builderClass)
.addStatement("$T.requireNonNull($L)", Objects, rpc)
.addCode(addRouteCode)
.addStatement("return this")
.build())
Expand All @@ -629,6 +630,8 @@ private TypeSpec.Builder addServiceFactory(final State state, final TypeSpec.Bui
.addParameter(GrpcExecutionStrategy, strategy, FINAL)
.addParameter(rpcInterface.className, rpc, FINAL)
.returns(builderClass)
.addStatement("$T.requireNonNull($L)", Objects, strategy)
.addStatement("$T.requireNonNull($L)", Objects, rpc)
.addCode(addRouteExecCode)
.addStatement("return this")
.build());
Expand All @@ -638,6 +641,7 @@ private TypeSpec.Builder addServiceFactory(final State state, final TypeSpec.Bui
.addModifiers(PUBLIC)
.returns(builderClass)
.addParameter(state.serviceClass, service, FINAL)
.addStatement("$T.requireNonNull($L)", Objects, service)
.addStatement("$L($L)", registerRoutes, service)
.addStatement("return this")
.build());
Expand All @@ -661,7 +665,8 @@ private TypeSpec.Builder addServiceFactory(final State state, final TypeSpec.Bui
final MethodSpec.Builder addBlockingServiceMethodSpecBuilder = methodBuilder(addBlockingService)
.addModifiers(PUBLIC)
.returns(builderClass)
.addParameter(state.blockingServiceClass, service, FINAL);
.addParameter(state.blockingServiceClass, service, FINAL)
.addStatement("$T.requireNonNull($L)", Objects, service);
final MethodSpec.Builder registerRoutesMethodSpecBuilder = methodBuilder(registerRoutes)
.addModifiers(PROTECTED)
.addAnnotation(Override.class)
Expand Down Expand Up @@ -1328,10 +1333,14 @@ private void addClientFieldsAndMethods(final State state, final TypeSpec.Builder
(__, b) -> b.addAnnotation(Deprecated.class)
.addAnnotation(Override.class)
.addParameter(clientMetaData.className, metadata, FINAL)
.addStatement("$T.requireNonNull($L)", Objects, metadata)
.addStatement("$T.requireNonNull($L)", Objects, request)
.addStatement("return $L.$L($L, $L)", callFieldName, request, metadata, request)))
.addMethod(newRpcMethodSpec(clientMetaData.methodProto, rpcMethodSpecsFlags, false,
(__, b) -> b.addAnnotation(Override.class)
.addParameter(GrpcClientMetadata, metadata, FINAL)
.addStatement("$T.requireNonNull($L)", Objects, metadata)
.addStatement("$T.requireNonNull($L)", Objects, request)
.addStatement("return $L.$L($L, $L)", callFieldName, request, metadata, request)));

constructorBuilder
Expand Down

0 comments on commit e83c645

Please sign in to comment.