Skip to content

Commit

Permalink
make query context changes backwards compatible (#12564)
Browse files Browse the repository at this point in the history
Adds a default implementation of getQueryContext, which was added to the Query interface in #12396. Query is marked with @ExtensionPoint, and lately we have been trying to be less volatile on these interfaces by providing default implementations to be more chill for extension writers.

The way this default implementation is done in this PR is a bit strange due to the way that getQueryContext is used (mutated with system default and system generated keys); the default implementation has a specific object that it returns, and I added another temporary default method isLegacyContext that checks if the getQueryContext returns that object or not. If not, callers fall back to using getContext and withOverriddenContext to set these default and system values.

I am open to other ideas as well, but this way should work at least without exploding, and added some tests to ensure that it is wired up correctly for QueryLifecycle, including the context authorization stuff.

The added test shows the strange behavior if query context authorization is enabled, mainly that the system default and system generated query context keys also need to be granted as permissions for things to function correctly. This is not great, so I mentioned it in the javadocs as well. Not sure if it needs to be called out anywhere else.
  • Loading branch information
clintropolis authored May 25, 2022
1 parent 9f9faee commit d0c9c37
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 8 deletions.
15 changes: 13 additions & 2 deletions processing/src/main/java/org/apache/druid/query/Query.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,14 @@ public interface Query<T>
Map<String, Object> getContext();

/**
* Returns QueryContext for this query.
* Returns QueryContext for this query. This type distinguishes between user provided, system default, and system
* generated query context keys so that authorization may be employed directly against the user supplied context
* values.
*
* This method is marked @Nullable, but is only so for backwards compatibility with Druid versions older than 0.23.
* Callers should check if the result of this method is null, and if so, they are dealing with a legacy query
* implementation, and should fall back to using {@link #getContext()} and {@link #withOverriddenContext(Map)} to
* manipulate the query context.
*
* Note for query context serialization and deserialization.
* Currently, once a query is serialized, its queryContext can be different from the original queryContext
Expand All @@ -110,7 +117,11 @@ public interface Query<T>
* after it is deserialized. This is because {@link BaseQuery#getContext()} uses
* {@link QueryContext#getMergedParams()} for serialization, and queries accept a map for deserialization.
*/
QueryContext getQueryContext();
@Nullable
default QueryContext getQueryContext()
{
return null;
}

<ContextType> ContextType getContextValue(String key);

Expand Down
187 changes: 187 additions & 0 deletions processing/src/test/java/org/apache/druid/query/QueryContextTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,26 @@
package org.apache.druid.query;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Ordering;
import nl.jqno.equalsverifier.EqualsVerifier;
import nl.jqno.equalsverifier.Warning;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.spec.QuerySegmentSpec;
import org.joda.time.DateTimeZone;
import org.joda.time.Duration;
import org.joda.time.Interval;
import org.junit.Assert;
import org.junit.Test;

import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.Map;

public class QueryContextTest
{
@Test
Expand Down Expand Up @@ -232,4 +247,176 @@ public void testGetMergedParams()

Assert.assertSame(context.getMergedParams(), context.getMergedParams());
}

@Test
public void testLegacyReturnsLegacy()
{
Query legacy = new LegacyContextQuery(ImmutableMap.of("foo", "bar"));
Assert.assertNull(legacy.getQueryContext());
}

@Test
public void testNonLegacyIsNotLegacyContext()
{
Query timeseries = Druids.newTimeseriesQueryBuilder()
.dataSource("test")
.intervals("2015-01-02/2015-01-03")
.granularity(Granularities.DAY)
.aggregators(Collections.singletonList(new CountAggregatorFactory("theCount")))
.context(ImmutableMap.of("foo", "bar"))
.build();
Assert.assertNotNull(timeseries.getQueryContext());
}

public static class LegacyContextQuery implements Query
{
private final Map<String, Object> context;

public LegacyContextQuery(Map<String, Object> context)
{
this.context = context;
}

@Override
public DataSource getDataSource()
{
return new TableDataSource("fake");
}

@Override
public boolean hasFilters()
{
return false;
}

@Override
public DimFilter getFilter()
{
return null;
}

@Override
public String getType()
{
return "legacy-context-query";
}

@Override
public QueryRunner getRunner(QuerySegmentWalker walker)
{
return new NoopQueryRunner();
}

@Override
public List<Interval> getIntervals()
{
return Collections.singletonList(Intervals.ETERNITY);
}

@Override
public Duration getDuration()
{
return getIntervals().get(0).toDuration();
}

@Override
public Granularity getGranularity()
{
return Granularities.ALL;
}

@Override
public DateTimeZone getTimezone()
{
return DateTimeZone.UTC;
}

@Override
public Map<String, Object> getContext()
{
return context;
}

@Override
public boolean getContextBoolean(String key, boolean defaultValue)
{
if (context == null || !context.containsKey(key)) {
return defaultValue;
}
return (boolean) context.get(key);
}

@Override
public boolean isDescending()
{
return false;
}

@Override
public Ordering getResultOrdering()
{
return Ordering.natural();
}

@Override
public Query withQuerySegmentSpec(QuerySegmentSpec spec)
{
return new LegacyContextQuery(context);
}

@Override
public Query withId(String id)
{
context.put(BaseQuery.QUERY_ID, id);
return this;
}

@Nullable
@Override
public String getId()
{
return (String) context.get(BaseQuery.QUERY_ID);
}

@Override
public Query withSubQueryId(String subQueryId)
{
context.put(BaseQuery.SUB_QUERY_ID, subQueryId);
return this;
}

@Nullable
@Override
public String getSubQueryId()
{
return (String) context.get(BaseQuery.SUB_QUERY_ID);
}

@Override
public Query withDataSource(DataSource dataSource)
{
return this;
}

@Override
public Query withOverriddenContext(Map contextOverride)
{
return new LegacyContextQuery(contextOverride);
}

@Override
public Object getContextValue(String key, Object defaultValue)
{
if (!context.containsKey(key)) {
return defaultValue;
}
return context.get(key);
}

@Override
public Object getContextValue(String key)
{
return context.get(key);
}
}
}
25 changes: 20 additions & 5 deletions server/src/main/java/org/apache/druid/server/QueryLifecycle.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.druid.query.DruidMetrics;
import org.apache.druid.query.GenericQueryMetricsFactory;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryMetrics;
Expand Down Expand Up @@ -63,6 +64,7 @@
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -186,11 +188,18 @@ public void initialize(final Query baseQuery)
{
transition(State.NEW, State.INITIALIZED);

baseQuery.getQueryContext().addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString());
baseQuery.getQueryContext().addDefaultParams(defaultQueryConfig.getContext());
if (baseQuery.getQueryContext() == null) {
QueryContext context = new QueryContext(baseQuery.getContext());
context.addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString());
context.addDefaultParams(defaultQueryConfig.getContext());

this.baseQuery = baseQuery;
this.toolChest = warehouse.getToolChest(baseQuery);
this.baseQuery = baseQuery.withOverriddenContext(context.getMergedParams());
} else {
baseQuery.getQueryContext().addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString());
baseQuery.getQueryContext().addDefaultParams(defaultQueryConfig.getContext());
this.baseQuery = baseQuery;
}
this.toolChest = warehouse.getToolChest(this.baseQuery);
}

/**
Expand All @@ -204,14 +213,20 @@ public void initialize(final Query baseQuery)
public Access authorize(HttpServletRequest req)
{
transition(State.INITIALIZED, State.AUTHORIZING);
final Set<String> contextKeys;
if (baseQuery.getQueryContext() == null) {
contextKeys = baseQuery.getContext().keySet();
} else {
contextKeys = baseQuery.getQueryContext().getUserParams().keySet();
}
final Iterable<ResourceAction> resourcesToAuthorize = Iterables.concat(
Iterables.transform(
baseQuery.getDataSource().getTableNames(),
AuthorizationUtils.DATASOURCE_READ_RA_GENERATOR
),
authConfig.authorizeQueryContextParams()
? Iterables.transform(
baseQuery.getQueryContext().getUserParams().keySet(),
contextKeys,
contextParam -> new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE)
)
: Collections.emptyList()
Expand Down
10 changes: 9 additions & 1 deletion server/src/main/java/org/apache/druid/server/QueryResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.druid.query.BadQueryException;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryException;
import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryTimeoutException;
Expand Down Expand Up @@ -373,7 +374,14 @@ private Query<?> readQuery(
String prevEtag = getPreviousEtag(req);

if (prevEtag != null) {
baseQuery.getQueryContext().addSystemParam(HEADER_IF_NONE_MATCH, prevEtag);
if (baseQuery.getQueryContext() == null) {
QueryContext context = new QueryContext(baseQuery.getContext());
context.addSystemParam(HEADER_IF_NONE_MATCH, prevEtag);

return baseQuery.withOverriddenContext(context.getMergedParams());
} else {
baseQuery.getQueryContext().addSystemParam(HEADER_IF_NONE_MATCH, prevEtag);
}
}

return baseQuery;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.druid.query.DefaultQueryConfig;
import org.apache.druid.query.Druids;
import org.apache.druid.query.GenericQueryMetricsFactory;
import org.apache.druid.query.QueryContextTest;
import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QuerySegmentWalker;
Expand Down Expand Up @@ -244,6 +245,40 @@ public void testAuthorizeQueryContext_notAuthorized()
Assert.assertFalse(lifecycle.authorize(mockRequest()).isAllowed());
}

@Test
public void testAuthorizeLegacyQueryContext_authorized()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("fake", ResourceType.DATASOURCE), Action.READ))
.andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("baz", ResourceType.QUERY_CONTEXT), Action.WRITE)).andReturn(Access.OK);
// to use legacy query context with context authorization, even system generated things like queryId need to be explicitly added
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("queryId", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(Access.OK);

EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest)
.once();

replayAll();

final QueryContextTest.LegacyContextQuery query = new QueryContextTest.LegacyContextQuery(ImmutableMap.of("foo", "bar", "baz", "qux"));

lifecycle.initialize(query);

Assert.assertNull(lifecycle.getQuery().getQueryContext());
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("foo"));
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("baz"));
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId"));

Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed());
}

private HttpServletRequest mockRequest()
{
HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class);
Expand Down

0 comments on commit d0c9c37

Please sign in to comment.