Skip to content

Commit

Permalink
Merge pull request #31045 from ldclakmal/dev
Browse files Browse the repository at this point in the history
Improve desugar implementation for websocket auth
  • Loading branch information
ldclakmal authored Jun 9, 2021
2 parents c4b3049 + 3fc4a76 commit 5f7106e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.wso2.ballerinalang.compiler.desugar;

import io.ballerina.tools.diagnostics.Location;
import org.ballerinalang.model.tree.NodeKind;
import org.wso2.ballerinalang.compiler.semantics.analyzer.SymbolResolver;
import org.wso2.ballerinalang.compiler.semantics.model.SymbolEnv;
import org.wso2.ballerinalang.compiler.semantics.model.SymbolTable;
Expand All @@ -29,6 +30,9 @@
import org.wso2.ballerinalang.compiler.semantics.model.types.BObjectType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BUnionType;
import org.wso2.ballerinalang.compiler.tree.BLangBlockFunctionBody;
import org.wso2.ballerinalang.compiler.tree.BLangExprFunctionBody;
import org.wso2.ballerinalang.compiler.tree.BLangFunction;
import org.wso2.ballerinalang.compiler.tree.BLangIdentifier;
import org.wso2.ballerinalang.compiler.tree.BLangImportPackage;
import org.wso2.ballerinalang.compiler.tree.BLangResourceFunction;
Expand All @@ -37,6 +41,7 @@
import org.wso2.ballerinalang.compiler.tree.expressions.BLangListConstructorExpr;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangLiteral;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangSimpleVarRef;
import org.wso2.ballerinalang.compiler.tree.statements.BLangBlockStmt;
import org.wso2.ballerinalang.compiler.tree.statements.BLangSimpleVariableDef;
import org.wso2.ballerinalang.compiler.tree.statements.BLangStatement;
import org.wso2.ballerinalang.compiler.util.CompilerContext;
Expand All @@ -49,13 +54,14 @@
import static org.ballerinalang.model.symbols.SymbolOrigin.VIRTUAL;

/**
* This class injects the code that invokes the http filters to the first lines of an http resource. The code injected
* is as follows:
* This class injects the code that is required for declarative authentication, to the first lines of an given resource.
* The code injected will be one of the followings:
* <blockquote><pre>
* authenticateResource(self, "resourceMethod", ["resourcePath", "resourcePath"])
* authenticateResource(self)
* </pre></blockquote>
* `authenticateResource` is expected to panic with a distinct error that http:Listener knows and the listener handle
* that error specifically.
* `authenticateResource` is expected to panic with a `distinct error` that the relevant listener knows and the listener
* handles that error specifically.
*
* @since 0.974.1
*/
Expand All @@ -66,7 +72,8 @@ public class HttpFiltersDesugar {
private final Names names;

private static final String ORG_NAME = "ballerina";
private static final String PACKAGE_NAME = "http";
private static final String HTTP_PACKAGE_NAME = "http";
private static final String WEBSOCKET_PACKAGE_NAME = "websocket";
private static final String AUTHENTICATE_RESOURCE = "authenticateResource";

private static final CompilerContext.Key<HttpFiltersDesugar> HTTP_FILTERS_DESUGAR_KEY =
Expand All @@ -88,88 +95,122 @@ private HttpFiltersDesugar(CompilerContext context) {
this.names = Names.getInstance(context);
}

boolean isHttpPackage(List<BType> expressionTypes) {
void desugarFunction(BLangFunction functionNode, SymbolEnv env, List<BType> expressionTypes) {
if (isDefinedInStdLibPackage(expressionTypes, HTTP_PACKAGE_NAME)) {
addAuthDesugarFunctionInvocation(functionNode, env, HTTP_PACKAGE_NAME);
} else if (isDefinedInStdLibPackage(expressionTypes, WEBSOCKET_PACKAGE_NAME)) {
addAuthDesugarFunctionInvocation(functionNode, env, WEBSOCKET_PACKAGE_NAME);
}
}

boolean isDefinedInStdLibPackage(List<BType> expressionTypes, String packageName) {
for (BType expressionType : expressionTypes) {
if (expressionType.tag == TypeTags.UNION) {
for (BType memberType : ((BUnionType) expressionType).getMemberTypes()) {
if (memberType.tag == TypeTags.OBJECT && isHttpPackage((BObjectType) memberType)) {
if (memberType.tag == TypeTags.OBJECT &&
isDefinedInStdLibPackage((BObjectType) memberType, packageName)) {
return true;
}
}
} else if (expressionType.tag == TypeTags.OBJECT && isHttpPackage((BObjectType) expressionType)) {
} else if (expressionType.tag == TypeTags.OBJECT &&
isDefinedInStdLibPackage((BObjectType) expressionType, packageName)) {
return true;
}
}
return false;
}

private boolean isHttpPackage(BObjectType type) {
return type.tsymbol.pkgID.orgName.value.equals(ORG_NAME) && type.tsymbol.pkgID.name.value.equals(PACKAGE_NAME);
}

void addFilterStatements(BLangResourceFunction resourceNode, SymbolEnv env, List<BLangStatement> statements) {
addHttpFilterFunctionInvocation(resourceNode, env, statements);
private boolean isDefinedInStdLibPackage(BObjectType type, String packageName) {
return type.tsymbol.pkgID.orgName.value.equals(ORG_NAME) && type.tsymbol.pkgID.name.value.equals(packageName);
}

private void addHttpFilterFunctionInvocation(BLangResourceFunction resourceNode, SymbolEnv env,
List<BLangStatement> statements) {
BPackageSymbol httpPackageSymbol = getHttpPackageSymbol(env);
if (httpPackageSymbol == null) {
// Couldn't find http package in imports list.
void addAuthDesugarFunctionInvocation(BLangFunction functionNode, SymbolEnv env, String packageName) {
BPackageSymbol packageSymbol = getPackageSymbol(env, packageName);
if (packageSymbol == null) {
// Couldn't find http package in imports list or symbols list.
return;
}
// Expected method type:
// `function authenticateResource(service object {} servieRef, string methodName, string[] resourcePath)`

// Expected method type for HTTP:
// `function authenticateResource(service object {} serviceRef, string methodName, string[] resourcePath)`
// Expected method type for WebSocket:
// `function authenticateResource(service object {} serviceRef)`
// The function is expected to panic with a distinct error when fail to authenticate.
// http listener will handle this error.
BSymbol methodSym = symResolver.lookupMethodInModule(httpPackageSymbol,
names.fromString(AUTHENTICATE_RESOURCE), env);
// Relevant listener will handle this error.
BSymbol methodSym = symResolver.lookupMethodInModule(packageSymbol, names.fromString(AUTHENTICATE_RESOURCE),
env);
if (methodSym == symTable.notFoundSymbol || !(methodSym instanceof BInvokableSymbol)) {
return;
}
BInvokableSymbol filterInvocationSymbol = (BInvokableSymbol) methodSym;

BInvokableSymbol invocationSymbol = (BInvokableSymbol) methodSym;
BLangResourceFunction resourceNode = (BLangResourceFunction) functionNode;
Location pos = resourceNode.getPosition();

// Create method invocation.
BLangSimpleVarRef selfRef = ASTBuilderUtil.createVariableRef(
pos, resourceNode.symbol.receiverSymbol);

BLangLiteral methodNameLiteral = ASTBuilderUtil.createLiteral(
pos, symTable.stringType, resourceNode.methodName.value);

ArrayList<BLangExpression> pathLiterals = new ArrayList<>();
for (BLangIdentifier path : resourceNode.resourcePath) {
pathLiterals.add(ASTBuilderUtil.createLiteral(pos, symTable.stringType, path.value));
}
BLangListConstructorExpr.BLangArrayLiteral resourcePathLiteral = ASTBuilderUtil.createEmptyArrayLiteral(
pos, (BArrayType) symTable.stringArrayType);
resourcePathLiteral.exprs = pathLiterals;
BLangSimpleVarRef selfRef = ASTBuilderUtil.createVariableRef(pos, resourceNode.symbol.receiverSymbol);

ArrayList<BLangExpression> args = new ArrayList<>();
args.add(selfRef);
args.add(methodNameLiteral);
args.add(resourcePathLiteral);

BLangInvocation invocationExpr = ASTBuilderUtil
.createInvocationExprForMethod(pos, filterInvocationSymbol, args, symResolver);
if (packageName.equals(HTTP_PACKAGE_NAME)) {
BLangLiteral methodNameLiteral = ASTBuilderUtil.createLiteral(
pos, symTable.stringType, resourceNode.methodName.value);

ArrayList<BLangExpression> pathLiterals = new ArrayList<>();
for (BLangIdentifier path : resourceNode.resourcePath) {
pathLiterals.add(ASTBuilderUtil.createLiteral(pos, symTable.stringType, path.value));
}
BLangListConstructorExpr.BLangArrayLiteral resourcePathLiteral = ASTBuilderUtil.createEmptyArrayLiteral(
pos, (BArrayType) symTable.stringArrayType);
resourcePathLiteral.exprs = pathLiterals;

args.add(methodNameLiteral);
args.add(resourcePathLiteral);
}

BLangInvocation invocationExpr =
ASTBuilderUtil.createInvocationExprForMethod(pos, invocationSymbol, args, symResolver);
BLangSimpleVariableDef result = ASTBuilderUtil.createVariableDef(pos,
ASTBuilderUtil.createVariable(pos, "$temp$http$filter$result", symTable.anyType, invocationExpr, null));
ASTBuilderUtil.createVariable(pos, "$temp$auth$desugar$result",
symTable.anyType, invocationExpr, null));
List<BLangStatement> statements = getFunctionBodyStatementList(functionNode);
statements.add(0, result);

BVarSymbol resultSymbol = new BVarSymbol(0, names.fromIdNode(result.var.name), env.enclPkg.packageID,
result.var.type,
resourceNode.symbol, pos, VIRTUAL);
result.var.type, resourceNode.symbol, pos, VIRTUAL);
resourceNode.symbol.scope.define(resultSymbol.name, resultSymbol);
result.var.symbol = resultSymbol;
}

private BPackageSymbol getHttpPackageSymbol(SymbolEnv env) {
private BPackageSymbol getPackageSymbol(SymbolEnv env, String packageName) {
// This resolves the package symbol when the code have an import relevant to the particular service
for (BLangImportPackage pkg : env.enclPkg.imports) {
if (pkg.symbol.pkgID.orgName.value.equals(ORG_NAME) && pkg.symbol.pkgID.name.value.equals(PACKAGE_NAME)) {
if (pkg.symbol.pkgID.orgName.value.equals(ORG_NAME) && pkg.symbol.pkgID.name.value.equals(packageName)) {
return pkg.symbol;
}
}
// This resolves the package symbol when the code is at a submodule of the module which have the listener
// definition. In that case, there is no any import relevant to the particular service.
while (env.enclEnv != null) {
if (env.enclEnv.scope.owner.pkgID.orgName.value.equals(ORG_NAME) &&
env.enclEnv.scope.owner.pkgID.name.value.equals(packageName)) {
return env.enclEnv.enclPkg.symbol;
}
env = env.enclEnv;
}
return null;
}

private List<BLangStatement> getFunctionBodyStatementList(BLangFunction functionNode) {
List<BLangStatement> statements;
if (functionNode.body.getKind() == NodeKind.EXPR_FUNCTION_BODY) {
BLangExprFunctionBody exprFunctionBody = (BLangExprFunctionBody) functionNode.body;
BLangBlockStmt blockStmt = ASTBuilderUtil.createBlockStmt(functionNode.getPosition());
statements = blockStmt.stmts;
exprFunctionBody.expr = ASTBuilderUtil.createStatementExpression(blockStmt, exprFunctionBody.expr);
} else {
statements = ((BLangBlockFunctionBody) functionNode.body).stmts;
}
return statements;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@
import org.wso2.ballerinalang.compiler.semantics.model.types.BType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BUnionType;
import org.wso2.ballerinalang.compiler.tree.BLangBlockFunctionBody;
import org.wso2.ballerinalang.compiler.tree.BLangExprFunctionBody;
import org.wso2.ballerinalang.compiler.tree.BLangFunction;
import org.wso2.ballerinalang.compiler.tree.BLangResourceFunction;
import org.wso2.ballerinalang.compiler.tree.BLangService;
import org.wso2.ballerinalang.compiler.tree.BLangSimpleVariable;
import org.wso2.ballerinalang.compiler.tree.BLangVariable;
Expand All @@ -45,11 +43,9 @@
import org.wso2.ballerinalang.compiler.tree.expressions.BLangListConstructorExpr;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangLiteral;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangSimpleVarRef;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangStatementExpression;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangTypeConversionExpr;
import org.wso2.ballerinalang.compiler.tree.statements.BLangBlockStmt;
import org.wso2.ballerinalang.compiler.tree.statements.BLangExpressionStmt;
import org.wso2.ballerinalang.compiler.tree.statements.BLangStatement;
import org.wso2.ballerinalang.compiler.util.CompilerContext;
import org.wso2.ballerinalang.compiler.util.Name;
import org.wso2.ballerinalang.compiler.util.Names;
Expand Down Expand Up @@ -303,26 +299,6 @@ private void engageCustomResourceDesugar(BLangFunction functionNode, SymbolEnv e
.createBeginParticipantInvocation(functionNode.pos));
((BLangBlockFunctionBody) functionNode.body).stmts.add(0, stmt);
}
if (httpFiltersDesugar.isHttpPackage(expressionTypes)) {
List<BLangStatement> statements = getFunctionBodyStatementList(functionNode);
httpFiltersDesugar.addFilterStatements((BLangResourceFunction) functionNode, env, statements);
}
}

private List<BLangStatement> getFunctionBodyStatementList(BLangFunction functionNode) {
List<BLangStatement> statements = null;
if (functionNode.body.getKind() == NodeKind.EXPR_FUNCTION_BODY) {
BLangExprFunctionBody exprFunctionBody = (BLangExprFunctionBody) functionNode.body;
BLangBlockStmt blockStmt = ASTBuilderUtil.createBlockStmt(functionNode.getPosition());
statements = blockStmt.stmts;

BLangStatementExpression statementExpression = ASTBuilderUtil.createStatementExpression(
blockStmt,
exprFunctionBody.expr);
exprFunctionBody.expr = statementExpression;
} else {
statements = ((BLangBlockFunctionBody) functionNode.body).stmts;
}
return statements;
httpFiltersDesugar.desugarFunction(functionNode, env, expressionTypes);
}
}

0 comments on commit 5f7106e

Please sign in to comment.