Skip to content

Commit

Permalink
PPL integration with AD and ml-commons (#468)
Browse files Browse the repository at this point in the history
Signed-off-by: jackieyanghan <[email protected]>
  • Loading branch information
jackiehanyang authored Mar 9, 2022
1 parent 10a9846 commit eb1ede4
Show file tree
Hide file tree
Showing 36 changed files with 1,177 additions and 9 deletions.
4 changes: 4 additions & 0 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ test.finalizedBy(project.tasks.jacocoTestReport)
jacocoTestCoverageVerification {
violationRules {
rule {
element = 'CLASS'
excludes = [
'org.opensearch.sql.utils.MLCommonsConstants'
]
limit {
counter = 'LINE'
minimum = 1.0
Expand Down
44 changes: 44 additions & 0 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.ImmutablePair;
Expand All @@ -31,11 +37,13 @@
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Map;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Project;
Expand All @@ -58,11 +66,13 @@
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.aggregation.Aggregator;
import org.opensearch.sql.expression.aggregation.NamedAggregator;
import org.opensearch.sql.planner.logical.LogicalAD;
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalDedupe;
import org.opensearch.sql.planner.logical.LogicalEval;
import org.opensearch.sql.planner.logical.LogicalFilter;
import org.opensearch.sql.planner.logical.LogicalLimit;
import org.opensearch.sql.planner.logical.LogicalMLCommons;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalProject;
import org.opensearch.sql.planner.logical.LogicalRareTopN;
Expand Down Expand Up @@ -395,6 +405,40 @@ public LogicalPlan visitValues(Values node, AnalysisContext context) {
return new LogicalValues(valueExprs);
}

/**
* Build {@link LogicalMLCommons} for Kmeans command.
*/
@Override
public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
List<Argument> options = node.getOptions();

TypeEnvironment currentEnv = context.peek();
currentEnv.define(new Symbol(Namespace.FIELD_NAME, "ClusterID"), ExprCoreType.INTEGER);

return new LogicalMLCommons(child, "kmeans", options);
}

/**
* Build {@link LogicalAD} for AD command.
*/
@Override
public LogicalPlan visitAD(AD node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
java.util.Map<String, Literal> options = node.getArguments();

TypeEnvironment currentEnv = context.peek();

currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_SCORE), ExprCoreType.DOUBLE);
if (Objects.isNull(node.getArguments().get(TIME_FIELD).getValue())) {
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALOUS), ExprCoreType.BOOLEAN);
} else {
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALY_GRADE), ExprCoreType.DOUBLE);
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_TIMESTAMP), ExprCoreType.TIMESTAMP);
}
return new LogicalAD(child, options);
}

/**
* The first argument is always "asc", others are optional.
* Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC.
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
import org.opensearch.sql.ast.expression.When;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Project;
Expand Down Expand Up @@ -239,4 +241,12 @@ public T visitLimit(Limit node, C context) {
public T visitSpan(Span node, C context) {
return visitChildren(node, context);
}

public T visitKmeans(Kmeans node, C context) {
return visitChildren(node, context);
}

public T visitAD(AD node, C context) {
return visitChildren(node, context);
}
}
8 changes: 8 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ public static Literal longLiteral(Long value) {
return literal(value, DataType.LONG);
}

public static Literal shortLiteral(Short value) {
return literal(value, DataType.SHORT);
}

public static Literal floatLiteral(Float value) {
return literal(value, DataType.FLOAT);
}

public static Literal dateLiteral(String value) {
return literal(value, DataType.DATE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public enum DataType {

INTEGER(ExprCoreType.INTEGER),
LONG(ExprCoreType.LONG),
SHORT(ExprCoreType.SHORT),
FLOAT(ExprCoreType.FLOAT),
DOUBLE(ExprCoreType.DOUBLE),
STRING(ExprCoreType.STRING),
BOOLEAN(ExprCoreType.BOOLEAN),
Expand Down
47 changes: 47 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/AD.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Literal;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = true)
@RequiredArgsConstructor
@AllArgsConstructor
public class AD extends UnresolvedPlan {
private UnresolvedPlan child;

private final Map<String, Literal> arguments;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitAD(this, context);
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of(this.child);
}
}
46 changes: 46 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Argument;

@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = true)
@RequiredArgsConstructor
@AllArgsConstructor
public class Kmeans extends UnresolvedPlan {
private UnresolvedPlan child;

private final List<Argument> options;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitKmeans(this, context);
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of(this.child);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Literal;

/*
* AD logical plan.
*/
@Getter
@ToString
@EqualsAndHashCode(callSuper = true)
public class LogicalAD extends LogicalPlan {
private final Map<String, Literal> arguments;

/**
* Constructor of LogicalAD.
* @param child child logical plan
* @param arguments arguments of the algorithm
*/
public LogicalAD(LogicalPlan child, Map<String, Literal> arguments) {
super(Collections.singletonList(child));
this.arguments = arguments;
}

@Override
public <R, C> R accept(LogicalPlanNodeVisitor<R, C> visitor, C context) {
return visitor.visitAD(this, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Argument;

/**
* ml-commons logical plan.
*/
@Getter
@ToString
@EqualsAndHashCode(callSuper = true)
public class LogicalMLCommons extends LogicalPlan {
private final String algorithm;

private final List<Argument> arguments;

/**
* Constructor of LogicalMLCommons.
* @param child child logical plan
* @param algorithm algorithm name
* @param arguments arguments of the algorithm
*/
public LogicalMLCommons(LogicalPlan child, String algorithm,
List<Argument> arguments) {
super(Collections.singletonList(child));
this.algorithm = algorithm;
this.arguments = arguments;
}

@Override
public <R, C> R accept(LogicalPlanNodeVisitor<R, C> visitor, C context) {
return visitor.visitMLCommons(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,12 @@ public R visitRareTopN(LogicalRareTopN plan, C context) {
public R visitLimit(LogicalLimit plan, C context) {
return visitNode(plan, context);
}

public R visitMLCommons(LogicalMLCommons plan, C context) {
return visitNode(plan, context);
}

public R visitAD(LogicalAD plan, C context) {
return visitNode(plan, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,13 @@ public R visitLimit(LimitOperator node, C context) {
return visitNode(node, context);
}

public R visitMLCommons(PhysicalPlan node, C context) {
return visitNode(node, context);
}

public R visitAD(PhysicalPlan node, C context) {
return visitNode(node, context);
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.opensearch.sql.utils;

public class MLCommonsConstants {

public static final String SHINGLE_SIZE = "shingle_size";
public static final String TIME_DECAY = "time_decay";
public static final String TIME_FIELD = "time_field";

public static final String RCF_SCORE = "score";
public static final String RCF_ANOMALOUS = "anomalous";
public static final String RCF_ANOMALY_GRADE = "anomaly_grade";
public static final String RCF_TIMESTAMP = "timestamp";
}
Loading

0 comments on commit eb1ede4

Please sign in to comment.