Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visitor for authorization logic AST #643

Merged
merged 12 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/ir/auth_logic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,34 @@ cc_library(
name = "ast",
hdrs = [
"ast.h",
Cypher1 marked this conversation as resolved.
Show resolved Hide resolved
"auth_logic_ast_visitor.h",
],
visibility = ["//visibility:private"],
deps = [
"//src/common/logging",
"//src/common/utils:overloaded",
"//src/common/utils:types",
"//src/ir/datalog:program",
"@absl//absl/hash",
"@absl//absl/strings:str_format",
],
)

cc_library(
name = "auth_logic_ast_traversing_visitor",
hdrs = [
"auth_logic_ast_traversing_visitor.h",
"auth_logic_ast_visitor.h",
],
deps = [
":ast",
"//src/common/logging",
"//src/common/utils:fold",
"//src/common/utils:overloaded",
"//src/common/utils:types",
],
)

cc_library(
name = "lowering_ast_datalog",
srcs = ["lowering_ast_datalog.cc"],
Expand Down Expand Up @@ -94,6 +111,18 @@ cc_test(
],
)

cc_test(
name = "ast_visitor_test",
srcs = ["auth_logic_ast_traversing_visitor_test.cc"],
deps = [
":ast",
":auth_logic_ast_traversing_visitor",
"//src/common/testing:gtest",
"//src/ir/datalog:program",
"@absl//absl/container:flat_hash_set",
],
)

cc_library(
name = "ast_construction",
srcs = ["ast_construction.cc"],
Expand Down
101 changes: 101 additions & 0 deletions src/ir/auth_logic/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <vector>

#include "absl/hash/hash.h"
#include "src/ir/auth_logic/auth_logic_ast_visitor.h"
#include "src/ir/datalog/program.h"

namespace raksha::ir::auth_logic {
Expand All @@ -34,6 +35,16 @@ class Principal {
explicit Principal(std::string name) : name_(std::move(name)) {}
const std::string& name() const { return name_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
aferr marked this conversation as resolved.
Show resolved Hide resolved
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
std::string name_;
};
Expand All @@ -47,6 +58,16 @@ class Attribute {
const Principal& principal() const { return principal_; }
const datalog::Predicate& predicate() const { return predicate_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
Principal principal_;
datalog::Predicate predicate_;
Expand All @@ -62,6 +83,16 @@ class CanActAs {
const Principal& left_principal() const { return left_principal_; }
const Principal& right_principal() const { return right_principal_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
Principal left_principal_;
Principal right_principal_;
Expand All @@ -85,6 +116,16 @@ class BaseFact {
explicit BaseFact(BaseFactVariantType value) : value_(std::move(value)){};
const BaseFactVariantType& GetValue() const { return value_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
BaseFactVariantType value_;
};
Expand All @@ -103,6 +144,16 @@ class Fact {

const BaseFact& base_fact() const { return base_fact_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
std::forward_list<Principal> delegation_chain_;
BaseFact base_fact_;
Expand All @@ -118,6 +169,16 @@ class ConditionalAssertion {
const Fact& lhs() const { return lhs_; }
const std::vector<BaseFact>& rhs() const { return rhs_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
Fact lhs_;
std::vector<BaseFact> rhs_;
Expand All @@ -135,6 +196,16 @@ class Assertion {
explicit Assertion(AssertionVariantType value) : value_(std::move(value)) {}
const AssertionVariantType& GetValue() const { return value_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
AssertionVariantType value_;
};
Expand All @@ -147,6 +218,16 @@ class SaysAssertion {
const Principal& principal() const { return principal_; }
const std::vector<Assertion>& assertions() const { return assertions_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
Principal principal_;
std::vector<Assertion> assertions_;
Expand All @@ -164,6 +245,16 @@ class Query {
const Principal& principal() const { return principal_; }
const Fact& fact() const { return fact_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
std::string name_;
Principal principal_;
Expand Down Expand Up @@ -191,6 +282,16 @@ class Program {

const std::vector<Query>& queries() const { return queries_; }

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, false>& visitor) {
return visitor.Visit(*this);
}

template <typename Derived, typename Result>
Result Accept(AuthLogicAstVisitor<Derived, Result, true>& visitor) const {
return visitor.Visit(*this);
}

private:
std::vector<datalog::RelationDeclaration> relation_declarations_;
std::vector<SaysAssertion> says_assertions_;
Expand Down
Loading