Matěj Kripner

Paper notes: σ-GPTs, A New Approach to Autoregressive Models

Authors Arnaud Pannatier, Evann Courdier, François Fleuret
Paper arxiv.org/abs/2404.09562
Code github.com/idiap/sigma-gpt
Demo arnaudpannatier.ch/sigma-gpt/
Year 2024

When using a model like Transformer to generate a sequence, the traditional approach is to generate it left-to-right. The generation of a token is then conditioned on all tokens that precede it (that have already been generated). This is what is traditionally called the autoregressive order.

The idea of σ-GPTs is to instead allow the generation of tokens in any arbitrary order. In other words, to condition on any subset of tokens and generate any other token. Only two changes to a standard GPT-type model are needed to achieve this:

  1. Traditionally, we use positional encoding to inform the model about the position of each input token (because Attention doesn’t take order into account). That is, each token is enriched with cleverly encoded information about its position – in this case, authors use the standard sinusoidal positional encoding. The information about which token to generate next is then implicit – it is simply the next one in the left-to-right order. However, in the case of σ-GPTs, we must tell the model which token to generate next. So the authors include another positional encoding containing just that – for each token the information about which token follows.
  2. During training, the sequences are shuffled randomly. Otherwise, the model would only see the left-to-right order and produce nonsense when presented with a different order during inference.

Figure 1 from the paper illustrates both changes:

Training pipeline, showing random shuffling and two positional encodings
Figure 1 from the paper: Training pipeline, showing random shuffling and two positional encodings.

Note that this figure shows the situation during training when the whole sequence is available. During inference, we would only feed in the already generated tokens and specify which one token to generate next.

The presented approach immediately offers two advantages:

Both of these are illustrated in Figure 2 from the paper shown below. The model only saw the black points and estimated / infilled the rest.1

Infilling and Conditional Density Estimation
Figure 2 from the paper: Infilling and Conditional Density Estimation.

The paper has equivalent but way cooler examples in 2D mazes, so definitely check it out.

Token-based Rejection Sampling

σ-GPTs enable a faster way to generate tokens than sequentially (one by one). Let’s say we want to generate tokens on positions \(j_1, \ldots, j_l\). Let also \(I\) be the set of already generated positions together with the positions that were given as the input prompt. Traditionally, we’d generate tokens on positions \(j_1, \ldots, j_l\) sequentially. This would mean sampling from these random variables one by one:

\[\begin{align*} \mathcal{U} = \{\; & p(x_{j_1} \;|\; I), \\ & p(x_{j_2} \;|\; I, j_1), \\ & \ldots, \\ & p(x_{j_l} \;|\; I, j_1, \ldots, j_{l-1}) \;\} \end{align*}\]

However, this is slow since we have to wait for a token to be generated before we can generate the next one. Instead, we’d ideally want to sample all the tokens \(x_{j_1}, \ldots, x_{j_l}\) at once. In other words, to sample from these random variables all at once (in parallel):

\[\begin{align*} \mathcal{V} = \{\; & p(x_{j_1} \;|\; I), \\ & p(x_{j_2} \;|\; I), \\ & \ldots, \\ & p(x_{j_l} \;|\; I) \;\} \end{align*}\]

The problem is that in \(\mathcal{V}\), the sampled tokens do not depend on each other since they are generated all at once. This means that we might end up sampling tokens that all look good (probable) on their own but together they are nonsensical (improbable).

So we sampled efficiently from \(\mathcal{V}\) but really wanted our samples to come from \(\mathcal{U}\). This is where rejection sampling comes in.

The idea behind rejection sampling is simple. Let’s say we want to sample from a complicated distribution \(P_c\). We cannot use e.g. the inversion method because we are not able to compute the inverse of the cumulative distribution function. So instead, we sample from a simple distribution \(P_s\) (e.g. uniform, gaussian, …). The way we transform samples from \(P_s\) into samples from \(P_c\) is by discarding (rejecting) some of the samples.

Let’s describe the whole process following the image below which shows a simple density function (“proposal distribution”) and a complicated one (“target distribution”). We first generate a uniform sample from the whole 2D area below the simple function. We can do that by first uniformly sampling the \(x\) value (this assumes values only within a finite range) and then uniformly sampling the \(y\) value from \([0, f(x)]\) (this assumes that the density function is bounded). Then, we only accept the sample if it also lies below the complicated function (this assumes we’re able to evaluate the complicated function at any given point). Otherwise, we reject it and repeat.

The idea is that if all the samples are uniformly distributed below the simple function, the green (accepted) samples are uniformly distributed below the complicated function. The second idea is that if we then take the \(x\) value of such a uniformly distributed sample, we get a sample from a distribution that has the complicated function as its density function. This is intuitive if the distribution was discrete and the density function was therefore a histogram. And it also works in the limit for continuous distributions.

For this to work, the complicated function has to lie completely below the simple one. This can be achieved by scaling the simple function.

The idea behind rejection sampling.
The idea behind rejection sampling.

Returning to the original problem, we basically sample from \(\mathcal{V}\) and use rejection sampling to get samples from \(\mathcal{U}\). Using the notation introduced above, the algorithm is as follows (it’s Algorithm 1 from the paper).

First, compute the distributions over \(j_1, \ldots, j_l\) conditioned on positions \(I\), obtaining \(p_1, \ldots, p_l\). These can all be done in parallel. Then, sample from these distributions, obtaining tokens \(x_1, \ldots, x_l\). This is our sample from \(\mathcal{V}\), i.e. the proposed output tokens. As a reminder, these tokens are not conditioned on each other and can therefore be nonsensical when put together. The rest of the algorithm is just applying rejection sampling in order to get \(\mathcal{U}\).

So first we evaluate \(\mathcal{U}\) on \(x_1, \ldots, x_l\). This can again be done in a single pass through the model. Here we use the degree of freedom that the model can generate the tokens \(x_1, \ldots, x_l\) in any order. Let’s choose an arbitrary order (permutation) \(\sigma\) for now. We therefore obtain logits \(q_1, \ldots, q_l\) where:

\[q_i = p(x_{\sigma(j_i)} \;|\; I \cup \{\sigma(j_1), \ldots, \sigma(j_{i-1})\})\]

We now perform rejection sampling \(l\)-times. That is, for each position \(i \in {1, \ldots, l}\), we sample uniformly \(u_i \in [0, 1]\) and reject the \(i\)-th token if:

\[u_i \cdot p_i < q_i\]

More precisely, we have to handle the possibility of \(q_i > p_i\) (that would violate the condition required by rejection sampling). To this end, the authors simply modify the condition like this:

\[u_i \cdot p_i < \min(p_i, q_i)\]

To me, this correction seems wrong, because it actually has no effect. That is, for cases when \(q_i > p_i\), the condition will always be satisfied anyway (because \(u_i < 1\) with probability \(1\)). This is even clearer after dividing both sides by \(p_i\), obtaining the form listed in the paper:

\[u_i < \min \left(1, \frac{q_i}{p_i} \right)\]

It also seems conceptually wrong to try to make the resulting samples follow a “cropped” target density function instead of the original one (which is just \(q\)).

If you see why this is in fact correct, please let me know. Instead, what would seem right to me is to scale \(p_i\), obtaining this:

\[u_i \cdot p_i \cdot \left( \max_{i \in \{1, \ldots, l\}} \frac{q_i}{p_i} \right) < q_i\]

Anyway, these criteria specify for each token whether it’s rejected or not. The algorithm then accepts as many tokens as possible (going in the order \(\sigma\)) before the first rejection occurs. All tokens after the first rejection are discarded because they are conditioned on a nonsensical (rejected) token. This gives us \(a\) new tokens and we iterate.

There is one final twist. Remember when we chose the order \(\sigma\) arbitrarily? What the authors actually do is generate \(N_o\) orders randomly and choose the one that yields the highest number of accepted new tokens (the largest \(a\)). What seems important here is that the uniform samples \(u_i\) are generated only once. If instead they were generated for each order independently, we would be “hacking” the rejection criterion by waiting for high values of \(u_i\). But even still, it seems to me that this approach might be “hacking” the model by amplifying any imperfections in its probability estimation – as if we are waiting for the model to make a mistake that will allow us to accept more tokens.

Figure 6 in the paper shows the effect of varying \(\sigma\) from \(1\) to (in some cases) \(50\). From it, it’s clear that increasing \(\sigma\) decreases the number of iterations needed to produce the sequence. The effect of this on error rate / perplexity seems minimal. However, it would be interesting to see the effect of increasing \(\sigma\) beyond \(4\) in the case of Text Modeling.

Results

For results, please see Chapter 3 of the paper.

In short, σ-GPT reaches similar performance relative to a standard left-to-right GPT (Table 2 in the paper) while requiring an order of magnitude fewer steps during inference (Figure 6). On the other hand, due to the increased difficulty of the random-order generation, σ-GPTs require several times more training steps (Table 3) and the dataset size needed to switch from memorization to generalization is higher (Figure 5).

Other notes

Footnotes

  1. I don’t know what CFL means in the left figure and it seems not to be explained in the paper (I only ever encountered the acronym when referring to Context Free Languages).