[2024-07-30] BYOL learning dynamics are simple and well understood
This post is essentially a simplified version of one my earlier posts and is meant to help build some basic intuition about the learning dynamics of BYOL. The concepts covered in this post are not novel. I see many people still citing BYOL as some kind of unexplainable arcane wizardry, or describing it as a distillation-based method or as a method which relies on a non-symmetric architecture. The latter are both technically true, but fail to capture what's actually going on. This post is essentially an explanation of how BYOL works, as fast as possible.
A quick refresher on Bootstrap Your Own Latent (BYOL) (arxiv.org/abs/2006.07733):
- the encoder maps an image to a vector embedding \(=: z\)
- the predictor takes an embedding, \(z\), of an augmentation of an image and regresses the average embedding over all augmentations of that image \(=: \overline{z}\)
- loss = the predictor's error
Somehow, the encoder doesn't learn to map everything to the same vector, which would technically make the predictor's loss minimal. Why?
We can assume the predictor is nearly optimal. (This is reasonable since, in practice, the target for the predictor changes more slowly than the encoder does, due to momentum. The BYOL authors noticed similar effects when removing momentum and increasing the predictor's learning rate by 10x, which has a similar effect)
On any particular image input, the encoder is incentivized to make the predictor more accurate. Due to momentum / the stop grad (see arxiv.org/abs/2011.10566), from the encoder's perspective, the predictor is fixed.
Intuitively, an optimal predictor will generally have lower error on inputs where it is obvious which image in the training dataset an embedding came from. (This is... not always true, but has value as intuition.) For example, if only one image in the training dataset has augmentations that ever map to a given point in embedding space, then the optimal predictor will perfectly predict that image's average embedding, and no update to the encoder can result in better loss for that input.
If we take the above as true, then, during training, the encoder is incentivized to try to map embeddings of distinct images to distinct regions in embedding space, avoiding collapse. See my previous post for more details.
Additional Resources:
- arxiv.org/abs/2006.07733 - BYOL, and its various ablations:
- arxiv.org/abs/2011.10566 - SimSiam, or BYOL without momentum
- arxiv.org/abs/2010.10241 - BYOL without batchnorm
- arxiv.org/abs/2212.03319 - Related: non-collapse w/ linear predictors
- BYOL with a linear predictor recovers a spectral decomposition of the augmentation graph
- arxiv.org/abs/2301.08243 - Related: JEPA
- arxiv.org/abs/2008.01064 - Related: SSL and conditional independence
- Here, \(Y\) can be seen as the instance id of each training image