Differentiation on a graph
Here is the setup that we work with - a function is shown as a (input, output) edge. The Jacobian of this function is written on the edge itself:$$ x \in R^n \xrightarrow[{\amber{\text{Shape (m,n)}}}]{ \amber{\partial y / \partial x} } y \in R^m $$This representation is naturally extendable to a composition of functions \( \rose{z = g(y)}, y = f(x) \). Next we see how chain rule looks like on this graph.
0. Chain rule on a graph
Consider two cases:
In the first case, we have a simple composition of functions: \( \rose{z = g(y)}, y = f(x) \)
.1. Jacobian wrt input is the product of Jacobians on each edge
Next we look at a DAG which has two inputs and one output. In this case, the Jacobian wrt input is the sum of Jacobians of each path: \(
\frac{dz}{dx} =
(\underbrace{\frac{\partial z}{\partial y}}_{\amber{J_3}})
(\underbrace{\frac{\partial y}{\partial x}}_{\amber{J_2}})
+
\underbrace{\frac{\partial z}{\partial x}}_{\amber{J_1}}
\). The first path is \( J_2 \rightarrow J_3 \) and the second path is \( J_1 \).
2. Jacobian wrt input is the sum of Jacobians of each path
This is all we need to fully define the chain rule on a graph. In the next example we combine both these cases.
0.1 Chain rule on a larger graph
Consider a DAG of a function \( w = f(x), w \in R, x \in R^n \).
This leads to a insight about autodiff on a graph:
Jacobian of y wrt x is the sum of all path-products from x to y.
A path product is the product of all Jacobians along a path.
0.2 Chain rule for variance
Let's revisit the DAG of the variance function:
There are two paths from \( x \) to \( \sigma^2 \). We calculate the Jacobians along each path and add them up to get the Jacobian of \( \sigma^2 \) wrt \( x \):
\begin{align}
J_5 &= 1/2 \\
J_4 &= 2y^T \\
J_3 &= \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix} \\
J_2 &= \begin{bmatrix} -1 \\ -1 \end{bmatrix} \\
J_1 &= \begin{bmatrix} 1/2 & 1/2 \end{bmatrix} \\
\\
\partial \sigma^2 / \partial x &= J_5J_4J_2J_1 + J_5J_4J_3 \amber{\in R^{1 \times 2}} \\
&= y^T \begin{bmatrix} 1/2 & 1/2 \end{bmatrix} \begin{bmatrix} -1 \\ -1 \end{bmatrix} + y^T \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix} \\
&= \frac{1}{2} x^T \begin{bmatrix} 1 & -1 \\ -1 & 1 \end{bmatrix} \\
\end{align}
This is a linear-algebriac way of representing:
the variance function \( f(x_1, x_2) = \frac{1}{4}(x_1^2 + x_2^2) - \frac{1}{2}x_1x_2 \)
the partial derivative wrt \( x_1 \) which is\( \ = \frac{1}{2} (x_1 - x_2) \)
the partial derivative wrt \( x_2 \) which is\( \ = \frac{1}{2} (x_2 - x_1) \)
1. Two ways to calculate gradients
Consider multiplication of three matrices \( J_1, J_2, J_3 \) of shapes \( (m,n), (p,m), (q,p) \). There are two ways to multiply them:
$$
\underbrace{( \underset{\rose{(q,p)}}{J_3} \ \underset{\rose{(p,m)}}{J_2} ) \underset{\rose{(m,n)}}{J_1}}_{\amber{qpm \ + \ qmn}}
$$$$
\underbrace{ \underset{\rose{(q,p)}}{J_3} (\underset{\rose{(p,m)}}{J_2} \ \underset{\rose{(m,n)}}{J_1})}_{\amber{qpn \ + \ pmn}}
$$One can think of them as Jacobians corresponding to functions \( a \xrightarrow{J_1} b, b \xrightarrow{J_2} c, c \xrightarrow{J_3} d \). Their product is then the Jacobian corresponding to \( a \in R^n \rightarrow d \in R^q \).
Backward differentiation aka backpropagation
Forward differentiation
In most cases in scientific computing, the input dimension \( n \) is large and the output dimension \( q \) is small (mostly 1). In this case, the first way (multiplying from output towards input) is more efficient. In cases where the output dimension is large or comparable to the input dimension (eg when calculating Hessian), the second way (multiplying from input towards output) is more efficient.
2. Jacobians are an overkill for autodiff
This is something that is often skipped or just skimmed over in a lot of text on autodiff. While we assumed the existence of Jacobians to explain the ideas of differentiation on a graph, in practice it is neither cheap not necessary to actually calculate and store the entire Jacobian matrix for each edge. Let's consider the following function to illustrate the issue:$$
\underbrace{x}_{\rose{(n,n)} } \xrightarrow{J_1}
\underbrace{y}_{\rose{\ (=x^{-1})}} \xrightarrow{J_2}
\underbrace{z}_{\rose{\ (=\|y\|_F)} }
$$Here the Jacobian \( J_1 \) is of shape \( (n^2,n^2) \) and Jacobian \( J_2 \) is of shape \( (1,n^2) \). Let's try to calculate \( \partial z / \partial x \) just by calculating and then multiplying the Jacobians. This is not efficient but we are doing it just to see why it is not efficient.
2.1 Calculating the Jacbians
We know that:$$
\begin{align}
dy &= -x^{-1}(dx)x^{-1} \\
\Rightarrow \text{vec}(dy) &=
\underbrace{({(x^{-1})}^T \otimes {x^{-1}}) }_{\amber{J_1}}
\text{vec}(dx) \\
\end{align}
$$Similarly:$$
\\
\begin{align}
\nabla z_y &= \frac{y}{z} \\
\Rightarrow \amber{J_2} &= \frac{1}{z} (\text{vec}(y))^T
\end{align}
$$Now that we have both the Jacobian matrices, we can calculate the product of them to get the Jacobian of \( z \) wrt \( x \). The product will be of shape \( (1,n^2) \). It will need to be reshaped to a \( (n,n) \) matrix. In the next part, we will implement it in a more efficient way.
import torch
# perform backprop to calculate `x.grad`
x = torch.randn(100,100, requires_grad=True, dtype=torch.float64)
y = torch.linalg.inv(x)
z = torch.linalg.norm(y)
z.backward()
# calculate the Jacobians
x_inv = torch.linalg.inv(x).contiguous()
x_inv_t = torch.linalg.inv(x).t()
j1 = torch.kron(-x_inv_t, x_inv)
# need to transpose `y` first
# since pytorch uses row-major order
j2 = (y.t()/z).reshape(1,-1)
# calculate the jacobian and compare with actual gradient
j3 = (j2 @ j1).reshape(100,100).t()
assert torch.allclose(j3, x.grad)