Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sql hint extract when sql contains dbeaver hint comment #32331

Merged
merged 7 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,7 @@ public enum SQLHintTokenType {
/**
* SQL start hint token.
*/
SQL_START_HINT_TOKEN("/* SHARDINGSPHERE_HINT:", "/* ShardingSphere hint:"),

/**
* SQL hint token.
*/
SQL_HINT_TOKEN("shardingsphere_hint:", "shardingsphere hint:");
SQL_START_HINT_TOKEN("SHARDINGSPHERE_HINT", "ShardingSphere hint");

private final String key;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.infra.hint;

import com.cedarsoftware.util.CaseInsensitiveMap;
import com.google.common.base.Splitter;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
Expand All @@ -26,16 +27,18 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Properties;

/**
* SQL hint utility class.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class SQLHintUtils {

private static final String SQL_COMMENT_PREFIX = "/*";

private static final String SQL_COMMENT_SUFFIX = "*/";

private static final String SQL_HINT_SPLIT = ",";
Expand All @@ -53,72 +56,65 @@ public final class SQLHintUtils {
* @return hint value context
*/
public static HintValueContext extractHint(final String sql) {
if (!startWithHint(sql)) {
if (!containsSQLHint(sql)) {
return new HintValueContext();
}
HintValueContext result = new HintValueContext();
String hintText = sql.substring(0, sql.indexOf(SQL_COMMENT_SUFFIX) + 2);
Properties hintProps = getSQLHintProps(hintText);
if (containsPropertyKey(hintProps, SQLHintPropertiesKey.DATASOURCE_NAME_KEY)) {
result.setDataSourceName(getProperty(hintProps, SQLHintPropertiesKey.DATASOURCE_NAME_KEY));
int hintKeyValueBeginIndex = getHintKeyValueBeginIndex(sql);
String hintKeyValueText = sql.substring(hintKeyValueBeginIndex, sql.indexOf(SQL_COMMENT_SUFFIX, hintKeyValueBeginIndex));
Map<String, String> hintKeyValues = getSQLHintKeyValues(hintKeyValueText);
if (containsHintKey(hintKeyValues, SQLHintPropertiesKey.DATASOURCE_NAME_KEY)) {
result.setDataSourceName(getHintValue(hintKeyValues, SQLHintPropertiesKey.DATASOURCE_NAME_KEY));
}
if (containsPropertyKey(hintProps, SQLHintPropertiesKey.WRITE_ROUTE_ONLY_KEY)) {
result.setWriteRouteOnly(Boolean.parseBoolean(getProperty(hintProps, SQLHintPropertiesKey.WRITE_ROUTE_ONLY_KEY)));
if (containsHintKey(hintKeyValues, SQLHintPropertiesKey.WRITE_ROUTE_ONLY_KEY)) {
result.setWriteRouteOnly(Boolean.parseBoolean(getHintValue(hintKeyValues, SQLHintPropertiesKey.WRITE_ROUTE_ONLY_KEY)));
}
if (containsPropertyKey(hintProps, SQLHintPropertiesKey.SKIP_SQL_REWRITE_KEY)) {
result.setSkipSQLRewrite(Boolean.parseBoolean(getProperty(hintProps, SQLHintPropertiesKey.SKIP_SQL_REWRITE_KEY)));
if (containsHintKey(hintKeyValues, SQLHintPropertiesKey.SKIP_SQL_REWRITE_KEY)) {
result.setSkipSQLRewrite(Boolean.parseBoolean(getHintValue(hintKeyValues, SQLHintPropertiesKey.SKIP_SQL_REWRITE_KEY)));
}
if (containsPropertyKey(hintProps, SQLHintPropertiesKey.DISABLE_AUDIT_NAMES_KEY)) {
String property = getProperty(hintProps, SQLHintPropertiesKey.DISABLE_AUDIT_NAMES_KEY);
if (containsHintKey(hintKeyValues, SQLHintPropertiesKey.DISABLE_AUDIT_NAMES_KEY)) {
String property = getHintValue(hintKeyValues, SQLHintPropertiesKey.DISABLE_AUDIT_NAMES_KEY);
result.getDisableAuditNames().addAll(getSplitterSQLHintValue(property));
}
if (containsPropertyKey(hintProps, SQLHintPropertiesKey.SHADOW_KEY)) {
result.setShadow(Boolean.parseBoolean(getProperty(hintProps, SQLHintPropertiesKey.SHADOW_KEY)));
if (containsHintKey(hintKeyValues, SQLHintPropertiesKey.SHADOW_KEY)) {
result.setShadow(Boolean.parseBoolean(getHintValue(hintKeyValues, SQLHintPropertiesKey.SHADOW_KEY)));
}
for (Entry<Object, Object> entry : hintProps.entrySet()) {
Comparable<?> value = entry.getValue() instanceof Comparable ? (Comparable<?>) entry.getValue() : Objects.toString(entry.getValue());
if (containsPropertyKey(Objects.toString(entry.getKey()), SQLHintPropertiesKey.SHARDING_DATABASE_VALUE_KEY)) {
result.getShardingDatabaseValues().put(Objects.toString(entry.getKey()).toUpperCase(), value);
for (Entry<String, String> entry : hintKeyValues.entrySet()) {
Object value = convert(entry.getValue());
Comparable<?> comparable = value instanceof Comparable ? (Comparable<?>) value : Objects.toString(value);
if (containsHintKey(Objects.toString(entry.getKey()), SQLHintPropertiesKey.SHARDING_DATABASE_VALUE_KEY)) {
result.getShardingDatabaseValues().put(Objects.toString(entry.getKey()).toUpperCase(), comparable);
}
if (containsPropertyKey(Objects.toString(entry.getKey()), SQLHintPropertiesKey.SHARDING_TABLE_VALUE_KEY)) {
result.getShardingTableValues().put(Objects.toString(entry.getKey()).toUpperCase(), value);
if (containsHintKey(Objects.toString(entry.getKey()), SQLHintPropertiesKey.SHARDING_TABLE_VALUE_KEY)) {
result.getShardingTableValues().put(Objects.toString(entry.getKey()).toUpperCase(), comparable);
}
}
return result;
}

private static boolean startWithHint(final String sql) {
return null != sql && (sql.startsWith(SQLHintTokenType.SQL_START_HINT_TOKEN.getKey()) || sql.startsWith(SQLHintTokenType.SQL_START_HINT_TOKEN.getAlias()));
private static int getHintKeyValueBeginIndex(final String sql) {
int tokenBeginIndex = sql.contains(SQLHintTokenType.SQL_START_HINT_TOKEN.getKey()) ? sql.indexOf(SQLHintTokenType.SQL_START_HINT_TOKEN.getKey())
: sql.indexOf(SQLHintTokenType.SQL_START_HINT_TOKEN.getAlias());
return sql.indexOf(":", tokenBeginIndex) + 1;
}

private static Properties getSQLHintProps(final String comment) {
Properties result = new Properties();
int startIndex = getStartIndex(comment);
if (startIndex < 0) {
return result;
}
int endIndex = comment.endsWith(SQL_COMMENT_SUFFIX) ? comment.indexOf(SQL_COMMENT_SUFFIX) : comment.length();
Collection<String> sqlHints = Splitter.on(SQL_HINT_SPLIT).trimResults().splitToList(comment.substring(startIndex, endIndex).trim());
private static boolean containsSQLHint(final String sql) {
return null != sql && (sql.contains(SQLHintTokenType.SQL_START_HINT_TOKEN.getKey())
|| sql.contains(SQLHintTokenType.SQL_START_HINT_TOKEN.getAlias())) && sql.contains(SQL_COMMENT_PREFIX) && sql.contains(SQL_COMMENT_SUFFIX);
}

private static Map<String, String> getSQLHintKeyValues(final String hintKeyValueText) {
Collection<String> sqlHints = Splitter.on(SQL_HINT_SPLIT).trimResults().splitToList(hintKeyValueText.trim());
Map<String, String> result = new CaseInsensitiveMap<>(sqlHints.size(), 1F);
for (String each : sqlHints) {
List<String> hintValues = Splitter.on(SQL_HINT_VALUE_SPLIT).limit(SQL_HINT_VALUE_SIZE).trimResults().splitToList(each);
List<String> hintValues = Splitter.on(SQL_HINT_VALUE_SPLIT).trimResults().splitToList(each);
if (SQL_HINT_VALUE_SIZE == hintValues.size()) {
result.put(hintValues.get(0), convert(hintValues.get(1)));
result.put(hintValues.get(0), hintValues.get(1));
}
}
return result;
}

private static int getStartIndex(final String comment) {
String lowerCaseComment = comment.toLowerCase();
int result = lowerCaseComment.startsWith(SQLHintTokenType.SQL_START_HINT_TOKEN.getAlias().toLowerCase())
? lowerCaseComment.indexOf(SQLHintTokenType.SQL_HINT_TOKEN.getAlias())
: lowerCaseComment.indexOf(SQLHintTokenType.SQL_HINT_TOKEN.getKey());
if (result >= 0) {
return result + SQLHintTokenType.SQL_HINT_TOKEN.getKey().length();
}
return result;
}

private static Object convert(final String value) {
try {
return new BigInteger(value);
Expand All @@ -127,17 +123,17 @@ private static Object convert(final String value) {
}
}

private static boolean containsPropertyKey(final Properties hintProps, final SQLHintPropertiesKey sqlHintPropsKey) {
return hintProps.containsKey(sqlHintPropsKey.getKey()) || hintProps.containsKey(sqlHintPropsKey.getAlias());
private static boolean containsHintKey(final Map<String, String> hintKeyValues, final SQLHintPropertiesKey sqlHintPropsKey) {
return hintKeyValues.containsKey(sqlHintPropsKey.getKey()) || hintKeyValues.containsKey(sqlHintPropsKey.getAlias());
}

private static boolean containsPropertyKey(final String hintPropKey, final SQLHintPropertiesKey sqlHintPropsKey) {
private static boolean containsHintKey(final String hintPropKey, final SQLHintPropertiesKey sqlHintPropsKey) {
return hintPropKey.contains(sqlHintPropsKey.getKey()) || hintPropKey.contains(sqlHintPropsKey.getAlias());
}

private static String getProperty(final Properties hintProps, final SQLHintPropertiesKey sqlHintPropsKey) {
String result = hintProps.getProperty(sqlHintPropsKey.getKey());
return null == result ? hintProps.getProperty(sqlHintPropsKey.getAlias()) : result;
private static String getHintValue(final Map<String, String> hintKeyValues, final SQLHintPropertiesKey sqlHintPropsKey) {
String result = hintKeyValues.get(sqlHintPropsKey.getKey());
return null == result ? hintKeyValues.get(sqlHintPropsKey.getAlias()) : result;
}

private static Collection<String> getSplitterSQLHintValue(final String property) {
Expand All @@ -151,6 +147,13 @@ private static Collection<String> getSplitterSQLHintValue(final String property)
* @return SQL after remove hint
*/
public static String removeHint(final String sql) {
return startWithHint(sql) ? sql.substring(sql.indexOf(SQL_COMMENT_SUFFIX) + 2).trim() : sql;
if (containsSQLHint(sql)) {
int hintKeyValueBeginIndex = getHintKeyValueBeginIndex(sql);
int sqlHintBeginIndex = sql.substring(0, hintKeyValueBeginIndex).lastIndexOf(SQL_COMMENT_PREFIX, hintKeyValueBeginIndex);
int sqlHintEndIndex = sql.indexOf(SQL_COMMENT_SUFFIX, hintKeyValueBeginIndex) + SQL_COMMENT_SUFFIX.length();
String removedHintSQL = sql.substring(0, sqlHintBeginIndex) + sql.substring(sqlHintEndIndex);
return removedHintSQL.trim();
}
return sql;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,23 @@ void assertFindHintDataSourceNameAliasExist() {
assertTrue(actual.findHintDataSourceName().isPresent());
assertThat(actual.findHintDataSourceName().get(), is("ds_1"));
}

@Test
void assertFindHintDataSourceNameWithDBeaverHint() {
HintValueContext actual = SQLHintUtils.extractHint("/* ApplicationName=DBeaver 24.1.0 - SQLEditor <Script-84.sql> */ /* SHARDINGSPHERE_HINT: DATA_SOURCE_NAME=ds_1*/ SELECT * FROM t_order");
assertTrue(actual.findHintDataSourceName().isPresent());
assertThat(actual.findHintDataSourceName().get(), is("ds_1"));
}

@Test
void assertRemoveHint() {
String actual = SQLHintUtils.removeHint("/* SHARDINGSPHERE_HINT: DATA_SOURCE_NAME=ds_1*/ SELECT * FROM t_order");
assertThat(actual, is("SELECT * FROM t_order"));
}

@Test
void assertRemoveHintWithDBeaverHint() {
String actual = SQLHintUtils.removeHint("/* ApplicationName=DBeaver 24.1.0 - SQLEditor <Script-84.sql> */ /* SHARDINGSPHERE_HINT: DATA_SOURCE_NAME=ds_1*/ SELECT * FROM t_order");
assertThat(actual, is("/* ApplicationName=DBeaver 24.1.0 - SQLEditor <Script-84.sql> */ SELECT * FROM t_order"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DatabaseConnectionManager;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.ExecutorStatementManager;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;

/**
* DistSQL connection context.
Expand All @@ -31,7 +31,7 @@
@Getter
public final class DistSQLConnectionContext {

private final ConnectionContext connectionContext;
private final QueryContext queryContext;

private final int connectionSize;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public final class DistSQLQueryBackendHandler implements DistSQLBackendHandler {
private MergedResult mergedResult;

public DistSQLQueryBackendHandler(final DistSQLStatement sqlStatement, final ConnectionSession connectionSession) {
DistSQLConnectionContext distsqlConnectionContext = new DistSQLConnectionContext(connectionSession.getConnectionContext(),
DistSQLConnectionContext distsqlConnectionContext = new DistSQLConnectionContext(connectionSession.getQueryContext(),
connectionSession.getDatabaseConnectionManager().getConnectionSize(), connectionSession.getProtocolType(),
connectionSession.getDatabaseConnectionManager(), connectionSession.getStatementManager());
engine = new DistSQLQueryExecuteEngine(sqlStatement, connectionSession.getUsedDatabaseName(), ProxyContext.getInstance().getContextManager(), distsqlConnectionContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.JDBCDriverType;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.StatementOption;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.hint.SQLHintUtils;
import org.apache.shardingsphere.infra.merge.result.impl.local.LocalDataQueryResultRow;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
Expand Down Expand Up @@ -89,11 +88,12 @@ public Collection<String> getColumnNames(final PreviewStatement sqlStatement) {
@Override
public Collection<LocalDataQueryResultRow> getRows(final PreviewStatement sqlStatement, final ContextManager contextManager) throws SQLException {
ShardingSphereMetaData metaData = contextManager.getMetaDataContexts().getMetaData();
String toBePreviewedSQL = SQLHintUtils.removeHint(sqlStatement.getSql());
HintValueContext hintValueContext = SQLHintUtils.extractHint(sqlStatement.getSql());
String toBePreviewedSQL = sqlStatement.getSql();
SQLStatement toBePreviewedStatement = metaData.getGlobalRuleMetaData().getSingleRule(SQLParserRule.class).getSQLParserEngine(database.getProtocolType()).parse(toBePreviewedSQL, false);
HintValueContext hintValueContext = connectionContext.getQueryContext().getHintValueContext();
SQLStatementContext toBePreviewedStatementContext = new SQLBindEngine(metaData, database.getName(), hintValueContext).bind(toBePreviewedStatement, Collections.emptyList());
QueryContext queryContext = new QueryContext(toBePreviewedStatementContext, toBePreviewedSQL, Collections.emptyList(), hintValueContext, connectionContext.getConnectionContext(), metaData);
QueryContext queryContext =
new QueryContext(toBePreviewedStatementContext, toBePreviewedSQL, Collections.emptyList(), hintValueContext, connectionContext.getQueryContext().getConnectionContext(), metaData);
if (toBePreviewedStatementContext instanceof CursorAvailable && toBePreviewedStatementContext instanceof CursorAware) {
setUpCursorDefinition(toBePreviewedStatementContext);
}
Expand All @@ -112,20 +112,21 @@ private String getSchemaName(final SQLStatementContext sqlStatementContext, fina

private Collection<ExecutionUnit> getExecutionUnits(final ContextManager contextManager, final String schemaName, final ShardingSphereMetaData metaData,
final QueryContext queryContext) {
JDBCExecutor jdbcExecutor = new JDBCExecutor(BackendExecutorContext.getInstance().getExecutorEngine(), connectionContext.getConnectionContext());
JDBCExecutor jdbcExecutor = new JDBCExecutor(BackendExecutorContext.getInstance().getExecutorEngine(), connectionContext.getQueryContext().getConnectionContext());
SQLFederationEngine federationEngine = new SQLFederationEngine(database.getName(), schemaName, metaData, contextManager.getMetaDataContexts().getStatistics(), jdbcExecutor);
if (federationEngine.decide(queryContext, metaData.getGlobalRuleMetaData())) {
return getFederationExecutionUnits(queryContext, metaData, federationEngine);
}
return new KernelProcessor().generateExecutionContext(queryContext, metaData.getGlobalRuleMetaData(), metaData.getProps(), connectionContext.getConnectionContext()).getExecutionUnits();
return new KernelProcessor().generateExecutionContext(queryContext, metaData.getGlobalRuleMetaData(), metaData.getProps(), connectionContext.getQueryContext().getConnectionContext())
.getExecutionUnits();
}

private void setUpCursorDefinition(final SQLStatementContext toBePreviewedStatementContext) {
if (!((CursorAvailable) toBePreviewedStatementContext).getCursorName().isPresent()) {
return;
}
String cursorName = ((CursorAvailable) toBePreviewedStatementContext).getCursorName().get().getIdentifier().getValue().toLowerCase();
CursorStatementContext cursorStatementContext = connectionContext.getConnectionContext().getCursorContext().getCursorStatementContexts().get(cursorName);
CursorStatementContext cursorStatementContext = connectionContext.getQueryContext().getConnectionContext().getCursorContext().getCursorStatementContexts().get(cursorName);
Preconditions.checkNotNull(cursorStatementContext, "Cursor %s does not exist.", cursorName);
((CursorAware) toBePreviewedStatementContext).setCursorStatementContext(cursorStatementContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.config.props.temporary.TemporaryConfigurationProperties;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.exception.kernel.syntax.UnsupportedVariableException;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DatabaseConnectionManager;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.ExecutorStatementManager;
import org.apache.shardingsphere.infra.merge.result.impl.local.LocalDataQueryResultRow;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.infra.exception.kernel.syntax.UnsupportedVariableException;
import org.apache.shardingsphere.test.util.PropertiesBuilder;
import org.apache.shardingsphere.test.util.PropertiesBuilder.Property;
import org.junit.jupiter.api.Test;
Expand All @@ -48,7 +48,7 @@ class ShowDistVariableExecutorTest {
@Test
void assertShowCachedConnections() {
ShowDistVariableExecutor executor = new ShowDistVariableExecutor();
executor.setConnectionContext(new DistSQLConnectionContext(mock(ConnectionContext.class), 1,
executor.setConnectionContext(new DistSQLConnectionContext(mock(QueryContext.class), 1,
mock(DatabaseType.class), mock(DatabaseConnectionManager.class), mock(ExecutorStatementManager.class)));
Collection<LocalDataQueryResultRow> actual = executor.getRows(new ShowDistVariableStatement("CACHED_CONNECTIONS"), contextManager);
assertThat(actual.size(), is(1));
Expand Down
Loading