High-dimensional limit theorems for SGD: Effective dynamics and critical scaling
NeurIPS• 2022
Abstract
We study the scaling limits of stochastic gradient descent (SGD) with
constant step-size in the high-dimensional regime. We prove limit theorems for
the trajectories of summary statistics (i.e., finite-dimensional functions) of
SGD as the dimension goes to infinity. Our approach allows one to choose the
summary statistics that are tracked, the initialization, and the step-size. It
yields both ballistic (ODE) and diffusive (SDE) limits, with the limit
depending dramatically on the former choices. We show a critical scaling regime
for the step-size, below which the effective ballistic dynamics matches
gradient flow for the population loss, but at which, a new correction term
appears which changes the phase diagram. About the fixed points of this
effective dynamics, the corresponding diffusive limits can be quite complex and
even degenerate. We demonstrate our approach on popular examples including
estimation for spiked matrix and tensor models and classification via two-layer
networks for binary and XOR-type Gaussian mixture models. These examples
exhibit surprising phenomena including multimodal timescales to convergence as
well as convergence to sub-optimal solutions with probability bounded away from
zero from random (e.g., Gaussian) initializations. At the same time, we
demonstrate the benefit of overparametrization by showing that the latter
probability goes to zero as the second layer width grows.