Longformer is a 2020 attempt to address the efficiency problem of self-attention. Self-attention, where each token of an
input sequence attends to every other token, has an inherent quadratic time and memory complexity. Longformer addresses
this by attending mostly locally.
This enables much longer context windows, which is desirable when a long input document needs to be considered as a
whole to achieve some task - for example in the case of Question Answering. Simple approaches include chopping a long
sequence into (overlapping) chunks and somehow combining the activations, or using a two-stage system where the first
stage retrieves the relevant passages. Other general approach (which includes the Longformer) is based on replacing the
full self-attention with some sparse, more efficient version. The authors point to Sparse Transformer (Child et al.,
2019), BP-Transformer (Ye et al., 2019), and Blockwise attention (Qiu et al., 2019).
Longformer employs the following techniques:
- Sliding Window attention - Each token attends locally to a window of constant size \(w\), i.e. to \(w/2\) tokens
on the left and \(w/2\) tokens on the right. This has inherent complexity of \(O(nw)\) instead of \(O(n^2)\) (where
\(n\) is the size of the input sequence). Best results were achieved when \(w\) is increasing from the bottom to the
top layer (from 32 to 512).
- Dilated attention - To increase the receptive field, a token only attends to every \((d+1)\)-th token to the left
and to the right, where \(d\) is the dilation factor (see the image bellow). Standard attention corresponds to \(d =
0\). Similarly to window size, dilation factor is increasing from the bottom to the top layer (from 0 to 3).
- Global attention - It still helps to have some global attention between tokens. Authors implement it so that any
tokens can be specified to have global attention, in which case they attend to every other token and every
other token attends to them (it’s symmetric). The idea is that the number of globally-attenting tokens is relatively
small - e.g. question tokens in Question Answering or just the
[CLS]
token in classification. The justification is
that “the windowed and dilated attention are not flexible enough”, which is confirmed by ablations.
The paper contains the following awesome illustrations for these concepts (figure 2 from the paper), where both the
horizontal and the vertical axis represent the token sequence.
 |
 |
 |
 |
Full \(n^2\) attention |
Sliding Window attention |
Dilated Sliding attention |
Global+Sliding attention |
As pointed out by the authors, these concepts have a direct analogy in Convolutional Networks - a global receptive field
is obtained by repeating a small local attention. And when we need a very long receptive field (like in speech
recognition, e.g. WaveNet), we make the convolution “diluted”, so that it “sees” further, albeit with lower precision.
Implementation
Authors provide an efficient implementation of Longformer. On high level, they modify the RoBERTa implementation from
PyTorch Transformers by replacing all the self-attentions with custom LongformerSelfAttention
. You can find all the
code at
github.com/allenai/…/longformer.py.
There are the following 4 implementations of the sliding window attention, which is effectively just a matrix
multiplication from which we only need some small number of diagonals (see the images above). The whole point is that
this should be more efficient than standard matrix multiplication.
Longformer-loop
: A non-vectorized Python loop.
Longformer-chunks
-
This
splits the token sequence into chunks of size \(2w\) with an overlap of size \(w\). Now when we perform a full
matrix multiplication on a chunk, we get the correct self-attention values for the \(w\) middle tokens (
the seq[w / 2:-w / 2 + 1]
tokens), because all of them “see” their whole context window of size \(w\). The issue is
that the other \(w\) values of the self-attention can’t be computed, because some of their context is outside the
boundary of the chunk. This means that this consumes 2x the theoretical minimum needed memory. But it’s time
efficient, since it uses the highly-optimized standard matrix multiplication. Another issue is that this approach
doesn’t support dilated attention.
Longformer-chunks
(non-overlapping) -
This modified version of chunks
is not listed in the paper, but comments in the code claim that it’s 30% faster and
uses 95% of the memory. It accomplishes the job with really just one call to torch.einsum
using the
template bcxhd,bcyhde->bcxhey
(how cool is that?), which I have yet to understand. Neither this version supports
dilated attention.
Longformer-cuda
- This is a
custom CUDA kernel called diagonal_mm
, which does support dilated attention. It’s implemented using TVM (Tensor
Virtual Machine) which is, quoting the paper’s authors: “a deep learning compiler stack that
compiles high level description of a function into optimized device-specific code”. This makes the kernel more
flexible and maintainable than if it was written in C++ and designed for a specific version of PyTorch.
Runtime and memory efficiency are compared in Figure 1 from the paper (excluding the non-overlapping version
of Longformer-chunks
).
 |
Figure 1 from the paper: Runtime and memory of different implementations of Longformer attention. |
Tasks
Longformer is tested on character-level language modeling, question answering, classification and coreference
resolution. It’s pretrained using Masked Language Modeling (MLM).
Training
Authors use a kind of curriculum learning (staged training) where the model is first taught to use local context and
then to leverage the long receptive field. Concretely, they train the model in 5 phases, where each phase doubles the
window size and sequence length and halves the learning rate (see the table bellow). Because MLM pretraining is
expensive,
they start training from the RoBERTa released checkpoint. Since RoBERTa uses (absolute) position embeddings with maximum
position of 512 (contrast to Longformer’s 4096), the RoBERTa’s embeddings are just copied 8 times.
Authors train 2 different model sizes - small (12 layers, 512 hidden size) and large (30 layers, 512 hidden size). They
experiment with using fp16 (floating point half precision) and fp32 (single precision) operations, even employing
mixed-precision training
using the apex.amp
tool. Concretely, they use a mode where weights are
stored in fp32, but many operations are performed in fp16 (but attention and batchnorms are still done in fp32). To
further reduce memory usage, they use the ingenious gradient checkpointing technique
(Chen et al., 2016), which enables training of \(n\)-layers deep networks with \(O(
\sqrt{n})\) memory at the cost of doing one more forward pass per batch. Basically, instead of saving all \(n\)
activations for later use in backpropagation, only every \(\sqrt{n}\)-th “checkpoint” activation is saved and the
remaining ones are calculated on the fly - see the
blogpost accompanying the gradient checkpointing paper.
The experiments were not done on some sort of supercomputer - for the large model, they ran experiments on 8 RTX8000
GPUs for 13 days.
Hyperparameter search was done by training for a smaller number of steps (about 150K steps). The results for the task of
character-level language modeling on text8
dataset are listed in the following Table 12 from the paper.
Table 12 from the paper: Hyperparameters for the best performing model for character-level language modeling.
I highly encourage you to read the paper, which I found extremely well written.