From 1336e9eb6afeb22e32bbf2449d0206a2fa1d3876 Mon Sep 17 00:00:00 2001 From: Shivansh Yadav Date: Wed, 6 Apr 2022 16:20:35 +0530 Subject: [PATCH] feat: update_filtered_policies --- casbin/internal_enforcer.py | 31 +++++++++++++++++++++++ casbin/management_enforcer.py | 14 ++++++++++ casbin/persist/adapters/update_adapter.py | 8 ++++++ tests/test_management_api.py | 18 +++++++++++++ 4 files changed, 71 insertions(+) diff --git a/casbin/internal_enforcer.py b/casbin/internal_enforcer.py index 2a15fe0a..ee9067d5 100644 --- a/casbin/internal_enforcer.py +++ b/casbin/internal_enforcer.py @@ -88,6 +88,37 @@ def _update_policies(self, sec, ptype, old_rules, new_rules): return rules_updated + def _update_filtered_policies( + self, sec, ptype, new_rules, field_index, *field_values + ): + """_update_filtered_policies deletes old rules and adds new rules.""" + + old_rules = self.model.get_filtered_policy( + sec, ptype, field_index, *field_values + ) + + if self.adapter and self.auto_save: + try: + old_rules = self.adapter.update_filtered_policies( + sec, ptype, new_rules, field_index, *field_values + ) + except: + pass + + if not old_rules: + return False + + is_rule_changed = self.model.remove_policies(sec, ptype, old_rules) + self.model.add_policies(sec, ptype, new_rules) + is_rule_changed = is_rule_changed and len(new_rules) != 0 + if not is_rule_changed: + return is_rule_changed + if sec == "g": + self.build_role_links() + if self.watcher: + self.watcher.update() + return is_rule_changed + def _remove_policy(self, sec, ptype, rule): """removes a rule from the current policy.""" rule_removed = self.model.remove_policy(sec, ptype, rule) diff --git a/casbin/management_enforcer.py b/casbin/management_enforcer.py index 3f1155f9..e90b286b 100644 --- a/casbin/management_enforcer.py +++ b/casbin/management_enforcer.py @@ -151,6 +151,20 @@ def update_named_policies(self, ptype, old_rules, new_rules): """updates authorization rules from the current named policy.""" return self._update_policies("p", ptype, old_rules, new_rules) + def update_filtered_policies(self, new_rules, field_index, *field_values): + """update_filtered_policies deletes old rules and adds new rules.""" + return self.update_filtered_named_policies( + "p", new_rules, field_index, *field_values + ) + + def update_filtered_named_policies( + self, ptype, new_rules, field_index, *field_values + ): + """update_filtered_named_policies deletes old rules and adds new rules.""" + return self._update_filtered_policies( + "p", ptype, new_rules, field_index, *field_values + ) + def remove_policy(self, *params): """removes an authorization rule from the current policy.""" return self.remove_named_policy("p", *params) diff --git a/casbin/persist/adapters/update_adapter.py b/casbin/persist/adapters/update_adapter.py index 590c1891..1ac09a3f 100644 --- a/casbin/persist/adapters/update_adapter.py +++ b/casbin/persist/adapters/update_adapter.py @@ -28,3 +28,11 @@ def update_policies(self, sec, ptype, old_rules, new_rules): UpdatePolicies updates some policy rules to storage, like db, redis. """ pass + + def update_filtered_policies( + self, sec, ptype, new_rules, field_index, *field_values + ): + """ + update_filtered_policies deletes old rules and adds new rules. + """ + pass diff --git a/tests/test_management_api.py b/tests/test_management_api.py index 7bcfcd2b..79735311 100644 --- a/tests/test_management_api.py +++ b/tests/test_management_api.py @@ -115,6 +115,24 @@ def test_get_policy_api(self): self.assertTrue(e.has_grouping_policy(["alice", "data2_admin"])) self.assertFalse(e.has_grouping_policy(["bob", "data2_admin"])) + def test_update_filtered_policies(self): + e = self.get_enforcer( + get_examples("rbac_model.conf"), + get_examples("rbac_policy.csv"), + ) + + e.update_filtered_policies( + [ + ["data2_admin", "data3", "read"], + ["data2_admin", "data3", "write"], + ["bob", "data3", "write"], + ], + 0, + "data2_admin", + ) + self.assertTrue(e.enforce("data2_admin", "data3", "write")) + self.assertTrue(e.enforce("data2_admin", "data3", "read")) + def test_get_policy_matching_function(self): e = self.get_enforcer( get_examples("rbac_with_domain_and_policy_pattern_model.conf"),