From 0088890f63ef17f3e791eb64d6720de39e3f09b1 Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Fri, 19 Aug 2022 11:26:17 -0700 Subject: [PATCH] Visitor for authorization logic AST (#643) Closes #643 COPYBARA_INTEGRATE_REVIEW=https://github.com/google-research/raksha/pull/643 from google-research:auth-ast-visitor@aferr d3426cbdcacb7bc1d3b8ba86af0076d7c469adfc PiperOrigin-RevId: 468746566 --- src/ir/auth_logic/BUILD | 31 ++ src/ir/auth_logic/ast.h | 216 +++++++++++- .../auth_logic_ast_traversing_visitor.h | 316 ++++++++++++++++++ .../auth_logic_ast_traversing_visitor_test.cc | 279 ++++++++++++++++ src/ir/auth_logic/auth_logic_ast_visitor.h | 59 ++++ src/ir/datalog/program.h | 26 ++ src/ir/ir_visitor.h | 4 +- 7 files changed, 928 insertions(+), 3 deletions(-) create mode 100644 src/ir/auth_logic/auth_logic_ast_traversing_visitor.h create mode 100644 src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc create mode 100644 src/ir/auth_logic/auth_logic_ast_visitor.h diff --git a/src/ir/auth_logic/BUILD b/src/ir/auth_logic/BUILD index 1b93f526f..dc19ff57e 100644 --- a/src/ir/auth_logic/BUILD +++ b/src/ir/auth_logic/BUILD @@ -25,17 +25,36 @@ cc_library( name = "ast", hdrs = [ "ast.h", + "auth_logic_ast_visitor.h", ], visibility = ["//visibility:private"], deps = [ "//src/common/logging", + "//src/common/utils:map_iter", "//src/common/utils:overloaded", + "//src/common/utils:types", "//src/ir/datalog:program", "@absl//absl/hash", "@absl//absl/strings:str_format", ], ) +cc_library( + name = "auth_logic_ast_traversing_visitor", + hdrs = [ + "auth_logic_ast_traversing_visitor.h", + "auth_logic_ast_visitor.h", + ], + deps = [ + ":ast", + "//src/common/logging", + "//src/common/utils:fold", + "//src/common/utils:overloaded", + "//src/common/utils:types", + "//src/ir/datalog:program", + ], +) + cc_library( name = "lowering_ast_datalog", srcs = ["lowering_ast_datalog.cc"], @@ -97,6 +116,18 @@ cc_test( ], ) +cc_test( + name = "ast_visitor_test", + srcs = ["auth_logic_ast_traversing_visitor_test.cc"], + deps = [ + ":ast", + ":auth_logic_ast_traversing_visitor", + "//src/common/testing:gtest", + "//src/ir/datalog:program", + "@absl//absl/container:btree", + ], +) + cc_library( name = "ast_construction", srcs = ["ast_construction.cc"], diff --git a/src/ir/auth_logic/ast.h b/src/ir/auth_logic/ast.h index 78163f179..adc32acd0 100644 --- a/src/ir/auth_logic/ast.h +++ b/src/ir/auth_logic/ast.h @@ -25,6 +25,8 @@ #include #include "absl/hash/hash.h" +#include "src/common/utils/map_iter.h" +#include "src/ir/auth_logic/auth_logic_ast_visitor.h" #include "src/ir/datalog/program.h" namespace raksha::ir::auth_logic { @@ -34,6 +36,21 @@ class Principal { explicit Principal(std::string name) : name_(std::move(name)) {} const std::string& name() const { return name_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept( + AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { return name_; } + private: std::string name_; }; @@ -47,6 +64,23 @@ class Attribute { const Principal& principal() const { return principal_; } const datalog::Predicate& predicate() const { return predicate_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept( + AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { + return absl::StrCat(principal_.name(), predicate_.DebugPrint()); + } + private: Principal principal_; datalog::Predicate predicate_; @@ -62,6 +96,24 @@ class CanActAs { const Principal& left_principal() const { return left_principal_; } const Principal& right_principal() const { return right_principal_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept( + AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { + return absl::StrCat(left_principal_.DebugPrint(), " canActAs ", + right_principal_.DebugPrint()); + } + private: Principal left_principal_; Principal right_principal_; @@ -85,6 +137,26 @@ class BaseFact { explicit BaseFact(BaseFactVariantType value) : value_(std::move(value)){}; const BaseFactVariantType& GetValue() const { return value_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept( + AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { + return absl::StrCat( + "BaseFact(", + std::visit([](auto& obj) { return obj.DebugPrint(); }, this->value_), + ")"); + } + private: BaseFactVariantType value_; }; @@ -103,6 +175,28 @@ class Fact { const BaseFact& base_fact() const { return base_fact_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept( + AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { + std::vector delegations; + for (const Principal& delegatee : delegation_chain_) { + delegations.push_back(delegatee.DebugPrint()); + } + return absl::StrCat("deleg: { ", absl::StrJoin(delegations, ", "), " }", + base_fact_.DebugPrint()); + } + private: std::forward_list delegation_chain_; BaseFact base_fact_; @@ -118,6 +212,29 @@ class ConditionalAssertion { const Fact& lhs() const { return lhs_; } const std::vector& rhs() const { return rhs_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept( + AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { + std::vector rhs_strings; + rhs_strings.reserve(rhs_.size()); + for (const BaseFact& base_fact : rhs_) { + rhs_strings.push_back(base_fact.DebugPrint()); + } + return absl::StrCat(lhs_.DebugPrint(), ":-", + absl::StrJoin(rhs_strings, ", ")); + } + private: Fact lhs_; std::vector rhs_; @@ -135,6 +252,26 @@ class Assertion { explicit Assertion(AssertionVariantType value) : value_(std::move(value)) {} const AssertionVariantType& GetValue() const { return value_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept( + AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { + return absl::StrCat( + "Assertion(", + std::visit([](auto& obj) { return obj.DebugPrint(); }, this->value_), + ")"); + } + private: AssertionVariantType value_; }; @@ -143,10 +280,33 @@ class Assertion { class SaysAssertion { public: explicit SaysAssertion(Principal principal, std::vector assertions) - : principal_(std::move(principal)), assertions_(std::move(assertions)){}; + : principal_(std::move(principal)), assertions_(std::move(assertions)) {} const Principal& principal() const { return principal_; } const std::vector& assertions() const { return assertions_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept( + AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { + std::vector assertion_strings; + assertion_strings.reserve(assertions_.size()); + for (const Assertion& assertion : assertions_) { + assertion_strings.push_back(assertion.DebugPrint()); + } + return absl::StrCat(principal_.DebugPrint(), "says {\n", + absl::StrJoin(assertion_strings, "\n"), "}"); + } + private: Principal principal_; std::vector assertions_; @@ -164,6 +324,24 @@ class Query { const Principal& principal() const { return principal_; } const Fact& fact() const { return fact_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept( + AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { + return absl::StrCat("Query(", name_, principal_.DebugPrint(), + fact_.DebugPrint(), ")"); + } + private: std::string name_; Principal principal_; @@ -191,6 +369,42 @@ class Program { const std::vector& queries() const { return queries_; } + template + Result Accept(AuthLogicAstVisitor& visitor) { + return visitor.Visit(*this); + } + + template + Result Accept( + AuthLogicAstVisitor& visitor) const { + return visitor.Visit(*this); + } + + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { + std::vector relation_decl_strings; + relation_decl_strings.reserve(relation_declarations_.size()); + for (const datalog::RelationDeclaration& rel_decl : + relation_declarations_) { + relation_decl_strings.push_back(rel_decl.DebugPrint()); + } + std::vector says_assertion_strings; + says_assertion_strings.reserve(says_assertions_.size()); + for (const SaysAssertion& says_assertion : says_assertions_) { + says_assertion_strings.push_back(says_assertion.DebugPrint()); + } + std::vector query_strings; + query_strings.reserve(queries_.size()); + for (const Query& query : queries_) { + query_strings.push_back(query.DebugPrint()); + } + return absl::StrCat("Program(\n", + absl::StrJoin(relation_decl_strings, "\n"), + absl::StrJoin(says_assertion_strings, "\n"), + absl::StrJoin(query_strings, "\n"), ")"); + } + private: std::vector relation_declarations_; std::vector says_assertions_; diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h new file mode 100644 index 000000000..2a77ab89d --- /dev/null +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor.h @@ -0,0 +1,316 @@ +//----------------------------------------------------------------------------- +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//---------------------------------------------------------------------------- +#ifndef SRC_IR_AUTH_LOGIC_AST_TRAVERSING_VISITOR_H_ +#define SRC_IR_AUTH_LOGIC_AST_TRAVERSING_VISITOR_H_ + +#include "src/common/logging/logging.h" +#include "src/common/utils/fold.h" +#include "src/common/utils/overloaded.h" +#include "src/common/utils/types.h" +#include "src/ir/auth_logic/ast.h" +#include "src/ir/auth_logic/auth_logic_ast_visitor.h" +#include "src/ir/datalog/program.h" + +// The implementation of this visitor over the AST nodes of authorizaiton logic +// directly follows the one for the IR in /src/ir/ir_traversing_visitor.h + +namespace raksha::ir::auth_logic { + +// A visitor that also traverses the children of a node and allows performing +// different actions before (PreVisit) and after (PostVisit) the children are +// visited. Override any of the `PreVisit` and `PostVisit` methods as needed. +template +class AuthLogicAstTraversingVisitor + : public AuthLogicAstVisitor { + private: + template + struct DefaultValueGetter { + static ValueType Get() { + LOG(FATAL) << "Override required for non-default-constructible type."; + } + }; + + template + struct DefaultValueGetter< + ValueType, std::enable_if_t>> { + static ValueType Get() { return ValueType(); } + }; + + public: + virtual ~AuthLogicAstTraversingVisitor() {} + + // Gives a default value for all 'PreVisit's to start with. + // Should be over-ridden if the Result is not default constructable. + virtual Result GetDefaultValue() { return DefaultValueGetter::Get(); } + + // Used to combine two `Result`s into one result while visiting a node + virtual Result CombineResult(Result left_result, Result right_result) { + return left_result; + } + // Invoked before all the children of `principal` are visited. + virtual Result PreVisit(CopyConst& principal) { + return GetDefaultValue(); + } + // Invoked after all the children of `principal` are visited. + virtual Result PostVisit(CopyConst& principal, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `attribute` are visited. + virtual Result PreVisit(CopyConst& attribute) { + return GetDefaultValue(); + } + // Invoked after all the children of `attribute` are visited. + virtual Result PostVisit(CopyConst& attribute, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `can_act_as` are visited. + virtual Result PreVisit(CopyConst& can_act_as) { + return GetDefaultValue(); + } + // Invoked after all the children of `canActAs` are visited. + virtual Result PostVisit(CopyConst& can_act_as, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `baseFact` are visited. + virtual Result PreVisit(CopyConst& base_fact) { + return GetDefaultValue(); + } + // Invoked after all the children of `baseFact` are visited. + virtual Result PostVisit(CopyConst& base_fact, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `fact` are visited. + virtual Result PreVisit(CopyConst& fact) { + return GetDefaultValue(); + } + // Invoked after all the children of `fact` are visited. + virtual Result PostVisit(CopyConst& fact, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `conditionalAssertion` are visited. + virtual Result PreVisit( + CopyConst& conditional_assertion) { + return GetDefaultValue(); + } + // Invoked after all the children of `conditionalAssertion` are visited. + virtual Result PostVisit( + CopyConst& conditional_assertion, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `assertion` are visited. + virtual Result PreVisit(CopyConst& assertion) { + return GetDefaultValue(); + } + // Invoked after all the children of `assertion` are visited. + virtual Result PostVisit(CopyConst& assertion, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `saysAssertion` are visited. + virtual Result PreVisit(CopyConst& says_assertion) { + return GetDefaultValue(); + } + // Invoked after all the children of `saysAssertion` are visited. + virtual Result PostVisit(CopyConst& says_assertion, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `query` are visited. + virtual Result PreVisit(CopyConst& query) { + return GetDefaultValue(); + } + // Invoked after all the children of `query` are visited. + virtual Result PostVisit(CopyConst& query, + Result in_order_result) { + return in_order_result; + } + // Invoked before all the children of `program` are visited. + virtual Result PreVisit(CopyConst& program) { + return GetDefaultValue(); + } + // Invoked after all the children of `program` are visited. + virtual Result PostVisit(CopyConst& program, + Result in_order_result) { + return in_order_result; + } + + // TODO (#644) aferr + // The Visits for the Datalog IR classes (RelationDeclaration, Predciate) + // are here temporarily until these AST classes are refactored out + // of the Datalog IR. + + virtual Result Visit( + CopyConst& relation_declaration) { + return GetDefaultValue(); + } + + virtual Result Visit(CopyConst& predicate) { + return GetDefaultValue(); + } + + // The remaining Visits are meant to follow the convention + + Result Visit(CopyConst& principal) final override { + Result pre_visit_result = PreVisit(principal); + return PostVisit(principal, std::move(pre_visit_result)); + } + + Result Visit(CopyConst& attribute) final override { + Result pre_visit_result = PreVisit(attribute); + Result fold_result = + CombineResult(CombineResult(std::move(pre_visit_result), + attribute.principal().Accept(*this)), + // TODO(#644 aferr): fix this to use predicate().Accept + // once predicate has been refactored into ast.h + Visit(attribute.predicate())); + return PostVisit(attribute, std::move(fold_result)); + } + + Result Visit(CopyConst& can_act_as) final override { + Result pre_visit_result = PreVisit(can_act_as); + Result fold_result = + CombineResult(CombineResult(std::move(pre_visit_result), + can_act_as.left_principal().Accept(*this)), + can_act_as.right_principal().Accept(*this)); + return PostVisit(can_act_as, std::move(fold_result)); + } + + Result Visit(CopyConst& base_fact) final override { + Result pre_visit_result = PreVisit(base_fact); + Result variant_visit_result = std::visit( + raksha::utils::overloaded{ + [this](const datalog::Predicate& pred) { + return VariantVisit(pred); + }, + [this](const Attribute& attrib) { return VariantVisit(attrib); }, + [this](const CanActAs& can_act_as) { + return VariantVisit(can_act_as); + }}, + base_fact.GetValue()); + Result fold_result = CombineResult(std::move(pre_visit_result), + std::move(variant_visit_result)); + return PostVisit(base_fact, std::move(fold_result)); + } + + Result Visit(CopyConst& fact) final override { + Result pre_visit_result = PreVisit(fact); + Result deleg_result = FoldAccept>( + fact.delegation_chain(), pre_visit_result); + Result base_fact_result = + CombineResult(std::move(deleg_result), fact.base_fact().Accept(*this)); + return PostVisit(fact, std::move(base_fact_result)); + } + + Result Visit(CopyConst& conditional_assertion) + final override { + Result pre_visit_result = PreVisit(conditional_assertion); + Result lhs_result = CombineResult( + std::move(pre_visit_result), conditional_assertion.lhs().Accept(*this)); + Result fold_result = FoldAccept>( + conditional_assertion.rhs(), lhs_result); + return PostVisit(conditional_assertion, std::move(fold_result)); + } + + Result Visit(CopyConst& assertion) final override { + Result pre_visit_result = PreVisit(assertion); + Result variant_visit_result = + std::visit(raksha::utils::overloaded{ + [this](const Fact& fact) { return VariantVisit(fact); }, + [this](const ConditionalAssertion& cond_assertion) { + return VariantVisit(cond_assertion); + }}, + assertion.GetValue()); + Result fold_result = CombineResult(std::move(pre_visit_result), + std::move(variant_visit_result)); + return PostVisit(assertion, std::move(fold_result)); + } + + Result Visit( + CopyConst& says_assertion) final override { + Result pre_visit_result = PreVisit(says_assertion); + Result principal_result = CombineResult( + std::move(pre_visit_result), says_assertion.principal().Accept(*this)); + Result fold_result = FoldAccept>( + says_assertion.assertions(), principal_result); + return PostVisit(says_assertion, fold_result); + } + + Result Visit(CopyConst& query) final override { + Result pre_visit_result = PreVisit(query); + Result fold_result = + CombineResult(std::move(pre_visit_result), + CombineResult(query.principal().Accept(*this), + query.fact().Accept(*this))); + return PostVisit(query, fold_result); + } + + Result Visit(CopyConst& program) final override { + Result pre_visit_result = PreVisit(program); + Result declarations_result = common::utils::fold( + program.relation_declarations(), std::move(pre_visit_result), + [this](Result acc, CopyConst + relation_declaration) { + // TODO(#644 aferr) Fix this to accept once once relationDeclaration + // has been refactored into ast.h + return CombineResult(std::move(acc), Visit(relation_declaration)); + }); + Result says_assertions_result = + FoldAccept>( + program.says_assertions(), declarations_result); + Result queries_result = FoldAccept>( + program.queries(), says_assertions_result); + return PostVisit(program, queries_result); + } + + // The VariantVisit methods use overloading to help visit + // the alternatives for the underlying std::variants in the AST + + // For BaseFactVariantType + Result VariantVisit(datalog::Predicate predicate) { + // TODO(#644 aferr) once a separate predicate has been added to ast.h + // this should use predicate.Accept(*this); + return Visit(predicate); + } + Result VariantVisit(Attribute attribute) { return attribute.Accept(*this); } + Result VariantVisit(CanActAs can_act_as) { return can_act_as.Accept(*this); } + + // For AssertionVariantType + Result VariantVisit(Fact fact) { return fact.Accept(*this); } + Result VariantVisit(ConditionalAssertion conditional_assertion) { + return conditional_assertion.Accept(*this); + } + + private: + template + Result FoldAccept(Container container, Result initial) { + return common::utils::fold( + container, std::move(initial), + [this](Result acc, CopyConst element) { + return CombineResult(std::move(acc), element.Accept(*this)); + }); + } +}; + +} // namespace raksha::ir::auth_logic + +#endif // SRC_IR_AUTH_LOGIC_AST_TRAVERSING_VISITOR_H_ diff --git a/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc new file mode 100644 index 000000000..d5702e79f --- /dev/null +++ b/src/ir/auth_logic/auth_logic_ast_traversing_visitor_test.cc @@ -0,0 +1,279 @@ +//----------------------------------------------------------------------------- +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//---------------------------------------------------------------------------- + +#include "src/ir/auth_logic/auth_logic_ast_traversing_visitor.h" + +#include + +#include "absl/container/btree_set.h" +#include "src/common/testing/gtest.h" +#include "src/ir/auth_logic/ast.h" +#include "src/ir/datalog/program.h" + +namespace raksha::ir::auth_logic { +namespace { + +// A visitor that makes a set of all the names of principals in the program +class PrincipalNameCollectorVisitor + : public AuthLogicAstTraversingVisitor> { + public: + absl::btree_set GetDefaultValue() override { return {}; } + + absl::btree_set CombineResult( + absl::btree_set acc, + absl::btree_set child_result) override { + acc.merge(std::move(child_result)); + return acc; + } + + absl::btree_set PreVisit( + const Principal& principal) override { + return {principal.name()}; + } +}; + +Program BuildTestProgram1() { + SaysAssertion assertion1 = SaysAssertion( + Principal("PrincipalA"), + {Assertion(Fact({}, BaseFact(datalog::Predicate("foo", {"bar", "baz"}, + datalog::kPositive))))}); + SaysAssertion assertion2 = SaysAssertion( + Principal("PrincipalA"), + {Assertion( + Fact({}, BaseFact(datalog::Predicate("foo", {"barbar", "bazbaz"}, + datalog::kPositive))))}); + SaysAssertion assertion3 = SaysAssertion( + Principal("PrincipalB"), + {Assertion(Fact({}, BaseFact(CanActAs(Principal("PrincipalA"), + Principal("PrincipalC")))))}); + std::vector assertion_list = { + std::move(assertion1), std::move(assertion2), std::move(assertion3)}; + return Program({}, std::move(assertion_list), {}); +} + +TEST(AuthLogicAstTraversingVisitorTest, PrincipalNameCollectorTest) { + Program test_prog = BuildTestProgram1(); + PrincipalNameCollectorVisitor collector_visitor; + const absl::btree_set result = + test_prog.Accept(collector_visitor); + const absl::btree_set expected = {"PrincipalA", "PrincipalB", + "PrincipalC"}; + EXPECT_EQ(result, expected); +} + +enum class TraversalType { kPre = 0x1, kPost = 0x2, kBoth = 0x3 }; +class TraversalOrderVisitor + : public AuthLogicAstTraversingVisitor { + public: + TraversalOrderVisitor(TraversalType traversal_type) + : pre_visits_(traversal_type == TraversalType::kPre || + traversal_type == TraversalType::kBoth), + post_visits_(traversal_type == TraversalType::kPost || + traversal_type == TraversalType::kBoth) {} + + Unit PreVisit(const Principal& prin) override { + if (pre_visits_) nodes_.push_back(prin.DebugPrint()); + return Unit(); + } + Unit PostVisit(const Principal& prin, Unit result) override { + if (post_visits_) nodes_.push_back(prin.DebugPrint()); + return result; + } + + Unit PreVisit(const Attribute& attrib) override { + if (pre_visits_) nodes_.push_back(attrib.DebugPrint()); + return Unit(); + } + Unit PostVisit(const Attribute& attrib, Unit result) override { + if (post_visits_) nodes_.push_back(attrib.DebugPrint()); + return result; + } + + Unit PreVisit(const CanActAs& canActAs) override { + if (pre_visits_) nodes_.push_back(canActAs.DebugPrint()); + return Unit(); + } + Unit PostVisit(const CanActAs& canActAs, Unit result) override { + if (post_visits_) nodes_.push_back(canActAs.DebugPrint()); + return result; + } + + Unit PreVisit(const BaseFact& baseFact) override { + if (pre_visits_) nodes_.push_back(baseFact.DebugPrint()); + return Unit(); + } + Unit PostVisit(const BaseFact& baseFact, Unit result) override { + if (post_visits_) nodes_.push_back(baseFact.DebugPrint()); + return result; + } + + Unit PreVisit(const Fact& fact) override { + if (pre_visits_) nodes_.push_back(fact.DebugPrint()); + return Unit(); + } + Unit PostVisit(const Fact& fact, Unit result) override { + if (post_visits_) nodes_.push_back(fact.DebugPrint()); + return result; + } + + Unit PreVisit(const ConditionalAssertion& condAssertion) override { + if (pre_visits_) nodes_.push_back(condAssertion.DebugPrint()); + return Unit(); + } + Unit PostVisit(const ConditionalAssertion& condAssertion, + Unit result) override { + if (post_visits_) nodes_.push_back(condAssertion.DebugPrint()); + return result; + } + + Unit PreVisit(const Assertion& assertion) override { + if (pre_visits_) nodes_.push_back(assertion.DebugPrint()); + return Unit(); + } + Unit PostVisit(const Assertion& assertion, Unit result) override { + if (post_visits_) nodes_.push_back(assertion.DebugPrint()); + return result; + } + + Unit PreVisit(const SaysAssertion& saysAssertion) override { + if (pre_visits_) nodes_.push_back(saysAssertion.DebugPrint()); + return Unit(); + } + Unit PostVisit(const SaysAssertion& saysAssertion, Unit result) override { + if (post_visits_) nodes_.push_back(saysAssertion.DebugPrint()); + return result; + } + + Unit PreVisit(const Query& query) override { + if (pre_visits_) nodes_.push_back(query.DebugPrint()); + return Unit(); + } + Unit PostVisit(const Query& query, Unit result) override { + if (post_visits_) nodes_.push_back(query.DebugPrint()); + return result; + } + + Unit PreVisit(const Program& program) override { + if (pre_visits_) nodes_.push_back(program.DebugPrint()); + return Unit(); + } + Unit PostVisit(const Program& program, Unit result) override { + if (post_visits_) nodes_.push_back(program.DebugPrint()); + return result; + } + + const std::vector& nodes() const { return nodes_; } + + private: + bool pre_visits_; + bool post_visits_; + std::vector nodes_; +}; + +TEST(AuthLogicAstTraversingVisitorTest, SimpleTraversalTest) { + Principal prinA("PrincipalA"); + Principal prinB("PrincipalB"); + Principal prinC("PrincipalC"); + + datalog::Predicate pred1("pred1", {}, datalog::kPositive); + datalog::Predicate pred2("pred2", {}, datalog::kPositive); + + BaseFact baseFact1(pred1); + BaseFact baseFact2(pred2); + + Fact fact1({}, baseFact1); + Fact fact2({prinB}, baseFact2); + + Assertion assertion1(fact1); + Assertion assertion2(fact2); + + SaysAssertion saysAssertion1(prinA, {assertion1}); + SaysAssertion saysAssertion2(prinC, {assertion1, assertion2}); + + Query query1("query1", prinA, fact1); + Query query2("query2", prinB, fact2); + + Program program1({}, {saysAssertion1, saysAssertion2}, {query1, query2}); + + TraversalOrderVisitor preorder_visitor(TraversalType::kPre); + program1.Accept(preorder_visitor); + EXPECT_THAT( + preorder_visitor.nodes(), + testing::ElementsAre( + program1.DebugPrint(), saysAssertion1.DebugPrint(), + prinA.DebugPrint(), assertion1.DebugPrint(), fact1.DebugPrint(), + baseFact1.DebugPrint(), saysAssertion2.DebugPrint(), + prinC.DebugPrint(), assertion1.DebugPrint(), fact1.DebugPrint(), + baseFact1.DebugPrint(), assertion2.DebugPrint(), fact2.DebugPrint(), + prinB.DebugPrint(), baseFact2.DebugPrint(), query1.DebugPrint(), + prinA.DebugPrint(), fact1.DebugPrint(), baseFact1.DebugPrint(), + query2.DebugPrint(), prinB.DebugPrint(), fact2.DebugPrint(), + prinB.DebugPrint(), baseFact2.DebugPrint())); + + TraversalOrderVisitor postorder_visitor(TraversalType::kPost); + program1.Accept(postorder_visitor); + EXPECT_THAT( + postorder_visitor.nodes(), + testing::ElementsAre( + prinA.DebugPrint(), baseFact1.DebugPrint(), fact1.DebugPrint(), + assertion1.DebugPrint(), saysAssertion1.DebugPrint(), + prinC.DebugPrint(), baseFact1.DebugPrint(), fact1.DebugPrint(), + assertion1.DebugPrint(), prinB.DebugPrint(), baseFact2.DebugPrint(), + fact2.DebugPrint(), assertion2.DebugPrint(), + saysAssertion2.DebugPrint(), prinA.DebugPrint(), + baseFact1.DebugPrint(), fact1.DebugPrint(), query1.DebugPrint(), + prinB.DebugPrint(), prinB.DebugPrint(), baseFact2.DebugPrint(), + fact2.DebugPrint(), query2.DebugPrint(), program1.DebugPrint())); + + // The bits of syntax not in program1 + Attribute attribute1(prinC, pred2); + CanActAs canActAs1(prinA, prinB); + BaseFact baseFact3(attribute1); + BaseFact baseFact4(canActAs1); + Fact fact3({}, baseFact3); + ConditionalAssertion conditionalAssertion1(fact3, {baseFact4}); + + TraversalOrderVisitor preorder_visitor2(TraversalType::kPre); + conditionalAssertion1.Accept(preorder_visitor2); + EXPECT_THAT( + preorder_visitor2.nodes(), + testing::ElementsAre(conditionalAssertion1.DebugPrint(), + fact3.DebugPrint(), baseFact3.DebugPrint(), + attribute1.DebugPrint(), prinC.DebugPrint(), + baseFact4.DebugPrint(), canActAs1.DebugPrint(), + prinA.DebugPrint(), prinB.DebugPrint())); + + TraversalOrderVisitor postorder_visitor2(TraversalType::kPost); + conditionalAssertion1.Accept(postorder_visitor2); + EXPECT_THAT( + postorder_visitor2.nodes(), + testing::ElementsAre(prinC.DebugPrint(), attribute1.DebugPrint(), + baseFact3.DebugPrint(), fact3.DebugPrint(), + prinA.DebugPrint(), prinB.DebugPrint(), + canActAs1.DebugPrint(), baseFact4.DebugPrint(), + conditionalAssertion1.DebugPrint())); + + TraversalOrderVisitor both_order_visitor(TraversalType::kBoth); + canActAs1.Accept(both_order_visitor); + EXPECT_THAT(both_order_visitor.nodes(), + testing::ElementsAre(canActAs1.DebugPrint(), prinA.DebugPrint(), + prinA.DebugPrint(), prinB.DebugPrint(), + prinB.DebugPrint(), canActAs1.DebugPrint())); +} + +} // namespace +} // namespace raksha::ir::auth_logic diff --git a/src/ir/auth_logic/auth_logic_ast_visitor.h b/src/ir/auth_logic/auth_logic_ast_visitor.h new file mode 100644 index 000000000..856b48b67 --- /dev/null +++ b/src/ir/auth_logic/auth_logic_ast_visitor.h @@ -0,0 +1,59 @@ +//----------------------------------------------------------------------------- +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//---------------------------------------------------------------------------- +#ifndef SRC_IR_AUTH_LOGIC_AST_VISITOR_H_ +#define SRC_IR_AUTH_LOGIC_AST_VISITOR_H_ + +#include "src/common/utils/types.h" + +namespace raksha::ir::auth_logic { + +class Principal; +class Attribute; +class CanActAs; +class BaseFact; +class Fact; +class ConditionalAssertion; +class Assertion; +class SaysAssertion; +class Query; +class Program; + +enum AstNodeMutability : bool { + Mutable = false, + Immutable = true +}; + + +template +class AuthLogicAstVisitor { + public: + virtual ~AuthLogicAstVisitor() {} + virtual Result Visit(CopyConst& principal) = 0; + virtual Result Visit(CopyConst& attribute) = 0; + virtual Result Visit(CopyConst& canActAs) = 0; + virtual Result Visit(CopyConst& baseFact) = 0; + virtual Result Visit(CopyConst& fact) = 0; + virtual Result Visit( + CopyConst& conditionalAssertion) = 0; + virtual Result Visit(CopyConst& assertion) = 0; + virtual Result Visit(CopyConst& saysAssertion) = 0; + virtual Result Visit(CopyConst& query) = 0; + virtual Result Visit(CopyConst& program) = 0; +}; + +} // namespace raksha::ir::auth_logic + +#endif // SRC_IR_AUTH_LOGIC_AST_VISITOR_H_ diff --git a/src/ir/datalog/program.h b/src/ir/datalog/program.h index e37d2a6e2..92d0fb81e 100644 --- a/src/ir/datalog/program.h +++ b/src/ir/datalog/program.h @@ -66,6 +66,12 @@ class Predicate { return this->name() < otherPredicate.name(); } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { + return absl::StrCat(sign_, name_, absl::StrJoin(args_, ", ")); + } + private: std::string name_; std::vector args_; @@ -84,6 +90,8 @@ class ArgumentType { Kind kind() const { return kind_; } absl::string_view name() const { return name_; } + std::string DebugPrint() const { return absl::StrCat(kind_, name_); } + private: Kind kind_; std::string name_; @@ -97,6 +105,12 @@ class Argument { absl::string_view argument_name() const { return argument_name_; } ArgumentType argument_type() const { return argument_type_; } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { + return absl::StrCat(argument_name_, " : ", argument_type_.DebugPrint()); + } + private: std::string argument_name_; ArgumentType argument_type_; @@ -114,6 +128,18 @@ class RelationDeclaration { bool is_attribute() const { return is_attribute_; } const std::vector& arguments() const { return arguments_; } + // A potentially ugly print of the state in this class + // for debugging/testing only + std::string DebugPrint() const { + std::vector arg_strings; + arg_strings.reserve(arguments_.size()); + for (const Argument& arg : arguments_) { + arg_strings.push_back(arg.DebugPrint()); + } + return absl::StrCat(".decl ", relation_name_, is_attribute_, + absl::StrJoin(arg_strings, ", ")); + } + private: std::string relation_name_; bool is_attribute_; diff --git a/src/ir/ir_visitor.h b/src/ir/ir_visitor.h index a412a6743..fcca721ca 100644 --- a/src/ir/ir_visitor.h +++ b/src/ir/ir_visitor.h @@ -26,12 +26,12 @@ class Operation; // An interface for the visitor class. We will also pass in the `Derived` class // as template argument if we want to support CRTP at a later point. -template +template class IRVisitor { public: virtual ~IRVisitor() {} virtual Result Visit(CopyConst& module) = 0; - virtual Result Visit(CopyConst& operation) = 0; + virtual Result Visit(CopyConst& block) = 0; virtual Result Visit(CopyConst& operation) = 0; };