From 3c15db59b0bb30f3677a0bfa09039c11e6ff201a Mon Sep 17 00:00:00 2001 From: ksamoylov Date: Mon, 8 Feb 2021 23:46:37 +0300 Subject: [PATCH] try to keep an SSLSocketFactory instance for an SslConfig instance makes it possible to reuse an HTTPs connection for consequent successful requests --- .../java/com/bettercloud/vault/SslConfig.java | 8 +++ .../com/bettercloud/vault/api/Logical.java | 8 +++ .../java/com/bettercloud/vault/rest/Rest.java | 32 +++++++++-- .../bettercloud/vault/api/LogicalTests.java | 53 +++++++++++++++++++ 4 files changed, 96 insertions(+), 5 deletions(-) diff --git a/src/main/java/com/bettercloud/vault/SslConfig.java b/src/main/java/com/bettercloud/vault/SslConfig.java index 21179b14..23c16452 100644 --- a/src/main/java/com/bettercloud/vault/SslConfig.java +++ b/src/main/java/com/bettercloud/vault/SslConfig.java @@ -30,6 +30,7 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.SSLSocketFactory; /** *

A container for SSL-related configuration options, meant to be stored within a {@link VaultConfig} instance.

@@ -46,6 +47,7 @@ public class SslConfig implements Serializable { private boolean verify; private transient SSLContext sslContext; + private transient SSLSocketFactory sslSocketFactory; private transient KeyStore trustStore; private transient KeyStore keyStore; private String keyStorePassword; @@ -469,6 +471,10 @@ public SSLContext getSslContext() { return sslContext; } + public SSLSocketFactory getSslSocketFactory() { + return sslSocketFactory; + } + protected String getPemUTF8() { return pemUTF8; } @@ -489,8 +495,10 @@ private void buildSsl() throws VaultException { if (verify) { if (keyStore != null || trustStore != null) { this.sslContext = buildSslContextFromJks(); + this.sslSocketFactory = sslContext.getSocketFactory(); } else if (pemUTF8 != null || clientPemUTF8 != null || clientKeyPemUTF8 != null) { this.sslContext = buildSslContextFromPem(); + this.sslSocketFactory = sslContext.getSocketFactory(); } } } diff --git a/src/main/java/com/bettercloud/vault/api/Logical.java b/src/main/java/com/bettercloud/vault/api/Logical.java index e52bbf07..d34c6301 100644 --- a/src/main/java/com/bettercloud/vault/api/Logical.java +++ b/src/main/java/com/bettercloud/vault/api/Logical.java @@ -91,6 +91,7 @@ private LogicalResponse read(final String path, Boolean shouldRetry, final logic .readTimeoutSeconds(config.getReadTimeout()) .sslVerification(config.getSslConfig().isVerify()) .sslContext(config.getSslConfig().getSslContext()) + .sslSocketFactory(config.getSslConfig().getSslSocketFactory()) .get(); // Validate response - don't treat 4xx class errors as exceptions, we want to return an error as the response @@ -160,6 +161,7 @@ public LogicalResponse read(final String path, Boolean shouldRetry, final Intege .readTimeoutSeconds(config.getReadTimeout()) .sslVerification(config.getSslConfig().isVerify()) .sslContext(config.getSslConfig().getSslContext()) + .sslSocketFactory(config.getSslConfig().getSslSocketFactory()) .get(); // Validate response - don't treat 4xx class errors as exceptions, we want to return an error as the response @@ -261,6 +263,7 @@ private LogicalResponse write(final String path, final Map nameV .readTimeoutSeconds(config.getReadTimeout()) .sslVerification(config.getSslConfig().isVerify()) .sslContext(config.getSslConfig().getSslContext()) + .sslSocketFactory(config.getSslConfig().getSslSocketFactory()) .post(); // HTTP Status should be either 200 (with content - e.g. PKI write) or 204 (no content) @@ -352,6 +355,7 @@ private LogicalResponse delete(final String path, final Logical.logicalOperation .readTimeoutSeconds(config.getReadTimeout()) .sslVerification(config.getSslConfig().isVerify()) .sslContext(config.getSslConfig().getSslContext()) + .sslSocketFactory(config.getSslConfig().getSslSocketFactory()) .delete(); // Validate response @@ -412,6 +416,7 @@ public LogicalResponse delete(final String path, final int[] versions) throws Va .readTimeoutSeconds(config.getReadTimeout()) .sslVerification(config.getSslConfig().isVerify()) .sslContext(config.getSslConfig().getSslContext()) + .sslSocketFactory(config.getSslConfig().getSslSocketFactory()) .body(versionsToDelete.toString().getBytes(StandardCharsets.UTF_8)) .post(); @@ -483,6 +488,7 @@ public LogicalResponse unDelete(final String path, final int[] versions) throws .readTimeoutSeconds(config.getReadTimeout()) .sslVerification(config.getSslConfig().isVerify()) .sslContext(config.getSslConfig().getSslContext()) + .sslSocketFactory(config.getSslConfig().getSslSocketFactory()) .body(versionsToUnDelete.toString().getBytes(StandardCharsets.UTF_8)) .post(); @@ -542,6 +548,7 @@ public LogicalResponse destroy(final String path, final int[] versions) throws V .readTimeoutSeconds(config.getReadTimeout()) .sslVerification(config.getSslConfig().isVerify()) .sslContext(config.getSslConfig().getSslContext()) + .sslSocketFactory(config.getSslConfig().getSslSocketFactory()) .body(versionsToDestroy.toString().getBytes(StandardCharsets.UTF_8)) .post(); @@ -593,6 +600,7 @@ public LogicalResponse upgrade(final String kvPath) throws VaultException { .readTimeoutSeconds(config.getReadTimeout()) .sslVerification(config.getSslConfig().isVerify()) .sslContext(config.getSslConfig().getSslContext()) + .sslSocketFactory(config.getSslConfig().getSslSocketFactory()) .body(kvToUpgrade.toString().getBytes(StandardCharsets.UTF_8)) .post(); diff --git a/src/main/java/com/bettercloud/vault/rest/Rest.java b/src/main/java/com/bettercloud/vault/rest/Rest.java index 3d528d13..f872f7cd 100644 --- a/src/main/java/com/bettercloud/vault/rest/Rest.java +++ b/src/main/java/com/bettercloud/vault/rest/Rest.java @@ -23,6 +23,7 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; +import javax.net.ssl.SSLSocketFactory; /** *

A simple client for issuing HTTP requests. Supports the HTTP verbs:

@@ -66,6 +67,7 @@ public class Rest { * verification process, to always trust any certificates. */ private static SSLContext DISABLED_SSL_CONTEXT; + private static SSLSocketFactory DISABLED_SSL_SOCKET_FACTORY; static { try { @@ -84,6 +86,7 @@ public X509Certificate[] getAcceptedIssuers() { return new X509Certificate[0]; } }}, new java.security.SecureRandom()); + DISABLED_SSL_SOCKET_FACTORY = DISABLED_SSL_CONTEXT.getSocketFactory(); } catch (NoSuchAlgorithmException | KeyManagementException e) { e.printStackTrace(); } @@ -98,6 +101,7 @@ public X509Certificate[] getAcceptedIssuers() { private Integer readTimeoutSeconds; private Boolean sslVerification; private SSLContext sslContext; + private SSLSocketFactory sslSocketFactory; /** *

Sets the base URL to which the HTTP request will be sent. The URL may or may not include query parameters @@ -248,6 +252,11 @@ public Rest sslContext(final SSLContext sslContext) { return this; } + public Rest sslSocketFactory(final SSLSocketFactory sslSocketFactory) { + this.sslSocketFactory = sslSocketFactory; + return this; + } + /** *

Executes an HTTP GET request with the settings already configured. Parameters and headers are optional, but * a RestException will be thrown if the caller has not first set a base URL with the @@ -446,8 +455,11 @@ private URLConnection initURLConnection(final String urlString, final String met final HttpsURLConnection httpsURLConnection = (HttpsURLConnection) connection; if (sslVerification != null && !sslVerification) { // SSL verification disabled - httpsURLConnection.setSSLSocketFactory(DISABLED_SSL_CONTEXT.getSocketFactory()); + httpsURLConnection.setSSLSocketFactory(DISABLED_SSL_SOCKET_FACTORY); httpsURLConnection.setHostnameVerifier((s, sslSession) -> true); + } else if (sslSocketFactory != null) { + // Socket factory supplied for keep-alive connections + httpsURLConnection.setSSLSocketFactory(sslSocketFactory); } else if (sslContext != null) { // Cert file supplied httpsURLConnection.setSSLSocketFactory(sslContext.getSocketFactory()); @@ -463,11 +475,10 @@ private URLConnection initURLConnection(final String urlString, final String met return connection; } catch (Exception e) { - throw new RestException(e); - } finally { - if (connection instanceof HttpURLConnection) { + if (connection != null && connection instanceof HttpURLConnection) { ((HttpURLConnection) connection).disconnect(); } + throw new RestException(e); } } @@ -499,8 +510,8 @@ private String parametersToQueryString() { * @throws RestException */ private byte[] responseBodyBytes(final URLConnection connection) throws RestException { + InputStream inputStream = null; try { - final InputStream inputStream; final int responseCode = this.connectionStatus(connection); if (200 <= responseCode && responseCode <= 299) { inputStream = connection.getInputStream(); @@ -519,9 +530,20 @@ private byte[] responseBodyBytes(final URLConnection connection) throws RestExce while ((bytesRead = inputStream.read(bytes, 0, bytes.length)) != -1) { byteArrayOutputStream.write(bytes, 0, bytesRead); } + inputStream.close(); byteArrayOutputStream.flush(); return byteArrayOutputStream.toByteArray(); } catch (IOException e) { + try { + if (inputStream == null) { + inputStream = ((HttpURLConnection) connection).getErrorStream(); + } + if (inputStream != null) { + inputStream.close(); + } + } catch (IOException ee) { + //do nothing + } return new byte[0]; } } diff --git a/src/test-integration/java/com/bettercloud/vault/api/LogicalTests.java b/src/test-integration/java/com/bettercloud/vault/api/LogicalTests.java index a5a7d473..af10a61e 100644 --- a/src/test-integration/java/com/bettercloud/vault/api/LogicalTests.java +++ b/src/test-integration/java/com/bettercloud/vault/api/LogicalTests.java @@ -7,6 +7,7 @@ import com.bettercloud.vault.response.LogicalResponse; import com.bettercloud.vault.util.VaultContainer; import java.io.IOException; +import java.io.InputStream; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -443,4 +444,56 @@ public void testVaultUpgrade() throws VaultException { Assert.assertEquals(kVOriginalVersion, "1"); Assert.assertEquals(kVUpgradedVersion, "2"); } + + /** + * Verify that value can be read several times with connection re-usage + * + * @throws VaultException + */ + @Test + public void testReadSeveralTimesWithConnectionReUsage() throws VaultException { + final int readTimes = 10; + final String pathToWrite = "secret/hello"; + final String pathToRead = "secret/hello"; + + final String value = "world"; + final Map testMap = new HashMap<>(); + testMap.put("value", value); + + final Vault vault = container.getRootVault(); + String hostport = String.format("%s.%s", container.getContainerIpAddress(), container.getMappedPort(8200)); + Logical logical = vault.logical(); + + int connBefore = connStat(hostport); + logical.write(pathToWrite, testMap); + for(int i = 0; i < readTimes; i++) { + final String valueRead = logical.read(pathToRead).getData().get("value"); + assertEquals(value, valueRead); + } + int connCreated = connStat(hostport) - connBefore; + assertTrue("Too many new connections to '" + hostport + "' created: " + connCreated, connCreated <= 1); + } + + private int connStat(String host) { + ProcessBuilder pb = new ProcessBuilder(); + pb.command("netstat"); + pb.redirectErrorStream(true); + + try { + Process p = pb.start(); + InputStream inputStream = p.getInputStream(); + String result = new String(inputStream.readAllBytes()); + int conn = 0; + for (String line : result.split("\n")) { + if (line.matches(".*" + host + "\\s+ESTABLISHED")) { + conn++; + System.out.println(line); + } + } + return conn; + } catch (IOException e) { + System.err.println("Error executing netstat: " + e.getMessage()); + return 0; + } + } }