[2024-11-25] Training for Infinite Contexts via Gradient Bootstrapping

With recurrent models, it's easy to condition on long contexts during training and inference, by simply saving the (detached) recurrent state and passing it to the next input segment.

However, such approaches don't produce a learning signal that encourages the model to effectively compress states. Learning how to use information which has already been stored is one thing, learning what information to store is another. The latter is arguably less important for transformers, since they do no compression of the state at all. Of course, transformers were doomed from the start when it comes to doing inference with infinite context. Once the KV cache caps out the memory, we're back to the problem of deciding how to compress the state---this post is about learning how to compress, for cases where we may want to hold onto information for longer than the total number of model inputs that can be held in memory during training.

The obvious method for training for a context length longer than the number of tokens you can store in memory is via reinforcement learning, which has always had to deal with the problem of long-horizon credit assignment.

For recurrent or state-space models, you could imagine treating the hidden state update as an action, and optimizing for the reward signal of next-token cross entropy.

Using empirical return (the actual cumulative discounted future cross entropy of the model's future predictions) to optimize the actions is out of the question for long context lengths, since we need to hold onto the inputs that we want to do gradient descent with until we have the estimate for the future cross entropy and can compute the loss. A typical solution is using value function approximation and training via bootstrapping. For those unfamiliar with reinforcement learning, if we let \(x_t\) denote the current state, we can learn the (on policy) value \(V(x_t)\) by taking a step, observing a reward \(R_t\) (e.g. negative cross entropy), evaluating \(V(x_{t + 1})\), and training to minimize \((V(x_t) - R_t - \gamma V(x_{t + 1}))^2\). The discount factor \(\gamma\), \(0 \lt \gamma \lt 1\), weights rewards in the immediate future more than rewards in the distant future, and also bounds the value target (which would be an infinite sum for infinite length contexts) to finite values.

Unfortunately, using reinforcement learning directly in this case is a bit terrible. The model needs to do exploration over a distribution of actions (updates to the recurrent state) until it randomly stumbles upon a "better" state update, but it also needs to be consistent enough in its behavior that the value function can provide reliable estimates about which state updates are actually beneficial. Considering that our action space has the same dimensionality as the recurrent state of our model (for smaller Mamba models, the hidden state has, what, a few hundred thousand effective dimensions?), the sample efficiency here would almost certainly be abysmal.

"big dimensionality -> use gradient??"

What a good idea. Nobody's ever heard that one before.

The idea here is that for any given prefix of the input sequence, the only learning signal that we would have obtained from the distant future which applies to the computation we perform on this prefix, if we had simply trained with a longer sequence length, is captured in the gradient of the hidden state generated after the last input in the prefix. Uh, how do I write that sentence in a less terrible way? Basically, when we train with a long sequence length, we're using the gradient of the future losses w.r.t. the parameters. The computation graph that produces those gradients, can be cut at the hidden states. So, if we have the gradient of the future losses w.r.t. the hidden state at timestep \(t\), that's enough to compute the gradient of those future losses w.r.t. the parameters for the inputs we provided before \(t\), by simply backpropogating.

It follows that if we can estimate the expected gradient of the future loss w.r.t. the current hidden state, then we can estimate the expected gradient of the future loss w.r.t. the parameters used to compute the current hidden state, emulating the expectated parameter update we would have performed if we had actually trained with a sample of future inputs.

Estimate the expected gradient of future losses w.r.t. the current hidden state, is a very similar problem as estimating the expected future reward w.r.t. the current state in reinforcement learning. We can adapt the idea of value bootstrapping to bootstrap the gradients, training an estimator for the gradient using constant memory.

Focusing on the case of a single Mamba layer (real-valued) as an example, recall that the state update rule is \[h_t = \overline{A}_t h_{t-1} + \overline{B}_t x_t,\] where \(\overline{A} \in \mathbb{R}^{N \times N}\) is a diagonal matrix with eigenvalues in (0, 1). Outputs \(y\) are computed as \[y_t = C_t h_t.\]

We have some loss function as a function of \(y_t\), e.g. something like next-token cross entropy, which we will write as \(\mathcal{L}(y_t)\). The expected cumulative sum of \(\mathcal{L}\) over time is the "true" objective we're trying to optimize, akin to reward.

Note that the Jacobian of the hidden state update is given by \[\frac{\partial h_t}{\partial h_{t-1}} = \overline{A}_t\] and similarly \[\frac{\partial h_{t + T}}{\partial h_t} = \prod_{s=t + 1}^{t + T} \overline{A}_s\] (abusing notation since \(\overline{A}\) is a diagonal matrix). So, letting \(\mathcal{L}_{t : \infty}\) denote the future cumulative loss, we have that \[ \begin{align*} \nabla_{h_t} \mathcal{L}_{t : \infty} &= \nabla_{h_t} \mathcal L(y_t) + \frac{\partial h_{t + 1}}{\partial h_t} \nabla_{h_{t + 1}} \mathcal{L}_{(t + 1) : \infty} \\ &= \nabla_{h_t} \mathcal L(y_t) + \overline{A}_{t + 1} \nabla_{h_{t + 1}} \mathcal{L}_{(t + 1) : \infty} \end{align*} \] An estimator \(g: \mathbb{R}^N \to \mathbb{R}^N\) for the gradient \(\nabla_{h_t} \mathcal{L}_{t : \infty}\) can be learned via bootstrapping \[g(h_t) \leftarrow \nabla_{h_t} \mathcal{L}(y_t) + \overline{A}_{t + 1} g(h_{t + 1})\] which converges to the true expected gradient for the same reasons that the value function in reinforcement learning converges, something something bellman contractivity. idk, i'm not a theorist

Once we have learned \(g\), we can add the expected contribution of future loss by emulating backpropogation: to make the gradient of \(h_{t + T}\) take on the value of \(g({h_{t + T}})\), we simply add the dot product \(\langle h_{t + T}, \text{sg} \{ g({h_{t + T}}) \}\rangle\) to the training loss, obtaining \[\langle h_{t + T}, \text{sg} \{ g({h_{t + T}}) \}\rangle + \sum_{s=t}^{t + T} \mathcal L(y_s),\] and then standard backpropogation with e.g. PyTorch will compute the correct gradient w.r.t. the parameters. In practice, I imagine that the gradient estimator \(g\) and the model will be trained in parallel, like with online reinforcement learning.

Does this actually work in practice? idk i came up with all of this, like, yesterday night. hit me up with funding and compute if you want to find out ig

EDIT (2024-11-29): lol, turns out this idea isn't original at all, see Decoupled Neural Interfaces using Synthetic Gradients (Jaderberg et al. 2016)