Skip to content

Commit

Permalink
Visitor for authorization logic AST (#643)
Browse files Browse the repository at this point in the history
Closes #643

COPYBARA_INTEGRATE_REVIEW=#643 from google-research:auth-ast-visitor@aferr d3426cb
PiperOrigin-RevId: 468746566
  • Loading branch information
Andrew Ferraiuolo authored and arcs-c3po committed Aug 19, 2022
1 parent 7b3011c commit 0088890
Show file tree
Hide file tree
Showing 7 changed files with 928 additions and 3 deletions.
31 changes: 31 additions & 0 deletions src/ir/auth_logic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,36 @@ cc_library(
name = "ast",
hdrs = [
"ast.h",
"auth_logic_ast_visitor.h",
],
visibility = ["//visibility:private"],
deps = [
"//src/common/logging",
"//src/common/utils:map_iter",
"//src/common/utils:overloaded",
"//src/common/utils:types",
"//src/ir/datalog:program",
"@absl//absl/hash",
"@absl//absl/strings:str_format",
],
)

cc_library(
name = "auth_logic_ast_traversing_visitor",
hdrs = [
"auth_logic_ast_traversing_visitor.h",
"auth_logic_ast_visitor.h",
],
deps = [
":ast",
"//src/common/logging",
"//src/common/utils:fold",
"//src/common/utils:overloaded",
"//src/common/utils:types",
"//src/ir/datalog:program",
],
)

cc_library(
name = "lowering_ast_datalog",
srcs = ["lowering_ast_datalog.cc"],
Expand Down Expand Up @@ -97,6 +116,18 @@ cc_test(
],
)

cc_test(
name = "ast_visitor_test",
srcs = ["auth_logic_ast_traversing_visitor_test.cc"],
deps = [
":ast",
":auth_logic_ast_traversing_visitor",
"//src/common/testing:gtest",
"//src/ir/datalog:program",
"@absl//absl/container:btree",
],
)

cc_library(
name = "ast_construction",
srcs = ["ast_construction.cc"],
Expand Down
216 changes: 215 additions & 1 deletion src/ir/auth_logic/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <vector>

#include "absl/hash/hash.h"
#include "src/common/utils/map_iter.h"
#include "src/ir/auth_logic/auth_logic_ast_visitor.h"
#include "src/ir/datalog/program.h"

namespace raksha::ir::auth_logic {
Expand All @@ -34,6 +36,21 @@ class Principal {
explicit Principal(std::string name) : name_(std::move(name)) {}
const std::string& name() const { return name_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const { return name_; }

private:
std::string name_;
};
Expand All @@ -47,6 +64,23 @@ class Attribute {
const Principal& principal() const { return principal_; }
const datalog::Predicate& predicate() const { return predicate_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
return absl::StrCat(principal_.name(), predicate_.DebugPrint());
}

private:
Principal principal_;
datalog::Predicate predicate_;
Expand All @@ -62,6 +96,24 @@ class CanActAs {
const Principal& left_principal() const { return left_principal_; }
const Principal& right_principal() const { return right_principal_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
return absl::StrCat(left_principal_.DebugPrint(), " canActAs ",
right_principal_.DebugPrint());
}

private:
Principal left_principal_;
Principal right_principal_;
Expand All @@ -85,6 +137,26 @@ class BaseFact {
explicit BaseFact(BaseFactVariantType value) : value_(std::move(value)){};
const BaseFactVariantType& GetValue() const { return value_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
return absl::StrCat(
"BaseFact(",
std::visit([](auto& obj) { return obj.DebugPrint(); }, this->value_),
")");
}

private:
BaseFactVariantType value_;
};
Expand All @@ -103,6 +175,28 @@ class Fact {

const BaseFact& base_fact() const { return base_fact_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
std::vector<std::string> delegations;
for (const Principal& delegatee : delegation_chain_) {
delegations.push_back(delegatee.DebugPrint());
}
return absl::StrCat("deleg: { ", absl::StrJoin(delegations, ", "), " }",
base_fact_.DebugPrint());
}

private:
std::forward_list<Principal> delegation_chain_;
BaseFact base_fact_;
Expand All @@ -118,6 +212,29 @@ class ConditionalAssertion {
const Fact& lhs() const { return lhs_; }
const std::vector<BaseFact>& rhs() const { return rhs_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
std::vector<std::string> rhs_strings;
rhs_strings.reserve(rhs_.size());
for (const BaseFact& base_fact : rhs_) {
rhs_strings.push_back(base_fact.DebugPrint());
}
return absl::StrCat(lhs_.DebugPrint(), ":-",
absl::StrJoin(rhs_strings, ", "));
}

private:
Fact lhs_;
std::vector<BaseFact> rhs_;
Expand All @@ -135,6 +252,26 @@ class Assertion {
explicit Assertion(AssertionVariantType value) : value_(std::move(value)) {}
const AssertionVariantType& GetValue() const { return value_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
return absl::StrCat(
"Assertion(",
std::visit([](auto& obj) { return obj.DebugPrint(); }, this->value_),
")");
}

private:
AssertionVariantType value_;
};
Expand All @@ -143,10 +280,33 @@ class Assertion {
class SaysAssertion {
public:
explicit SaysAssertion(Principal principal, std::vector<Assertion> assertions)
: principal_(std::move(principal)), assertions_(std::move(assertions)){};
: principal_(std::move(principal)), assertions_(std::move(assertions)) {}
const Principal& principal() const { return principal_; }
const std::vector<Assertion>& assertions() const { return assertions_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
std::vector<std::string> assertion_strings;
assertion_strings.reserve(assertions_.size());
for (const Assertion& assertion : assertions_) {
assertion_strings.push_back(assertion.DebugPrint());
}
return absl::StrCat(principal_.DebugPrint(), "says {\n",
absl::StrJoin(assertion_strings, "\n"), "}");
}

private:
Principal principal_;
std::vector<Assertion> assertions_;
Expand All @@ -164,6 +324,24 @@ class Query {
const Principal& principal() const { return principal_; }
const Fact& fact() const { return fact_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
return absl::StrCat("Query(", name_, principal_.DebugPrint(),
fact_.DebugPrint(), ")");
}

private:
std::string name_;
Principal principal_;
Expand Down Expand Up @@ -191,6 +369,42 @@ class Program {

const std::vector<Query>& queries() const { return queries_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, Mutable>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(
AuthLogicAstVisitor<Derived, Result, Immutable>& visitor) const {
return visitor.Visit(*this);
}

// A potentially ugly print of the state in this class
// for debugging/testing only
std::string DebugPrint() const {
std::vector<std::string> relation_decl_strings;
relation_decl_strings.reserve(relation_declarations_.size());
for (const datalog::RelationDeclaration& rel_decl :
relation_declarations_) {
relation_decl_strings.push_back(rel_decl.DebugPrint());
}
std::vector<std::string> says_assertion_strings;
says_assertion_strings.reserve(says_assertions_.size());
for (const SaysAssertion& says_assertion : says_assertions_) {
says_assertion_strings.push_back(says_assertion.DebugPrint());
}
std::vector<std::string> query_strings;
query_strings.reserve(queries_.size());
for (const Query& query : queries_) {
query_strings.push_back(query.DebugPrint());
}
return absl::StrCat("Program(\n",
absl::StrJoin(relation_decl_strings, "\n"),
absl::StrJoin(says_assertion_strings, "\n"),
absl::StrJoin(query_strings, "\n"), ")");
}

private:
std::vector<datalog::RelationDeclaration> relation_declarations_;
std::vector<SaysAssertion> says_assertions_;
Expand Down
Loading

0 comments on commit 0088890

Please sign in to comment.