Simplifying Logic In Your Python Code#

Last week I took a deeper look into some ideas covering boolean logic and how we can derive expressions from truth tables. In that same spirit, I wanted to share my absolute favorite example from my seminar on logic where I covered expression simplification and how it can be used to simplify valid Python flow control statement expressions. Additionally, this approach enabled us to determine whether specific branches were unsolvable meaning that there are branches of code that can never be executed due to a poorly formed conditional statement.

A Little AST (Abstract Syntax Tree)#

Since we covered how to use The simplify_logic function from SymPy to simplify boolean expressions last week, I wanted to dive straight into parsing Python code using the built-in Abstract Syntax Trees (ast) module. This module exists to parse Python code according to its own grammar rules in a programmatic manner. We can essentially represent valid Python code as a graph of nodes, accounting for various expressions and statements while also inspecting those aspects a little further- and even injecting some custom processing behavior.

To start off with the basics, most Python code exists as text. The code you write in a .py file exists as a text document, which can be parsed. In this case, instead of using a .py file, we can use a multiline string (which is what you would end up with if you called .read() on a file object).

from ast import parse

code = '''
def add_one(x):
    return x + 1

y = add_one(5)
print(y)
'''

module = parse(code)
print(
    f'{module      = }',
    f'{module.body = }',
    sep='\n'
)
module      = <ast.Module object at 0x70361c1f25c0>
module.body = [<ast.FunctionDef object at 0x70361c1f24a0>, <ast.Assign object at 0x70361c1f0ac0>, <ast.Expr object at 0x70361c1f2980>]

Using the ast.parse function, we let the module do all of the heavy lifting for us- returning an ast.Module object back to us. Remember that module is (behaviorally) an exchangeable term for a Python file, so in other words we have an ast representation of this file. From this object, we can inspect all of the nodes of our syntax tree. For example the body of the Module is composed of 3 nodes: a function definition, a variable assignment, and an expression- which visually correspond to the 3 parts of our inputted code.

We can inspect these nodes in more detail by using ast.dump to help recursively print information about all of the nodes that live under the root Module object.

from ast import dump

print(
    dump(module, indent=2)
)
Module(
  body=[
    FunctionDef(
      name='add_one',
      args=arguments(
        posonlyargs=[],
        args=[
          arg(arg='x')],
        kwonlyargs=[],
        kw_defaults=[],
        defaults=[]),
      body=[
        Return(
          value=BinOp(
            left=Name(id='x', ctx=Load()),
            op=Add(),
            right=Constant(value=1)))],
      decorator_list=[]),
    Assign(
      targets=[
        Name(id='y', ctx=Store())],
      value=Call(
        func=Name(id='add_one', ctx=Load()),
        args=[
          Constant(value=5)],
        keywords=[])),
    Expr(
      value=Call(
        func=Name(id='print', ctx=Load()),
        args=[
          Name(id='y', ctx=Load())],
        keywords=[]))],
  type_ignores=[])

Inspecting each of the nodes we can see that in:

  1. FunctionDef corresponds to the function we defined named def add_one that takes a single argument x.

  • In the body of this function, we return the result from a binary operation (addition) of a variable Named ‘x’ and a Constant 1.

  1. An assignment into the target variable Named ‘y’ whose value is the result of Calling add_one with a value of 5

  2. An Expression (In ast, expressions can also be function calls whose output is not stored) that simply Calls print on the variable Named ‘y’

Take a look closely at the ast nodes, this descriptive text and the inputted code. They should all start to pair up and you can begin to see the feature of the code that correspond to the nodes in the AST.

Specific Operations on Specific Nodes#

For most use cases, we’re not interested in the entire tree. More often than not we’re examining specific nodes and applying operations that are node dependent. For our use case, we want to identify whether or not the expression within an if statement can be simplified.

While we could manually traverse the AST and filter for ast.If nodes, we can use the ast.NodeVisitor and define a visit_If method (ast.NodeVisitor allows us to define any visit_{NodeType} to easily define methods that operate on specific nodes within the tree.

Something else to consider is the structure of an ast.If node itself. Each if/elif statement in Python is stored as a single If node with a test attribute (the expression to test) and an orelse attribute, which is another If node depending on whether or not there are elif/else logic in the statement. To navigate this, we can use a simple recursive traversal yielding each node within the nested orelse structures.

Finally, we can use ast.unparse to create a string representation of the inputted node for an easy way represent that node. If we are concerned with any differences unparse may have from the original source code (we are recreating textual code from an object that represents it), we could use ast.get_source_segment instead.

from ast import NodeVisitor, unparse, If

class Visitor(NodeVisitor):
    def visit_If(self, node):
        def traverse_nodes(*nodes):
            for n in nodes:
                yield n.test
                yield from traverse_nodes(*n.orelse)

        # recursively traverse If node and print all `test`s
        print(
            *map(unparse, traverse_nodes(node)),
            sep='\n'
        )


code = '''
if (p and q) or not p:
    x = 1
elif p or True:
    x = 2
elif p:
    x = 4
elif p and q and not p:
    pass
elif p and q or (not p and q):
    pass

print(x + 2)
'''

tree = parse(code)
Visitor().visit(tree)
p and q or not p
p or True
p
p and q and (not p)
p and q or (not p and q)

Simplifying Python code in SymPy#

Now that we have an easy way traverse each test within some code with ast, let’s think about how we can run these tests through SymPy’s logic simplification algorithm. Since sympy is not designed to handle Python syntax (operators such as and and or thanks to PEP 335 – Overloadable Boolean Operators, we’ll have to hack our way around to get this to work.

Primarily, we will need to parse Pythons and/or operators into a SymPy expression. The robust way of going about this would probably be to dive a little deeper into the ast and convert ast.And/ast.Or nodes into their corresponding SymPy counterparts. The simple (though less robust) alternative is to use a regular expression and substitute boolean operators for their bitwise operator counterparts (and&, or|). This was the admittedly quick solution I used for this problem.

So the final approach for this application is:

  1. Parse code into an AST

  2. For each encountered test replace (and&, or|)

  3. simplify_logic the resultant expression from above step

  4. Return the string representation of step 3

  5. Reverse transform bitwise operators to boolean operators

  6. Print results, highlighting tests that could be simplified

This obviously will fail for code that actually uses a bitwise operator as a part of a test in an if statement, but nonetheless the process does make for a fun example.

from sympy.logic import simplify_logic
from re import compile as re_compile, escape
from functools import partial

class Visitor(NodeVisitor):
    def visit_If(self, node):
        new_node = simplify_if_node(node)
        node_lines = unparse(node).splitlines()
        newnode_lines = unparse(new_node).splitlines()

        print(
            'before', 'after', sep=' '*31,
            end='\n{}\n'.format('\N{box drawings light horizontal}'*50)
        )
        sep = '\N{box drawings light vertical}'
        for lineno, (l, r) in enumerate(zip(node_lines, newnode_lines), start=node.lineno):
            if l != r:
                r = f'\033[42;39;1m{r}\033[0m'
            print(f'{lineno:<3} {l:<30} {sep} {r}')

def make_sub(repl):
    re_sub = re_compile(r'|'.join(map(escape, repl)))
    return partial(re_sub.sub, repl=lambda m: repl[m.group(0)])

repl = {'and': '&', 'or': '|', 'not ': '~'}
rev_repl = {v: k for k, v in repl.items()}

to_bitwise = make_sub(repl)
to_boolean = make_sub(rev_repl)
from sympy.core.parameters import evaluate

def simplify_if_node(node):
    if not isinstance(node, If):
        return node

    test_str = to_bitwise(string=unparse(node.test))
    simplified_test = to_boolean(string=str(simplify_logic(test_str)))

    return If(
        test=parse(simplified_test).body[0].value,
        body=node.body,
        orelse=[simplify_if_node(n) for n in node.orelse]
    )

tree = parse(code)
Visitor().visit(tree)
before                               after
──────────────────────────────────────────────────
2   if p and q or not p:           │ if q or not p:
3       x = 1                      │     x = 1
4   elif p or True:                │ elif True:
5       x = 2                      │     x = 2
6   elif p:                        │ elif p:
7       x = 4                      │     x = 4
8   elif p and q and (not p):      │ elif False:
9       pass                       │     pass
10  elif p and q or (not p and q): │ elif q:
11      pass                       │     pass

And there you have it! Using SymPy to simplify Python if statements and flag any impossible branches (see the expression that was simplified to elif False? That branch can never be reached. Further extensions of this idea should test between branches: does a previous branch make a subsequent branch impossible to reach? In this example, you can see the first elif test was simplified to elif True meaning any branch after it is unreachable.

Wrap Up#

Hopefully you learned a little about Python’s abstract syntax tree and how you can use it to parse code for some fun metaprogramming projects. Thanks for tuning in this week- talk to you all later!