Skip to content

Commit

Permalink
Add CallerInfo annotation type validation
Browse files Browse the repository at this point in the history
  • Loading branch information
chamil321 committed Apr 21, 2021
1 parent 37ef8af commit bb455d2
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 19 deletions.
2 changes: 1 addition & 1 deletion http-ballerina/http_annotation.bal
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public annotation HttpPayload Payload on parameter, return;
#
# + respondType - Specifies the type of response
public type HttpCallerInfo record {|
string respondType?;
typedesc<ResponseMessage> respondType?;
|};

# The annotation which is used to configure the type of the response.
Expand Down
1 change: 1 addition & 0 deletions http-ballerina/http_commons.bal
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ isolated function uuid() returns string {
# + path - Resource path
# + method - http method of the request
# + statusCode - status code of the response
# + url - The request URL
isolated function addObservabilityInformation(string path, string method, int statusCode, string url) {
string statusCodeConverted = statusCode.toString();
_ = checkpanic observe:addTagToSpan(HTTP_URL, path);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[package]
org = "http_test"
name = "sample_10"
version = "0.1.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import ballerina/http;

type Person record {|
int id;
string name;
|};

type AA record {|
string a;
|};

service http:Service on new http:Listener(9090) {
resource function get callerInfo1(int xyz, @http:CallerInfo {respondType: string} http:Caller abc) {
checkpanic abc->respond("done");
}

resource function get callerInfo2(@http:CallerInfo {respondType: int} http:Caller abc) {
error? err = abc->respond(56);
}

resource function get callerInfo3(@http:CallerInfo {respondType: int} http:Caller abc) {
var err = abc->respond({a:"abc"});
if (err is error) {
}
}

resource function get callerInfo4(@http:CallerInfo {respondType: decimal} http:Caller abc) returns error? {
return abc->respond(5.6);
}

resource function get callerInfo5(@http:CallerInfo {respondType: string} http:Caller abc) returns error? {
var a = check abc->respond("done");
}

resource function get callerInfo6(@http:CallerInfo {respondType: decimal} http:Caller abc) returns error? {
return abc->respond(5565.6d);
}

resource function get callerInfo7(@http:CallerInfo {} http:Caller abc) returns error? {
return abc->respond(54341.6); // no validation
}

resource function get callerInfo8(@http:CallerInfo http:Caller abc) returns error? {
var a = abc->respond("ufww"); // no validation
if (a is error) {
}
}

resource function get callerInfo9(@http:CallerInfo {respondType: Person}http:Caller abc) returns error? {
error? a = abc->respond({id:123, name:"elle"});
}

resource function get callerInfo10(@http:CallerInfo {} http:Caller abc) returns error? {
checkpanic abc->respond(); // empty annotation value exp
}

resource function get callerInfo11(@http:CallerInfo {respondType: Person}http:Caller abc) returns error? {
return abc->'continue(); // different remote method call
}

resource function get callerInfo12(int xyz, @http:CallerInfo {respondType: string} http:Caller abc) {
int a = 5;
if (a > 0) {
checkpanic abc->respond("Go");
} else {
error? ab = abc->respond({a:"hello"}); //error
}
}

resource function get callerInfo13(@http:CallerInfo {respondType: string} http:Caller abc, http:Caller xyz) {
checkpanic xyz->respond("done"); // error:multiple callers
}

resource function get callerInfo14(@untainted @http:CallerInfo {respondType: string} http:Caller abc) {
checkpanic abc->respond("done"); // multiple annotations
}

resource function get callerInfo15(@http:CallerInfo {respondType: string} http:Caller abc) returns error? {
http:Client c = check new("path");
var a = c->get("done"); // different remote method call
if (a is error) {
}
}

resource function get callerInfo16(@http:CallerInfo {respondType: Person}http:Caller abc) returns error? {
Person p = {id:123, name:"elle"};
error? a = abc->respond(p);
}

resource function get callerInfo17(@http:CallerInfo {respondType: Person}http:Caller abc) returns error? {
error? a = abc->respond({school:1.23}); // This getting passed as map
}

resource function get callerInfo18(@http:CallerInfo {respondType: Person}http:Caller abc) returns error? {
AA val = { a: "hello" };
error? a = abc->respond(val); // error
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,8 @@ public class Constants {
public static final String PAYLOAD_ANNOTATION_TYPE = "HttpPayload";
public static final String CALLER_ANNOTATION_TYPE = "HttpCallerInfo";
public static final String HEADER_ANNOTATION_TYPE = "HttpHeader";
public static final String CALLER_ANNOTATION_NAME = "CallerInfo";
public static final String FIELD_RESPONSE_TYPE = "respondType";
public static final String RESPOND_METHOD_NAME = "respond";
public static final String ALLOWED_RETURN_UNION = "anydata|http:Response|http:StatusCodeRecord|error";
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ public enum HttpDiagnosticCodes {
HTTP_112("HTTP_112", "invalid type of query param '%s': expected one of the 'string', 'int', 'float', " +
"'boolean', 'decimal' types or the array types of them", ERROR),
HTTP_113("HTTP_113", "invalid union type of query param '%s': 'string', 'int', 'float', 'boolean', " +
"'decimal' type or the array types of them can only be union with '()'. Eg: string? or int[]?", ERROR);
"'decimal' type or the array types of them can only be union with '()'. Eg: string? or int[]?", ERROR),
HTTP_114("HTTP_114", "incompatible respond method argument type : expected '%s' according " +
"to the 'http:CallerInfo' annotation", ERROR);

private final String code;
private final String message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,16 @@
import io.ballerina.compiler.api.symbols.UnionTypeSymbol;
import io.ballerina.compiler.syntax.tree.AnnotationNode;
import io.ballerina.compiler.syntax.tree.FunctionDefinitionNode;
import io.ballerina.compiler.syntax.tree.MappingConstructorExpressionNode;
import io.ballerina.compiler.syntax.tree.MetadataNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NodeList;
import io.ballerina.compiler.syntax.tree.ParameterNode;
import io.ballerina.compiler.syntax.tree.PositionalArgumentNode;
import io.ballerina.compiler.syntax.tree.RequiredParameterNode;
import io.ballerina.compiler.syntax.tree.ReturnTypeDescriptorNode;
import io.ballerina.compiler.syntax.tree.SeparatedNodeList;
import io.ballerina.compiler.syntax.tree.SpecificFieldNode;
import io.ballerina.compiler.syntax.tree.SyntaxKind;
import io.ballerina.projects.plugins.SyntaxNodeAnalysisContext;
import io.ballerina.tools.diagnostics.DiagnosticFactory;
Expand All @@ -48,8 +54,10 @@
import java.util.stream.Collectors;

import static io.ballerina.stdlib.http.compiler.Constants.BALLERINA;
import static io.ballerina.stdlib.http.compiler.Constants.CALLER_ANNOTATION_NAME;
import static io.ballerina.stdlib.http.compiler.Constants.CALLER_ANNOTATION_TYPE;
import static io.ballerina.stdlib.http.compiler.Constants.CALLER_OBJ_NAME;
import static io.ballerina.stdlib.http.compiler.Constants.FIELD_RESPONSE_TYPE;
import static io.ballerina.stdlib.http.compiler.Constants.HEADER_ANNOTATION_TYPE;
import static io.ballerina.stdlib.http.compiler.Constants.HEADER_OBJ_NAME;
import static io.ballerina.stdlib.http.compiler.Constants.HTTP;
Expand Down Expand Up @@ -100,7 +108,9 @@ private static void extractInputParamTypeAndValidate(SyntaxNodeAnalysisContext c
if (parametersOptional.isEmpty()) {
return;
}
int paramIndex = -1;
for (ParameterSymbol param : parametersOptional.get()) {
paramIndex++;
String paramType = param.typeDescriptor().signature();
Optional<String> nameOptional = param.getName();
String paramName = nameOptional.isEmpty() ? "" : nameOptional.get();
Expand Down Expand Up @@ -137,8 +147,10 @@ private static void extractInputParamTypeAndValidate(SyntaxNodeAnalysisContext c
continue;
}
String typeName = typeNameOptional.get();
if (!CALLER_OBJ_NAME.equals(typeName) && !REQUEST_OBJ_NAME.equals(typeName) &&
!HEADER_OBJ_NAME.equals(typeName)) {
if (CALLER_OBJ_NAME.equals(typeName) || REQUEST_OBJ_NAME.equals(typeName) ||
HEADER_OBJ_NAME.equals(typeName)) {

} else {
reportInvalidParameterType(ctx, member, paramType);
}
} else {
Expand Down Expand Up @@ -288,7 +300,9 @@ private static void extractInputParamTypeAndValidate(SyntaxNodeAnalysisContext c
continue;
}
String callerTypeName = typeNameOptional.get();
if (!CALLER_OBJ_NAME.equals(callerTypeName)) {
if (CALLER_OBJ_NAME.equals(callerTypeName)) {
extractCallerInfoValueAndValidate(ctx, member, paramIndex);
} else {
reportInvalidCallerParameterType(ctx, member, paramName);
}
} else {
Expand Down Expand Up @@ -352,6 +366,52 @@ private static void extractInputParamTypeAndValidate(SyntaxNodeAnalysisContext c
}
}

private static void extractCallerInfoValueAndValidate(SyntaxNodeAnalysisContext ctx,
FunctionDefinitionNode member, int paramIndex) {
ParameterNode parameterNode = member.functionSignature().parameters().get(paramIndex);
NodeList<AnnotationNode> annotations = ((RequiredParameterNode) parameterNode).annotations();
for (AnnotationNode annotationNode : annotations) {
if (!annotationNode.annotReference().toString().contains(CALLER_ANNOTATION_NAME)) {
continue;
}
Optional<MappingConstructorExpressionNode> annotValue = annotationNode.annotValue();
if (annotValue.isEmpty()) {
continue;
}
if (annotValue.get().fields().size() == 0) {
continue;
}
SeparatedNodeList fields = annotValue.get().fields();
for (Object node : fields) {
SpecificFieldNode specificFieldNode = (SpecificFieldNode) node;
if (!specificFieldNode.fieldName().toString().equals(FIELD_RESPONSE_TYPE)) {
continue;
}
String expectedType = specificFieldNode.valueExpr().get().toString();
String callerToken = ((RequiredParameterNode) parameterNode).paramName().get().text();
List<PositionalArgumentNode> respondParamNodes = getRespondParamNode(ctx, member, callerToken);
if (respondParamNodes.isEmpty()) {
continue;
}
for (PositionalArgumentNode argumentNode : respondParamNodes) {
TypeSymbol argTypeSymbol = ctx.semanticModel().type(argumentNode.expression()).get();
TypeSymbol annotValueSymbol =
(TypeSymbol) ctx.semanticModel().symbol(specificFieldNode.valueExpr().get()).get();
if (!annotValueSymbol.assignableTo(argTypeSymbol)) {
reportInCompatibleCallerInfoType(ctx, argumentNode, expectedType);
}
}
}
}
}

private static List<PositionalArgumentNode> getRespondParamNode(SyntaxNodeAnalysisContext ctx,
FunctionDefinitionNode member, String callerToken) {
RespondExpressionVisitor respondNodeVisitor = new RespondExpressionVisitor(ctx, callerToken);
member.accept(respondNodeVisitor);
return respondNodeVisitor.getRespondStatementNodes();
}

private static boolean isAllowedQueryParamType(TypeDescKind kind) {
return kind == TypeDescKind.STRING || kind == TypeDescKind.INT || kind == TypeDescKind.FLOAT ||
kind == TypeDescKind.DECIMAL || kind == TypeDescKind.BOOLEAN;
Expand Down Expand Up @@ -390,8 +450,8 @@ private static void validateReturnType(SyntaxNodeAnalysisContext ctx, FunctionDe
validateReturnType(ctx, member, returnTypeStringValue, typeSymbol.typeKind(), typeSymbol);
}
} else if (kind == TypeDescKind.ARRAY) {
TypeDescKind elementKind = ((ArrayTypeSymbol) returnTypeSymbol).memberTypeDescriptor().typeKind();
validateArrayElementType(ctx, member, returnTypeStringValue, elementKind, returnTypeSymbol);
TypeSymbol memberTypeDescriptor = ((ArrayTypeSymbol) returnTypeSymbol).memberTypeDescriptor();
validateArrayElementType(ctx, member, returnTypeStringValue, memberTypeDescriptor);
} else if (kind == TypeDescKind.TYPE_REFERENCE) {
TypeSymbol typeDescriptor = ((TypeReferenceTypeSymbol) returnTypeSymbol).typeDescriptor();
TypeDescKind typeDescKind = typeDescriptor.typeKind();
Expand All @@ -403,14 +463,9 @@ private static void validateReturnType(SyntaxNodeAnalysisContext ctx, FunctionDe
reportInvalidReturnType(ctx, member, returnTypeStringValue);
}
} else if (kind == TypeDescKind.MAP) {
Optional<TypeSymbol> typeSymbol = ((MapTypeSymbol) returnTypeSymbol).typeParameter();
if (typeSymbol.isEmpty()) {
reportInvalidReturnType(ctx, member, returnTypeStringValue);
} else {
TypeSymbol elementTypeSymbol = typeSymbol.get();
TypeDescKind typeDescKind = elementTypeSymbol.typeKind();
validateReturnType(ctx, member, returnTypeStringValue, typeDescKind, elementTypeSymbol);
}
TypeSymbol typeSymbol = ((MapTypeSymbol) returnTypeSymbol).typeParam();
TypeDescKind typeDescKind = typeSymbol.typeKind();
validateReturnType(ctx, member, returnTypeStringValue, typeDescKind, typeSymbol);
} else if (kind == TypeDescKind.TABLE) {
TypeSymbol typeSymbol = ((TableTypeSymbol) returnTypeSymbol).rowTypeParameter();
if (typeSymbol == null) {
Expand All @@ -425,12 +480,12 @@ private static void validateReturnType(SyntaxNodeAnalysisContext ctx, FunctionDe
}

private static void validateArrayElementType(SyntaxNodeAnalysisContext ctx, FunctionDefinitionNode member,
String typeStringValue, TypeDescKind kind,
TypeSymbol returnTypeSymbol) {
String typeStringValue, TypeSymbol memberTypeDescriptor) {
TypeDescKind kind = memberTypeDescriptor.typeKind();
if (isBasicTypeDesc(kind)) {
return;
} else if (kind == TypeDescKind.TYPE_REFERENCE) {
TypeSymbol typeDescriptor = ((TypeReferenceTypeSymbol) returnTypeSymbol).typeDescriptor();
TypeSymbol typeDescriptor = ((TypeReferenceTypeSymbol) memberTypeDescriptor).typeDescriptor();
TypeDescKind typeDescKind = typeDescriptor.typeKind();
if (typeDescKind == TypeDescKind.OBJECT) {
reportInvalidReturnType(ctx, member, typeStringValue);
Expand Down Expand Up @@ -527,7 +582,12 @@ private static void reportInvalidUnionQueryType(SyntaxNodeAnalysisContext ctx, F
updateDiagnostic(ctx, node, paramName, HttpDiagnosticCodes.HTTP_113);
}

private static void updateDiagnostic(SyntaxNodeAnalysisContext ctx, FunctionDefinitionNode node, String returnType,
private static void reportInCompatibleCallerInfoType(SyntaxNodeAnalysisContext ctx, PositionalArgumentNode node,
String paramName) {
updateDiagnostic(ctx, node, paramName, HttpDiagnosticCodes.HTTP_114);
}

private static void updateDiagnostic(SyntaxNodeAnalysisContext ctx, Node node, String returnType,
HttpDiagnosticCodes httpDiagnosticCodes) {
DiagnosticInfo diagnosticInfo = getDiagnosticInfo(httpDiagnosticCodes, returnType);
ctx.reportDiagnostic(DiagnosticFactory.createDiagnostic(diagnosticInfo, node.location()));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) 2021, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
*
* WSO2 Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package io.ballerina.stdlib.http.compiler;

import io.ballerina.compiler.api.ModuleID;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.syntax.tree.NodeVisitor;
import io.ballerina.compiler.syntax.tree.PositionalArgumentNode;
import io.ballerina.compiler.syntax.tree.RemoteMethodCallActionNode;
import io.ballerina.compiler.syntax.tree.SimpleNameReferenceNode;
import io.ballerina.projects.plugins.SyntaxNodeAnalysisContext;

import java.util.ArrayList;
import java.util.List;

/**
* A class for visiting respond expression.
*/
public class RespondExpressionVisitor extends NodeVisitor {
private final SyntaxNodeAnalysisContext ctx;
private final String callerToken;
private List<PositionalArgumentNode> respondStatementNodes = new ArrayList<>();

public RespondExpressionVisitor(SyntaxNodeAnalysisContext ctx, String callerToken) {
this.ctx = ctx;
this.callerToken = callerToken;
}

@Override
public void visit(RemoteMethodCallActionNode node) {
TypeSymbol typeSymbol = ctx.semanticModel().type(node.expression()).get();
ModuleID moduleID = typeSymbol.getModule().get().id();
if (!Constants.BALLERINA.equals(moduleID.orgName())) {
return;
}
if (!Constants.HTTP.equals(moduleID.moduleName())) {
return;
}
if (!Constants.CALLER_OBJ_NAME.equals(typeSymbol.getName().get())) {
return;
}
if (!callerToken.equals(node.expression().toString())) {
return;
}
SimpleNameReferenceNode simpleNameReferenceNode = node.methodName();
if (simpleNameReferenceNode.name().text().equals(Constants.RESPOND_METHOD_NAME)) {
if (node.arguments().size() > 0) {
respondStatementNodes.add((PositionalArgumentNode) node.arguments().get(0));
}
}
}

List<PositionalArgumentNode> getRespondStatementNodes() {
return respondStatementNodes;
}
}

0 comments on commit bb455d2

Please sign in to comment.