diff --git a/extensions/panache/hibernate-orm-panache-common/runtime/src/main/java/io/quarkus/hibernate/orm/panache/common/runtime/CommonPanacheQueryImpl.java b/extensions/panache/hibernate-orm-panache-common/runtime/src/main/java/io/quarkus/hibernate/orm/panache/common/runtime/CommonPanacheQueryImpl.java index 6e4ee88dca82c..f9dd38ee25652 100644 --- a/extensions/panache/hibernate-orm-panache-common/runtime/src/main/java/io/quarkus/hibernate/orm/panache/common/runtime/CommonPanacheQueryImpl.java +++ b/extensions/panache/hibernate-orm-panache-common/runtime/src/main/java/io/quarkus/hibernate/orm/panache/common/runtime/CommonPanacheQueryImpl.java @@ -302,7 +302,7 @@ private String countQuery(String selectQuery) { return countQuery; } - return PanacheJpaUtil.getCountQuery(selectQuery); + return PanacheJpaUtil.getFastCountQuery(selectQuery); } @SuppressWarnings("unchecked") diff --git a/extensions/panache/hibernate-reactive-panache-common/runtime/src/main/java/io/quarkus/hibernate/reactive/panache/common/runtime/CommonPanacheQueryImpl.java b/extensions/panache/hibernate-reactive-panache-common/runtime/src/main/java/io/quarkus/hibernate/reactive/panache/common/runtime/CommonPanacheQueryImpl.java index c1e7516c22765..4173dac1973d7 100644 --- a/extensions/panache/hibernate-reactive-panache-common/runtime/src/main/java/io/quarkus/hibernate/reactive/panache/common/runtime/CommonPanacheQueryImpl.java +++ b/extensions/panache/hibernate-reactive-panache-common/runtime/src/main/java/io/quarkus/hibernate/reactive/panache/common/runtime/CommonPanacheQueryImpl.java @@ -299,7 +299,7 @@ private String countQuery(String selectQuery) { if (countQuery != null) { return countQuery; } - return PanacheJpaUtil.getCountQuery(selectQuery); + return PanacheJpaUtil.getFastCountQuery(selectQuery); } @SuppressWarnings({ "unchecked", "rawtypes" }) diff --git a/extensions/panache/panache-hibernate-common/runtime/pom.xml b/extensions/panache/panache-hibernate-common/runtime/pom.xml index dd0b52b346c16..d2adbad237ab9 100644 --- a/extensions/panache/panache-hibernate-common/runtime/pom.xml +++ b/extensions/panache/panache-hibernate-common/runtime/pom.xml @@ -24,6 +24,14 @@ io.quarkus quarkus-panache-common + + org.hibernate.orm + hibernate-core + + + org.antlr + antlr4-runtime + jakarta.persistence jakarta.persistence-api diff --git a/extensions/panache/panache-hibernate-common/runtime/src/main/java/io/quarkus/panache/hibernate/common/runtime/CountParserVisitor.java b/extensions/panache/panache-hibernate-common/runtime/src/main/java/io/quarkus/panache/hibernate/common/runtime/CountParserVisitor.java new file mode 100644 index 0000000000000..d8eaf36794da5 --- /dev/null +++ b/extensions/panache/panache-hibernate-common/runtime/src/main/java/io/quarkus/panache/hibernate/common/runtime/CountParserVisitor.java @@ -0,0 +1,104 @@ +package io.quarkus.panache.hibernate.common.runtime; + +import org.antlr.v4.runtime.tree.TerminalNode; +import org.hibernate.grammars.hql.HqlParser.JoinContext; +import org.hibernate.grammars.hql.HqlParser.QueryContext; +import org.hibernate.grammars.hql.HqlParser.QueryOrderContext; +import org.hibernate.grammars.hql.HqlParser.SelectClauseContext; +import org.hibernate.grammars.hql.HqlParser.SimpleQueryGroupContext; +import org.hibernate.grammars.hql.HqlParserBaseVisitor; + +public class CountParserVisitor extends HqlParserBaseVisitor { + + private int inSimpleQueryGroup; + private StringBuilder sb = new StringBuilder(); + + @Override + public String visitSimpleQueryGroup(SimpleQueryGroupContext ctx) { + inSimpleQueryGroup++; + try { + return super.visitSimpleQueryGroup(ctx); + } finally { + inSimpleQueryGroup--; + } + } + + @Override + public String visitQuery(QueryContext ctx) { + super.visitQuery(ctx); + if (inSimpleQueryGroup == 1 && ctx.selectClause() == null) { + // insert a count because there's no select + sb.append(" select count( * )"); + } + return null; + } + + @Override + public String visitSelectClause(SelectClauseContext ctx) { + if (ctx.SELECT() != null) { + ctx.SELECT().accept(this); + } + if (ctx.DISTINCT() != null) { + sb.append(" count("); + ctx.DISTINCT().accept(this); + if (ctx.selectionList().children.size() != 1) { + // FIXME: error message should include query + throw new RuntimeException("Cannot count on more than one column"); + } + ctx.selectionList().children.get(0).accept(this); + sb.append(" )"); + } else { + sb.append(" count( * )"); + } + return null; + } + + @Override + public String visitJoin(JoinContext ctx) { + if (inSimpleQueryGroup == 1 && ctx.FETCH() != null) { + // ignore fetch joins for main query + return null; + } + return super.visitJoin(ctx); + } + + @Override + public String visitQueryOrder(QueryOrderContext ctx) { + if (inSimpleQueryGroup == 1) { + // ignore order/limit/offset for main query + return null; + } + return super.visitQueryOrder(ctx); + } + + @Override + public String visitTerminal(TerminalNode node) { + append(node.getText()); + return null; + } + + @Override + protected String defaultResult() { + return null; + } + + @Override + protected String aggregateResult(String aggregate, String nextResult) { + if (nextResult != null) { + append(nextResult); + } + return null; + } + + private void append(String nextResult) { + // don't add space at start, or around dots + if (!sb.isEmpty() && sb.charAt(sb.length() - 1) != '.' && !nextResult.equals(".")) { + sb.append(" "); + } + sb.append(nextResult); + } + + public String result() { + return sb.toString(); + } +} \ No newline at end of file diff --git a/extensions/panache/panache-hibernate-common/runtime/src/main/java/io/quarkus/panache/hibernate/common/runtime/PanacheJpaUtil.java b/extensions/panache/panache-hibernate-common/runtime/src/main/java/io/quarkus/panache/hibernate/common/runtime/PanacheJpaUtil.java index 5c265e0f892d4..922df0707887e 100644 --- a/extensions/panache/panache-hibernate-common/runtime/src/main/java/io/quarkus/panache/hibernate/common/runtime/PanacheJpaUtil.java +++ b/extensions/panache/panache-hibernate-common/runtime/src/main/java/io/quarkus/panache/hibernate/common/runtime/PanacheJpaUtil.java @@ -3,6 +3,12 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.antlr.v4.runtime.CharStreams; +import org.antlr.v4.runtime.CommonTokenStream; +import org.hibernate.grammars.hql.HqlLexer; +import org.hibernate.grammars.hql.HqlParser; +import org.hibernate.grammars.hql.HqlParser.SelectStatementContext; + import io.quarkus.panache.common.Sort; import io.quarkus.panache.common.exception.PanacheQueryException; @@ -17,10 +23,27 @@ public class PanacheJpaUtil { static final Pattern FROM_PATTERN = Pattern.compile("^\\s*FROM\\s+.*", Pattern.CASE_INSENSITIVE | Pattern.DOTALL); - public static String getCountQuery(String query) { + // match a FETCH + static final Pattern FETCH_PATTERN = Pattern.compile(".*\\s+FETCH\\s+.*", + Pattern.CASE_INSENSITIVE | Pattern.DOTALL); + + // match a lone SELECT + static final Pattern LONE_SELECT_PATTERN = Pattern.compile(".*SELECT\\s+.*", + Pattern.CASE_INSENSITIVE | Pattern.DOTALL); + + /** + * This turns an HQL (already expanded from Panache-QL) query into a count query, using text manipulation + * if we can, because it's faster, or fall back to using the ORM HQL parser in {@link #getCountQueryUsingParser(String)} + */ + public static String getFastCountQuery(String query) { // try to generate a good count query from the existing query - Matcher selectMatcher = SELECT_PATTERN.matcher(query); String countQuery; + // there are no fast ways to get rid of fetches + if (FETCH_PATTERN.matcher(query).matches()) { + return getCountQueryUsingParser(query); + } + // if it starts with select, we can optimise + Matcher selectMatcher = SELECT_PATTERN.matcher(query); if (selectMatcher.matches()) { // this one cannot be null String firstSelection = selectMatcher.group(1).trim(); @@ -36,6 +59,9 @@ public static String getCountQuery(String query) { // it's not distinct, forget the column list countQuery = "SELECT COUNT(*) " + selectMatcher.group(3); } + } else if (LONE_SELECT_PATTERN.matcher(query).matches()) { + // a select anywhere else in there might be tricky + return getCountQueryUsingParser(query); } else if (FROM_PATTERN.matcher(query).matches()) { countQuery = "SELECT COUNT(*) " + query; } else { @@ -51,6 +77,20 @@ public static String getCountQuery(String query) { return countQuery; } + /** + * This turns an HQL (already expanded from Panache-QL) query into a count query, using the + * ORM HQL parser. Slow version, see {@link #getFastCountQuery(String)} for the fast version. + */ + public static String getCountQueryUsingParser(String query) { + HqlLexer lexer = new HqlLexer(CharStreams.fromString(query)); + CommonTokenStream tokens = new CommonTokenStream(lexer); + HqlParser parser = new HqlParser(tokens); + SelectStatementContext statement = parser.selectStatement(); + CountParserVisitor visitor = new CountParserVisitor(); + statement.accept(visitor); + return visitor.result(); + } + public static String getEntityName(Class entityClass) { // FIXME: not true? return entityClass.getName(); diff --git a/extensions/panache/panache-hibernate-common/runtime/src/test/java/io/quarkus/panache/hibernate/common/runtime/CountTest.java b/extensions/panache/panache-hibernate-common/runtime/src/test/java/io/quarkus/panache/hibernate/common/runtime/CountTest.java new file mode 100644 index 0000000000000..7db1c1a45fb6e --- /dev/null +++ b/extensions/panache/panache-hibernate-common/runtime/src/test/java/io/quarkus/panache/hibernate/common/runtime/CountTest.java @@ -0,0 +1,72 @@ +package io.quarkus.panache.hibernate.common.runtime; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class CountTest { + @Test + public void testParser() { + // one column, order/limit/offset + assertCountQueryUsingParser("select count( * ) from bar", "select foo from bar order by foo, bar ASC limit 2 offset 3"); + // two columns + assertCountQueryUsingParser("select count( * ) from bar", "select foo,gee from bar"); + // one column distinct + assertCountQueryUsingParser("select count( distinct foo ) from bar", "select distinct foo from bar"); + // two columns distinct + Assertions.assertThrows(RuntimeException.class, + () -> assertCountQueryUsingParser("XX", "select distinct foo,gee from bar")); + // nested order by not touched + assertCountQueryUsingParser("select count( * ) from ( from entity order by id )", + "select foo from (from entity order by id) order by foo, bar ASC"); + // what happens to literals? + assertCountQueryUsingParser("select count( * ) from bar where some = 2 and other = '23'", + "select foo from bar where some = 2 and other = '23'"); + // fetches are gone + assertCountQueryUsingParser("select count( * ) from bar b", "select foo from bar b left join fetch b.things"); + // non-fetches remain + assertCountQueryUsingParser("select count( * ) from bar b left join b.things", + "select foo from bar b left join b.things"); + + // inverted select + assertCountQueryUsingParser("from bar select count( * )", "from bar select foo"); + // from without select + assertCountQueryUsingParser("from bar select count( * )", "from bar"); + } + + @Test + public void testFastVersion() { + // one column, order/limit/offset + assertFastCountQuery("SELECT COUNT(*) from bar", "select foo from bar order by foo, bar ASC limit 2 offset 3"); + // two columns + assertFastCountQuery("SELECT COUNT(*) from bar", "select foo,gee from bar"); + // one column distinct + assertFastCountQuery("SELECT COUNT(distinct foo) from bar", "select distinct foo from bar"); + // two columns distinct + Assertions.assertThrows(RuntimeException.class, () -> assertFastCountQuery("XX", "select distinct foo,gee from bar")); + // nested order by not touched + assertFastCountQuery("SELECT COUNT(*) from (from entity order by id)", + "select foo from (from entity order by id) order by foo, bar ASC"); + // what happens to literals? + assertFastCountQuery("SELECT COUNT(*) from bar where some = 2 and other = '23'", + "select foo from bar where some = 2 and other = '23'"); + // fetches are gone + assertFastCountQuery("select count( * ) from bar b", "select foo from bar b left join fetch b.things"); + // non-fetches remain + assertFastCountQuery("SELECT COUNT(*) from bar b left join b.things", "select foo from bar b left join b.things"); + + // inverted select + assertFastCountQuery("from bar select count( * )", "from bar select foo"); + // from without select + assertFastCountQuery("SELECT COUNT(*) from bar", "from bar"); + } + + private void assertCountQueryUsingParser(String expected, String selectQuery) { + String countQuery = PanacheJpaUtil.getCountQueryUsingParser(selectQuery); + Assertions.assertEquals(expected, countQuery); + } + + private void assertFastCountQuery(String expected, String selectQuery) { + String countQuery = PanacheJpaUtil.getFastCountQuery(selectQuery); + Assertions.assertEquals(expected, countQuery); + } +} diff --git a/integration-tests/hibernate-orm-panache/src/main/java/io/quarkus/it/panache/TestEndpoint.java b/integration-tests/hibernate-orm-panache/src/main/java/io/quarkus/it/panache/TestEndpoint.java index cc367d8d9e4fa..5d0cf2a51e232 100644 --- a/integration-tests/hibernate-orm-panache/src/main/java/io/quarkus/it/panache/TestEndpoint.java +++ b/integration-tests/hibernate-orm-panache/src/main/java/io/quarkus/it/panache/TestEndpoint.java @@ -1804,4 +1804,24 @@ public String testEnhancement27184DeleteDetached() { return "OK"; } + + @GET + @Path("26308") + @Transactional + public String testBug26308() { + testBug26308Query("from Person2 p left join fetch p.address"); + testBug26308Query("from Person2 p left join p.address"); + testBug26308Query("select p from Person2 p left join fetch p.address"); + testBug26308Query("select p from Person2 p left join p.address"); + testBug26308Query("from Person2 p left join fetch p.address select p"); + testBug26308Query("from Person2 p left join p.address select p"); + + return "OK"; + } + + private void testBug26308Query(String hql) { + PanacheQuery query = Person.find(hql); + Assertions.assertEquals(0, query.list().size()); + Assertions.assertEquals(0, query.count()); + } } diff --git a/integration-tests/hibernate-orm-panache/src/test/java/io/quarkus/it/panache/PanacheFunctionalityTest.java b/integration-tests/hibernate-orm-panache/src/test/java/io/quarkus/it/panache/PanacheFunctionalityTest.java index 3625fce017ac3..96ccec9e08403 100644 --- a/integration-tests/hibernate-orm-panache/src/test/java/io/quarkus/it/panache/PanacheFunctionalityTest.java +++ b/integration-tests/hibernate-orm-panache/src/test/java/io/quarkus/it/panache/PanacheFunctionalityTest.java @@ -244,4 +244,9 @@ Person getBug7102(Long id) { void testEnhancement27184DeleteDetached() { RestAssured.when().get("/test/testEnhancement27184DeleteDetached").then().body(is("OK")); } + + @Test + public void testBug26308() { + RestAssured.when().get("/test/26308").then().body(is("OK")); + } } diff --git a/integration-tests/hibernate-reactive-panache/src/main/java/io/quarkus/it/panache/reactive/TestEndpoint.java b/integration-tests/hibernate-reactive-panache/src/main/java/io/quarkus/it/panache/reactive/TestEndpoint.java index 6a0ee6bf30fb4..c9392e8095ccf 100644 --- a/integration-tests/hibernate-reactive-panache/src/main/java/io/quarkus/it/panache/reactive/TestEndpoint.java +++ b/integration-tests/hibernate-reactive-panache/src/main/java/io/quarkus/it/panache/reactive/TestEndpoint.java @@ -2054,4 +2054,30 @@ public Uni testSortByNullPrecedence() { return Person.deleteAll(); }).map(v -> "OK"); } + + @GET + @Path("26308") + @WithTransaction + public Uni testBug26308() { + return testBug26308Query("from Person2 p left join fetch p.address") + .flatMap(p -> testBug26308Query("from Person2 p left join p.address")) + .flatMap(p -> testBug26308Query("select p from Person2 p left join fetch p.address")) + .flatMap(p -> testBug26308Query("select p from Person2 p left join p.address")) + .flatMap(p -> testBug26308Query("from Person2 p left join fetch p.address select p")) + .flatMap(p -> testBug26308Query("from Person2 p left join p.address select p")) + .map(p -> "OK"); + } + + private Uni testBug26308Query(String hql) { + PanacheQuery query = Person.find(hql); + return query.list() + .flatMap(list -> { + Assertions.assertEquals(0, list.size()); + return query.count(); + }) + .map(count -> { + Assertions.assertEquals(0, count); + return null; + }); + } } diff --git a/integration-tests/hibernate-reactive-panache/src/test/java/io/quarkus/it/panache/reactive/PanacheFunctionalityTest.java b/integration-tests/hibernate-reactive-panache/src/test/java/io/quarkus/it/panache/reactive/PanacheFunctionalityTest.java index 8570db6c056d2..2aefcf9d2c1b5 100644 --- a/integration-tests/hibernate-reactive-panache/src/test/java/io/quarkus/it/panache/reactive/PanacheFunctionalityTest.java +++ b/integration-tests/hibernate-reactive-panache/src/test/java/io/quarkus/it/panache/reactive/PanacheFunctionalityTest.java @@ -305,4 +305,8 @@ public void testBeerRepository() { RestAssured.when().get("/test-repo/beers").then().body(is("OK")); } + @Test + public void testBug26308() { + RestAssured.when().get("/test/26308").then().body(is("OK")); + } }