ICLR optimization papers I: Fluctuation-Dissipation relations for SGD
--
In this series of posts I will talk about optimization papers that caught my eye at ICLR 2019. The first post in the series is an overview of “Fluctuation-dissipation relations for stochastic gradient descent” by Sho Yaida.
This paper uses beautifully simple math in order to characterize the behavior of SGD at convergence.
Unlike regular optimization, SGD updates never stop changing parameters. There’s been a number of efforts trying to characterize what SGD looks like when it has converged, and how learning rate, momentum affect convergence behavior and time to convergence.
Several recent papers use differential equations. They look at SGD in the limit of learning rate going to 0 and connect it to known results for the corresponding continuous system. However, it’s not clear how well these results generalize to non-zero learning rates. Sho Yaida derives the relationship for non-zero learning rates from basic principles.
The key tool is the following “Master Equation”
This characterizes the behavior of SGD at convergence. Angle bracket denote expectation over steps, theta is parameter value, O is arbitrary function and eta is learning rate. The equation says: “once you’ve converged, statistics after SGD step = statistics before the step, on average.” This is the stochastic version of the gradient descent fixed point equation.
If you let O be the outer product and rearrange, you get the following equation, FDR1
C refers to empirical Fisher Information Matrix. Taking the trace is equivalent to adding up squares of gradient components. If our setting is a Noisy Quadratic Model (NQM), we can further simplify this to the following equation, NQM1
To illustrate this relationship, consider two-dimensional NQM model at convergence.
Because of noise, optimization travels in some neighborhood of the true minimum, while the loss is contained within a band
Equation NQM1 tells us how to get the width of this band from magnitudes of squared gradients and learning rate.
An interesting thing to note, is that this is valid for any NQM, which may have different shapes depending on the Hessian. Here’s one with an ill-conditioned Hessian
Changing the Hessian increased the oscillations and hence the width of the loss band, yet the NQM1 doesn’t depend on Hessian, which might be surprising.
The equation actually includes the Hessian, masquerading as C. At convergence, empirical Fisher Information Matrix is proportional to Fisher Information Matrix in expectation, which is equal to Hessian for the case of NQM.
You get analogous gradient-Hessian connections for more general models, see derivation here.
To get higher order terms, the paper applies Taylor expansion to the Master Equation, resulting in FDR2 below
At lower learning rates, the higher order term vanishes so you can use this equation to estimate magnitude of the Hessian.
At higher learning rates, plug in your Hessian estimate, and you get the “higher order terms” part. This term measure the degree of anharmonicity — how much your optimization problem deviates from a noisy quadratic.
Versions of these equations are also derived for the case of SGD with momentum.
In his neural network experiments, Sho Yaida discovered that SGD for a convolution network on CIFAR displayed a high degree of anharmonicity, thereby invalidating the use of harmonic approximations for that problem.
Ben Mann has implemented fluctuation dissipation equations for learning rate scheduling for PyTorch.
We’ve had mixed practical success with it. For some problems it detected convergence within a reasonable time-frame, whereas for others it required many epochs after the optimization “seemed” to have converged. This also happened when we tried a Noisy Quadratic problem initialized at the true solution point.
Breaking down the estimator, it turned out that variance of left hand side of FDR1 can be huge which meant that we haven’t run the optimization for long enough to get an accurate value.
Empirically, this variance seems to grow linearly with variance of the gradients, and in inverse proportion to the learning rate (notebook)
It would be interesting to get a theoretical analysis of estimator variance in terms of learning rate, batch size, and label noise. This would allow us to characterize the set of problems for which this estimator is cheap and accurate.
Thanks to Ben Mann for extensive proof-reading.
Links:
- Fluctuation-Dissipation learning rate scheduler for PyTorch:
- Variance of left hand term: notebook
- Noisy Quadratic in Mathematica: notebook
- Noisy Quadratic in PyTorch: test_fd_quadratic.py
- Hessian/FIM connection for ReLU networks: derivation