Skip to content

Commit

Permalink
Ensure queries returned via REST API are redacted
Browse files Browse the repository at this point in the history
@JsonConstructor for TrimmedBasicQueryInfo was introduced to facilitate
the deserialization of server responses in tests.
  • Loading branch information
piotrrzysko committed Aug 22, 2024
1 parent e915378 commit c444c3e
Show file tree
Hide file tree
Showing 5 changed files with 439 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.server.ui;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.errorprone.annotations.Immutable;
import io.trino.execution.QueryState;
Expand Down Expand Up @@ -53,6 +54,43 @@ public class TrimmedBasicQueryInfo
private final Optional<QueryType> queryType;
private final RetryPolicy retryPolicy;

@JsonCreator
public TrimmedBasicQueryInfo(
@JsonProperty("queryId") QueryId queryId,
@JsonProperty("sessionUser") String sessionUser,
@JsonProperty("sessionPrincipal") Optional<String> sessionPrincipal,
@JsonProperty("sessionSource") Optional<String> sessionSource,
@JsonProperty("resourceGroupId") Optional<ResourceGroupId> resourceGroupId,
@JsonProperty("state") QueryState state,
@JsonProperty("scheduled") boolean scheduled,
@JsonProperty("self") URI self,
@JsonProperty("queryTextPreview") String queryTextPreview,
@JsonProperty("updateType") Optional<String> updateType,
@JsonProperty("preparedQuery") Optional<String> preparedQuery,
@JsonProperty("queryStats") BasicQueryStats queryStats,
@JsonProperty("errorType") Optional<ErrorType> errorType,
@JsonProperty("errorCode") Optional<ErrorCode> errorCode,
@JsonProperty("queryType") Optional<QueryType> queryType,
@JsonProperty("retryPolicy") RetryPolicy retryPolicy)
{
this.queryId = requireNonNull(queryId, "queryId is null");
this.sessionUser = requireNonNull(sessionUser, "sessionUser is null");
this.sessionPrincipal = requireNonNull(sessionPrincipal, "sessionPrincipal is null");
this.sessionSource = requireNonNull(sessionSource, "sessionSource is null");
this.resourceGroupId = requireNonNull(resourceGroupId, "resourceGroupId is null");
this.state = requireNonNull(state, "state is null");
this.scheduled = scheduled;
this.self = requireNonNull(self, "self is null");
this.queryTextPreview = requireNonNull(queryTextPreview, "queryTextPreview is null");
this.updateType = requireNonNull(updateType, "updateType is null");
this.preparedQuery = requireNonNull(preparedQuery, "preparedQuery is null");
this.queryStats = requireNonNull(queryStats, "queryStats is null");
this.errorType = requireNonNull(errorType, "errorType is null");
this.errorCode = requireNonNull(errorCode, "errorCode is null");
this.queryType = requireNonNull(queryType, "queryType is null");
this.retryPolicy = requireNonNull(retryPolicy, "retryPolicy is null");
}

public TrimmedBasicQueryInfo(BasicQueryInfo queryInfo)
{
this.queryId = requireNonNull(queryInfo.getQueryId(), "queryId is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.server;

import com.google.common.collect.ImmutableSet;
import com.google.inject.Key;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.HttpUriBuilder;
Expand All @@ -25,6 +26,8 @@
import io.opentelemetry.api.OpenTelemetry;
import io.opentelemetry.api.trace.Span;
import io.trino.client.QueryResults;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorPlugin;
import io.trino.execution.QueryInfo;
import io.trino.plugin.tpch.TpchPlugin;
import io.trino.server.testing.TestingTrinoServer;
Expand Down Expand Up @@ -61,6 +64,7 @@
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.KILL_QUERY;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.VIEW_QUERY;
import static io.trino.testing.TestingAccessControlManager.privilege;
import static io.trino.testing.TestingNames.randomNameSuffix;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand All @@ -84,6 +88,9 @@ public void setup()
{
client = new JettyHttpClient();
server = TestingTrinoServer.create();
server.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder()
.withSecuritySensitivePropertyNames(ImmutableSet.of("password"))
.build()));
server.installPlugin(new TpchPlugin());
server.createCatalog("tpch", "tpch");
}
Expand Down Expand Up @@ -217,6 +224,47 @@ public void testGetQueryInfoExecutionFailure()
assertThat(info.getFailureInfo().getErrorCode()).isEqualTo(DIVISION_BY_ZERO.toErrorCode());
}

@Test
public void testGetQueryInfosWithRedactedSecrets()
{
String catalog = "catalog_" + randomNameSuffix();
runToCompletion("""
CREATE CATALOG %s USING mock
WITH (
"user" = 'bob',
"password" = '1234'
)""".formatted(catalog));

List<BasicQueryInfo> infos = getQueryInfos("/v1/query");
assertThat(infos.size()).isEqualTo(1);
assertThat(infos.getFirst().getQuery()).isEqualTo("""
CREATE CATALOG %s USING mock
WITH (
"user" = 'bob',
"password" = '***'
)""".formatted(catalog));
}

@Test
public void testGetQueryInfoWithRedactedSecrets()
{
String catalog = "catalog_" + randomNameSuffix();
String queryId = runToCompletion("""
CREATE CATALOG %s USING mock
WITH (
"user" = 'bob',
"password" = '1234'
)""".formatted(catalog));

QueryInfo queryInfo = getQueryInfo(queryId);
assertThat(queryInfo.getQuery()).isEqualTo("""
CREATE CATALOG %s USING mock
WITH (
"user" = 'bob',
"password" = '***'
)""".formatted(catalog));
}

@Test
public void testCancel()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.server;

import com.google.common.collect.ImmutableSet;
import com.google.common.io.Closer;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.Request;
Expand All @@ -21,6 +22,8 @@
import io.airlift.json.JsonCodec;
import io.airlift.units.Duration;
import io.trino.client.QueryResults;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorPlugin;
import io.trino.plugin.tpch.TpchPlugin;
import io.trino.server.testing.TestingTrinoServer;
import io.trino.spi.ErrorCode;
Expand All @@ -43,6 +46,7 @@
import static io.airlift.json.JsonCodec.listJsonCodec;
import static io.trino.client.ProtocolHeaders.TRINO_HEADERS;
import static io.trino.execution.QueryState.FAILED;
import static io.trino.execution.QueryState.FINISHING;
import static io.trino.execution.QueryState.RUNNING;
import static io.trino.server.TestQueryResource.BASIC_QUERY_INFO_CODEC;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.VIEW_QUERY;
Expand All @@ -65,11 +69,15 @@ public class TestQueryStateInfoResource
private TestingTrinoServer server;
private HttpClient client;
private QueryResults queryResults;
private QueryResults createCatalogResults;

@BeforeAll
public void setUp()
{
server = TestingTrinoServer.create();
server.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder()
.withSecuritySensitivePropertyNames(ImmutableSet.of("password"))
.build()));
server.installPlugin(new TpchPlugin());
server.createCatalog("tpch", "tpch");
client = new JettyHttpClient();
Expand All @@ -90,6 +98,19 @@ public void setUp()
QueryResults queryResults2 = client.execute(request2, createJsonResponseHandler(jsonCodec(QueryResults.class)));
client.execute(prepareGet().setUri(queryResults2.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC));

Request createCatalogRequest = preparePost()
.setUri(uriBuilderFrom(server.getBaseUrl()).replacePath("/v1/statement").build())
.setBodyGenerator(createStaticBodyGenerator("""
CREATE CATALOG test_catalog USING mock
WITH (
"user" = 'bob',
"password" = '1234'
)""", UTF_8))
.setHeader(TRINO_HEADERS.requestUser(), "catalogCreator")
.build();
createCatalogResults = client.execute(createCatalogRequest, createJsonResponseHandler(jsonCodec(QueryResults.class)));
client.execute(prepareGet().setUri(createCatalogResults.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC));

// queries are started in the background, so they may not all be immediately visible
long start = System.nanoTime();
while (Duration.nanosSince(start).compareTo(new Duration(5, MINUTES)) < 0) {
Expand All @@ -99,8 +120,8 @@ public void setUp()
.setHeader(TRINO_HEADERS.requestUser(), "unknown")
.build(),
createJsonResponseHandler(BASIC_QUERY_INFO_CODEC));
if (queryInfos.size() == 2) {
if (queryInfos.stream().allMatch(info -> info.getState() == RUNNING)) {
if (queryInfos.size() == 3) {
if (queryInfos.stream().allMatch(info -> info.getState() == RUNNING || info.getState() == FINISHING)) {
break;
}

Expand Down Expand Up @@ -137,7 +158,12 @@ public void testGetAllQueryStateInfos()
.build(),
createJsonResponseHandler(listJsonCodec(QueryStateInfo.class)));

assertThat(infos.size()).isEqualTo(2);
assertThat(infos.size()).isEqualTo(3);
QueryStateInfo createCatalogInfo = infos.stream()
.filter(info -> info.getQueryId().getId().equals(createCatalogResults.getId()))
.findFirst()
.orElse(null);
assertCreateCatalogQueryIsRedacted(createCatalogInfo);
}

@Test
Expand Down Expand Up @@ -179,6 +205,19 @@ public void testGetQueryStateInfo()
assertThat(info).isNotNull();
}

@Test
public void testGetQueryStateInfoWithRedactedSecrets()
{
QueryStateInfo info = client.execute(
prepareGet()
.setUri(server.resolve("/v1/queryState/" + createCatalogResults.getId()))
.setHeader(TRINO_HEADERS.requestUser(), "unknown")
.build(),
createJsonResponseHandler(jsonCodec(QueryStateInfo.class)));

assertCreateCatalogQueryIsRedacted(info);
}

@Test
public void testGetAllQueryStateInfosDenied()
{
Expand All @@ -188,7 +227,7 @@ public void testGetAllQueryStateInfosDenied()
.setHeader(TRINO_HEADERS.requestUser(), "any-other-user")
.build(),
createJsonResponseHandler(listJsonCodec(QueryStateInfo.class)));
assertThat(infos.size()).isEqualTo(2);
assertThat(infos.size()).isEqualTo(3);

testGetAllQueryStateInfosDenied("user1", 1);
testGetAllQueryStateInfosDenied("any-other-user", 0);
Expand Down Expand Up @@ -243,4 +282,15 @@ public void testGetQueryStateInfoNo()
.isInstanceOf(UnexpectedResponseException.class)
.hasMessageMatching("Expected response code .*, but was 404");
}

private static void assertCreateCatalogQueryIsRedacted(QueryStateInfo info)
{
assertThat(info).isNotNull();
assertThat(info.getQuery()).isEqualTo("""
CREATE CATALOG test_catalog USING mock
WITH (
"user" = 'bob',
"password" = '***'
)""");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.server;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.Request;
import io.airlift.http.client.jetty.JettyHttpClient;
import io.trino.client.QueryResults;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorPlugin;
import io.trino.server.testing.TestingTrinoServer;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.parallel.Execution;

import java.net.URI;
import java.util.List;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkState;
import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler;
import static io.airlift.http.client.Request.Builder.prepareGet;
import static io.airlift.http.client.Request.Builder.preparePost;
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
import static io.airlift.json.JsonCodec.jsonCodec;
import static io.airlift.testing.Closeables.closeAll;
import static io.trino.client.ProtocolHeaders.TRINO_HEADERS;
import static io.trino.testing.TestingNames.randomNameSuffix;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT;

@TestInstance(PER_CLASS)
@Execution(CONCURRENT)
final class TestResourceGroupStateInfoResource
{
private TestingTrinoServer server;
private HttpClient client;

@BeforeAll
public void setup()
{
client = new JettyHttpClient();
server = TestingTrinoServer.builder()
.setProperties(ImmutableMap.<String, String>builder()
.put("web-ui.authentication.type", "fixed")
.put("web-ui.user", "test-user")
.buildOrThrow())
.build();
server.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder()
.withSecuritySensitivePropertyNames(ImmutableSet.of("password"))
.build()));
}

@AfterAll
public void teardown()
throws Exception
{
closeAll(server, client);
server = null;
client = null;
}

@Test
void testGetResourceGroupInfoWithRedactedSecrets()
{
String catalog = "catalog_" + randomNameSuffix();
startQuery("""
CREATE CATALOG %s USING mock
WITH (
"user" = 'bob',
"password" = '1234'
)""".formatted(catalog));

ResourceGroupInfo resourceGroupInfo = getResourceGroupInfo("global");
Optional<List<QueryStateInfo>> queryStateInfos = resourceGroupInfo.runningQueries();
assertThat(queryStateInfos.isPresent()).isTrue();
List<QueryStateInfo> queryStates = queryStateInfos.get();
assertThat(queryStates.size()).isEqualTo(1);
assertThat(queryStates.getFirst().getQuery()).isEqualTo("""
CREATE CATALOG %s USING mock
WITH (
"user" = 'bob',
"password" = '***'
)""".formatted(catalog));
}

private void startQuery(String sql)
{
Request request = preparePost()
.setUri(server.resolve("/v1/statement"))
.setBodyGenerator(createStaticBodyGenerator(sql, UTF_8))
.setHeader(TRINO_HEADERS.requestUser(), "unknown")
.build();
QueryResults queryResults = client.execute(request, createJsonResponseHandler(jsonCodec(QueryResults.class)));
checkState(queryResults.getNextUri() != null && queryResults.getNextUri().toString().contains("/v1/statement/queued/"), "nextUri should point to /v1/statement/queued/");
request = prepareGet()
.setHeader(TRINO_HEADERS.requestUser(), "unknown")
.setUri(queryResults.getNextUri())
.build();
client.execute(request, createJsonResponseHandler(jsonCodec(QueryResults.class)));
}

private ResourceGroupInfo getResourceGroupInfo(String resourceGroupId)
{
URI uri = uriBuilderFrom(server.getBaseUrl())
.replacePath("/v1/resourceGroupState")
.appendPath(resourceGroupId)
.build();
Request request = prepareGet()
.setUri(uri)
.setHeader(TRINO_HEADERS.requestUser(), "unknown")
.build();
return client.execute(request, createJsonResponseHandler(jsonCodec(ResourceGroupInfo.class)));
}
}
Loading

0 comments on commit c444c3e

Please sign in to comment.