How to do matrix derivatives
Suppose you have the following scalar function of matrix variable W.
What’s the derivative with respect to matrix W? Define“matrix derivative” Df as “the thing that you subtract from your variable to go in the steepest descent direction”. IE, your gradient descent update would use Df as follows:
The old-school approach is to write out objective as a sum, differentiate, then stare at the result until you can see the corresponding matrix notation. This was the approach we used in Andrew Ng’s UFLDL course taught at Google (this course led to the formation of the Google Brain team)
Taking this approach, convert to summation:
Differentiate and convert back to matrix form by visually identifying the matrix form:
However, this visual identification approach doesn’t scale. What if we wanted derivative of the following function in matrix notation?
We know that the derivative is the following (p.201 of Magnus’ Matrix Calculus book)
How do we derive this expression?
Converting to summation and back is cumbersome. Fortunately, there’s a much better procedure. An end-to-end example motivated by neural network training is in my answer on math.SE forum, but to summarize, you need to:
- Compute differential instead of derivative.
- Manipulate the differential into one of several standard forms
- Extract the derivative from the differential by looking up standard form in the “identification table” (p. 199, p.215 in Magnus)
Differential is closely related to derivative. In the Taylor series expansion, differential is the second term of the series, while the derivative is the coefficient in front of the second term.
Because it’s closely related, rules like derivative chain-rule have corresponding differential versions. The reason to use differentials, is that in the expansion above, derivative is a rank-4 tensor, while differential is a matrix, hence dealing with differentials keeps you in familiar matrix land.
As to standard identification forms, the only entry in the identification table you probably need is the following:
You can also use Mathematica’s matrix differentiation package to compute this expression automatically: example