Skip to content

Commit

Permalink
ORM/HR+Panache: Remove FETCH from count queries
Browse files Browse the repository at this point in the history
We do this by turning to the ORM HQLParser for non-trivial queries, but
only for them, because the parser is much more expensive than simple
string manipulation, so we keep the fast/easy logic.

Fixes quarkusio#26308
  • Loading branch information
FroMage committed Feb 13, 2024
1 parent 490992e commit 7093c73
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ private String countQuery(String selectQuery) {
return countQuery;
}

return PanacheJpaUtil.getCountQuery(selectQuery);
return PanacheJpaUtil.getFastCountQuery(selectQuery);
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ private String countQuery(String selectQuery) {
if (countQuery != null) {
return countQuery;
}
return PanacheJpaUtil.getCountQuery(selectQuery);
return PanacheJpaUtil.getFastCountQuery(selectQuery);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
Expand Down
8 changes: 8 additions & 0 deletions extensions/panache/panache-hibernate-common/runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
<groupId>io.quarkus</groupId>
<artifactId>quarkus-panache-common</artifactId>
</dependency>
<dependency>
<groupId>org.hibernate.orm</groupId>
<artifactId>hibernate-core</artifactId>
</dependency>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>antlr4-runtime</artifactId>
</dependency>
<dependency>
<groupId>jakarta.persistence</groupId>
<artifactId>jakarta.persistence-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package io.quarkus.panache.hibernate.common.runtime;

import org.antlr.v4.runtime.tree.TerminalNode;
import org.hibernate.grammars.hql.HqlParser.JoinContext;
import org.hibernate.grammars.hql.HqlParser.QueryContext;
import org.hibernate.grammars.hql.HqlParser.QueryOrderContext;
import org.hibernate.grammars.hql.HqlParser.SelectClauseContext;
import org.hibernate.grammars.hql.HqlParser.SimpleQueryGroupContext;
import org.hibernate.grammars.hql.HqlParserBaseVisitor;

public class CountParserVisitor extends HqlParserBaseVisitor<String> {

private int inSimpleQueryGroup;
private StringBuilder sb = new StringBuilder();

@Override
public String visitSimpleQueryGroup(SimpleQueryGroupContext ctx) {
inSimpleQueryGroup++;
try {
return super.visitSimpleQueryGroup(ctx);
} finally {
inSimpleQueryGroup--;
}
}

@Override
public String visitQuery(QueryContext ctx) {
super.visitQuery(ctx);
if (inSimpleQueryGroup == 1 && ctx.selectClause() == null) {
// insert a count because there's no select
sb.append(" select count( * )");
}
return null;
}

@Override
public String visitSelectClause(SelectClauseContext ctx) {
if (ctx.SELECT() != null) {
ctx.SELECT().accept(this);
}
if (ctx.DISTINCT() != null) {
sb.append(" count(");
ctx.DISTINCT().accept(this);
if (ctx.selectionList().children.size() != 1) {
// FIXME: error message should include query
throw new RuntimeException("Cannot count on more than one column");
}
ctx.selectionList().children.get(0).accept(this);
sb.append(" )");
} else {
sb.append(" count( * )");
}
return null;
}

@Override
public String visitJoin(JoinContext ctx) {
if (inSimpleQueryGroup == 1 && ctx.FETCH() != null) {
// ignore fetch joins for main query
return null;
}
return super.visitJoin(ctx);
}

@Override
public String visitQueryOrder(QueryOrderContext ctx) {
if (inSimpleQueryGroup == 1) {
// ignore order/limit/offset for main query
return null;
}
return super.visitQueryOrder(ctx);
}

@Override
public String visitTerminal(TerminalNode node) {
append(node.getText());
return null;
}

@Override
protected String defaultResult() {
return null;
}

@Override
protected String aggregateResult(String aggregate, String nextResult) {
if (nextResult != null) {
append(nextResult);
}
return null;
}

private void append(String nextResult) {
// don't add space at start, or around dots
if (!sb.isEmpty() && sb.charAt(sb.length() - 1) != '.' && !nextResult.equals(".")) {
sb.append(" ");
}
sb.append(nextResult);
}

public String result() {
return sb.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.hibernate.grammars.hql.HqlLexer;
import org.hibernate.grammars.hql.HqlParser;
import org.hibernate.grammars.hql.HqlParser.SelectStatementContext;

import io.quarkus.panache.common.Sort;
import io.quarkus.panache.common.exception.PanacheQueryException;

Expand All @@ -17,10 +23,27 @@ public class PanacheJpaUtil {
static final Pattern FROM_PATTERN = Pattern.compile("^\\s*FROM\\s+.*",
Pattern.CASE_INSENSITIVE | Pattern.DOTALL);

public static String getCountQuery(String query) {
// match a FETCH
static final Pattern FETCH_PATTERN = Pattern.compile(".*\\s+FETCH\\s+.*",
Pattern.CASE_INSENSITIVE | Pattern.DOTALL);

// match a lone SELECT
static final Pattern LONE_SELECT_PATTERN = Pattern.compile(".*SELECT\\s+.*",
Pattern.CASE_INSENSITIVE | Pattern.DOTALL);

/**
* This turns an HQL (already expanded from Panache-QL) query into a count query, using text manipulation
* if we can, because it's faster, or fall back to using the ORM HQL parser in {@link #getCountQueryUsingParser(String)}
*/
public static String getFastCountQuery(String query) {
// try to generate a good count query from the existing query
Matcher selectMatcher = SELECT_PATTERN.matcher(query);
String countQuery;
// there are no fast ways to get rid of fetches
if (FETCH_PATTERN.matcher(query).matches()) {
return getCountQueryUsingParser(query);
}
// if it starts with select, we can optimise
Matcher selectMatcher = SELECT_PATTERN.matcher(query);
if (selectMatcher.matches()) {
// this one cannot be null
String firstSelection = selectMatcher.group(1).trim();
Expand All @@ -36,6 +59,9 @@ public static String getCountQuery(String query) {
// it's not distinct, forget the column list
countQuery = "SELECT COUNT(*) " + selectMatcher.group(3);
}
} else if (LONE_SELECT_PATTERN.matcher(query).matches()) {
// a select anywhere else in there might be tricky
return getCountQueryUsingParser(query);
} else if (FROM_PATTERN.matcher(query).matches()) {
countQuery = "SELECT COUNT(*) " + query;
} else {
Expand All @@ -51,6 +77,20 @@ public static String getCountQuery(String query) {
return countQuery;
}

/**
* This turns an HQL (already expanded from Panache-QL) query into a count query, using the
* ORM HQL parser. Slow version, see {@link #getFastCountQuery(String)} for the fast version.
*/
public static String getCountQueryUsingParser(String query) {
HqlLexer lexer = new HqlLexer(CharStreams.fromString(query));
CommonTokenStream tokens = new CommonTokenStream(lexer);
HqlParser parser = new HqlParser(tokens);
SelectStatementContext statement = parser.selectStatement();
CountParserVisitor visitor = new CountParserVisitor();
statement.accept(visitor);
return visitor.result();
}

public static String getEntityName(Class<?> entityClass) {
// FIXME: not true?
return entityClass.getName();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package io.quarkus.panache.hibernate.common.runtime;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class CountTest {
@Test
public void testParser() {
// one column, order/limit/offset
assertCountQueryUsingParser("select count( * ) from bar", "select foo from bar order by foo, bar ASC limit 2 offset 3");
// two columns
assertCountQueryUsingParser("select count( * ) from bar", "select foo,gee from bar");
// one column distinct
assertCountQueryUsingParser("select count( distinct foo ) from bar", "select distinct foo from bar");
// two columns distinct
Assertions.assertThrows(RuntimeException.class,
() -> assertCountQueryUsingParser("XX", "select distinct foo,gee from bar"));
// nested order by not touched
assertCountQueryUsingParser("select count( * ) from ( from entity order by id )",
"select foo from (from entity order by id) order by foo, bar ASC");
// what happens to literals?
assertCountQueryUsingParser("select count( * ) from bar where some = 2 and other = '23'",
"select foo from bar where some = 2 and other = '23'");
// fetches are gone
assertCountQueryUsingParser("select count( * ) from bar b", "select foo from bar b left join fetch b.things");
// non-fetches remain
assertCountQueryUsingParser("select count( * ) from bar b left join b.things",
"select foo from bar b left join b.things");

// inverted select
assertCountQueryUsingParser("from bar select count( * )", "from bar select foo");
// from without select
assertCountQueryUsingParser("from bar select count( * )", "from bar");
}

@Test
public void testFastVersion() {
// one column, order/limit/offset
assertFastCountQuery("SELECT COUNT(*) from bar", "select foo from bar order by foo, bar ASC limit 2 offset 3");
// two columns
assertFastCountQuery("SELECT COUNT(*) from bar", "select foo,gee from bar");
// one column distinct
assertFastCountQuery("SELECT COUNT(distinct foo) from bar", "select distinct foo from bar");
// two columns distinct
Assertions.assertThrows(RuntimeException.class, () -> assertFastCountQuery("XX", "select distinct foo,gee from bar"));
// nested order by not touched
assertFastCountQuery("SELECT COUNT(*) from (from entity order by id)",
"select foo from (from entity order by id) order by foo, bar ASC");
// what happens to literals?
assertFastCountQuery("SELECT COUNT(*) from bar where some = 2 and other = '23'",
"select foo from bar where some = 2 and other = '23'");
// fetches are gone
assertFastCountQuery("select count( * ) from bar b", "select foo from bar b left join fetch b.things");
// non-fetches remain
assertFastCountQuery("SELECT COUNT(*) from bar b left join b.things", "select foo from bar b left join b.things");

// inverted select
assertFastCountQuery("from bar select count( * )", "from bar select foo");
// from without select
assertFastCountQuery("SELECT COUNT(*) from bar", "from bar");
}

private void assertCountQueryUsingParser(String expected, String selectQuery) {
String countQuery = PanacheJpaUtil.getCountQueryUsingParser(selectQuery);
Assertions.assertEquals(expected, countQuery);
}

private void assertFastCountQuery(String expected, String selectQuery) {
String countQuery = PanacheJpaUtil.getFastCountQuery(selectQuery);
Assertions.assertEquals(expected, countQuery);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1804,4 +1804,24 @@ public String testEnhancement27184DeleteDetached() {

return "OK";
}

@GET
@Path("26308")
@Transactional
public String testBug26308() {
testBug26308Query("from Person2 p left join fetch p.address");
testBug26308Query("from Person2 p left join p.address");
testBug26308Query("select p from Person2 p left join fetch p.address");
testBug26308Query("select p from Person2 p left join p.address");
testBug26308Query("from Person2 p left join fetch p.address select p");
testBug26308Query("from Person2 p left join p.address select p");

return "OK";
}

private void testBug26308Query(String hql) {
PanacheQuery<Person> query = Person.find(hql);
Assertions.assertEquals(0, query.list().size());
Assertions.assertEquals(0, query.count());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,9 @@ Person getBug7102(Long id) {
void testEnhancement27184DeleteDetached() {
RestAssured.when().get("/test/testEnhancement27184DeleteDetached").then().body(is("OK"));
}

@Test
public void testBug26308() {
RestAssured.when().get("/test/26308").then().body(is("OK"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2054,4 +2054,30 @@ public Uni<String> testSortByNullPrecedence() {
return Person.deleteAll();
}).map(v -> "OK");
}

@GET
@Path("26308")
@WithTransaction
public Uni<String> testBug26308() {
return testBug26308Query("from Person2 p left join fetch p.address")
.flatMap(p -> testBug26308Query("from Person2 p left join p.address"))
.flatMap(p -> testBug26308Query("select p from Person2 p left join fetch p.address"))
.flatMap(p -> testBug26308Query("select p from Person2 p left join p.address"))
.flatMap(p -> testBug26308Query("from Person2 p left join fetch p.address select p"))
.flatMap(p -> testBug26308Query("from Person2 p left join p.address select p"))
.map(p -> "OK");
}

private Uni<Void> testBug26308Query(String hql) {
PanacheQuery<Person> query = Person.find(hql);
return query.list()
.flatMap(list -> {
Assertions.assertEquals(0, list.size());
return query.count();
})
.map(count -> {
Assertions.assertEquals(0, count);
return null;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,8 @@ public void testBeerRepository() {
RestAssured.when().get("/test-repo/beers").then().body(is("OK"));
}

@Test
public void testBug26308() {
RestAssured.when().get("/test/26308").then().body(is("OK"));
}
}

0 comments on commit 7093c73

Please sign in to comment.