Transforming Python ASTs to Optimize Comprehensions

25-30 minute read

tl;dr Python comprehensions can have duplicate function calls (e.g. [foo(x) for x in ... if foo(x)]). If these function calls are expensive, we need to rewrite our comprehensions to avoid the cost of calling them multiple times. In this post, we solve this by writing a decorator that converts a function in to AST, optimizes away duplicate function calls and compiles it at runtime in ~200 lines of code.

I love list, dict and set comprehensions in Python. It’s the one feature of the language that I feel like I could point at and say “Pythonic”. It feels like the map and filter functions are really baked into the language. More than that, comprehensions stem from set theory, so they must be good, right?

In this blog post, I’ll briefly describe comprehensions, explain a performance disadvantage that I frequently face, and show you some code that mitigates this disadvantage by transforming them at runtime.

A very relevant stock photo of an iguana

What are comprehensions

If you’re already familiar with comprehensions, feel free to skip this section! If not, I’ll provide a quick breakdown.

Essentially, comprehensions are just syntactic sugar that Python has to, quoting PEP 202, “provide a more concise way to create [lists, sets, dictionares] in situations where map() and filter() and/or nested loops would currently be used”. An example…

# An example of a basic `for` loop in Python
acc = []
for y in range(100):
    if foo(y):

# And here's how you could ~~rewrite~~ simplify the above example using 
# a list comprehension
acc = [bar(y) for y in range(100) if foo(y)]

The observant Python programmer will also notice that the above can also be rewritten using the built-in map and filter methods as map(bar, filter(foo, range(100))). Anyway, back to comprehensions…

Comprehensions also allow to you “chain” for loops, with latter loops being equivalent to deeper loops. Some more examples…

# This...
acc = []
for y in range(100):
    if foo(y):
        for x in range(y):
            if bar(x):

# semantically equivalent to...
acc = [
    for y in range(100)
    if foo(y)
    for x in range(y)
    if bar(x)

# ...which also provides the same result as...
acc = [
    for y in range(100)
    for x in range(y)
    if foo(y) and bar(x)

Comprehensive shortcomings

In my opinion, there’s one large shortcoming of comprehensions over their procedural cousins. for loops contain statements (e.g. x = foo(y)) whereas comprehensions can only contain expressions (e.g. x + foo(y)) and are in fact expressions themselves. As a result, we can’t alias the return value of a function call so to use it again. This wouldn’t be an issue if assignments in Python were treated as expressions (like in C or C++), but they’re treated as statements.

# This code calls the function `foo` with arguements `y` twice
acc = []
for y in range(100):
    if foo(y):

# We can rewrite this to make sure `foo` is only called once by assigning
# it's result to a variable and using that instead
acc = []
for y in range(100):
    foo_result = foo(y)
    if foo_result:

# This comprehension calls the function `foo` with arguements `y` twice...
# but we can't rewrite this comprehension in a "clean" way to call `func`
# only once
acc = [foo(y) for y in range(100) if foo(y)]

This shortcoming really becomes an issue if the foo function is expensive. There are some work-arounds to this problem, but I don’t like any.

  1. We could rewrite the list comprehension as a series of map and filter function calls.
  2. We can cache/memoize the expensive function that we’re calling (e.g. using functool’s lru_cache decorator). If the function is in a different module that you don’t have access to, you’ll have to write a wrapper function and cache the result of the wrapper function. Meh…
  3. Or… you can alter your list comprehensions by appending additional for loops where their loop invariant work effectively the same as variables. Unfortunately, this makes your comprehension harder to comprehend (*cough* *cough*) and the whole point of comprehensions is that they’re more concise and easier to understand than standard for loops. Here’s an example:
# We're only calling `foo` once! But we write code for others to read...
acc = [
    for y in range(100)
    for foo_result in [foo(y)]
    if foo_result

Let's build a compiler!

One idea is to take an inefficient, but clean and concise comprehension and optimize away duplicate and equivalent function calls so that it’s efficient but ugly (i.e. the 2nd work-around mentioned above). The act of taking a program and rewriting it be quicker while sacrifising clarify is something most compilers do for you, so on the surface this idea doesn’t seem unreasonable.

We could build a mini-compiler to do these comprehension transformations before running but since Python is dynamic, let’s do this at runtime! This has a nice advantage of not requiring our users to change their build workflow. We can distribute our comprehension optimization program as a Python module that users can install normally (using PyPI). Users could just import out module to have it optimize their modules or we could have it using Python decorators. The latter is slightly more explicit, so we’ll start with that!

However, building an end-to-end compiler is hard and is usually broken into multiple steps. We really only care about taking an easy-to-manipulate abstract representation of a Python program and manipulating it to be faster.

  1. Lexical analysis (or tokenization… or scanning).
  2. Syntax analysis (or parsing).
  3. Semantic analysis.
  4. Intermediate code generation.
  5. Optimization (this is the only step we really care about).
  6. Output code generation.

Python's `ast` module to the rescue!

Luckily, Python’s standard library has an ast module that can take some Python source code as input and produce an abstract syntax tree (AST). It also provides a NodeVisitor and NodeTransformer that let’s us easily walk the structure of a Python program in tree form and transform it as we go. Python also has a built-in compile function that takes an AST (or string) as input and compiles it into a code object that can then be executed.

Using this module, it’s actually surprisingly easy to write a Python decorator that accesses the wrapped function’s source code, converts it to an AST, transforms it, compiles it and returns that newly compiled function. We’ll call the decorator @optimize_comprehensions for fun.

def optimize_comprehensions(func):
    source = inspect.getsource(func)
    in_node = parse(source)
    # TODO: Implement OptimizeComprehensions
    out_node = OptimizeComprehensions().visit(in_node)
    new_func_name = out_node.body[0].name
    func_scope = func.__globals__
    # Compile the new method in the old methods scope. If we don't change the
    # name, this actually overrides the old function with the new one
    exec(compile(out_node, '<string>', 'exec'), func_scope)
    return func_scope[new_func_name]

# Example usage
def basic_list_comprehension(expensive_func):
    return [
        for x in range(100)
        if expensive_func(x)

The grammar for the AST is described in the Abstract Syntax Description Language (ASDL). Here’s the description of the ListComp, SetComp and DictComp nodes. I’ve also included the GeneratorExp node since it has a very similar syntactic structure to comphrensions and we can apply the same optimizations to it. I’ve also included the description of function call (Call) nodes as we’ll be optimizating out duplicates of those.

expr = ...
     | ListComp(expr elt, comprehension* generators)
     | SetComp(expr elt, comprehension* generators)
     | DictComp(expr key, expr value, comprehension* generators)
     | GeneratorExp(expr elt, comprehension* generators)
     | Call(expr func, expr* args, keyword* keywords,
            expr? starargs, expr? kwargs)

comprehension = (expr target, expr iter, expr* ifs)

Removing the function decorator (or preventing infinite recursion)

If we’re not careful, since we’re performing optimizations at runtime, we may perform these optimizations every time our function (the one wrapped with @optimize_comprehensions decorator) is called. Moreover, this would actually cause users to infinitely recurse when they attempt to call the function to be optimized.

To prevent this, the first transformation we’ll perform is to remove the decorator from our function. We’ll do this by removing any decorators from the function with the name optimize_comprehensions. This is by far the simplest solution so we’ll go with it, but it doesn’t actually work if user renamed the decorator when importing it (i.e. from optimize_comprehensions import optimize_comprehensions as oc) or if the user namespaced their decorator by importing the module (i.e. import optimize_comprehensions).

    def visit_FunctionDef(self, node):
        # Remove the fast_comprehensions decorator from the method so we
        # don't infinitely recurse and/or attempt to re-optimize the node
        decorators = node.decorator_list
        node.decorator_list = [
            for decorator in node.decorator_list
            if != 'optimize_comprehensions'
        return node

Equivalent function calls

To determine if a there are duplicte function calls, we must first define an equality relation between Call nodes. We could make this super smart by ingorning ordering of keyword arguments (e.g. foo(a=1, b=2) == foo(b=2, a=1)) and evaluating arguement expressions to determine if their result is equal (e.g. foo(a=1+2) == foo(a=3)). However, for the sake of simplicity, we’ll just say two Call nodes are equal if they had the same “formatted dump” (essentially, the same string representation) using the dump method.

Finding duplicate function calls

Our OptimizeComprehensions NodeTransformer will work by walking the subtree of comprehension (or generator) node and relacing duplicate Call nodes with a Name node that will read the value that a variable points to. In order to do this, we must first do a initial pass over the subtree of a comprehension and find duplicate nodes. We’ll acheive this by creating a NodeVisitor class that will visit Call nodes and take return duplicate Calls.

class DuplicateCallFinder(NodeVisitor):
    calls = {}

    def visit_Call(self, call):
        call_hash = dump(call)
        _, current_count = self.calls.get(call_hash, (call, 0))
        self.calls[call_hash] = (call, current_count + 1)

    def duplicate_calls(self):
        return [
            for _, (call, call_count) in self.calls.items()
            if call_count > 1

class OptimizeComprehensions(NodeTransformer):
    calls_to_replace_stack = []

    def visit_comp(self, node):
        # Find all functions that are called multiple times with the same
        # arguments as we will replace them with one variable
        call_visitor = DuplicateCallFinder()

        # Keep track of what calls we need to replace using a stack so we
        # support nested comprehensions

        # TODO: Replace all duplicate calls with a variable and new
        # comprehensions to the list of comprehensions with the variable and
        # function call

        # Make sure we clear the calls to replace so we don't replace other
        # calls outside of the scope of this current list comprehension
        return node

    # Optimize list, set and dict comps, and generators the same way
    visit_ListComp     = visit_comp
    visit_SetComp      = visit_comp
    visit_DictComp     = visit_comp
    visit_GeneratorExp = visit_comp

Replacing duplicate calls with variables

The next step is to replace each duplicate Call node with a Name node. In Python terms, this means replacing duplicate function calls with variables.