Hessian and the bilinear form
Hessian of a function is the Jacobian of the Jacobian of the function. A function \( f(x): R^n \rightarrow R \) has a Jacobian \( J_f \) of shape \( (,n) \) which is a function of \( x \). Thus one can imagine a mapping \( x \rightarrow J_f := g(x): R^n \rightarrow R^n \). The Jacobian of \( g \) will have a shape of \( (n,n) \). This is the Hessian of the function \( f \).
We start with calculation of the entire Jacobian matrix - something we have avoided so far.
0. Calculating the Jacobian matrix
We use the inverse function \( f(A) = A^{-1} \) to calculate the Jacobian matrix. First let's calculate the Jacobian using PyTorch as a reference:
import torch
from torch.autograd.functional import jacobian
def inverse(x):
return torch.linalg.inv(x)
x = torch.randn(100,100)
j_x = jacobian(inverse, x)
Now remember the JVP of the inverse function:$$ jvp_f(x,\amber{v}) = -x^{-1}\amber{v}x^{-1} $$This is the compressed (or reshaped) form of the vectorized form of the JVP. Now note that \( \underbrace{J e_i}_{\text{jvp}(x, e_i)} = \text{i-th column of J} \). Using this trick, we can calculate the i-th column of the Jacobian matrix:
i = 1234
# construct a basis vector
basis_vector_i = torch.zeros(100,100)
row = i % 100
col = i // 100
basis_vector_i[row,col] = 1
jacobian_column_i = -y @ basis_vector_i @ y
# need to transpose since pytorch follows row-major order
# also need to transpose the indices of `j_x`
assert torch.allclose(jacobian_column_i, j_x[col, row].t())
Once could then iterate over the columns to calculate the full Jacobian matrix.0.1 Faster calculation by batching
A faster way to calculate the Jacobian matrix is to use batching and calculate \( J [e_1, e_2, ..., e_n] = J \ I = J \) in one go.
eye = torch.eye(10000,10000).view(10000,100,100)
j_manual = -y.unsqueeze(0) @ eye @ y.unsqueeze(0)
# the first two dimensions of `j_x` seem to be arranged in column-major order
# `j_manual.view(100,100, ...)` reshapes in row-major order
# so we need to transpose the first two dimensions
j_manual = j_manual.view(100,100,100,100).transpose(0,1)
assert torch.allclose(j_manual, j_x.transpose(-1,-2))
0.2 Calculating rows vs columns
We used JVP to calculate each column of the Jacobian at a time. One could also calculate each row of the Jacobian at a time using the VJP since\( \ e_i^TJ = \text{i-th row of Jacobian} \).
Remember that the VJP of the inverse function is:$$ vjp_f(x,\amber{v}) = -x^{-T}\amber{v}x^{-T} $$One can then use the same procedure as above to calculate the rows of the Jacobian matrix (try it yourself).
1. Hessian matrix
Let \( g(x) \in R^n \) be the gradient of \( f(x): R^n \rightarrow R \). Consider the function \( h(x): R^n \rightarrow R^n = f \rightarrow g \) written in matrix notation:$$
\begin{bmatrix}
x_1 \\
x_2 \\
\vdots \\
x_n
\end{bmatrix}
\xrightarrow{h}
\begin{bmatrix}
\amber{g_1} \\
\rose{g_2} \\
\vdots \\
\zinc{g_n}
\end{bmatrix}
=
\begin{bmatrix}
\amber{\partial f / \partial x_1} \\
\rose{\partial f / \partial x_2} \\
\vdots \\
\zinc{\partial f / \partial x_n}
\end{bmatrix}
$$The Jacobian of \( h \) is the Hessian of \( f \):$$
J_h =
\underbrace{
\begin{bmatrix}
\partial \amber{g_1} / \partial x_1 & \partial \amber{g_1} / \partial x_2 & \cdots & \partial \amber{g_1} / \partial x_n \\
\partial \rose{g_2} / \partial x_1 & \partial \rose{g_2} / \partial x_2 & \cdots & \partial \rose{g_2} / \partial x_n \\
\vdots & \vdots & \ddots & \vdots \\
\partial \zinc{g_n} / \partial x_1 & \partial \zinc{g_n} / \partial x_2 & \cdots & \partial \zinc{g_n} / \partial x_n
\end{bmatrix}
}_{\amber{\text{Hessian of } f}}
$$We could expand each term \( \amber{\partial g_i} / \partial x_j = \amber{\partial ^2 f} / \partial x_j \amber{\partial x_i} \) in the above matrix to get a more standard form.
1.1 Hessian is symmetric
There are two ways to look at it:
The second derivative is symmetric:$$ \underbrace{\frac{\partial ^2 f}{\partial \amber{x_j} \partial \rose{x_i}}}_{\zinc{H_{ij}}} = \underbrace{\frac{\partial ^2 f}{\partial \rose{x_i} \partial \amber{x_j}}}_{\zinc{H_{ji}}} $$
We can calculate how a change in \( g \) affects a change in \( f \):$$
\begin{align}
dg &= g(x + \amber{dx_g}) - g(x) \\
\Rightarrow \underbrace{H(\amber{dx_g})}_{\in \amber{R^n}} &= g(x + \amber{dx_g}) - g(x) \\
\Rightarrow \underbrace{\rose{dx_f^T}H(\amber{dx_g})}_{\in R} &=
\underbrace{\langle \rose{dx_f}, g(x + \amber{dx_g}) \rangle}_{\text{change in} \ f(x + \amber{dx_g})}
-
\underbrace{\langle \rose{dx_f}, g(x) \rangle}_{\text{change in} \ f(x)}
\\
&= \underbrace{f(x + \amber{dx_g} + \rose{dx_f})}_{\text{symmetric in} \ \amber{dx_g}, \rose{dx_f}}
-
\underbrace{f(x + \amber{dx_g}) - f(x + \rose{dx_f})}_{\text{symmetric in} \ \amber{dx_g}, \rose{dx_f}}
+ f(x) \\
&= \amber{dx_g^T}H(\rose{dx_f})
\end{align}
$$But since a scalar is its own transpose, we get: \( H = H^T \).
1.2 Bilinear form
Now let's take a look at the LHS of the above expression:$$
\rose{dx_f^T}H(\amber{dx_g}) =
\begin{bmatrix} \cdots & dx_f^T & \cdots \end{bmatrix}
\begin{bmatrix}
\ & \vdots & \ \\
\cdots & H & \cdots \\
\ & \vdots & \ \\
\end{bmatrix}
\begin{bmatrix} \vdots \\ dx_g^T \\ \vdots \end{bmatrix}
$$The obvious interpretation of this value is that it is the change in (change in f) at two points \( x \) and \( x + \amber{dx_g} \). At each point you change the input by \( \rose{dx_f} \) and measure the change in output, say \( df_1, df_2 \). The bilinear form is the difference between the two: \( df_1 - df_2 \).
Since \( \amber{dx_g}, \rose{dx_f} \) are interchangeable, we can swap them in the above argument.
The bilinear form can be seen as the difference \( dy_2 - dy_1 \)
Hessian is used in optimization by using second order approximation:$$ f(x + dx) = f(x) + \underbrace{\amber{\nabla f(x)^T}}_{\amber{\text{gradient}}} dx + \frac{1}{2} dx^T \underbrace{\amber{H}}_{\amber{\text{Hessian}}} dx $$In machine learning, it is used in quantization of large models. Check out optimal brain compression for more information.
2. Calculating a Hessian in PyTorch
Since Hessian is just the grad of the grad of a scalar-valued function, one can simply take the gradient of the gradient of a scalar-valued function. The only trick to be aware of is that one needs to be able to have a DAG for a function that maps the input to the gradient of the function. Once that is taken care of, the rest is standard PyTorch `autograd.grad` stuff.
In the example below, we calculate the Hessian of \( f(x,y) = x^2y + sin(x+y) \):
import torch
from torch.autograd import grad
from torch.autograd.functional import hessian
def f(x):
return x[0]**2 * x[1] + torch.sin(x.sum())
x = torch.rand(2, requires_grad=True)
y = f(x)
# `create_graph=True` is important!
g = grad(y, x, create_graph=True)[0]
first_row_of_hessian = grad(g, x, grad_outputs=torch.Tensor([1,0]), retain_graph=True)[0]
second_row_of_hessian = grad(g, x, grad_outputs=torch.Tensor([0,1]))[0]
hessian_manual = torch.stack([first_row_of_hessian, second_row_of_hessian], dim=0)
hessian_torch = hessian(f, x)
assert torch.allclose(hessian_manual, hessian_torch)
One may have noticed that since we calculated each row of the Hessian at a time, we actually used the vjp function to calculate the gradient of the gradient of a scalar-valued function. Since the Hessian is a square matrix, it may be better to use the jvp function to calculate each column of the Hessian at a time since it does not have any memory overhead. JAX documentation provides more details on how and when to use the jvp and vjp functions.