diff --git a/core/src/main/java/io/confluent/rest/ApplicationServer.java b/core/src/main/java/io/confluent/rest/ApplicationServer.java index 7d8a2f0d09..1cacfe0cda 100644 --- a/core/src/main/java/io/confluent/rest/ApplicationServer.java +++ b/core/src/main/java/io/confluent/rest/ApplicationServer.java @@ -39,6 +39,8 @@ import java.net.URI; import java.net.URISyntaxException; import java.net.URL; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -230,6 +232,15 @@ private void configureClientAuth(SslContextFactory sslContextFactory, RestConfig } } + private Path getWatchLocation(RestConfig config) { + Path keystorePath = Paths.get(config.getString(RestConfig.SSL_KEYSTORE_LOCATION_CONFIG)); + String watchLocation = config.getString(RestConfig.SSL_KEYSTORE_WATCH_LOCATION_CONFIG); + if (!watchLocation.isEmpty()) { + keystorePath = Paths.get(watchLocation); + } + return keystorePath; + } + private SslContextFactory createSslContextFactory(RestConfig config) { SslContextFactory sslContextFactory = new SslContextFactory.Server(); if (!config.getString(RestConfig.SSL_KEYSTORE_LOCATION_CONFIG).isEmpty()) { @@ -250,6 +261,23 @@ private SslContextFactory createSslContextFactory(RestConfig config) { sslContextFactory.setKeyManagerFactoryAlgorithm( config.getString(RestConfig.SSL_KEYMANAGER_ALGORITHM_CONFIG)); } + + if (config.getBoolean(RestConfig.SSL_KEYSTORE_RELOAD_CONFIG)) { + Path watchLocation = getWatchLocation(config); + try { + FileWatcher.onFileChange(watchLocation, () -> { + // Need to reset the key store path for symbolic link case + sslContextFactory.setKeyStorePath( + config.getString(RestConfig.SSL_KEYSTORE_LOCATION_CONFIG) + ); + sslContextFactory.reload(scf -> log.info("Reloaded SSL cert")); + } + ); + log.info("Enabled SSL cert auto reload for: " + watchLocation); + } catch (java.io.IOException e) { + log.error("Can not enabled SSL cert auto reload", e); + } + } } configureClientAuth(sslContextFactory, config); diff --git a/core/src/main/java/io/confluent/rest/FileWatcher.java b/core/src/main/java/io/confluent/rest/FileWatcher.java new file mode 100644 index 0000000000..c129847a95 --- /dev/null +++ b/core/src/main/java/io/confluent/rest/FileWatcher.java @@ -0,0 +1,129 @@ +/** + * Copyright 2019 Confluent Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.confluent.rest; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.FileSystems; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardWatchEventKinds; +import java.nio.file.WatchService; +import java.nio.file.WatchKey; +import java.nio.file.WatchEvent; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; + +// reference https://gist.github.com/danielflower/f54c2fe42d32356301c68860a4ab21ed +public class FileWatcher implements Runnable { + private static final Logger log = LoggerFactory.getLogger(FileWatcher.class); + private static final ExecutorService executor = Executors.newFixedThreadPool(1, + new ThreadFactory() { + public Thread newThread(Runnable r) { + Thread t = Executors.defaultThreadFactory().newThread(r); + t.setDaemon(true); + return t; + } + }); + + public interface Callback { + void run() throws Exception; + } + + private volatile boolean shutdown; + private final WatchService watchService; + private final Path file; + private final Callback callback; + + public FileWatcher(Path file, Callback callback) throws IOException { + this.file = file; + this.watchService = FileSystems.getDefault().newWatchService(); + // Listen to both CREATE and MODIFY to reload, so taking care of delete then create. + file.getParent().register(watchService, + StandardWatchEventKinds.ENTRY_CREATE, + StandardWatchEventKinds.ENTRY_MODIFY); + this.callback = callback; + } + + /** + * Starts watching a file calls the callback when it is changed. + * A shutdown hook is registered to stop watching. + */ + public static void onFileChange(Path file, Callback callback) throws IOException { + log.info("Configure watch file change: " + file); + FileWatcher fileWatcher = new FileWatcher(file, callback); + executor.submit(fileWatcher); + } + + public void run() { + try { + while (!shutdown) { + try { + handleNextWatchNotification(); + } catch (InterruptedException e) { + throw e; + } catch (Exception e) { + log.info("Watch service caught exception, will continue:" + e); + } + } + } catch (InterruptedException e) { + log.info("Ending watch due to interrupt"); + } + } + + private void handleNextWatchNotification() throws InterruptedException { + log.debug("Watching file change: " + file); + // wait for key to be signalled + WatchKey key = watchService.take(); + log.info("Watch Key notified"); + for (WatchEvent event : key.pollEvents()) { + WatchEvent.Kind kind = event.kind(); + if (kind == StandardWatchEventKinds.OVERFLOW) { + log.debug("Watch event is OVERFLOW"); + continue; + } + WatchEvent ev = (WatchEvent)event; + Path changed = this.file.getParent().resolve(ev.context()); + log.info("Watch file change: " + ev.context() + "=>" + changed); + // Need to use path equals than isSameFile + if (Files.exists(changed) && changed.equals(this.file)) { + log.debug("Watch matching file: " + file); + try { + callback.run(); + } catch (Exception e) { + log.warn("Hit error callback on file change", e); + } + break; + } + } + key.reset(); + } + + public void shutdown() { + shutdown = true; + try { + watchService.close(); + } catch (IOException e) { + log.info("Error closing watch service", e); + } + } + +} diff --git a/core/src/main/java/io/confluent/rest/RestConfig.java b/core/src/main/java/io/confluent/rest/RestConfig.java index 4908c44cd8..425c648693 100644 --- a/core/src/main/java/io/confluent/rest/RestConfig.java +++ b/core/src/main/java/io/confluent/rest/RestConfig.java @@ -120,10 +120,18 @@ public class RestConfig extends AbstractConfig { + "details, etc."; protected static final String METRICS_TAGS_DEFAULT = ""; + public static final String SSL_KEYSTORE_RELOAD_CONFIG = "ssl.keystore.reload"; + protected static final String SSL_KEYSTORE_RELOAD_DOC = + "Enable auto reload of ssl keystore"; + protected static final boolean SSL_KEYSTORE_RELOAD_DEFAULT = false; public static final String SSL_KEYSTORE_LOCATION_CONFIG = "ssl.keystore.location"; protected static final String SSL_KEYSTORE_LOCATION_DOC = "Location of the keystore file to use for SSL. This is required for HTTPS."; protected static final String SSL_KEYSTORE_LOCATION_DEFAULT = ""; + public static final String SSL_KEYSTORE_WATCH_LOCATION_CONFIG = "ssl.keystore.watch.location"; + protected static final String SSL_KEYSTORE_WATCH_LOCATION_DOC = + "Location to watch keystore file change if it is different from keystore location "; + protected static final String SSL_KEYSTORE_WATCH_LOCATION_DEFAULT = ""; public static final String SSL_KEYSTORE_PASSWORD_CONFIG = "ssl.keystore.password"; protected static final String SSL_KEYSTORE_PASSWORD_DOC = "The store password for the keystore file."; @@ -411,12 +419,24 @@ private static ConfigDef incompleteBaseConfigDef() { METRICS_TAGS_DEFAULT, Importance.LOW, METRICS_TAGS_DOC + ).define( + SSL_KEYSTORE_RELOAD_CONFIG, + Type.BOOLEAN, + SSL_KEYSTORE_RELOAD_DEFAULT, + Importance.LOW, + SSL_KEYSTORE_RELOAD_DOC ).define( SSL_KEYSTORE_LOCATION_CONFIG, Type.STRING, SSL_KEYSTORE_LOCATION_DEFAULT, Importance.HIGH, SSL_KEYSTORE_LOCATION_DOC + ).define( + SSL_KEYSTORE_WATCH_LOCATION_CONFIG, + Type.STRING, + SSL_KEYSTORE_WATCH_LOCATION_DEFAULT, + Importance.LOW, + SSL_KEYSTORE_WATCH_LOCATION_DOC ).define( SSL_KEYSTORE_PASSWORD_CONFIG, Type.PASSWORD, diff --git a/core/src/test/java/io/confluent/rest/SslTest.java b/core/src/test/java/io/confluent/rest/SslTest.java index afb39a8a7a..3c70a4a2ea 100644 --- a/core/src/test/java/io/confluent/rest/SslTest.java +++ b/core/src/test/java/io/confluent/rest/SslTest.java @@ -37,6 +37,8 @@ import java.io.File; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; import java.net.SocketException; import java.security.KeyPair; import java.security.cert.X509Certificate; @@ -66,9 +68,12 @@ public class SslTest { private File trustStore; private File clientKeystore; private File serverKeystore; + private File serverKeystoreBak; + private File serverKeystoreErr; public static final String SSL_PASSWORD = "test1234"; public static final String EXPECTED_200_MSG = "Response status must be 200."; + public static final int CERT_RELOAD_WAIT_TIME = 20000; @Before public void setUp() throws Exception { @@ -76,6 +81,8 @@ public void setUp() throws Exception { trustStore = File.createTempFile("SslTest-truststore", ".jks"); clientKeystore = File.createTempFile("SslTest-client-keystore", ".jks"); serverKeystore = File.createTempFile("SslTest-server-keystore", ".jks"); + serverKeystoreBak = File.createTempFile("SslTest-server-keystore", ".jks.bak"); + serverKeystoreErr = File.createTempFile("SslTest-server-keystore", ".jks.err"); } catch (IOException ioe) { throw new RuntimeException("Unable to create temporary files for trust stores and keystores."); } @@ -83,6 +90,10 @@ public void setUp() throws Exception { createKeystoreWithCert(clientKeystore, "client", certs); createKeystoreWithCert(serverKeystore, "server", certs); TestSslUtils.createTrustStore(trustStore.getAbsolutePath(), new Password(SSL_PASSWORD), certs); + + Files.copy(serverKeystore.toPath(), serverKeystoreBak.toPath(), StandardCopyOption.REPLACE_EXISTING); + certs = new HashMap<>(); + createWrongKeystoreWithCert(serverKeystoreErr, "server", certs); } private void createKeystoreWithCert(File file, String alias, Map certs) throws Exception { @@ -109,6 +120,16 @@ private void enableSslClientAuth(Properties props) { props.put(RestConfig.SSL_CLIENT_AUTH_CONFIG, true); } + private void createWrongKeystoreWithCert(File file, String alias, Map certs) throws Exception { + KeyPair keypair = TestSslUtils.generateKeyPair("RSA"); + CertificateBuilder certificateBuilder = new CertificateBuilder(30, "SHA1withRSA"); + X509Certificate cCert = certificateBuilder.sanDnsName("fail") + .generate("CN=mymachine.local, O=A client", keypair); + TestSslUtils.createKeyStore(file.getPath(), new Password(SSL_PASSWORD), alias, keypair.getPrivate(), cCert); + certs.put(alias, cCert); + } + + @Test public void testHttpAndHttps() throws Exception { TestMetricsReporter.reset(); @@ -134,6 +155,50 @@ public void testHttpAndHttps() throws Exception { } } + @Test + public void testHttpsWithAutoReload() throws Exception { + TestMetricsReporter.reset(); + Properties props = new Properties(); + String httpsUri = "https://localhost:8082"; + props.put(RestConfig.LISTENERS_CONFIG, httpsUri); + props.put(RestConfig.METRICS_REPORTER_CLASSES_CONFIG, "io.confluent.rest.TestMetricsReporter"); + props.put(RestConfig.SSL_KEYSTORE_RELOAD_CONFIG, "true"); + configServerKeystore(props); + TestRestConfig config = new TestRestConfig(props); + SslTestApplication app = new SslTestApplication(config); + try { + app.start(); + int statusCode = makeGetRequest(httpsUri + "/test", + clientKeystore.getAbsolutePath(), SSL_PASSWORD, SSL_PASSWORD); + assertEquals(EXPECTED_200_MSG, 200, statusCode); + assertMetricsCollected(); + + // verify reload -- override the server keystore with a wrong one + Files.copy(serverKeystoreErr.toPath(), serverKeystore.toPath(), StandardCopyOption.REPLACE_EXISTING); + Thread.sleep(CERT_RELOAD_WAIT_TIME); + boolean hitError = false; + try { + makeGetRequest(httpsUri + "/test", + clientKeystore.getAbsolutePath(), SSL_PASSWORD, SSL_PASSWORD); + } catch (Exception e) { + System.out.println(e); + hitError = true; + } + + // verify reload -- override the server keystore with a correct one + Files.copy(serverKeystoreBak.toPath(), serverKeystore.toPath(), StandardCopyOption.REPLACE_EXISTING); + Thread.sleep(CERT_RELOAD_WAIT_TIME); + statusCode = makeGetRequest(httpsUri + "/test", + clientKeystore.getAbsolutePath(), SSL_PASSWORD, SSL_PASSWORD); + assertEquals(EXPECTED_200_MSG, 200, statusCode); + assertEquals("expect hit error with new server cert", true, hitError); + } finally { + if (app != null) { + app.stop(); + } + } + } + @Test(expected = ClientProtocolException.class) public void testHttpsOnly() throws Exception { TestMetricsReporter.reset();