Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try to keep an SSLSocketFactory instance for an SslConfig instance ma… #251

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/main/java/com/bettercloud/vault/SslConfig.java
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.SSLSocketFactory;

/**
* <p>A container for SSL-related configuration options, meant to be stored within a {@link VaultConfig} instance.</p>
@@ -46,6 +47,7 @@ public class SslConfig implements Serializable {

private boolean verify;
private transient SSLContext sslContext;
private transient SSLSocketFactory sslSocketFactory;
private transient KeyStore trustStore;
private transient KeyStore keyStore;
private String keyStorePassword;
@@ -469,6 +471,10 @@ public SSLContext getSslContext() {
return sslContext;
}

public SSLSocketFactory getSslSocketFactory() {
return sslSocketFactory;
}

protected String getPemUTF8() {
return pemUTF8;
}
@@ -489,8 +495,10 @@ private void buildSsl() throws VaultException {
if (verify) {
if (keyStore != null || trustStore != null) {
this.sslContext = buildSslContextFromJks();
this.sslSocketFactory = sslContext.getSocketFactory();
} else if (pemUTF8 != null || clientPemUTF8 != null || clientKeyPemUTF8 != null) {
this.sslContext = buildSslContextFromPem();
this.sslSocketFactory = sslContext.getSocketFactory();
}
}
}
8 changes: 8 additions & 0 deletions src/main/java/com/bettercloud/vault/api/Logical.java
Original file line number Diff line number Diff line change
@@ -91,6 +91,7 @@ private LogicalResponse read(final String path, Boolean shouldRetry, final logic
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.get();

// Validate response - don't treat 4xx class errors as exceptions, we want to return an error as the response
@@ -160,6 +161,7 @@ public LogicalResponse read(final String path, Boolean shouldRetry, final Intege
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.get();

// Validate response - don't treat 4xx class errors as exceptions, we want to return an error as the response
@@ -261,6 +263,7 @@ private LogicalResponse write(final String path, final Map<String, Object> nameV
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.post();

// HTTP Status should be either 200 (with content - e.g. PKI write) or 204 (no content)
@@ -352,6 +355,7 @@ private LogicalResponse delete(final String path, final Logical.logicalOperation
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.delete();

// Validate response
@@ -412,6 +416,7 @@ public LogicalResponse delete(final String path, final int[] versions) throws Va
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.body(versionsToDelete.toString().getBytes(StandardCharsets.UTF_8))
.post();

@@ -483,6 +488,7 @@ public LogicalResponse unDelete(final String path, final int[] versions) throws
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.body(versionsToUnDelete.toString().getBytes(StandardCharsets.UTF_8))
.post();

@@ -542,6 +548,7 @@ public LogicalResponse destroy(final String path, final int[] versions) throws V
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.body(versionsToDestroy.toString().getBytes(StandardCharsets.UTF_8))
.post();

@@ -593,6 +600,7 @@ public LogicalResponse upgrade(final String kvPath) throws VaultException {
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.body(kvToUpgrade.toString().getBytes(StandardCharsets.UTF_8))
.post();

32 changes: 27 additions & 5 deletions src/main/java/com/bettercloud/vault/rest/Rest.java
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import javax.net.ssl.SSLSocketFactory;

/**
* <p>A simple client for issuing HTTP requests. Supports the HTTP verbs:</p>
@@ -66,6 +67,7 @@ public class Rest {
* verification process, to always trust any certificates.
*/
private static SSLContext DISABLED_SSL_CONTEXT;
private static SSLSocketFactory DISABLED_SSL_SOCKET_FACTORY;

static {
try {
@@ -84,6 +86,7 @@ public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
}
}}, new java.security.SecureRandom());
DISABLED_SSL_SOCKET_FACTORY = DISABLED_SSL_CONTEXT.getSocketFactory();
} catch (NoSuchAlgorithmException | KeyManagementException e) {
e.printStackTrace();
}
@@ -98,6 +101,7 @@ public X509Certificate[] getAcceptedIssuers() {
private Integer readTimeoutSeconds;
private Boolean sslVerification;
private SSLContext sslContext;
private SSLSocketFactory sslSocketFactory;

/**
* <p>Sets the base URL to which the HTTP request will be sent. The URL may or may not include query parameters
@@ -248,6 +252,11 @@ public Rest sslContext(final SSLContext sslContext) {
return this;
}

public Rest sslSocketFactory(final SSLSocketFactory sslSocketFactory) {
this.sslSocketFactory = sslSocketFactory;
return this;
}

/**
* <p>Executes an HTTP GET request with the settings already configured. Parameters and headers are optional, but
* a <code>RestException</code> will be thrown if the caller has not first set a base URL with the
@@ -446,8 +455,11 @@ private URLConnection initURLConnection(final String urlString, final String met
final HttpsURLConnection httpsURLConnection = (HttpsURLConnection) connection;
if (sslVerification != null && !sslVerification) {
// SSL verification disabled
httpsURLConnection.setSSLSocketFactory(DISABLED_SSL_CONTEXT.getSocketFactory());
httpsURLConnection.setSSLSocketFactory(DISABLED_SSL_SOCKET_FACTORY);
httpsURLConnection.setHostnameVerifier((s, sslSession) -> true);
} else if (sslSocketFactory != null) {
// Socket factory supplied for keep-alive connections
httpsURLConnection.setSSLSocketFactory(sslSocketFactory);
} else if (sslContext != null) {
// Cert file supplied
httpsURLConnection.setSSLSocketFactory(sslContext.getSocketFactory());
@@ -463,11 +475,10 @@ private URLConnection initURLConnection(final String urlString, final String met

return connection;
} catch (Exception e) {
throw new RestException(e);
} finally {
if (connection instanceof HttpURLConnection) {
if (connection != null && connection instanceof HttpURLConnection) {
((HttpURLConnection) connection).disconnect();
}
throw new RestException(e);
}
}

@@ -499,8 +510,8 @@ private String parametersToQueryString() {
* @throws RestException
*/
private byte[] responseBodyBytes(final URLConnection connection) throws RestException {
InputStream inputStream = null;
try {
final InputStream inputStream;
final int responseCode = this.connectionStatus(connection);
if (200 <= responseCode && responseCode <= 299) {
inputStream = connection.getInputStream();
@@ -519,9 +530,20 @@ private byte[] responseBodyBytes(final URLConnection connection) throws RestExce
while ((bytesRead = inputStream.read(bytes, 0, bytes.length)) != -1) {
byteArrayOutputStream.write(bytes, 0, bytesRead);
}
inputStream.close();
byteArrayOutputStream.flush();
return byteArrayOutputStream.toByteArray();
} catch (IOException e) {
try {
if (inputStream == null) {
inputStream = ((HttpURLConnection) connection).getErrorStream();
}
if (inputStream != null) {
inputStream.close();
}
} catch (IOException ee) {
//do nothing
}
return new byte[0];
}
}
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import com.bettercloud.vault.response.LogicalResponse;
import com.bettercloud.vault.util.VaultContainer;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -443,4 +444,56 @@ public void testVaultUpgrade() throws VaultException {
Assert.assertEquals(kVOriginalVersion, "1");
Assert.assertEquals(kVUpgradedVersion, "2");
}

/**
* Verify that value can be read several times with connection re-usage
*
* @throws VaultException
*/
@Test
public void testReadSeveralTimesWithConnectionReUsage() throws VaultException {
final int readTimes = 10;
final String pathToWrite = "secret/hello";
final String pathToRead = "secret/hello";

final String value = "world";
final Map<String, Object> testMap = new HashMap<>();
testMap.put("value", value);

final Vault vault = container.getRootVault();
String hostport = String.format("%s.%s", container.getContainerIpAddress(), container.getMappedPort(8200));
Logical logical = vault.logical();

int connBefore = connStat(hostport);
logical.write(pathToWrite, testMap);
for(int i = 0; i < readTimes; i++) {
final String valueRead = logical.read(pathToRead).getData().get("value");
assertEquals(value, valueRead);
}
int connCreated = connStat(hostport) - connBefore;
assertTrue("Too many new connections to '" + hostport + "' created: " + connCreated, connCreated <= 1);
}

private int connStat(String host) {
ProcessBuilder pb = new ProcessBuilder();
pb.command("netstat");
pb.redirectErrorStream(true);

try {
Process p = pb.start();
InputStream inputStream = p.getInputStream();
String result = new String(inputStream.readAllBytes());
int conn = 0;
for (String line : result.split("\n")) {
if (line.matches(".*" + host + "\\s+ESTABLISHED")) {
conn++;
System.out.println(line);
}
}
return conn;
} catch (IOException e) {
System.err.println("Error executing netstat: " + e.getMessage());
return 0;
}
}
}