tl;dr Python comprehensions can have duplicate function calls (e.g.
[foo(x) for x in ... if foo(x)]). If these function calls are expensive, we need to rewrite our comprehensions to avoid the cost of calling them multiple times. In this post, we solve this by writing a decorator that converts a function in to AST, optimizes away duplicate function calls and compiles it at runtime in ~200 lines of code.
I love list, dict and set comprehensions in Python. It’s the one feature of the language that I feel like I could point at and say “Pythonic”. It feels like the
filter functions are really baked into the language. More than that, comprehensions stem from set theory, so they must be good, right?
In this blog post, I’ll briefly describe comprehensions, explain a performance disadvantage that I frequently face, and show you some code that mitigates this disadvantage by transforming them at runtime.
A very relevant stock photo of an iguana
If you’re already familiar with comprehensions, feel free to skip this section! If not, I’ll provide a quick breakdown.
Essentially, comprehensions are just syntactic sugar that Python has to, quoting PEP 202, “provide a more concise way to create [lists, sets, dictionares] in situations where map() and filter() and/or nested loops would currently be used”. An example…
The observant Python programmer will also notice that the above can also be rewritten using the built-in
filter methods as
map(bar, filter(foo, range(100))). Anyway, back to comprehensions…
Comprehensions also allow to you “chain”
for loops, with latter loops being equivalent to deeper loops. Some more examples…
In my opinion, there’s one large shortcoming of comprehensions over their procedural cousins.
for loops contain statements (e.g.
x = foo(y)) whereas comprehensions can only contain expressions (e.g.
x + foo(y)) and are in fact expressions themselves. As a result, we can’t alias the return value of a function call so to use it again. This wouldn’t be an issue if assignments in Python were treated as expressions (like in C or C++), but they’re treated as statements.
This shortcoming really becomes an issue if the
foo function is expensive. There are some work-arounds to this problem, but I don’t like any.
lru_cachedecorator). If the function is in a different module that you don’t have access to, you’ll have to write a wrapper function and cache the result of the wrapper function. Meh…
forloops where their loop invariant work effectively the same as variables. Unfortunately, this makes your comprehension harder to comprehend (*cough* *cough*) and the whole point of comprehensions is that they’re more concise and easier to understand than standard
forloops. Here’s an example:
One idea is to take an inefficient, but clean and concise comprehension and optimize away duplicate and equivalent function calls so that it’s efficient but ugly (i.e. the 2nd work-around mentioned above). The act of taking a program and rewriting it be quicker while sacrifising clarify is something most compilers do for you, so on the surface this idea doesn’t seem unreasonable.
We could build a mini-compiler to do these comprehension transformations before running but since Python is dynamic, let’s do this at runtime! This has a nice advantage of not requiring our users to change their build workflow. We can distribute our comprehension optimization program as a Python module that users can install normally (using PyPI). Users could just import out module to have it optimize their modules or we could have it using Python decorators. The latter is slightly more explicit, so we’ll start with that!
However, building an end-to-end compiler is hard and is usually broken into multiple steps. We really only care about taking an easy-to-manipulate abstract representation of a Python program and manipulating it to be faster.
Luckily, Python’s standard library has an
ast module that can take some Python source code as input and produce an abstract syntax tree (AST). It also provides a
NodeTransformer that let’s us easily walk the structure of a Python program in tree form and transform it as we go. Python also has a built-in
compile function that takes an AST (or string) as input and compiles it into a code object that can then be executed.
Using this module, it’s actually surprisingly easy to write a Python decorator that accesses the wrapped function’s source code, converts it to an AST, transforms it, compiles it and returns that newly compiled function. We’ll call the decorator
@optimize_comprehensions for fun.
The grammar for the AST is described in the Abstract Syntax Description Language (ASDL). Here’s the description of the
DictComp nodes. I’ve also included the
GeneratorExp node since it has a very similar syntactic structure to comphrensions and we can apply the same optimizations to it. I’ve also included the description of function call (
Call) nodes as we’ll be optimizating out duplicates of those.
If we’re not careful, since we’re performing optimizations at runtime, we may perform these optimizations every time our function (the one wrapped with
@optimize_comprehensions decorator) is called. Moreover, this would actually cause users to infinitely recurse when they attempt to call the function to be optimized.
To prevent this, the first transformation we’ll perform is to remove the decorator from our function. We’ll do this by removing any decorators from the function with the name
optimize_comprehensions. This is by far the simplest solution so we’ll go with it, but it doesn’t actually work if user renamed the decorator when importing it (i.e.
from optimize_comprehensions import optimize_comprehensions as oc) or if the user namespaced their decorator by importing the module (i.e.
To determine if a there are duplicte function calls, we must first define an equality relation between
Call nodes. We could make this super smart by ingorning ordering of keyword arguments (e.g.
foo(a=1, b=2) == foo(b=2, a=1)) and evaluating arguement expressions to determine if their result is equal (e.g.
foo(a=1+2) == foo(a=3)). However, for the sake of simplicity, we’ll just say two
Call nodes are equal if they had the same “formatted dump” (essentially, the same string representation) using the
NodeTransformer will work by walking the subtree of comprehension (or generator) node and relacing duplicate
Call nodes with a
Name node that will read the value that a variable points to. In order to do this, we must first do a initial pass over the subtree of a comprehension and find duplicate nodes. We’ll acheive this by creating a
NodeVisitor class that will visit
Call nodes and take return duplicate
The next step is to replace each duplicate
Call node with a
Name node. In Python terms, this means replacing duplicate function calls with variables.