Skip to content

Commit

Permalink
Add support for row filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
martint committed Feb 24, 2020
1 parent 7d80f74 commit 7efb49c
Show file tree
Hide file tree
Showing 22 changed files with 730 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.security.PrestoPrincipal;
import io.prestosql.spi.security.Privilege;
import io.prestosql.spi.security.ViewExpression;

import javax.inject.Inject;

Expand Down Expand Up @@ -267,4 +268,10 @@ public void checkCanShowCurrentRoles(ConnectorSecurityContext context, String ca
public void checkCanShowRoleGrants(ConnectorSecurityContext context, String catalogName)
{
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
{
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.prestosql.spi.security.PrestoPrincipal;
import io.prestosql.spi.security.Privilege;
import io.prestosql.spi.security.RoleGrant;
import io.prestosql.spi.security.ViewExpression;

import javax.inject.Inject;

Expand Down Expand Up @@ -389,6 +390,12 @@ public void checkCanShowRoleGrants(ConnectorSecurityContext context, String cata
{
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
{
return Optional.empty();
}

private boolean isAdmin(ConnectorSecurityContext context)
{
SemiTransactionalHiveMetastore metastore = metastoreProvider.apply(((HiveTransactionHandle) context.getTransactionHandle()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.prestosql.security;

import com.google.common.collect.ImmutableList;
import io.prestosql.metadata.QualifiedObjectName;
import io.prestosql.spi.connector.CatalogSchemaName;
import io.prestosql.spi.connector.CatalogSchemaTableName;
Expand All @@ -22,6 +23,7 @@
import io.prestosql.spi.security.Identity;
import io.prestosql.spi.security.PrestoPrincipal;
import io.prestosql.spi.security.Privilege;
import io.prestosql.spi.security.ViewExpression;

import java.security.Principal;
import java.util.List;
Expand Down Expand Up @@ -322,4 +324,9 @@ public interface AccessControl
* @throws io.prestosql.spi.security.AccessDeniedException if not allowed
*/
void checkCanShowRoleGrants(SecurityContext context, String catalogName);

default List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObjectName tableName)
{
return ImmutableList.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import io.prestosql.spi.security.SystemAccessControl;
import io.prestosql.spi.security.SystemAccessControlFactory;
import io.prestosql.spi.security.SystemSecurityContext;
import io.prestosql.spi.security.ViewExpression;
import io.prestosql.transaction.TransactionId;
import io.prestosql.transaction.TransactionManager;
import org.weakref.jmx.Managed;
Expand Down Expand Up @@ -721,6 +722,28 @@ public void checkCanShowRoleGrants(SecurityContext securityContext, String catal
catalogAuthorizationCheck(catalogName, securityContext, (control, context) -> control.checkCanShowRoleGrants(context, catalogName));
}

@Override
public List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObjectName tableName)
{
requireNonNull(context, "securityContext is null");
requireNonNull(tableName, "catalogName is null");

ImmutableList.Builder<ViewExpression> filters = ImmutableList.builder();
CatalogAccessControlEntry entry = getConnectorAccessControl(context.getTransactionId(), tableName.getCatalogName());

if (entry != null) {
entry.getAccessControl().getRowFilter(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName())
.ifPresent(filters::add);
}

for (SystemAccessControl systemAccessControl : systemAccessControls.get()) {
systemAccessControl.getRowFilter(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName())
.ifPresent(filters::add);
}

return filters.build();
}

private CatalogAccessControlEntry getConnectorAccessControl(TransactionId transactionId, String catalogName)
{
return transactionManager.getOptionalCatalogMetadata(transactionId, catalogName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.prestosql.spi.security.Identity;
import io.prestosql.spi.security.PrestoPrincipal;
import io.prestosql.spi.security.Privilege;
import io.prestosql.spi.security.ViewExpression;

import java.security.Principal;
import java.util.List;
Expand Down Expand Up @@ -300,4 +301,10 @@ public void checkCanShowRoleGrants(SecurityContext context, String catalogName)
{
delegate().checkCanShowRoleGrants(context, catalogName);
}

@Override
public List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObjectName tableName)
{
return delegate().getRowFilters(context, tableName);
}
}
64 changes: 64 additions & 0 deletions presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
package io.prestosql.sql.analyzer;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multiset;
import io.prestosql.metadata.NewTableLayout;
import io.prestosql.metadata.QualifiedObjectName;
import io.prestosql.metadata.ResolvedFunction;
Expand Down Expand Up @@ -56,6 +58,7 @@
import javax.annotation.concurrent.Immutable;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.HashSet;
Expand Down Expand Up @@ -133,6 +136,9 @@ public class Analysis

private final Map<NodeRef<QuerySpecification>, List<GroupingOperation>> groupingOperations = new LinkedHashMap<>();

private final Multiset<RowFilterScopeEntry> rowFilterScopes = HashMultiset.create();
private final Map<NodeRef<Table>, List<Expression>> rowFilters = new LinkedHashMap<>();

private Optional<Create> create = Optional.empty();
private Optional<Insert> insert = Optional.empty();
private Optional<TableHandle> analyzeTarget = Optional.empty();
Expand Down Expand Up @@ -663,6 +669,32 @@ public boolean isOrderByRedundant(OrderBy orderBy)
return redundantOrderBy.contains(NodeRef.of(orderBy));
}

public boolean hasRowFilter(QualifiedObjectName table, String identity)
{
return rowFilterScopes.contains(new RowFilterScopeEntry(table, identity));
}

public void registerTableForRowFiltering(QualifiedObjectName table, String identity)
{
rowFilterScopes.add(new RowFilterScopeEntry(table, identity));
}

public void unregisterTableForRowFiltering(QualifiedObjectName table, String identity)
{
rowFilterScopes.remove(new RowFilterScopeEntry(table, identity));
}

public void addRowFilter(Table table, Expression filter)
{
rowFilters.computeIfAbsent(NodeRef.of(table), node -> new ArrayList<>())
.add(filter);
}

public List<Expression> getRowFilters(Table node)
{
return rowFilters.getOrDefault(NodeRef.of(node), ImmutableList.of());
}

@Immutable
public static final class SelectExpression
{
Expand Down Expand Up @@ -895,4 +927,36 @@ public String toString()
return format("AccessControl: %s, Identity: %s", accessControl.getClass(), identity);
}
}

private static class RowFilterScopeEntry
{
private final QualifiedObjectName table;
private final String identity;

public RowFilterScopeEntry(QualifiedObjectName table, String identity)
{
this.table = requireNonNull(table, "table is null");
this.identity = requireNonNull(identity, "identity is null");
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
RowFilterScopeEntry that = (RowFilterScopeEntry) o;
return table.equals(that.table) &&
identity.equals(that.identity);
}

@Override
public int hashCode()
{
return Objects.hash(table, identity);
}
}
}
Loading

0 comments on commit 7efb49c

Please sign in to comment.