Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ternary conditional expression in tail call optimization #91

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 45 additions & 43 deletions macropy/experimental/tco.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,31 @@ def trampolined(*args, **kwargs):

@macros.decorator
def tco(tree, **kw):
def replace_call(func, args, keywords, tco_type):
starred = [arg for arg in args if isinstance(arg, ast.Starred)]
kwargs = [kw for kw in keywords if kw.arg is None]

if len(kwargs):
kwargs = kwargs[0].value
if len(starred):
starred = starred[0].value
# get rid of starargs
return hq[(tco_type,
ast_literal[func],
(ast_literal[ast.List(args, ast.Load())] +
list(ast_literal[starred])),
ast_literal[kwargs or ast.Dict([], [])])]
return hq[(tco_type,
ast_literal[func],
ast_literal[ast.List(args, ast.Load())],
ast_literal[kwargs or ast.Dict([], [])])]

def replace_call_node(node, tco_type):
with switch(node):
if ast.Call(func=func, args=args, keywords=keywords):
return replace_call(func, args, keywords, tco_type)
else:
return node

@Walker
# Replace returns of calls
Expand All @@ -80,28 +105,16 @@ def return_replacer(tree, **kw):
func=func,
args=args,
keywords=keywords)):
starred = [arg for arg in args if isinstance(arg, ast.Starred)]
kwargs = [kw for kw in keywords if kw.arg is None]

if len(kwargs):
kwargs = kwargs[0].value
if len(starred):
starred = starred[0].value
with hq as code:
# get rid of starargs
return (TCOType.CALL,
ast_literal[func],
(ast_literal[ast.List(args, ast.Load())] +
list(ast_literal[starred])),
ast_literal[kwargs or ast.Dict([],[])])
else:
with hq as code:
return (TCOType.CALL,
ast_literal[func],
ast_literal[ast.List(args, ast.Load())],
ast_literal[kwargs or ast.Dict([], [])])

return code
return ast.Return(value=replace_call(
func, args, keywords, TCOType.CALL))
elif ast.Return(value=ast.IfExp(
body=body,
orelse=orelse,
test=test)):
return ast.Return(value=ast.IfExp(
body=replace_call_node(body, TCOType.CALL),
orelse=replace_call_node(orelse, TCOType.CALL),
test=test))
else:
return tree

Expand All @@ -113,32 +126,21 @@ def replace_tc_pos(node):
func=func,
args=args,
keywords=keywords)):
starred = [arg for arg in args if isinstance(arg, ast.Starred)]
kwargs = [kw for kw in keywords if kw.arg is None]

if len(kwargs):
kwargs = kwargs[0].value
if len(starred):
starred = starred[0].value
with hq as code:
# get rid of starargs
return (TCOType.IGNORE,
ast_literal[func],
(ast_literal[ast.List(args, ast.Load())] +
list(ast_literal[starred])),
ast_literal[kwargs or ast.Dict([],[])])
else:
with hq as code:
return (TCOType.IGNORE,
ast_literal[func],
ast_literal[ast.List(args, ast.Load())],
ast_literal[kwargs or ast.Dict([], [])])
return code
return ast.Return(value=replace_call(
func, args, keywords, TCOType.IGNORE))
elif ast.If(test=test, body=body, orelse=orelse):
body[-1] = replace_tc_pos(body[-1])
if orelse:
orelse[-1] = replace_tc_pos(orelse[-1])
return ast.If(test, body, orelse)
elif ast.Expr(value=ast.IfExp(
body=body,
orelse=orelse,
test=test)):
return ast.Return(value=ast.IfExp(
body=replace_call_node(body, TCOType.IGNORE),
orelse=replace_call_node(orelse, TCOType.IGNORE),
test=test))
else:
return node

Expand Down
90 changes: 90 additions & 0 deletions macropy/experimental/test/tco.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,96 @@ def helper(n, cumulative):

self.assertEquals(120, fact(5))

def test_tco_ternary(self):
@tco
def foo(n):
return 1 if n == 0 else foo(n-1)
self.assertEquals(1, foo(3000))

def test_tco_returns_ternary(self):

@case
class Cons(x, rest): pass

@case
class Nil(): pass

def my_range(n):
cur = Nil()
for i in reversed(range(n)):
cur = Cons(i, cur)
return cur

@tco
def oddLength(xs):
with switch(xs):
return False if Nil() else evenLength(xs.rest)

@tco
def evenLength(xs):
with switch(xs):
return True if Nil() else oddLength(xs.rest)

self.assertTrue(True, evenLength(my_range(2000)))
self.assertTrue(True, oddLength(my_range(2001)))

def test_implicit_tailcall_ternary(self):
blah = []

@tco
def appendStuff(n):
blah.append(n)
appendStuff(n-1) if n > 1 else None

appendStuff(10000)
self.assertEquals(10000, len(blah))

def test_util_func_compatibility_ternary(self):
def util():
return 3 + 4

@tco
def f(n):
return util() if n == 0 else f(n-1)

self.assertEquals(7, f(1000))

def util2():
return None

@tco
def f2(n):
return util2() if n == 0 else f2(n-1)

self.assertEquals(None, f2(1000))

def test_tailcall_methods_ternary(self):

class Blah(object):
@tco
def foo(self, n):
return 1 if n == 0 else self.foo(n-1)

self.assertEquals(1, Blah().foo(5000))

def test_cross_calls_ternary(self):
def odd(n):
if n == 0:
return False
return even(n-1)

@tco
def even(n):
return True if n == 0 else odd(n-1)

def fact(n):
@tco
def helper(n, cumulative):
return cumulative if n == 0 else helper(n - 1, n * cumulative)
return helper(n, 1)

self.assertEquals(120, fact(5))


if __name__ == '__main__':
unittest.main()