Neural Networks that Learn Algorithms Implicitly

The remarkable capability of Transformers to show reasoning and few-shot abilities, without any fine-tuning, is widely conjectured to stem from their ability to implicitly simulate multi-step algorithms—such as gradient descent—with their weights in a single forward pass. Recently, there has been progress in understanding this complex phenomenon from an expressivity point of view, by demonstrating that Transformers can express such multi-step algorithms. However, our knowledge about the more fundamental aspect of its learnability, beyond single layer models, is very limited. In particular, can training Transformers enable convergence to algorithmic solutions? In [Gatmiry et al., ICML 2024], we resolve this for in-context linear regression with linear looped Transformers—a multi-layer model with weight sharing that is conjectured to have an inductive bias to learn fix-point iterative algorithms. More specifically, for this setting we show that the global minimizer of the population training loss implements multi-step preconditioned gradient descent, with a preconditioner that adapts to the data distribution. Furthermore, we show a fast convergence for gradient flow on the regression loss, despite the non-convexity of the landscape, by proving a novel gradient dominance condition. To our knowledge, this is the first theoretical analysis for a multi-layer Transformer in this setting.

Team Members

Stefanie Jegelka1
Yusu Wang2

Collaborators

Shashank Reddi3
Sanjeev Kumar3

1. MIT
2. UC San Diego
3. Google

Publications

ICML 2024 >