How many matmuls are needed to compute Hessian-vector products?
Suppose you have a simple composition of d dense functions. Computing Jacobian needs d matrix multiplications. What about computing Hessian vector product?
You can calculate it manually by differentiating function composition twice, grouping shared work together in temporary messages, and then counting the number of matrix multiplications. One trick is that equations when differentiating with the full multivariate chain rule are equivalent to those in the scalar case, provided you don’t treat multiplication as commutative. So then you can simply do the derivative in the scalar case, and the number of multiplications you obtain corresponds to the number of matrix multiplications in the matrix case.
Here’s the calculation worked out on a simple example above:
You can see there are up D temporary messages, and computing computing all the messages requires 5 multiplications, marked in purple in computation above.
Compare this against in PyTorch — create linear neural network with many layers, forward pass would require d matrix multiplications, whereas HVP would need 5*d. In practice, it takes 5-5.3x longer than forward pass, which is close to what’s expected.