Skip to content

Commit

Permalink
fix: add key_get() for builtin_operators (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
cs1137195420 authored Jul 30, 2022
1 parent f11bd4c commit 29d3e39
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 1 deletion.
66 changes: 65 additions & 1 deletion casbin/util/builtin_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import re

KEY_MATCH2_PATTERN = re.compile(r"(.*?):[^\/]+(.*?)")
KEY_MATCH3_PATTERN = re.compile(r"(.*?){[^\/]+}(.*?)")
KEY_MATCH3_PATTERN = re.compile(r"(.*?){[^\/]+?}(.*?)")
KEY_MATCH4_PATTERN = re.compile(r"{([^/]+)}")


Expand All @@ -42,6 +42,22 @@ def key_match_func(*args):
return key_match(name1, name2)


def key_get(key1, key2):
"""
key_get returns the matched part
For example, "/foo/bar/foo" matches "/foo/*"
"bar/foo" will been returned
"""
i = key2.find("*")
if i == -1:
return ""

if len(key1) > i:
if key1[:i] == key2[:i]:
return key1[i:]
return ""


def key_match2(key1, key2):
"""determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *.
For example, "/foo/bar" matches "/foo/*", "/resource1" matches "/:resource"
Expand All @@ -63,6 +79,30 @@ def key_match2_func(*args):
return key_match2(name1, name2)


def key_get2(key1, key2, path_var):
"""
key_get2 returns value matched pattern
For example, "/resource1" matches "/:resource"
if the pathVar == "resource", then "resource1" will be returned
"""
key2 = key2.replace("/*", "/.*")

keys = re.findall(":[^/]+", key2)
key2 = KEY_MATCH2_PATTERN.sub(r"\g<1>([^\/]+)\g<2>", key2, 0)

if key2 == "*":
key2 = "(.*)"

key2 = "^" + key2 + "$"
values = re.match(key2, key1)
if values is None:
return ""
for i, key in enumerate(keys):
if path_var == key[1:]:
return values.groups()[i]
return ""


def key_match3(key1, key2):
"""determines determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *.
For example, "/foo/bar" matches "/foo/*", "/resource1" matches "/{resource}"
Expand All @@ -81,6 +121,30 @@ def key_match3_func(*args):
return key_match3(name1, name2)


def key_get3(key1, key2, path_var):
"""
key_get3 returns value matched pattern
For example, "project/proj_project1_admin/" matches "project/proj_{project}_admin/"
if the pathVar == "project", then "project1" will be returned
"""
key2 = key2.replace("/*", "/.*")

keys = re.findall(r"{[^/]+?}", key2)
key2 = KEY_MATCH3_PATTERN.sub(r"\g<1>([^/]+?)\g<2>", key2, 0)

if key2 == "*":
key2 = "(.*)"

key2 = "^" + key2 + "$"
values = re.match(key2, key1)
if values is None:
return ""
for i, key in enumerate(keys):
if path_var == key[1 : len(key) - 1]:
return values.groups()[i]
return ""


def key_match4(key1: str, key2: str) -> bool:
"""
key_match4 determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *.
Expand Down
88 changes: 88 additions & 0 deletions tests/util/test_builtin_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def test_key_match(self):

self.assertFalse(util.key_match2_func("/alice/all", "/:/all"))

def test_key_get(self):
self.assertEqual(util.key_get("/foo", "/foo"), "")
self.assertEqual(util.key_get("/foo", "/foo*"), "")
self.assertEqual(util.key_get("/foo", "/foo/*"), "")
self.assertEqual(util.key_get("/foo/bar", "/foo"), "")
self.assertEqual(util.key_get("/foo/bar", "/foo*"), "/bar")
self.assertEqual(util.key_get("/foo/bar", "/foo/*"), "bar")
self.assertEqual(util.key_get("/foobar", "/foo"), "")
self.assertEqual(util.key_get("/foobar", "/foo*"), "bar")
self.assertEqual(util.key_get("/foobar", "/foo/*"), "")

def test_key_match2(self):
self.assertFalse(util.key_match2_func("/foo", "/"))
self.assertTrue(util.key_match2_func("/foo", "/foo"))
Expand Down Expand Up @@ -62,6 +73,38 @@ def test_key_match2(self):

self.assertFalse(util.key_match2_func("/alice/all", "/:/all"))

def test_key_get2(self):
self.assertEqual(util.key_get2("/foo", "/foo", "id"), "")
self.assertEqual(util.key_get2("/foo", "/foo*", "id"), "")
self.assertEqual(util.key_get2("/foo", "/foo/*", "id"), "")
self.assertEqual(util.key_get2("/foo/bar", "/foo", "id"), "")
self.assertEqual(util.key_get2("/foo/bar", "/foo*", "id"), "")
self.assertEqual(util.key_get2("/foo/bar", "/foo/*", "id"), "")
self.assertEqual(util.key_get2("/foobar", "/foo", "id"), "")
self.assertEqual(util.key_get2("/foobar", "/foo*", "id"), "")
self.assertEqual(util.key_get2("/foobar", "/foo/*", "id"), "")

self.assertEqual(util.key_get2("/", "/:resource", "resource"), "")
self.assertEqual(util.key_get2("/resource1", "/:resource", "resource"), "resource1")
self.assertEqual(util.key_get2("/myid", "/:id/using/:resId", "id"), "")
self.assertEqual(util.key_get2("/myid/using/myresid", "/:id/using/:resId", "id"), "myid")
self.assertEqual(util.key_get2("/myid/using/myresid", "/:id/using/:resId", "resId"), "myresid")

self.assertEqual(util.key_get2("/proxy/myid", "/proxy/:id/*", "id"), "")
self.assertEqual(util.key_get2("/proxy/myid/", "/proxy/:id/*", "id"), "myid")
self.assertEqual(util.key_get2("/proxy/myid/res", "/proxy/:id/*", "id"), "myid")
self.assertEqual(util.key_get2("/proxy/myid/res/res2", "/proxy/:id/*", "id"), "myid")
self.assertEqual(util.key_get2("/proxy/myid/res/res2/res3", "/proxy/:id/*", "id"), "myid")
self.assertEqual(util.key_get2("/proxy/myid/res/res2/res3", "/proxy/:id/res/*", "id"), "myid")
self.assertEqual(util.key_get2("/proxy/", "/proxy/:id/*", "id"), "")

self.assertEqual(util.key_get2("/alice", "/:id", "id"), "alice")
self.assertEqual(util.key_get2("/alice/all", "/:id/all", "id"), "alice")
self.assertEqual(util.key_get2("/alice", "/:id/all", "id"), "")
self.assertEqual(util.key_get2("/alice/all", "/:id", "id"), "")

self.assertEqual(util.key_get2("/alice/all", "/:/all", ""), "")

def test_key_match3(self):
self.assertTrue(util.key_match3_func("/foo", "/foo"))
self.assertTrue(util.key_match3_func("/foo", "/foo*"))
Expand All @@ -87,6 +130,51 @@ def test_key_match3(self):

self.assertFalse(util.key_match3_func("/myid/using/myresid", "/{id/using/{resId}"))

def test_key_get3(self):
self.assertEqual(util.key_get3("/foo", "/foo", "id"), "")
self.assertEqual(util.key_get3("/foo", "/foo*", "id"), "")
self.assertEqual(util.key_get3("/foo", "/foo/*", "id"), "")
self.assertEqual(util.key_get3("/foo/bar", "/foo", "id"), "")
self.assertEqual(util.key_get3("/foo/bar", "/foo*", "id"), "")
self.assertEqual(util.key_get3("/foo/bar", "/foo/*", "id"), "")
self.assertEqual(util.key_get3("/foobar", "/foo", "id"), "")
self.assertEqual(util.key_get3("/foobar", "/foo*", "id"), "")
self.assertEqual(util.key_get3("/foobar", "/foo/*", "id"), "")

self.assertEqual(util.key_get3("/", "/{resource}", "resource"), "")
self.assertEqual(util.key_get3("/resource1", "/{resource}", "resource"), "resource1")
self.assertEqual(util.key_get3("/myid", "/{id}/using/{resId}", "id"), "")
self.assertEqual(util.key_get3("/myid/using/myresid", "/{id}/using/{resId}", "id"), "myid")
self.assertEqual(util.key_get3("/myid/using/myresid", "/{id}/using/{resId}", "resId"), "myresid")

self.assertEqual(util.key_get3("/proxy/myid", "/proxy/{id}/*", "id"), "")
self.assertEqual(util.key_get3("/proxy/myid/", "/proxy/{id}/*", "id"), "myid")
self.assertEqual(util.key_get3("/proxy/myid/res", "/proxy/{id}/*", "id"), "myid")
self.assertEqual(util.key_get3("/proxy/myid/res/res2", "/proxy/{id}/*", "id"), "myid")
self.assertEqual(util.key_get3("/proxy/myid/res/res2/res3", "/proxy/{id}/*", "id"), "myid")
self.assertEqual(util.key_get3("/proxy/", "/proxy/{id}/*", "id"), "")

self.assertEqual(
util.key_get3("/api/group1_group_name/project1_admin/info", "/api/{proj}_admin/info", "proj"), ""
)
self.assertEqual(util.key_get3("/{id/using/myresid", "/{id/using/{resId}", "resId"), "myresid")
self.assertEqual(util.key_get3("/{id/using/myresid/status}", "/{id/using/{resId}/status}", "resId"), "myresid")

self.assertEqual(util.key_get3("/proxy/myid/res/res2/res3", "/proxy/{id}/*/{res}", "res"), "res3")
self.assertEqual(util.key_get3("/api/project1_admin/info", "/api/{proj}_admin/info", "proj"), "project1")
self.assertEqual(
util.key_get3("/api/group1_group_name/project1_admin/info", "/api/{g}_{gn}/{proj}_admin/info", "g"),
"group1",
)
self.assertEqual(
util.key_get3("/api/group1_group_name/project1_admin/info", "/api/{g}_{gn}/{proj}_admin/info", "gn"),
"group_name",
)
self.assertEqual(
util.key_get3("/api/group1_group_name/project1_admin/info", "/api/{g}_{gn}/{proj}_admin/info", "proj"),
"project1",
)

def test_key_match4(self):
self.assertTrue(util.key_match4_func("/parent/123/child/123", "/parent/{id}/child/{id}"))
self.assertFalse(util.key_match4_func("/parent/123/child/456", "/parent/{id}/child/{id}"))
Expand Down

0 comments on commit 29d3e39

Please sign in to comment.