modified source code transformation example from Sebastian F. Walter's blog post .

In [1]:
import ast
import codegen
In [2]:
%%writefile func.py
y = x1 + x1*x2
Writing func.py
In [3]:
with open("func.py") as f:
    source = f.read()
# en with open

node = ast.parse(source)
node = ast.fix_missing_locations(node)
ast.dump(node)
Out[3]:
"Module(body=[Assign(targets=[Name(id='y', ctx=Store())], value=BinOp(left=Name(id='x1', ctx=Load()), op=Add(), right=BinOp(left=Name(id='x1', ctx=Load()), op=Mult(), right=Name(id='x2', ctx=Load()))))])"
BinOp(
    left=Name(id='x1', ctx=Load()), 
    op=Add(), 
    right=BinOp(
        left=Name(id='x1', ctx=Load()),
        op=Mult(),
        right=Name(id='x2', ctx=Load())
    )
)
In [4]:
class MyVisitor(ast.NodeVisitor):
    def visit(self, node, d=False):
        if d == False:
            if isinstance(node, ast.Module):
                line_number = 0
                while line_number < len(node.body):
                    line = node.body[line_number]
                    if isinstance(line, ast.Assign):
                        # derivative code
                        expr = node.body[line_number].value
                        id = line.targets[0].id
                        node_d = self.visit(expr, True)
                        node.body.insert(line_number, ast.Assign(targets=[ast.Name(id='%s_d'%id)], value=node_d))
                        line_number += 2
                    else:
                        self.visit(line)
                        line_number += 1
            else:
                super(MyVisitor, self).visit(node)

        else:
            # print 'visit node_d'
            # print type(node)
            if isinstance(node, ast.BinOp):
                # print 'Binop'
                if isinstance(node.op, ast.Mult):
                    # print 'v'*20 + 'Mult'
                    # print ast.dump(node.left)
                    # print ast.dump(node.right)
                    left_d = self.visit(node.left, d=True)
                    right_d = self.visit(node.right, d=True)
                    # print ast.dump(left_d)
                    # print ast.dump(right_d)
                    a_d  = ast.BinOp(op=ast.Mult(), left=node.left, right=right_d)
                    b_d  = ast.BinOp(op=ast.Mult(), left=left_d, right=node.right)
                    node_d = ast.BinOp(op=ast.Add(),  left=a_d, right=b_d)
                    # print ast.dump(node_d)
                    # print '-'*20

                elif isinstance(node.op, ast.Add):
                    # print 'Add'
                    # print ast.dump(node.left)
                    # print ast.dump(node.right)
                    left_d = self.visit(node.left, d=True)
                    right_d = self.visit(node.right, d=True)
                    node_d = ast.BinOp(op=ast.Add(),  left=left_d, right=right_d)
                else:
                    pass#raise NotImplemented()

            elif isinstance(node, ast.Name):
                return ast.Name(id='%s_d'%node.id)
            else:
                pass#raise NotImplemented()
            return node_d
In [5]:
# transform source code instructions
MyVisitor().visit(node)
# unparse instructions to derivative source code
print codegen.to_source(ast.parse(node))
y_d = x1_d + x1 * x2_d + x1_d * x2
y = x1 + x1 * x2
In [6]:
print source
y = x1 + x1*x2
In [7]:
ast.dump(node)
Out[7]:
"Module(body=[Assign(targets=[Name(id='y_d')], value=BinOp(left=Name(id='x1_d'), op=Add(), right=BinOp(left=BinOp(left=Name(id='x1', ctx=Load()), op=Mult(), right=Name(id='x2_d')), op=Add(), right=BinOp(left=Name(id='x1_d'), op=Mult(), right=Name(id='x2', ctx=Load()))))), Assign(targets=[Name(id='y', ctx=Store())], value=BinOp(left=Name(id='x1', ctx=Load()), op=Add(), right=BinOp(left=Name(id='x1', ctx=Load()), op=Mult(), right=Name(id='x2', ctx=Load()))))])"
BinOp(
    left=Name(id='x1_d'), 
    op=Add(), 
    right=BinOp(
        left=BinOp(
            left=Name(id='x1', ctx=Load()), 
            op=Mult(), 
            right=Name(id='x2_d')
        ), 
        op=Add(), 
        right=BinOp(
            left=Name(id='x1_d'), 
            op=Mult(), 
            right=Name(id='x2', ctx=Load())
        )
    )
)