Skip to content

Commit

Permalink
fix: Fixes auth for forwarded requests for pull queries (#6895)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanConfluent authored Jan 27, 2021
1 parent 4ba542f commit c360c9c
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ PullQueryResult executePullQuery(
analysis
);
return routing.handlePullQuery(
serviceContext,
physicalPlan, statement, routingOptions, physicalPlan.getOutputSchema(),
physicalPlan.getQueryId());
} catch (final Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,28 @@ public final class HARouting implements AutoCloseable {

private final ExecutorService executorService;
private final RoutingFilterFactory routingFilterFactory;
private final ServiceContext serviceContext;
private final Optional<PullQueryExecutorMetrics> pullQueryMetrics;
private final RouteQuery routeQuery;

public HARouting(
final RoutingFilterFactory routingFilterFactory,
final ServiceContext serviceContext,
final Optional<PullQueryExecutorMetrics> pullQueryMetrics,
final KsqlConfig ksqlConfig
) {
this(routingFilterFactory, serviceContext, pullQueryMetrics, ksqlConfig,
this(routingFilterFactory, pullQueryMetrics, ksqlConfig,
HARouting::executeOrRouteQuery);
}


@VisibleForTesting
HARouting(
final RoutingFilterFactory routingFilterFactory,
final ServiceContext serviceContext,
final Optional<PullQueryExecutorMetrics> pullQueryMetrics,
final KsqlConfig ksqlConfig,
final RouteQuery routeQuery
) {
this.routingFilterFactory =
Objects.requireNonNull(routingFilterFactory, "routingFilterFactory");
this.serviceContext = Objects.requireNonNull(serviceContext, "serviceContext");
this.executorService = Executors.newFixedThreadPool(
ksqlConfig.getInt(KsqlConfig.KSQL_QUERY_PULL_THREAD_POOL_SIZE_CONFIG),
new ThreadFactoryBuilder().setNameFormat("pull-query-executor-%d").build());
Expand All @@ -99,6 +95,7 @@ public void close() {
}

public PullQueryResult handlePullQuery(
final ServiceContext serviceContext,
final PullPhysicalPlan pullPhysicalPlan,
final ConfiguredStatement<Query> statement,
final RoutingOptions routingOptions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ public void setUp() {
when(location4.getNodes()).thenReturn(ImmutableList.of(node2, node1));
when(ksqlConfig.getInt(KsqlConfig.KSQL_QUERY_PULL_THREAD_POOL_SIZE_CONFIG)).thenReturn(1);
haRouting = new HARouting(
routingFilterFactory, serviceContext, Optional.empty(), ksqlConfig, routeQuery);

routingFilterFactory, Optional.empty(), ksqlConfig, routeQuery);
}

@After
Expand Down Expand Up @@ -142,7 +141,8 @@ public void shouldCallRouteQuery_success() throws InterruptedException {
});

// When:
PullQueryResult result = haRouting.handlePullQuery(pullPhysicalPlan, statement, routingOptions, logicalSchema, queryId);
PullQueryResult result = haRouting.handlePullQuery(serviceContext, pullPhysicalPlan, statement,
routingOptions, logicalSchema, queryId);

// Then:
verify(routeQuery).routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), any());
Expand Down Expand Up @@ -192,7 +192,8 @@ public Object answer(InvocationOnMock invocation) {
});

// When:
PullQueryResult result = haRouting.handlePullQuery(pullPhysicalPlan, statement, routingOptions, logicalSchema, queryId);
PullQueryResult result = haRouting.handlePullQuery(serviceContext, pullPhysicalPlan, statement,
routingOptions, logicalSchema, queryId);

// Then:
verify(routeQuery).routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), any());
Expand Down Expand Up @@ -244,7 +245,8 @@ public Object answer(InvocationOnMock invocation) {
// When:
final Exception e = assertThrows(
MaterializationException.class,
() -> haRouting.handlePullQuery(pullPhysicalPlan, statement, routingOptions, logicalSchema, queryId)
() -> haRouting.handlePullQuery(serviceContext, pullPhysicalPlan, statement, routingOptions,
logicalSchema, queryId)
);

// Then:
Expand Down Expand Up @@ -277,7 +279,8 @@ public void shouldCallRouteQuery_allFiltered() {
// When:
final Exception e = assertThrows(
MaterializationException.class,
() -> haRouting.handlePullQuery(pullPhysicalPlan, statement, routingOptions, logicalSchema, queryId)
() -> haRouting.handlePullQuery(serviceContext, pullPhysicalPlan, statement, routingOptions,
logicalSchema, queryId)
);

// Then:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ static KsqlRestApplication buildApplication(


final HARouting pullQueryRouting = new HARouting(
routingFilterFactory, serviceContext, pullQueryMetrics, ksqlConfig);
routingFilterFactory, pullQueryMetrics, ksqlConfig);

final Optional<LocalCommands> localCommands = createLocalCommands(restConfig, ksqlEngine);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,17 @@ static void waitForClusterToBeDiscovered(

static void waitForStreamsMetadataToInitialize(
final TestKsqlRestApp restApp, List<KsqlHostInfoEntity> hosts, String queryId
) {
waitForStreamsMetadataToInitialize(restApp, hosts, queryId, Optional.empty());
}

static void waitForStreamsMetadataToInitialize(
final TestKsqlRestApp restApp, List<KsqlHostInfoEntity> hosts, String queryId,
final Optional<BasicCredentials> credentials
) {
while (true) {
ClusterStatusResponse clusterStatusResponse = HighAvailabilityTestUtil.sendClusterStatusRequest(restApp);
ClusterStatusResponse clusterStatusResponse
= HighAvailabilityTestUtil.sendClusterStatusRequest(restApp, credentials);
List<KsqlHostInfoEntity> initialized = hosts.stream()
.filter(hostInfo -> Optional.ofNullable(
clusterStatusResponse
Expand Down Expand Up @@ -392,11 +400,24 @@ public static void makeAdminRequest(TestKsqlRestApp restApp, final String sql) {
RestIntegrationTestUtil.makeKsqlRequest(restApp, sql, Optional.empty());
}

public static void makeAdminRequest(
final TestKsqlRestApp restApp,
final String sql,
final Optional<BasicCredentials> userCreds
) {
RestIntegrationTestUtil.makeKsqlRequest(restApp, sql, userCreds);
}

public static List<KsqlEntity> makeAdminRequestWithResponse(
TestKsqlRestApp restApp, final String sql) {
return RestIntegrationTestUtil.makeKsqlRequest(restApp, sql, Optional.empty());
}

public static List<KsqlEntity> makeAdminRequestWithResponse(
TestKsqlRestApp restApp, final String sql, final Optional<BasicCredentials> userCreds) {
return RestIntegrationTestUtil.makeKsqlRequest(restApp, sql, userCreds);
}

public static List<StreamedRow> makePullQueryRequest(
final TestKsqlRestApp target,
final String sql
Expand All @@ -412,5 +433,15 @@ public static List<StreamedRow> makePullQueryRequest(
return RestIntegrationTestUtil.makeQueryRequest(target, sql, Optional.empty(),
properties, ImmutableMap.of(KsqlRequestConfig.KSQL_DEBUG_REQUEST, true));
}

public static List<StreamedRow> makePullQueryRequest(
final TestKsqlRestApp target,
final String sql,
final Map<String, ?> properties,
final Optional<BasicCredentials> userCreds
) {
return RestIntegrationTestUtil.makeQueryRequest(target, sql, userCreds,
properties, ImmutableMap.of(KsqlRequestConfig.KSQL_DEBUG_REQUEST, true));
}
}

Loading

0 comments on commit c360c9c

Please sign in to comment.