modified source code transformation example from Sebastian F. Walter's blog post .
import ast
import codegen
%%writefile func.py
y = x1 + x1*x2
with open("func.py") as f:
source = f.read()
# en with open
node = ast.parse(source)
node = ast.fix_missing_locations(node)
ast.dump(node)
BinOp(
left=Name(id='x1', ctx=Load()),
op=Add(),
right=BinOp(
left=Name(id='x1', ctx=Load()),
op=Mult(),
right=Name(id='x2', ctx=Load())
)
)
class MyVisitor(ast.NodeVisitor):
def visit(self, node, d=False):
if d == False:
if isinstance(node, ast.Module):
line_number = 0
while line_number < len(node.body):
line = node.body[line_number]
if isinstance(line, ast.Assign):
# derivative code
expr = node.body[line_number].value
id = line.targets[0].id
node_d = self.visit(expr, True)
node.body.insert(line_number, ast.Assign(targets=[ast.Name(id='%s_d'%id)], value=node_d))
line_number += 2
else:
self.visit(line)
line_number += 1
else:
super(MyVisitor, self).visit(node)
else:
# print 'visit node_d'
# print type(node)
if isinstance(node, ast.BinOp):
# print 'Binop'
if isinstance(node.op, ast.Mult):
# print 'v'*20 + 'Mult'
# print ast.dump(node.left)
# print ast.dump(node.right)
left_d = self.visit(node.left, d=True)
right_d = self.visit(node.right, d=True)
# print ast.dump(left_d)
# print ast.dump(right_d)
a_d = ast.BinOp(op=ast.Mult(), left=node.left, right=right_d)
b_d = ast.BinOp(op=ast.Mult(), left=left_d, right=node.right)
node_d = ast.BinOp(op=ast.Add(), left=a_d, right=b_d)
# print ast.dump(node_d)
# print '-'*20
elif isinstance(node.op, ast.Add):
# print 'Add'
# print ast.dump(node.left)
# print ast.dump(node.right)
left_d = self.visit(node.left, d=True)
right_d = self.visit(node.right, d=True)
node_d = ast.BinOp(op=ast.Add(), left=left_d, right=right_d)
else:
pass#raise NotImplemented()
elif isinstance(node, ast.Name):
return ast.Name(id='%s_d'%node.id)
else:
pass#raise NotImplemented()
return node_d
# transform source code instructions
MyVisitor().visit(node)
# unparse instructions to derivative source code
print codegen.to_source(ast.parse(node))
print source
ast.dump(node)
BinOp(
left=Name(id='x1_d'),
op=Add(),
right=BinOp(
left=BinOp(
left=Name(id='x1', ctx=Load()),
op=Mult(),
right=Name(id='x2_d')
),
op=Add(),
right=BinOp(
left=Name(id='x1_d'),
op=Mult(),
right=Name(id='x2', ctx=Load())
)
)
)