Transforming Python ASTs to Optimize Comprehensions
Posted April 20, 2018; 25-30 min read
tl;dr Python comprehensions can have duplicate function calls (e.g. [foo(x) for x in ... if foo(x)]
). If these function calls are expensive, we need to rewrite our comprehensions to avoid the cost of calling them multiple times. In this post, we solve this by writing a decorator that converts a function in to AST, optimizes away duplicate function calls and compiles it at runtime in ~200 lines of code.
I love list, dict and set comprehensions in Python. It’s the one feature of the language that I feel like I could point at and say “Pythonic”. It feels like the map
and filter
functions are really baked into the language. More than that, comprehensions stem from set theory, so they must be good, right?
In this blog post, I’ll briefly describe comprehensions, explain a performance disadvantage that I frequently face, and show you some code that mitigates this disadvantage by transforming them at runtime.
What are comprehensions
If you’re already familiar with comprehensions, feel free to skip this section! If not, I’ll provide a quick breakdown.
Essentially, comprehensions are just syntactic sugar that Python has to, quoting PEP 202, “provide a more concise way to create [lists, sets, dictionares] in situations where map() and filter() and/or nested loops would currently be used”. An example…
The observant Python programmer will also notice that the above can also be rewritten using the built-in map
and filter
methods as map(bar, filter(foo, range(100)))
. Anyway, back to comprehensions…
Comprehensions also allow to you “chain” for
loops, with latter loops being equivalent to deeper loops. Some more examples…
Comprehensive shortcomings
In my opinion, there’s one large shortcoming of comprehensions over their procedural cousins. for
loops contain statements (e.g. x = foo(y)
) whereas comprehensions can only contain expressions (e.g. x + foo(y)
) and are in fact expressions themselves. As a result, we can’t alias the return value of a function call so to use it again. This wouldn’t be an issue if assignments in Python were treated as expressions (like in C or C++), but they’re treated as statements.
This shortcoming really becomes an issue if the foo
function is expensive. There are some work-arounds to this problem, but I don’t like any.
- We could rewrite the list comprehension as a series of
map
andfilter
function calls. - We can cache/memoize the expensive function that we’re calling (e.g. using functool’s
lru_cache
decorator). If the function is in a different module that you don’t have access to, you’ll have to write a wrapper function and cache the result of the wrapper function. Meh… - Or… you can alter your list comprehensions by appending additional
for
loops where their loop invariant work effectively the same as variables. Unfortunately, this makes your comprehension harder to comprehend (*cough* *cough*) and the whole point of comprehensions is that they’re more concise and easier to understand than standardfor
loops. Here’s an example:
Let's build a compiler!
One idea is to take an inefficient, but clean and concise comprehension and optimize away duplicate and equivalent function calls so that it’s efficient but ugly (i.e. the 2nd work-around mentioned above). The act of taking a program and rewriting it be quicker while sacrifising clarify is something most compilers do for you, so on the surface this idea doesn’t seem unreasonable.
We could build a mini-compiler to do these comprehension transformations before running but since Python is dynamic, let’s do this at runtime! This has a nice advantage of not requiring our users to change their build workflow. We can distribute our comprehension optimization program as a Python module that users can install normally (using PyPI). Users could just import out module to have it optimize their modules or we could have it using Python decorators. The latter is slightly more explicit, so we’ll start with that!
However, building an end-to-end compiler is hard and is usually broken into multiple steps. We really only care about taking an easy-to-manipulate abstract representation of a Python program and manipulating it to be faster.
- Lexical analysis (or tokenization… or scanning).
- Syntax analysis (or parsing).
- Semantic analysis.
- Intermediate code generation.
- Optimization (this is the only step we really care about).
- Output code generation.
Python's ast
module to the rescue!
Luckily, Python’s standard library has an ast
module that can take some Python source code as input and produce an abstract syntax tree (AST). It also provides a NodeVisitor
and NodeTransformer
that let’s us easily walk the structure of a Python program in tree form and transform it as we go. Python also has a built-in compile
function that takes an AST (or string) as input and compiles it into a code object that can then be executed.
Using this module, it’s actually surprisingly easy to write a Python decorator that accesses the wrapped function’s source code, converts it to an AST, transforms it, compiles it and returns that newly compiled function. We’ll call the decorator @optimize_comprehensions
for fun.
The grammar for the AST is described in the Abstract Syntax Description Language (ASDL). Here’s the description of the ListComp
, SetComp
and DictComp
nodes. I’ve also included the GeneratorExp
node since it has a very similar syntactic structure to comphrensions and we can apply the same optimizations to it. I’ve also included the description of function call (Call
) nodes as we’ll be optimizating out duplicates of those.
Removing the function decorator (or preventing infinite recursion)
If we’re not careful, since we’re performing optimizations at runtime, we may perform these optimizations every time our function (the one wrapped with @optimize_comprehensions
decorator) is called. Moreover, this would actually cause users to infinitely recurse when they attempt to call the function to be optimized.
To prevent this, the first transformation we’ll perform is to remove the decorator from our function. We’ll do this by removing any decorators from the function with the name optimize_comprehensions
. This is by far the simplest solution so we’ll go with it, but it doesn’t actually work if user renamed the decorator when importing it (i.e. from optimize_comprehensions import optimize_comprehensions as oc
) or if the user namespaced their decorator by importing the module (i.e. import optimize_comprehensions
).
Equivalent function calls
To determine if a there are duplicte function calls, we must first define an equality relation between Call
nodes. We could make this super smart by ingorning ordering of keyword arguments (e.g. foo(a=1, b=2) == foo(b=2, a=1)
) and evaluating arguement expressions to determine if their result is equal (e.g. foo(a=1+2) == foo(a=3)
). However, for the sake of simplicity, we’ll just say two Call
nodes are equal if they had the same “formatted dump” (essentially, the same string representation) using the dump
method.
Finding duplicate function calls
Our OptimizeComprehensions
NodeTransformer
will work by walking the subtree of comprehension (or generator) node and relacing duplicate Call
nodes with a Name
node that will read the value that a variable points to. In order to do this, we must first do a initial pass over the subtree of a comprehension and find duplicate nodes. We’ll acheive this by creating a NodeVisitor
class that will visit Call
nodes and take return duplicate Calls
.
Replacing duplicate calls with variables
The next step is to replace each duplicate Call
node with a Name
node. In Python terms, this means replacing duplicate function calls with variables.
This is fairly trivial, but there is one decision we have to make. How do we generate variable names from a function call? I’ve chosen a simple but ugly solution: hash the formatted function dump and prepend it with double underscores (as Python variable names cannot start with numbers). If we were to inspect our transformed comprehension, we may see variables that look like __258792
. This solution is “ugly” as it’s vulnerable to hash collisions. However, for the purpose of this blog post, we’re going to pretend they don’t exist.
Moving the function call to a new "comprehension"
Now that we’re using variables (Name
nodes) instead of duplicate function calls (Call
nodes), we must make sure we assign the result of the function calls to the new variables. We do this by adding another comprehension
to the ListComp
, SetComp
, DictComp
or GeneratorExp
node’s list of comprehension
s that looks like for __256792 in [foo(y)]
.
However, if we reference the ASDL, we’ll see that the description of a comprehension
is (expr target, expr iter, expr* ifs)
, which put plainly means that comprehension
s contain if
statements. If we leave these if
statements as is (or is it “as are”?) and if we append the new comprehension
s to the end of the list of comprehension
s, we’ll have issues looking up new variables before they’re defined. Consider the following example:
One solution would be to be very smart about our placement of new comprehension
s such that they’re always placed before any uses of the new variable. A simpler, but less efficient solution is to just move all if
statements to the last comprehension
in the list of comprehension
s. This is less efficient as if
statements interleaved with for
loop statements allow us to break out of loops faster. We’ll go with the latter solution for the sake of this blog post, but it’s not the ideal solution.
This is what our visit_comp
method looks like now:
Unique target variable names
Our optimizer is nearly there. It currently fails when a nested comprehension has the target variable name as another child comprehension within the same top level comprehension (it also fails if it has the same target variable name as the top level comprehension). So what’s the solution? Another stupidily simple solution that is to change the name of all target variable names that are children of a comprehension node so that they’re unique.
To make variable names unique, I’ve decided to append random numbers to the end of variable names. This also suffers from naming conflicts, but for the purpose of this blog post, it’ll do. We’ll also use a stack of variable names to replace, similar to what we did with the calls_to_replace_stack
in the OptimizeComprehensions
NodeTransformer
to keep track of scope.
Here’s an example implementation of a NodeTransformer
that renames variable names in comprehensions. You’ll notice we only need to call it once within our OptimizeComprehensions
.
Another thing to note is that we push a map of variables to replace for each comprehension in a comprehension node’s generators. By doing so, we can handle the case where we re-assign target variable names across generators within the same comprehension node. For example, the following test case passes.
To conclude, we've done it!
In 200 lines of Python (including docstrings, comments and imports), we’ve built an optimizer that takes potentially slow but elegant comprehensions and produces ugly but fast comprehensions at runtime.
For simple comprehensions with a duplicate function call, the execution time of our optimized version converges to half the execution time of the original, as the time it takes to execute the function increases. The line chart below illustrates this statement and was generated by comparing the execution time of list comprehension before and after optimization (see GitHub gist for script).
Anyway, I hope you’ve enjoyed reading this walk-through and here’s a GitHub gist of the optimzer code we’ve gone through! Feel free to run it, modify it and distribute it. I may polish it up one day and publish to PyPI for all to use.
Some takeaways:
- Writing a code optimizer is generally pretty simple, especially with Python’s
ast
module. Go ahead and try! - You can never trust the Python code you’re calling… even if you have the source code (which you almost always have access to when working with Python). If someone can inject some malicious code into your Python process, they now have entire control of it.
Caveats
Writing optimizers and/or compilers has the potential of creating really-hard-to-debug bugs. Although this has been a fun exercise, I don’t recommend using the optimizer written in this post at the moment in any production environments.
It’s also important to properly trace and profile code before applying optimizations like the ones above. Runtime optimizations incur overhead and may actually slow down code instead of speeding it up. For example, our optimizer only works if the overhead of transforming functions is lower than the overhead of executing duplicate function calls.
Update (August 11, 2018): I thought I would add, PEP 572 introduces assigment expressions to Python with the new :=
operator! This solves the comprehension issue described in this post, but this should still be a helpful guide on how to use the ast
module to transform Python programs at runtime!