diff --git a/ksql-engine/src/main/java/io/confluent/ksql/schema/registry/KsqlSchemaRegistryClientFactory.java b/ksql-engine/src/main/java/io/confluent/ksql/schema/registry/KsqlSchemaRegistryClientFactory.java index fc433bd28e69..e34617d7cb93 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/schema/registry/KsqlSchemaRegistryClientFactory.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/schema/registry/KsqlSchemaRegistryClientFactory.java @@ -15,6 +15,7 @@ package io.confluent.ksql.schema.registry; +import com.google.common.annotations.VisibleForTesting; import io.confluent.kafka.schemaregistry.client.CachedSchemaRegistryClient; import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; import io.confluent.kafka.schemaregistry.client.rest.RestService; @@ -43,14 +44,21 @@ CachedSchemaRegistryClient create(RestService service, Map httpHeaders); } + public KsqlSchemaRegistryClientFactory( + final KsqlConfig config, + final Map schemaRegistryHttpHeaders + ) { + this(config, newSchemaRegistrySslFactory(config), schemaRegistryHttpHeaders); + } public KsqlSchemaRegistryClientFactory( final KsqlConfig config, + final SslFactory sslFactory, final Map schemaRegistryHttpHeaders ) { this(config, () -> new RestService(config.getString(KsqlConfig.SCHEMA_REGISTRY_URL_PROPERTY)), - new SslFactory(Mode.CLIENT), + sslFactory, CachedSchemaRegistryClient::new, schemaRegistryHttpHeaders ); @@ -59,6 +67,7 @@ public KsqlSchemaRegistryClientFactory( config.getString(KsqlConfig.SCHEMA_REGISTRY_URL_PROPERTY); } + @VisibleForTesting KsqlSchemaRegistryClientFactory(final KsqlConfig config, final Supplier serviceSupplier, final SslFactory sslFactory, @@ -69,13 +78,25 @@ public KsqlSchemaRegistryClientFactory( this.schemaRegistryClientConfigs = config.originalsWithPrefix( KsqlConfig.KSQL_SCHEMA_REGISTRY_PREFIX); - this.sslFactory - .configure(config.valuesWithPrefixOverride(KsqlConfig.KSQL_SCHEMA_REGISTRY_PREFIX)); - this.schemaRegistryClientFactory = schemaRegistryClientFactory; this.httpHeaders = httpHeaders; } + /** + * Creates an SslFactory configured to be used with the KsqlSchemaRegistryClient. + */ + public static SslFactory newSchemaRegistrySslFactory(final KsqlConfig config) { + final SslFactory sslFactory = new SslFactory(Mode.CLIENT); + configureSslFactory(config, sslFactory); + return sslFactory; + } + + @VisibleForTesting + static void configureSslFactory(final KsqlConfig config, final SslFactory sslFactory) { + sslFactory + .configure(config.valuesWithPrefixOverride(KsqlConfig.KSQL_SCHEMA_REGISTRY_PREFIX)); + } + public SchemaRegistryClient get() { final RestService restService = serviceSupplier.get(); final SSLContext sslContext = sslFactory.sslEngineBuilder().sslContext(); diff --git a/ksql-engine/src/test/java/io/confluent/ksql/schema/registry/KsqlSchemaRegistryClientFactoryTest.java b/ksql-engine/src/test/java/io/confluent/ksql/schema/registry/KsqlSchemaRegistryClientFactoryTest.java index 4172235dd559..9a2be33e54a2 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/schema/registry/KsqlSchemaRegistryClientFactoryTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/schema/registry/KsqlSchemaRegistryClientFactoryTest.java @@ -88,14 +88,10 @@ public void shouldSetSocketFactoryWhenNoSpecificSslConfig() { final Map expectedConfigs = defaultConfigs(); // When: - final SchemaRegistryClient client = - new KsqlSchemaRegistryClientFactory(config, restServiceSupplier, sslFactory, - srClientFactory, Collections.emptyMap()).get(); + KsqlSchemaRegistryClientFactory.configureSslFactory(config, sslFactory); // Then: - assertThat(client, is(notNullValue())); verify(sslFactory).configure(expectedConfigs); - verify(restService).setSslSocketFactory(isA(SSL_CONTEXT.getSocketFactory().getClass())); } @Test @@ -109,14 +105,10 @@ public void shouldPickUpNonPrefixedSslConfig() { expectedConfigs.put(SslConfigs.SSL_PROTOCOL_CONFIG, "SSLv3"); // When: - final SchemaRegistryClient client = - new KsqlSchemaRegistryClientFactory(config, restServiceSupplier, sslFactory, - srClientFactory, Collections.emptyMap()).get(); + KsqlSchemaRegistryClientFactory.configureSslFactory(config, sslFactory); // Then: - assertThat(client, is(notNullValue())); verify(sslFactory).configure(expectedConfigs); - verify(restService).setSslSocketFactory(isA(SSL_CONTEXT.getSocketFactory().getClass())); } @Test @@ -130,15 +122,11 @@ public void shouldPickUpPrefixedSslConfig() { expectedConfigs.put(SslConfigs.SSL_PROTOCOL_CONFIG, "SSLv3"); // When: - final SchemaRegistryClient client = - new KsqlSchemaRegistryClientFactory(config, restServiceSupplier, sslFactory, - srClientFactory, Collections.emptyMap()).get(); + KsqlSchemaRegistryClientFactory.configureSslFactory(config, sslFactory); // Then: - assertThat(client, is(notNullValue())); verify(sslFactory).configure(expectedConfigs); - verify(restService).setSslSocketFactory(isA(SSL_CONTEXT.getSocketFactory().getClass())); } @Test @@ -160,6 +148,7 @@ public void shouldPassBasicAuthCredentialsToSchemaRegistryClient() { config, restServiceSupplier, sslFactory, srClientFactory, Collections.emptyMap()).get(); // Then: + verify(restService).setSslSocketFactory(isA(SSL_CONTEXT.getSocketFactory().getClass())); srClientFactory.create(same(restService), anyInt(), eq(expectedConfigs), any()); } diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java index 5759222d6429..4a96777fed9b 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/KsqlRestApplication.java @@ -27,6 +27,7 @@ import com.google.common.util.concurrent.ListeningScheduledExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; import io.confluent.ksql.ServiceInfo; import io.confluent.ksql.engine.KsqlEngine; import io.confluent.ksql.function.InternalFunctionRegistry; @@ -69,6 +70,7 @@ import io.confluent.ksql.rest.util.KsqlUncaughtExceptionHandler; import io.confluent.ksql.rest.util.ProcessingLogServerUtils; import io.confluent.ksql.rest.util.RocksDBConfigSetterHandler; +import io.confluent.ksql.schema.registry.KsqlSchemaRegistryClientFactory; import io.confluent.ksql.security.KsqlAuthorizationValidator; import io.confluent.ksql.security.KsqlAuthorizationValidatorFactory; import io.confluent.ksql.security.KsqlDefaultSecurityExtension; @@ -96,6 +98,7 @@ import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; import java.util.List; @@ -444,7 +447,8 @@ public T getEndpointInstance(final Class endpointClass) { authorizationValidator, errorHandler, securityExtension, - serverState + serverState, + serviceContext.getSchemaRegistryClientFactory() ); } }) @@ -460,8 +464,11 @@ static KsqlRestApplication buildApplication( final Function, VersionCheckerAgent> versionCheckerFactory ) { final KsqlConfig ksqlConfig = new KsqlConfig(restConfig.getKsqlConfigProperties()); + final Supplier schemaRegistryClientFactory = + new KsqlSchemaRegistryClientFactory(ksqlConfig, Collections.emptyMap())::get; final ServiceContext serviceContext = new LazyServiceContext(() -> - RestServiceContextFactory.create(ksqlConfig, Optional.empty())); + RestServiceContextFactory.create(ksqlConfig, Optional.empty(), + schemaRegistryClientFactory)); return buildApplication( "", @@ -469,7 +476,8 @@ static KsqlRestApplication buildApplication( versionCheckerFactory, Integer.MAX_VALUE, serviceContext, - KsqlSecurityContextBinder::new); + (config, securityExtension) -> + new KsqlSecurityContextBinder(config, securityExtension, schemaRegistryClientFactory)); } static KsqlRestApplication buildApplication( @@ -478,8 +486,7 @@ static KsqlRestApplication buildApplication( final Function, VersionCheckerAgent> versionCheckerFactory, final int maxStatementRetries, final ServiceContext serviceContext, - final BiFunction serviceContextBinderFactory - ) { + final BiFunction serviceContextBinderFactory) { final String ksqlInstallDir = restConfig.getString(KsqlRestConfig.INSTALL_DIR_CONFIG); final KsqlConfig ksqlConfig = new KsqlConfig(restConfig.getKsqlConfigProperties()); diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/context/KsqlSecurityContextBinder.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/context/KsqlSecurityContextBinder.java index 1692b14fe1f5..e5c3f6d0d164 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/context/KsqlSecurityContextBinder.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/context/KsqlSecurityContextBinder.java @@ -15,9 +15,11 @@ package io.confluent.ksql.rest.server.context; +import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; import io.confluent.ksql.security.KsqlSecurityContext; import io.confluent.ksql.security.KsqlSecurityExtension; import io.confluent.ksql.util.KsqlConfig; +import java.util.function.Supplier; import org.glassfish.hk2.utilities.binding.AbstractBinder; import org.glassfish.jersey.process.internal.RequestScoped; @@ -31,9 +33,11 @@ public class KsqlSecurityContextBinder extends AbstractBinder { public KsqlSecurityContextBinder( final KsqlConfig ksqlConfig, - final KsqlSecurityExtension securityExtension + final KsqlSecurityExtension securityExtension, + final Supplier schemaRegistryClientFactory ) { - KsqlSecurityContextBinderFactory.configure(ksqlConfig, securityExtension); + KsqlSecurityContextBinderFactory.configure(ksqlConfig, securityExtension, + schemaRegistryClientFactory); } @Override diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/context/KsqlSecurityContextBinderFactory.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/context/KsqlSecurityContextBinderFactory.java index 79c30795b607..04a98a137fef 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/context/KsqlSecurityContextBinderFactory.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/context/KsqlSecurityContextBinderFactory.java @@ -18,6 +18,7 @@ import static java.util.Objects.requireNonNull; import com.google.common.annotations.VisibleForTesting; +import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; import io.confluent.ksql.rest.server.services.RestServiceContextFactory; import io.confluent.ksql.rest.server.services.RestServiceContextFactory.DefaultServiceContextFactory; import io.confluent.ksql.rest.server.services.RestServiceContextFactory.UserServiceContextFactory; @@ -26,6 +27,7 @@ import io.confluent.ksql.util.KsqlConfig; import java.security.Principal; import java.util.Optional; +import java.util.function.Supplier; import javax.inject.Inject; import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.HttpHeaders; @@ -39,14 +41,18 @@ public class KsqlSecurityContextBinderFactory implements Factory { private static KsqlConfig ksqlConfig; private static KsqlSecurityExtension securityExtension; + private static Supplier schemaRegistryClientFactory; public static void configure( final KsqlConfig ksqlConfig, - final KsqlSecurityExtension securityExtension + final KsqlSecurityExtension securityExtension, + final Supplier schemaRegistryClientFactory ) { KsqlSecurityContextBinderFactory.ksqlConfig = requireNonNull(ksqlConfig, "ksqlConfig"); KsqlSecurityContextBinderFactory.securityExtension = requireNonNull(securityExtension, "securityExtension"); + KsqlSecurityContextBinderFactory.schemaRegistryClientFactory + = requireNonNull(schemaRegistryClientFactory, "schemaRegistryClientFactory"); } private final SecurityContext securityContext; @@ -91,7 +97,7 @@ public KsqlSecurityContext provide() { if (!securityExtension.getUserContextProvider().isPresent()) { return new KsqlSecurityContext( Optional.ofNullable(principal), - defaultServiceContextFactory.create(ksqlConfig, authHeader) + defaultServiceContextFactory.create(ksqlConfig, authHeader, schemaRegistryClientFactory) ); } diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpoint.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpoint.java index bfd46f281f80..04a088e1c5fc 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpoint.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpoint.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Iterables; import com.google.common.util.concurrent.ListeningScheduledExecutorService; +import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; import io.confluent.ksql.engine.KsqlEngine; import io.confluent.ksql.parser.KsqlParser.PreparedStatement; import io.confluent.ksql.parser.tree.PrintTopic; @@ -53,6 +54,7 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeoutException; +import java.util.function.Supplier; import javax.websocket.CloseReason; import javax.websocket.CloseReason.CloseCodes; import javax.websocket.EndpointConfig; @@ -97,6 +99,7 @@ public class WSQueryEndpoint { private final DefaultServiceContextFactory defaultServiceContextFactory; private final ServerState serverState; private final Errors errorHandler; + private final Supplier schemaRegistryClientFactory; private WebSocketSubscriber subscriber; private KsqlSecurityContext securityContext; @@ -115,7 +118,8 @@ public WSQueryEndpoint( final Optional authorizationValidator, final Errors errorHandler, final KsqlSecurityExtension securityExtension, - final ServerState serverState + final ServerState serverState, + final Supplier schemaRegistryClientFactory ) { this(ksqlConfig, mapper, @@ -133,7 +137,8 @@ public WSQueryEndpoint( securityExtension, RestServiceContextFactory::create, RestServiceContextFactory::create, - serverState); + serverState, + schemaRegistryClientFactory); } // CHECKSTYLE_RULES.OFF: ParameterNumberCheck @@ -155,7 +160,8 @@ public WSQueryEndpoint( final KsqlSecurityExtension securityExtension, final UserServiceContextFactory serviceContextFactory, final DefaultServiceContextFactory defaultServiceContextFactory, - final ServerState serverState + final ServerState serverState, + final Supplier schemaRegistryClientFactory ) { this.ksqlConfig = Objects.requireNonNull(ksqlConfig, "ksqlConfig"); this.mapper = Objects.requireNonNull(mapper, "mapper"); @@ -179,7 +185,9 @@ public WSQueryEndpoint( this.defaultServiceContextFactory = Objects.requireNonNull(defaultServiceContextFactory, "defaultServiceContextFactory"); this.serverState = Objects.requireNonNull(serverState, "serverState"); - this.errorHandler = Objects.requireNonNull(errorHandler, "errorHandler");; + this.errorHandler = Objects.requireNonNull(errorHandler, "errorHandler"); + this.schemaRegistryClientFactory = + Objects.requireNonNull(schemaRegistryClientFactory, "schemaRegistryClientFactory"); } @SuppressWarnings("unused") @@ -288,7 +296,8 @@ private KsqlSecurityContext createSecurityContext(final Principal principal) { final ServiceContext serviceContext; if (!securityExtension.getUserContextProvider().isPresent()) { - serviceContext = defaultServiceContextFactory.create(ksqlConfig, Optional.empty()); + serviceContext = defaultServiceContextFactory.create(ksqlConfig, Optional.empty(), + schemaRegistryClientFactory); } else { // Creates a ServiceContext using the user's credentials, so the WS query topics are // accessed with the user permission context (defaults to KSQL service context) diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/services/RestServiceContextFactory.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/services/RestServiceContextFactory.java index 2a9d5b89c216..fa108450ac00 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/services/RestServiceContextFactory.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/services/RestServiceContextFactory.java @@ -16,12 +16,10 @@ package io.confluent.ksql.rest.server.services; import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; -import io.confluent.ksql.schema.registry.KsqlSchemaRegistryClientFactory; import io.confluent.ksql.services.DefaultConnectClient; import io.confluent.ksql.services.ServiceContext; import io.confluent.ksql.services.ServiceContextFactory; import io.confluent.ksql.util.KsqlConfig; -import java.util.Collections; import java.util.Optional; import java.util.function.Supplier; import org.apache.kafka.streams.KafkaClientSupplier; @@ -36,7 +34,8 @@ public interface DefaultServiceContextFactory { ServiceContext create( KsqlConfig config, - Optional authHeader + Optional authHeader, + Supplier srClientFactory ); } @@ -52,13 +51,14 @@ ServiceContext create( public static ServiceContext create( final KsqlConfig ksqlConfig, - final Optional authHeader + final Optional authHeader, + final Supplier schemaRegistryClientFactory ) { return create( ksqlConfig, authHeader, new DefaultKafkaClientSupplier(), - new KsqlSchemaRegistryClientFactory(ksqlConfig, Collections.emptyMap())::get + schemaRegistryClientFactory ); } diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/KsqlRestApplicationTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/KsqlRestApplicationTest.java index 3e6bb79664f9..889711990e59 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/KsqlRestApplicationTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/KsqlRestApplicationTest.java @@ -31,6 +31,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; import io.confluent.ksql.engine.KsqlEngine; import io.confluent.ksql.logging.processing.ProcessingLogConfig; import io.confluent.ksql.logging.processing.ProcessingLogContext; @@ -62,6 +63,7 @@ import java.util.Optional; import java.util.Queue; import java.util.function.Consumer; +import java.util.function.Supplier; import javax.ws.rs.core.Configurable; import org.apache.kafka.streams.StreamsConfig; import org.junit.Before; @@ -125,6 +127,10 @@ public class KsqlRestApplicationTest { @Mock private Consumer rocksDBConfigSetterHandler; + @Mock + private SchemaRegistryClient schemaRegistryClient; + + private Supplier schemaRegistryClientFactory; private String logCreateStatement; private KsqlRestApplication app; private KsqlRestConfig restConfig; @@ -136,6 +142,7 @@ public class KsqlRestApplicationTest { @SuppressWarnings("unchecked") @Before public void setUp() { + schemaRegistryClientFactory = () -> schemaRegistryClient; when(processingLogConfig.getBoolean(ProcessingLogConfig.STREAM_AUTO_CREATE)) .thenReturn(true); when(processingLogConfig.getString(ProcessingLogConfig.STREAM_NAME)) @@ -417,7 +424,8 @@ private void givenAppWithRestConfig(final Map restConfigMap) { streamedQueryResource, ksqlResource, versionCheckerAgent, - KsqlSecurityContextBinder::new, + (config, securityExtension) -> + new KsqlSecurityContextBinder(config, securityExtension, schemaRegistryClientFactory), securityExtension, serverState, processingLogContext, diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/TestKsqlRestApp.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/TestKsqlRestApp.java index d42dad375290..7c2dd5cbafbd 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/TestKsqlRestApp.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/TestKsqlRestApp.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; import io.confluent.ksql.KsqlExecutionContext; import io.confluent.ksql.json.JsonMapper; import io.confluent.ksql.query.QueryId; @@ -37,7 +38,9 @@ import io.confluent.ksql.rest.entity.StreamsList; import io.confluent.ksql.rest.entity.TablesList; import io.confluent.ksql.rest.server.context.KsqlSecurityContextBinder; +import io.confluent.ksql.rest.server.services.RestServiceContextFactory; import io.confluent.ksql.rest.util.KsqlInternalTopicUtils; +import io.confluent.ksql.schema.registry.KsqlSchemaRegistryClientFactory; import io.confluent.ksql.security.KsqlSecurityContext; import io.confluent.ksql.security.KsqlSecurityExtension; import io.confluent.ksql.services.DisabledKsqlClient; @@ -432,8 +435,7 @@ public static final class Builder { private final Map additionalProps = new HashMap<>(); private Supplier serviceContext; - private BiFunction securityContextBinder - = KsqlSecurityContextBinder::new; + private BiFunction securityContextBinder; private Optional credentials = Optional.empty(); @@ -441,6 +443,9 @@ private Builder(final Supplier bootstrapServers) { this.bootstrapServers = requireNonNull(bootstrapServers, "bootstrapServers"); this.serviceContext = () -> defaultServiceContext(bootstrapServers, buildBaseConfig(additionalProps)); + this.securityContextBinder = (config, securityExtension) -> + new KsqlSecurityContextBinder(config, securityExtension, + new KsqlSchemaRegistryClientFactory(config, Collections.emptyMap())::get); } @SuppressWarnings("unused") // Part of public API diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/context/KsqlSecurityContextBinderFactoryTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/context/KsqlSecurityContextBinderFactoryTest.java index 4c4c25dbd218..c6814cc54d97 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/context/KsqlSecurityContextBinderFactoryTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/context/KsqlSecurityContextBinderFactoryTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; import io.confluent.ksql.rest.server.services.RestServiceContextFactory.DefaultServiceContextFactory; import io.confluent.ksql.rest.server.services.RestServiceContextFactory.UserServiceContextFactory; import io.confluent.ksql.security.KsqlSecurityContext; @@ -64,10 +65,13 @@ public class KsqlSecurityContextBinderFactoryTest { private ServiceContext userServiceContext; @Mock private HttpServletRequest request; + @Mock + private SchemaRegistryClient schemaRegistryClient; @Before public void setUp() { - KsqlSecurityContextBinderFactory.configure(ksqlConfig, securityExtension); + KsqlSecurityContextBinderFactory.configure(ksqlConfig, securityExtension, + () -> schemaRegistryClient); securityContextBinderFactory = new KsqlSecurityContextBinderFactory( securityContext, request, @@ -76,7 +80,8 @@ public void setUp() { ); when(securityContext.getUserPrincipal()).thenReturn(user1); - when(defaultServiceContextProvider.create(any(), any())).thenReturn(defaultServiceContext); + when(defaultServiceContextProvider.create(any(), any(), any())) + .thenReturn(defaultServiceContext); when(userServiceContextFactory.create(any(), any(), any(), any())) .thenReturn(userServiceContext); } @@ -91,7 +96,6 @@ public void shouldCreateDefaultServiceContextIfUserContextProviderIsNotEnabled() final KsqlSecurityContext ksqlSecurityContext = securityContextBinderFactory.provide(); // Then: - verify(defaultServiceContextProvider).create(ksqlConfig, Optional.empty()); assertThat(ksqlSecurityContext.getUserPrincipal(), is(Optional.empty())); assertThat(ksqlSecurityContext.getServiceContext(), is(defaultServiceContext)); } @@ -120,7 +124,7 @@ public void shouldPassAuthHeaderToDefaultFactory() { securityContextBinderFactory.provide(); // Then: - verify(defaultServiceContextProvider).create(any(), eq(Optional.of("some-auth"))); + verify(defaultServiceContextProvider).create(any(), eq(Optional.of("some-auth")), any()); } @Test diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpointTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpointTest.java index d479bb094260..641572587a3b 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpointTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpointTest.java @@ -59,7 +59,6 @@ import io.confluent.ksql.rest.server.state.ServerState; import io.confluent.ksql.security.KsqlAuthorizationProvider; import io.confluent.ksql.security.KsqlAuthorizationValidator; -import io.confluent.ksql.security.KsqlSecurityContext; import io.confluent.ksql.security.KsqlSecurityExtension; import io.confluent.ksql.security.KsqlUserContextProvider; import io.confluent.ksql.services.ConfiguredKafkaClientSupplier; @@ -185,7 +184,6 @@ public void setUp() { when(securityExtension.getAuthorizationProvider()) .thenReturn(Optional.of(authorizationProvider)); when(serviceContextFactory.create(any(), any(), any(), any())).thenReturn(serviceContext); - when(defaultServiceContextProvider.create(any(), any())).thenReturn(serviceContext); when(serviceContext.getTopicClient()).thenReturn(topicClient); when(serverState.checkReady()).thenReturn(Optional.empty()); when(ksqlEngine.getMetaStore()).thenReturn(metaStore); @@ -208,7 +206,8 @@ public void setUp() { securityExtension, serviceContextFactory, defaultServiceContextProvider, - serverState + serverState, + schemaRegistryClientSupplier ); } @@ -490,7 +489,8 @@ public void shouldCreateDefaultServiceContextIfUserContextProviderIsNotEnabled() wsQueryEndpoint.onOpen(session, null); // Then: - verify(defaultServiceContextProvider).create(ksqlConfig, Optional.empty()); + verify(defaultServiceContextProvider).create(ksqlConfig, Optional.empty(), + schemaRegistryClientSupplier); verifyZeroInteractions(userContextProvider); }