diff --git a/casbin/model/function.py b/casbin/model/function.py index fdd4a22..0d0b0c5 100644 --- a/casbin/model/function.py +++ b/casbin/model/function.py @@ -31,6 +31,7 @@ def load_function_map(): fm.add_function("keyMatch2", util.key_match2_func) fm.add_function("keyMatch3", util.key_match3_func) fm.add_function("keyMatch4", util.key_match4_func) + fm.add_function("keyMatch5", util.key_match5_func) fm.add_function("regexMatch", util.regex_match_func) fm.add_function("ipMatch", util.ip_match_func) fm.add_function("globMatch", util.glob_match_func) diff --git a/casbin/util/builtin_operators.py b/casbin/util/builtin_operators.py index 1b09d1e..84965f4 100644 --- a/casbin/util/builtin_operators.py +++ b/casbin/util/builtin_operators.py @@ -18,6 +18,7 @@ KEY_MATCH2_PATTERN = re.compile(r"(.*?):[^\/]+(.*?)") KEY_MATCH3_PATTERN = re.compile(r"(.*?){[^\/]+?}(.*?)") KEY_MATCH4_PATTERN = re.compile(r"{([^/]+)}") +KEY_MATCH5_PATTERN = re.compile(r"{[^/]+}") def key_match(key1, key2): @@ -194,6 +195,32 @@ def key_match4_func(*args) -> bool: return key_match4(name1, name2) +def key_match5(key1: str, key2: str) -> bool: + """ + key_match5 determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a * + For example, + - "/foo/bar?status=1&type=2" matches "/foo/bar" + - "/parent/child1" and "/parent/child1" matches "/parent/*" + - "/parent/child1?status=1" matches "/parent/*" + """ + i = key1.find("?") + if i != -1: + key1 = key1[:i] + + key2 = key2.replace("/*", "/.*") + + key2 = KEY_MATCH5_PATTERN.sub(r"[^/]+", key2, 0) + + return regex_match(key1, "^" + key2 + "$") + + +def key_match5_func(*args) -> bool: + name1 = args[0] + name2 = args[1] + + return key_match5(name1, name2) + + def regex_match(key1, key2): """determines whether key1 matches the pattern of key2 in regular expression.""" diff --git a/tests/util/test_builtin_operators.py b/tests/util/test_builtin_operators.py index 2dacf0b..b0b47da 100644 --- a/tests/util/test_builtin_operators.py +++ b/tests/util/test_builtin_operators.py @@ -13,6 +13,7 @@ # limitations under the License. from unittest import TestCase + from casbin import util @@ -190,6 +191,43 @@ def test_key_match4(self): self.assertFalse(util.key_match4_func("/parent/123/child/123", "/parent/{i/d}/child/{i/d}")) + def test_key_match5_func(self): + self.assertTrue(util.key_match5_func("/parent/child?status=1&type=2", "/parent/child")) + self.assertFalse(util.key_match5_func("/parent?status=1&type=2", "/parent/child")) + + self.assertTrue(util.key_match5_func("/parent/child/?status=1&type=2", "/parent/child/")) + self.assertFalse(util.key_match5_func("/parent/child/?status=1&type=2", "/parent/child")) + self.assertFalse(util.key_match5_func("/parent/child?status=1&type=2", "/parent/child/")) + + self.assertTrue(util.key_match5_func("/foo", "/foo")) + self.assertTrue(util.key_match5_func("/foo", "/foo*")) + self.assertFalse(util.key_match5_func("/foo", "/foo/*")) + self.assertFalse(util.key_match5_func("/foo/bar", "/foo")) + self.assertFalse(util.key_match5_func("/foo/bar", "/foo*")) + self.assertTrue(util.key_match5_func("/foo/bar", "/foo/*")) + self.assertFalse(util.key_match5_func("/foobar", "/foo")) + self.assertFalse(util.key_match5_func("/foobar", "/foo*")) + self.assertFalse(util.key_match5_func("/foobar", "/foo/*")) + + self.assertFalse(util.key_match5_func("/", "/{resource}")) + self.assertTrue(util.key_match5_func("/resource1", "/{resource}")) + self.assertFalse(util.key_match5_func("/myid", "/{id}/using/{resId}")) + self.assertTrue(util.key_match5_func("/myid/using/myresid", "/{id}/using/{resId}")) + + self.assertFalse(util.key_match5_func("/proxy/myid", "/proxy/{id}/*")) + self.assertTrue(util.key_match5_func("/proxy/myid/", "/proxy/{id}/*")) + self.assertTrue(util.key_match5_func("/proxy/myid/res", "/proxy/{id}/*")) + self.assertTrue(util.key_match5_func("/proxy/myid/res/res2", "/proxy/{id}/*")) + self.assertTrue(util.key_match5_func("/proxy/myid/res/res2/res3", "/proxy/{id}/*")) + self.assertFalse(util.key_match5_func("/proxy/", "/proxy/{id}/*")) + + self.assertFalse(util.key_match5_func("/proxy/myid?status=1&type=2", "/proxy/{id}/*")) + self.assertTrue(util.key_match5_func("/proxy/myid/", "/proxy/{id}/*")) + self.assertTrue(util.key_match5_func("/proxy/myid/res?status=1&type=2", "/proxy/{id}/*")) + self.assertTrue(util.key_match5_func("/proxy/myid/res/res2?status=1&type=2", "/proxy/{id}/*")) + self.assertTrue(util.key_match5_func("/proxy/myid/res/res2/res3?status=1&type=2", "/proxy/{id}/*")) + self.assertFalse(util.key_match5_func("/proxy/", "/proxy/{id}/*")) + def test_regex_match(self): self.assertTrue(util.regex_match_func("/topic/create", "/topic/create")) self.assertTrue(util.regex_match_func("/topic/create/123", "/topic/create"))