Differentiation on a graph

Here is the setup that we work with - a function is shown as a (input, output) edge. The Jacobian of this function is written on the edge itself:xRnShape (m,n)y/xyRmThis representation is naturally extendable to a composition of functions z=g(y),y=f(x). Next we see how chain rule looks like on this graph.
0. Chain rule on a graph
Consider two cases:
In the first case, we have a simple composition of functions: z=g(y),y=f(x)
.
1. Jacobian wrt input is the product of Jacobians on each edge
x RnShape (m,n)Jyy RmShape (p,m)Jzz Rpz/x=JzJyShape (p,n)
Next we look at a DAG which has two inputs and one output. In this case, the Jacobian wrt input is the sum of Jacobians of each path: dzdx=(zyJ3)(yxJ2)+zxJ1. The first path is J2J3 and the second path is J1.
2. Jacobian wrt input is the sum of Jacobians of each path
x Rny Rmz RpJ1(p,n)J2(m,n)J3(p,m)z/x=J3J2(p,n)+J1(p,n)
This is all we need to fully define the chain rule on a graph. In the next example we combine both these cases.
0.1 Chain rule on a larger graph
Consider a DAG of a function w=f(x),wR,xRn.
x Rny Rmz Rpw RJ1(p,n)J2(m,n)J3(p,m)J4(1,p)w/x=J4(J1+J3J2)=J4J1path 1+J4J3J2path 2
This leads to a insight about autodiff on a graph:
Jacobian of y wrt x is the sum of all path-products from x to y.
A path product is the product of all Jacobians along a path.
0.2 Chain rule for variance
Let's revisit the DAG of the variance function:
xR2μyR2y2σ2J1(1,2)J2(2,1)J3(2,2)J4(1,2)J5(1,1)
There are two paths from x to σ2. We calculate the Jacobians along each path and add them up to get the Jacobian of σ2 wrt x: J5=1/2J4=2yTJ3=[1001]J2=[11]J1=[1/21/2]σ2/x=J5J4J2J1+J5J4J3R1×2=yT[1/21/2][11]+yT[1001]=12xT[1111] This is a linear-algebriac way of representing:
the variance function f(x1,x2)=14(x12+x22)12x1x2
the partial derivative wrt x1 which is =12(x1x2)
the partial derivative wrt x2 which is =12(x2x1)
1. Two ways to calculate gradients
Consider multiplication of three matrices J1,J2,J3 of shapes (m,n),(p,m),(q,p). There are two ways to multiply them:
(J3(q,p) J2(p,m))J1(m,n)qpm + qmn
Backward differentiation aka backpropagation
J3(q,p)(J2(p,m) J1(m,n))qpn + pmn
Forward differentiation
One can think of them as Jacobians corresponding to functions aJ1b,bJ2c,cJ3d. Their product is then the Jacobian corresponding to aRndRq.
In most cases in scientific computing, the input dimension n is large and the output dimension q is small (mostly 1). In this case, the first way (multiplying from output towards input) is more efficient. In cases where the output dimension is large or comparable to the input dimension (eg when calculating Hessian), the second way (multiplying from input towards output) is more efficient.
2. Jacobians are an overkill for autodiff
This is something that is often skipped or just skimmed over in a lot of text on autodiff. While we assumed the existence of Jacobians to explain the ideas of differentiation on a graph, in practice it is neither cheap not necessary to actually calculate and store the entire Jacobian matrix for each edge. Let's consider the following function to illustrate the issue:x(n,n)J1y (=x1)J2z (=yF)Here the Jacobian J1 is of shape (n2,n2) and Jacobian J2 is of shape (1,n2). Let's try to calculate z/x just by calculating and then multiplying the Jacobians. This is not efficient but we are doing it just to see why it is not efficient.
2.1 Calculating the Jacbians
We know that:dy=x1(dx)x1vec(dy)=((x1)Tx1)J1vec(dx)Similarly:zy=yzJ2=1z(vec(y))TNow that we have both the Jacobian matrices, we can calculate the product of them to get the Jacobian of z wrt x. The product will be of shape (1,n2). It will need to be reshaped to a (n,n) matrix. In the next part, we will implement it in a more efficient way.
import torch

# perform backprop to calculate `x.grad`
x = torch.randn(100,100, requires_grad=True, dtype=torch.float64)
y = torch.linalg.inv(x)
z = torch.linalg.norm(y)
z.backward()

# calculate the Jacobians 
x_inv = torch.linalg.inv(x).contiguous()
x_inv_t = torch.linalg.inv(x).t()
j1 = torch.kron(-x_inv_t, x_inv)

# need to transpose `y` first
# since pytorch uses row-major order
j2 = (y.t()/z).reshape(1,-1)

# calculate the jacobian and compare with actual gradient
j3 = (j2 @ j1).reshape(100,100).t()
assert torch.allclose(j3, x.grad)