[back to home]

[2024-07-29] discrete diffusion was always going to suck

Relevant:
  1. arxiv.org/abs/2107.03006, Structured Denoising Diffusion Models in Discrete State-Spaces
  2. arxiv.org/abs/2310.16834, Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution
  3. arxiv.org/abs/2406.07524, Simple and Effective Masked Diffusion Models

Before we ask, "why use diffusion in discrete spaces?", we have to settle, "why use diffusion in continuous spaces?"

Ultimately, a generative model is a function which maps noise to data. When modeling real world continuous data, like images, this mapping will almost always be fairly non-smooth and... somewhat questionable to directly regress with a neural network. Hence, unstable GANs and blurry VAEs.

questionable := neural networks learn low frequency functions first, so regressing anything non-smooth puts you in overfitting territory. see: arxiv.org/abs/1806.08734 and arxiv.org/abs/2403.02241

Diffusion works, practically speaking, because the model itself is almost always regressing something smooth.

To see this, note that:

  1. The score of a single gaussian with fixed variance is linear
  2. At each timestep, the diffusion model is regressing a weighted combination of Gaussian scores / linear functions, one for each point in the training dataset.
  3. Each weight is based on the distance to that point, i.e. proportional to the Gaussian density, which is itself smooth.
  4. The variances of these Gaussians over time is also smooth, usually.

Of course, this no longer holds towards the end of the reverse process, where the variance approaches zero, but by that point the step size is small enough such that even if the model is underfit, the impact on the resulting generations is relatively small.

I'd hazard a guess that practically any approach that can decompose a generative process for a non-smooth distribution into a series of effectively smooth functions would perform as well as diffusion. yadda yadda cold diffusion

The reason that diffusion has never really brought anything new to discrete spaces (read: it's just BERT with extra steps) stems from the fact that this advantage doesn't transfer.

In particular, it's generally possible to model distribution over discrete input + output vocabularies with a smooth-ish function. E.g. focusing on just transformer-based LLMs for now, due to the input embedding layer, unique tokens can be mapped to arbitrarily distant regions in a transformer's input embedding space. If no two inputs are close together, then it's already possible to fit the data with a smooth function.

Of course, this only really applies to vocabulary-style discrete inputs, i.e. each dimension is a categorical variable. For ordinal variables, where some sense of input distance returns, diffusion could still have a place. Images are technically discrete variables with 256 bins, after all.