[2024-05-23] Next-token prediction == constrative learning
So you know SimSiam, right?
- Dataset: distribution of text sequences
- View 1 augmentation distribution: always returns the first first N-1 tokens
- View 2 augmentation distribution: always returns the last token
- Encoder network: the token embedding matrix, supposing tied inputs/outputs
- Projector network: N/A
- Predictor network: the entire transformer, excluding the embeddings
- Loss function: defined over a batch, i.e. a contrastive variation. Same softmax formulation as CLIP
- A batch?: every possible token for view 2; loss estimated with only one sample for view 1
- Stop grads: gone, but it's fine, because contrastive
👍