Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support passing custom filters with the same name as built-in flags #413

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/build/unreleased/140.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. change::
:tags: bug, lexer, codegen
:tickets: 140

During the lexical analysis phase, add an additional
prefix for undeclared identifiers that have the same name
as built-in flags, and determine the final filter to be used
during the code generation phase based on the context
provided by the user.
Pull request by Hai Zhu.
44 changes: 35 additions & 9 deletions mako/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mako import filters
from mako import parsetree
from mako import util
from mako.filters import DEFAULT_ESCAPE_PREFIX
from mako.pygen import PythonPrinter


Expand All @@ -26,6 +27,7 @@
# context itself
TOPLEVEL_DECLARED = {"UNDEFINED", "STOP_RENDERING"}
RESERVED_NAMES = {"context", "loop"}.union(TOPLEVEL_DECLARED)
DEFAULT_ESCAPED_N = "%sn" % DEFAULT_ESCAPE_PREFIX


def compile( # noqa
Expand Down Expand Up @@ -522,6 +524,7 @@ def write_variable_declares(self, identifiers, toplevel=False, limit=None):
self.printer.writeline("loop = __M_loop = runtime.LoopStack()")

for ident in to_write:
ident = ident.replace(DEFAULT_ESCAPE_PREFIX, "")
if ident in comp_idents:
comp = comp_idents[ident]
if comp.is_block:
Expand Down Expand Up @@ -785,25 +788,48 @@ def locate_encode(name):
else:
return filters.DEFAULT_ESCAPES.get(name, name)

if "n" not in args:
filter_args = set()
if DEFAULT_ESCAPED_N not in args:
if is_expression:
if self.compiler.pagetag:
args = self.compiler.pagetag.filter_args.args + args
if self.compiler.default_filters and "n" not in args:
filter_args = set(self.compiler.pagetag.filter_args.args)
if (
self.compiler.default_filters
and DEFAULT_ESCAPED_N not in args
):
args = self.compiler.default_filters + args
for e in args:
# if filter given as a function, get just the identifier portion
if e == "n":
if e == DEFAULT_ESCAPED_N:
continue

if e.startswith(DEFAULT_ESCAPE_PREFIX):
render_e = e.replace(DEFAULT_ESCAPE_PREFIX, "")
is_default_filter = True
else:
render_e = e
is_default_filter = False

# if filter given as a function, get just the identifier portion
m = re.match(r"(.+?)(\(.*\))", e)
if m:
ident, fargs = m.group(1, 2)
f = locate_encode(ident)
e = f + fargs
if not is_default_filter:
ident, fargs = m.group(1, 2)
f = locate_encode(ident)
render_e = f + fargs
target = "%s(%s)" % (render_e, target)
elif is_default_filter and e not in filter_args:
target = "%s(%s) if %s is not UNDEFINED else %s(%s)" % (
render_e,
target,
render_e,
locate_encode(render_e),
target,
)
else:
e = locate_encode(e)
e = locate_encode(render_e)
assert e is not None
target = "%s(%s)" % (e, target)
target = "%s(%s)" % (e, target)
return target

def visitExpression(self, node):
Expand Down
3 changes: 3 additions & 0 deletions mako/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,6 @@ def htmlentityreplace_errors(ex):
"str": "str",
"n": "n",
}


DEFAULT_ESCAPE_PREFIX = "__DEFAULT_ESCAPE_"
21 changes: 19 additions & 2 deletions mako/pyparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from mako import compat
from mako import exceptions
from mako import util
from mako.filters import DEFAULT_ESCAPE_PREFIX
from mako.filters import DEFAULT_ESCAPES

# words that cannot be assigned to (notably
# smaller than the total keys in __builtins__)
Expand Down Expand Up @@ -196,9 +198,24 @@ def visit_Tuple(self, node):
p.declared_identifiers
)
lui = self.listener.undeclared_identifiers
self.listener.undeclared_identifiers = lui.union(
p.undeclared_identifiers
undeclared_identifiers = lui.union(p.undeclared_identifiers)
conflict_identifiers = undeclared_identifiers.intersection(
DEFAULT_ESCAPES
)
if conflict_identifiers:
_map = {
i: DEFAULT_ESCAPE_PREFIX + i for i in conflict_identifiers
}
for i, arg in enumerate(self.listener.args):
if arg in _map:
self.listener.args[i] = _map[arg]
self.listener.undeclared_identifiers = (
undeclared_identifiers.symmetric_difference(
conflict_identifiers
).union(_map.values())
)
else:
self.listener.undeclared_identifiers = undeclared_identifiers


class ParseFunc(_ast_util.NodeVisitor):
Expand Down
30 changes: 25 additions & 5 deletions test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,36 @@ def test_python_fragment(self):

def test_argument_list(self):
parsed = ast.ArgumentList(
"3, 5, 'hi', x+5, " "context.get('lala')", **exception_kwargs
"3, 5, 'hi', g+5, " "context.get('lala')", **exception_kwargs
)
eq_(parsed.undeclared_identifiers, {"x", "context"})
eq_(parsed.undeclared_identifiers, {"g", "context"})
eq_(
[x for x in parsed.args],
["3", "5", "'hi'", "(x + 5)", "context.get('lala')"],
["3", "5", "'hi'", "(g + 5)", "context.get('lala')"],
)

parsed = ast.ArgumentList("h", **exception_kwargs)
eq_(parsed.args, ["h"])
parsed = ast.ArgumentList("m", **exception_kwargs)
eq_(parsed.args, ["m"])

def test_conflict_argument_list(self):
parsed = ast.ArgumentList(
"x-2, h*2, '(u)', n+5, trim, entity, unicode, decode, str, other",
**exception_kwargs,
)
eq_(
parsed.undeclared_identifiers,
{
"__DEFAULT_ESCAPE_trim",
"__DEFAULT_ESCAPE_h",
"__DEFAULT_ESCAPE_decode",
"__DEFAULT_ESCAPE_unicode",
"__DEFAULT_ESCAPE_x",
"__DEFAULT_ESCAPE_str",
"__DEFAULT_ESCAPE_entity",
"__DEFAULT_ESCAPE_n",
"other",
},
)

def test_function_decl(self):
"""test getting the arguments from a function"""
Expand Down
33 changes: 33 additions & 0 deletions test/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,36 @@ def test_capture_ccall(self):

# print t.render()
assert flatten_result(t.render()) == "this is foo. body: ccall body"

def test_conflict_filter_ident(self):
class h(object):
foo = str

t = Template(
"""
X:
${"asdf" | h.foo}
"""
)
assert flatten_result(t.render(h=h)) == "X: asdf"

def h(i):
return str(i) + "1"

t = Template(
"""
${123 | h}
"""
)
assert flatten_result(t.render()) == "123"
assert flatten_result(t.render(h=h)) == "1231"

t = Template(
"""
<%def name="foo()" filter="h">
this is foo</%def>
${foo()}
"""
)
assert flatten_result(t.render()) == "this is foo"
assert flatten_result(t.render(h=h)) == "this is foo1"
2 changes: 1 addition & 1 deletion test/test_lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,7 @@ def test_integration(self):
Text(" <tr>\n", (14, 1)),
ControlLine("for", "for x in j:", False, (15, 1)),
Text(" <td>Hello ", (16, 1)),
Expression("x", ["h"], (16, 23)),
Expression("x", ["__DEFAULT_ESCAPE_h"], (16, 23)),
Text("</td>\n", (16, 30)),
ControlLine("for", "endfor", True, (17, 1)),
Text(" </tr>\n", (18, 1)),
Expand Down
Loading