diff --git a/jdbc/src/test/java/test/container/ContainerUtil.java b/jdbc/src/test/java/test/container/ContainerUtil.java index a8158fd..058fc5c 100644 --- a/jdbc/src/test/java/test/container/ContainerUtil.java +++ b/jdbc/src/test/java/test/container/ContainerUtil.java @@ -1,8 +1,15 @@ package test.container; +import org.junit.jupiter.api.DynamicTest; import org.testcontainers.containers.MySQLContainer; import org.testcontainers.containers.PostgreSQLContainer; +import javax.sql.DataSource; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.DynamicTest.dynamicTest; + public final class ContainerUtil { public static final PostgreSQLContainer PG = new PostgreSQLContainer<>("postgres:16-alpine"); @@ -13,4 +20,23 @@ public final class ContainerUtil { public static final String MYSQL_LOCATIONS = "classpath:db/mysql"; + public static final List containers = List.of( + JdbcContainer.create(PG, PG_LOCATIONS), + JdbcContainer.create(MYSQL, MYSQL_LOCATIONS) + ); + + public static List testAllContainers(Executable executable) { + List tests = new ArrayList<>(); + for (JdbcContainer container : ContainerUtil.containers) { + DataSource dataSource = container.start(); + String name = container.getDriverClassName(); + tests.addAll( + executable.execute(dataSource) + .stream().map(test -> dynamicTest(name + " " + test.name(),test.executable())) + .toList() + ); + } + return tests; + } + } diff --git a/jdbc/src/test/java/test/container/Executable.java b/jdbc/src/test/java/test/container/Executable.java new file mode 100644 index 0000000..ae213e7 --- /dev/null +++ b/jdbc/src/test/java/test/container/Executable.java @@ -0,0 +1,11 @@ +package test.container; + +import javax.sql.DataSource; +import java.util.List; + +@FunctionalInterface +public interface Executable { + + List execute(DataSource dataSource); + +} diff --git a/jdbc/src/test/java/test/container/JdbcContainer.java b/jdbc/src/test/java/test/container/JdbcContainer.java index c9e738e..e37e44e 100644 --- a/jdbc/src/test/java/test/container/JdbcContainer.java +++ b/jdbc/src/test/java/test/container/JdbcContainer.java @@ -47,4 +47,8 @@ public void stop() { container.stop(); } + public String getDriverClassName() { + return container.getDriverClassName(); + } + } diff --git a/jdbc/src/test/java/test/container/JdbcTest.java b/jdbc/src/test/java/test/container/JdbcTest.java new file mode 100644 index 0000000..301b2b9 --- /dev/null +++ b/jdbc/src/test/java/test/container/JdbcTest.java @@ -0,0 +1,7 @@ +package test.container; + +import org.junit.jupiter.api.function.Executable; + +public record JdbcTest(String name, + Executable executable) { +} diff --git a/jdbc/src/test/java/test/jdbc/StatementTest.java b/jdbc/src/test/java/test/jdbc/StatementTest.java index 59569d5..a7db567 100644 --- a/jdbc/src/test/java/test/jdbc/StatementTest.java +++ b/jdbc/src/test/java/test/jdbc/StatementTest.java @@ -1,74 +1,109 @@ package test.jdbc; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import com.zaxxer.hikari.HikariDataSource; +import org.junit.jupiter.api.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import test.container.ContainerUtil; import test.container.JdbcContainer; +import test.container.JdbcTest; import javax.sql.DataSource; -import java.sql.Connection; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; +import java.sql.*; +import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; public class StatementTest { private final static Logger log = LoggerFactory.getLogger(StatementTest.class); - private final static JdbcContainer container = JdbcContainer.create( - ContainerUtil.PG, ContainerUtil.PG_LOCATIONS - ); - - private static DataSource dataSource; - - @BeforeAll - static void beforeAll() { - dataSource = container.start(); + @TestFactory + List testAllDatabase() { + return ContainerUtil.testAllContainers(dataSource -> List.of( + new JdbcTest("execute", () -> this.execute(dataSource)), + new JdbcTest("executeUpdate", () -> this.executeUpdate(dataSource)) + )); } @AfterAll static void afterAll() { - container.stop(); + for (JdbcContainer container : ContainerUtil.containers) { + container.stop(); + } } - @Test - void execute() throws SQLException { + void execute(DataSource dataSource) throws SQLException { + String driverClassName = dataSource.unwrap(HikariDataSource.class).getDriverClassName(); + + // 不指定autoGeneratedKeys, 无法获得生成的key(mysql的实现会抛出异常, 所以统一成必须指定) try (Connection connection = dataSource.getConnection()) { String sql = "insert into swift_user(name, status) values(?, ?)"; - try (PreparedStatement ps = connection.prepareStatement(sql)) { + try (PreparedStatement ps = connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) { ps.setQueryTimeout(1); ps.setString(1, "fantasy"); ps.setInt(2, 1); + // 没有返回内容(不包括生成的主键), 所以是false assertEquals(false, ps.execute()); try (ResultSet resultSet = ps.getGeneratedKeys()) { - assertEquals(false, resultSet.next()); + assertEquals(true, resultSet.next()); + log.debug("key: {}", resultSet.getLong(1)); + } + } + } + + // pg专场 + if (driverClassName.contains("postgresql")) { + // 不设置Statement.RETURN_GENERATED_KEYS, 可以获取returning的内容 + try (Connection connection = dataSource.getConnection()) { + String sql = "insert into swift_user(name, status) values(?, ?) returning id, name"; + try (PreparedStatement ps = connection.prepareStatement(sql)) { + ps.setQueryTimeout(1); + ps.setString(1, "fantasy"); + ps.setInt(2, 1); + // 在指定autoGeneratedKeys的情况下, 就算有returning, 也会返回false + assertEquals(true, ps.execute()); + try (ResultSet resultSet = ps.getGeneratedKeys()) { + assertEquals(false, resultSet.next()); + } + try (ResultSet resultSet = ps.getResultSet()) { + assertEquals(true, resultSet.next()); + log.debug("id: {}, name: {}", resultSet.getLong(1), resultSet.getString(2)); + } + } + } + // 设置Statement.RETURN_GENERATED_KEYS, returning的内容将被忽略 + try (Connection connection = dataSource.getConnection()) { + String sql = "insert into swift_user(name, status) values(?, ?) returning id"; + try (PreparedStatement ps = connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) { + ps.setQueryTimeout(1); + ps.setString(1, "fantasy"); + ps.setInt(2, 1); + // 在指定autoGeneratedKeys的情况下, 就算有returning, 也会返回false + assertEquals(false, ps.execute()); + try (ResultSet resultSet = ps.getGeneratedKeys()) { + assertEquals(true, resultSet.next()); + log.debug("key: {}", resultSet.getLong(1)); + } + assertNull(ps.getResultSet()); } } } } - @Test - void executeUpdate() throws SQLException { + void executeUpdate(DataSource dataSource) throws SQLException { + String driverClassName = dataSource.unwrap(HikariDataSource.class).getDriverClassName(); + if (driverClassName.contains("postgresql")) { + // returning的结果只能通过execute方法获取 + } try (Connection connection = dataSource.getConnection()) { - String sql = "insert into swift_user(name, status) values(?, ?) returning id"; - try (PreparedStatement ps = connection.prepareStatement(sql)) { + String sql = "insert into swift_user(name, status) values(?, ?)"; + try (PreparedStatement ps = connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) { ps.setQueryTimeout(1); ps.setString(1, "fantasy"); ps.setInt(2, 1); assertEquals(1, ps.executeUpdate()); - try (ResultSet resultSet = ps.getResultSet()) { - assertTrue(resultSet.next()); - long id = resultSet.getLong(1); - log.debug("id: {}", id); - assertTrue(id > 0); - } - + assertNull(ps.getResultSet()); try (ResultSet resultSet = ps.getGeneratedKeys()) { assertTrue(resultSet.next()); long id = resultSet.getLong(1); diff --git a/jdbc/src/test/resources/db/mysql/V1.0.0__init.sql b/jdbc/src/test/resources/db/mysql/V1.0.0__init.sql index bb3ea2e..9cd525d 100644 --- a/jdbc/src/test/resources/db/mysql/V1.0.0__init.sql +++ b/jdbc/src/test/resources/db/mysql/V1.0.0__init.sql @@ -12,3 +12,17 @@ insert into student(id, name, status) values (1, '小明', 0), (2, '张三', 1), (3, '李四', 2), (4, '董超', 2), (5, '薛霸', 2); + + +CREATE TABLE swift_user ( + id bigint PRIMARY KEY auto_increment, + name text not null, + locale text, + avatar text, + status integer not null, + is_del boolean, + description text null, + ext json null, + created_at datetime not null default now(), + updated_at datetime not null default now() +);