[back to home]

[2024-05-23] Next-token prediction = constrative learning

So you know SimSiam, right?

  1. Dataset: distribution of text sequences
  2. View 1 augmentation distribution: always returns the first first N-1 tokens
  3. View 2 augmentation distribution: always returns the last token
  4. Encoder network: the token embedding matrix, supposing tied inputs/outputs
  5. Projector network: N/A
  6. Predictor network: the entire transformer, excluding the embeddings
  7. Loss function: defined over a batch, i.e. a contrastive variation. Same softmax formulation as CLIP
  8. A batch?: every possible token for view 2; loss estimated with only one sample for view 1
  9. Stop grads: gone, but it's fine, because contrastive

👍