Automatic Differentiation

UPDATE: This code is now available in both Java and Python!

I’ve been on an automatic differentiation kick ever since reading about dual numbers on Wikipedia.

I implemented a simple forward-mode autodiff system in Rust, thinking it would allow me to do ML faster. I failed to realize/read that forward differentiation, while simpler, requires one forward pass to get the derivative of ALL outputs with respect to ONE input variable. Reverse-mode, in contrast, gives you the derivative of all inputs with respect to one output.

That is to say, if I had f(x, y, z) = [a, b, c], forward mode would give me da/dx, db/dx, dc/dx in a single pass. Reverse mdoe would give me da/dx, da/dy, da/dz in a single pass.

Forward mode is really easy. I have a repo with code changes here: https://github.com/JosephCatrambone/RustML

Reverse mode took me a while to figure out, mostly because I was confused about how adjoints worked. I’m still confused, but I’m now so accustomed to the strangeness that I’m not noticing it. Here’s some simple, single-variable reverse-mode autodiff. It’s about 100 lines of Python:

#!/usr/bin/env python
# JAD: Joseph's Automatic Differentiation
from collections import deque
class Graph(object):
def __init__(self):
self.names = list()
self.operations = list()
self.derivatives = list() # A list of LISTS, where each item is the gradient with respect to that argument.
self.node_inputs = list() # A list of the indices of the input nodes.
self.shapes = list()
self.graph_inputs = list()
self.forward = list() # Cleared on forward pass.
self.adjoint = list() # Cleared on reverse pass.
def get_output(self, input_set, node=-1):
self.forward = list()
for i, op in enumerate(self.operations):
self.forward.append(op(input_set))
return self.forward[node]
def get_gradient(self, input_set, node, forward_data=None):
if forward_data is not None:
self.forward = forward_data
else:
self.forward = list()
for i, op in enumerate(self.operations):
self.forward.append(op(input_set))
# Initialize adjoints to 0 except our target, which is 1.
self.adjoint = [0.0]*len(self.forward)
self.adjoint[node] = 1.0
gradient_stack = deque()
for input_node in self.node_inputs[node]:
gradient_stack.append((input_node, node)) # Keep pairs of target/parent.
while gradient_stack: # While not empty.
current_node, parent_node = gradient_stack.popleft()
for dop in self.derivatives[current_node]:
self.adjoint[current_node] += self.adjoint[parent_node]*dop(input_set)
for input_arg in self.node_inputs[current_node]:
gradient_stack.append((input_arg, current_node))
return self.adjoint
def get_shape(self, node):
return self.shapes[node]
def add_input(self, name, shape):
index = len(self.names)
self.names.append(name)
self.operations.append(lambda inputs: inputs[name])
self.derivatives.append([lambda inputs: 1])
self.node_inputs.append([])
self.graph_inputs.append(index)
self.shapes.append(shape)
return index
def add_add(self, name, left, right):
index = len(self.names)
self.names.append(name)
self.operations.append(lambda inputs: self.forward[left] + self.forward[right])
self.derivatives.append([lambda inputs: 1, lambda inputs: 1]) # d/dx a + b = 1 + 0 or 0 + 1
self.node_inputs.append([left, right])
self.shapes.append(self.get_shape(left))
return index
def add_multiply(self, name, left, right):
index = len(self.names)
self.names.append(name)
self.operations.append(lambda inputs: self.forward[left] * self.forward[right])
self.derivatives.append([lambda inputs: self.forward[right], lambda inputs: self.forward[left]])
self.node_inputs.append([left, right])
self.shapes.append(self.get_shape(left))
return index
if __name__=="__main__":
g = Graph()
x = g.add_input("x", (1, 1))
y = g.add_input("y", (1, 1))
a = g.add_add("a", x, y)
b = g.add_multiply("b", a, x)
input_map = {'x': 2, 'y': 3}
print(g.get_output(input_map)) # 10
print(g.get_gradient(input_map, b)) # 3, 2, 2, 1.

Comments are closed.