Skip to content

Commit

Permalink
增加InsertBuilder。
Browse files Browse the repository at this point in the history
  • Loading branch information
fantasy0v0 committed Nov 19, 2024
1 parent 41f9554 commit 6edbeae
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package com.github.fantasy0v0.swift.jdbc;

import com.github.fantasy0v0.swift.jdbc.exception.SwiftSQLException;
import com.github.fantasy0v0.swift.jdbc.util.LogUtil;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static com.github.fantasy0v0.swift.jdbc.Utils.*;

public class InsertBuilder extends UpdateBuilder {

InsertBuilder(DataSource dataSource, StatementConfiguration statementConfiguration, String sql) {
super(dataSource, statementConfiguration, sql);
}

private <T> List<T> _fetchKey(FetchMapper<T> mapper,
List<Object> params, boolean firstOnly) {
try (ConnectionReference ref = ConnectionPoolUtil.getReference(dataSource)) {
Connection conn = ref.unwrap();
LogUtil.performance().info("fetchKey begin");
long startTime = System.nanoTime() / 1000;
String callerInfo = printCallerInfo();
LogUtil.sql().debug("fetchKey: [{}], caller: {}", sql, callerInfo);
try (PreparedStatement statement = prepareStatement(conn, sql, PreparedStatement.RETURN_GENERATED_KEYS, statementConfiguration)) {
fillStatementParams(conn, statement, params, parameterHandler);
int updated = statement.executeUpdate();
LogUtil.sql().debug("executeUpdate: {}", updated);
return fetchByResultSet(statement.getGeneratedKeys(), mapper, firstOnly);
} finally {
long cost = System.nanoTime() / 1000 - startTime;
NumberFormat format = NumberFormat.getNumberInstance();
LogUtil.performance().info("fetchKey end, cost: {} μs", format.format(cost));
}
} catch (SQLException e) {
throw new SwiftSQLException(e);
}
}

public <T> List<T> fetchKey(FetchMapper<T> mapper, List<Object> params) {
return _fetchKey(mapper, params, false);
}

public <T> List<T> fetchKey(FetchMapper<T> mapper, Object... params) {
return fetchKey(mapper, Arrays.stream(params).toList());
}

public <T> List<T> fetchKey(FetchMapper<T> mapper) {
return fetchKey(mapper, (List<Object>) null);
}

public List<Object[]> fetchKey(List<Object> params) {
return fetchKey(Utils::fetchByRow, params);
}

public List<Object[]> fetchKey(Object... params) {
return fetchKey(Arrays.stream(params).toList());
}

public List<Object[]> fetchKey() {
return fetchKey(Utils::fetchByRow, (List<Object>) null);
}

public <T> T fetchKeyOne(FetchMapper<T> mapper, List<Object> params) {
List<T> list = _fetchKey(mapper, params, true);
return list.isEmpty() ? null : list.getFirst();
}

public <T> T fetchKeyOne(FetchMapper<T> mapper, Object... params) {
return fetchKeyOne(mapper, Arrays.stream(params).toList());
}

public <T> T fetchKeyOne(FetchMapper<T> mapper) {
return fetchKeyOne(mapper, (List<Object>) null);
}

public Object[] fetchKeyOne(List<Object> params) {
return fetchKeyOne(Utils::fetchByRow, params);
}

public Object[] fetchKeyOne(Object... params) {
return fetchKeyOne(Arrays.stream(params).toList());
}

public Object[] fetchKeyOne() {
return fetchKeyOne((List<Object>) null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

public class UpdateBuilder implements StatementConfigurator<UpdateBuilder> {

private final DataSource dataSource;
protected final DataSource dataSource;

private StatementConfiguration statementConfiguration;
protected StatementConfiguration statementConfiguration;

private final String sql;
protected final String sql;

private ParameterHandler parameterHandler;
protected ParameterHandler parameterHandler;

UpdateBuilder(DataSource dataSource, StatementConfiguration statementConfiguration, String sql) {
this.dataSource = dataSource;
Expand Down Expand Up @@ -117,7 +117,7 @@ public List<Object[]> fetch(List<Object> params) {
}

public List<Object[]> fetch(Object... params) {
return fetch(Utils::fetchByRow, params);
return fetch(Utils::fetchByRow, Arrays.stream(params).toList());
}

public List<Object[]> fetch() {
Expand All @@ -143,7 +143,7 @@ public Object[] fetchOne(List<Object> params) {
}

public Object[] fetchOne(Object... params) {
return fetchOne(Utils::fetchByRow, params);
return fetchOne(Utils::fetchByRow, Arrays.stream(params).toList());
}

public Object[] fetchOne() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ static void fillStatementParams(Connection conn,
}
}

private static PreparedStatement prepareStatement(Connection conn,
static PreparedStatement prepareStatement(Connection conn,
String sql,
int autoGeneratedKeys,
StatementConfiguration configuration) throws SQLException {
Expand All @@ -265,7 +265,7 @@ private static PreparedStatement prepareStatement(Connection conn,
return ps;
}

private static PreparedStatement prepareStatement(Connection conn,
static PreparedStatement prepareStatement(Connection conn,
String sql,
StatementConfiguration configuration) throws SQLException {
return prepareStatement(conn, sql, Statement.NO_GENERATED_KEYS, configuration);
Expand Down
79 changes: 73 additions & 6 deletions jdbc/src/test/java/test/jdbc/StatementTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import com.zaxxer.hikari.HikariDataSource;
import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.TestFactory;
import org.postgresql.util.PSQLException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import test.container.ContainerUtil;
import test.container.JdbcTest;

import javax.sql.DataSource;
import java.sql.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;
Expand All @@ -22,7 +25,8 @@ public class StatementTest {
List<DynamicTest> testAllDatabase() {
return ContainerUtil.testAllContainers(() -> List.of(
JdbcTest.of("execute", this::execute),
JdbcTest.of("executeUpdate", this::executeUpdate)
JdbcTest.of("executeUpdate", this::executeUpdate),
JdbcTest.of("executeBatch", this::executeBatch)
));
}

Expand All @@ -47,14 +51,13 @@ void execute(DataSource dataSource) throws SQLException {

// pg专场
if (driverClassName.contains("postgresql")) {
// 不设置Statement.RETURN_GENERATED_KEYS, 可以获取returning的内容
// 不设置Statement.RETURN_GENERATED_KEYS, execute可以获取returning的内容
try (Connection connection = dataSource.getConnection()) {
String sql = "insert into swift_user(name, status) values(?, ?) returning id, name";
try (PreparedStatement ps = connection.prepareStatement(sql)) {
ps.setQueryTimeout(1);
ps.setString(1, "fantasy");
ps.setInt(2, 1);
// 在指定autoGeneratedKeys的情况下, 就算有returning, 也会返回false
assertTrue(ps.execute());
try (ResultSet resultSet = ps.getGeneratedKeys()) {
assertFalse(resultSet.next());
Expand Down Expand Up @@ -85,14 +88,28 @@ void execute(DataSource dataSource) throws SQLException {
assertNull(ps.getResultSet());
}
}

// 尝试用execute取出返回的id
try (Connection connection = dataSource.getConnection()) {
String sql = "update swift_user set status = 1 returning id";
try (PreparedStatement ps = connection.prepareStatement(sql)) {
assertTrue(ps.execute());
assertNotNull(ps.getResultSet());
try (ResultSet resultSet = ps.getResultSet()) {
while (resultSet.next()) {
long id = resultSet.getLong(1);
log.debug("key: {}", id);
assertTrue(id > 0);
}
}
}
}
}
}

void executeUpdate(DataSource dataSource) throws SQLException {
String driverClassName = dataSource.unwrap(HikariDataSource.class).getDriverClassName();
if (driverClassName.contains("postgresql")) {
// returning的结果只能通过execute方法获取
}
// 取出生成的key
try (Connection connection = dataSource.getConnection()) {
String sql = "insert into swift_user(name, status) values(?, ?)";
try (PreparedStatement ps = connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) {
Expand All @@ -108,6 +125,56 @@ void executeUpdate(DataSource dataSource) throws SQLException {
assertTrue(id > 0);
}
}
// 尝试用executeUpdate取出returning返回的内容
if (driverClassName.contains("postgresql")) {
assertThrows(PSQLException.class, () -> {
String sql2 = "insert into swift_user(name, status) values(?, ?) returning id";
try (PreparedStatement ps = connection.prepareStatement(sql2)) {
ps.setString(1, "fantasy2");
ps.setInt(2, 2);
ps.executeUpdate();
// assertEquals(1, ps.executeUpdate());
assertNotNull(ps.getResultSet());
try (ResultSet resultSet = ps.getResultSet()) {
assertTrue(resultSet.next());
long id = resultSet.getLong(1);
log.debug("key: {}", id);
assertTrue(id > 0);
}
}
});
}
}
}

void executeBatch(DataSource dataSource) throws SQLException {
List<Object[]> params = List.of(
new Object[]{"fantasy1", 1},
new Object[]{"fantasy2", 2}
);
try (Connection connection = dataSource.getConnection()) {
String sql = "insert into swift_user(name, status) values(?, ?)";
try (PreparedStatement ps = connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) {
for (Object[] param : params) {
ps.setString(1, param[0].toString());
ps.setInt(2, (int)param[1]);
ps.addBatch();
}
int[] batchResult = ps.executeBatch();
assertTrue(Arrays.stream(batchResult).allMatch(i -> i == 1));
assertEquals(params.size(), batchResult.length);

List<Long> keyResult = new ArrayList<>(batchResult.length);
try (ResultSet resultSet = ps.getGeneratedKeys()) {
while (resultSet.next()) {
long id = resultSet.getLong(1);
log.debug("key: {}", id);
assertTrue(id > 0);
keyResult.add(id);
}
}
assertEquals(params.size(), keyResult.size());
}
}
}

Expand Down

0 comments on commit 6edbeae

Please sign in to comment.