Matěj Kripner

Paper notes: Longformer, The Long-Document Transformer

Authors Iz Beltagy, Matthew E. Peters, Arman Cohan
Paper arxiv.org/abs/2004.05150
Code github.com/allenai/longformer
Year 2020

Longformer is a 2020 attempt to address the efficiency problem of self-attention. Self-attention, where each token of an input sequence attends to every other token, has an inherent quadratic time and memory complexity. Longformer addresses this by attending mostly locally.

This enables much longer context windows, which is desirable when a long input document needs to be considered as a whole to achieve some task - for example in the case of Question Answering. Simple approaches include chopping a long sequence into (overlapping) chunks and somehow combining the activations, or using a two-stage system where the first stage retrieves the relevant passages. Other general approach (which includes the Longformer) is based on replacing the full self-attention with some sparse, more efficient version. The authors point to Sparse Transformer (Child et al., 2019), BP-Transformer (Ye et al., 2019), and Blockwise attention (Qiu et al., 2019).

Longformer employs the following techniques:

The paper contains the following awesome illustrations for these concepts (figure 2 from the paper), where both the horizontal and the vertical axis represent the token sequence.

n2 attention Sliding window attention Dilated attention Sliding+Global attention
Full \(n^2\) attention Sliding Window attention Dilated Sliding attention Global+Sliding attention

As pointed out by the authors, these concepts have a direct analogy in Convolutional Networks - a global receptive field is obtained by repeating a small local attention. And when we need a very long receptive field (like in speech recognition, e.g. WaveNet), we make the convolution “diluted”, so that it “sees” further, albeit with lower precision.

Implementation

Authors provide an efficient implementation of Longformer. On high level, they modify the RoBERTa implementation from PyTorch Transformers by replacing all the self-attentions with custom LongformerSelfAttention. You can find all the code at github.com/allenai/…/longformer.py.

There are the following 4 implementations of the sliding window attention, which is effectively just a matrix multiplication from which we only need some small number of diagonals (see the images above). The whole point is that this should be more efficient than standard matrix multiplication.

Runtime and memory efficiency are compared in Figure 1 from the paper (excluding the non-overlapping version of Longformer-chunks).

Runtime and memory of different attention implementations
Figure 1 from the paper: Runtime and memory of different implementations of Longformer attention.

Tasks

Longformer is tested on character-level language modeling, question answering, classification and coreference resolution. It’s pretrained using Masked Language Modeling (MLM).

Training

Authors use a kind of curriculum learning (staged training) where the model is first taught to use local context and then to leverage the long receptive field. Concretely, they train the model in 5 phases, where each phase doubles the window size and sequence length and halves the learning rate (see the table bellow). Because MLM pretraining is expensive, they start training from the RoBERTa released checkpoint. Since RoBERTa uses (absolute) position embeddings with maximum position of 512 (contrast to Longformer’s 4096), the RoBERTa’s embeddings are just copied 8 times.

Authors train 2 different model sizes - small (12 layers, 512 hidden size) and large (30 layers, 512 hidden size). They experiment with using fp16 (floating point half precision) and fp32 (single precision) operations, even employing mixed-precision training using the apex.amp tool. Concretely, they use a mode where weights are stored in fp32, but many operations are performed in fp16 (but attention and batchnorms are still done in fp32)2. To further reduce memory usage, they use the ingenious gradient checkpointing technique (Chen et al., 2016), which enables training of \(n\)-layers deep networks with \(O( \sqrt{n})\) memory at the cost of doing one more forward pass per batch. Basically, instead of saving all \(n\) activations for later use in backpropagation, only every \(\sqrt{n}\)-th “checkpoint” activation is saved and the remaining ones are calculated on the fly - see the blogpost accompanying the gradient checkpointing paper. The experiments were not done on some sort of supercomputer - for the large model, they ran experiments on 8 RTX8000 GPUs for 13 days.

Hyperparameter search was done by training for a smaller number of steps (about 150K steps). The results for the task of character-level language modeling on text8 dataset are listed in the following Table 12 from the paper.

Param Value
Position Embeddings Relative and Sinusoidal as in Dai et al. (2019)
Small model config 12 layers, 8 heads, 512 hidden size as in Dai et al. (2019)
Large model config 30 layers, 8 heads, 512 hidden size as in Child et al. (2019)
Optimizer AdamW
Dropout 0.2 (small model), 0.4 (large model)
Gradient clipping 0.25
Weight Decay 0.01
Layernorm Location pre-layernorm (Xiong et al., 2020)
Activation GeLU
Number of phases 5
Phase 1 window sizes 32 (bottom layer) - 8,192 (top layer)
Phase 5 window sizes 512 (bottom layer) - (top layer)
Phase 1 sequence length 2,048
Phase 5 sequence length 23,040 (gpu memory limit)
Phase 1 LR 0.00025
Phase 5 LR 000015625
Batch size per phase 32, 32, 16, 16, 16
#Steps per phase (small) 430K, 50k, 50k, 35k, 5k
#Steps per phase (large) 350K, 25k, 10k, 5k, 5k
Warmup 10% of the phase steps with maximum 10K steps
LR scheduler constant throughout each phase
Dilation (small model) 0 (layers 0-5), 1 (layers 6-7), 2 (layers 8-9), 3 (layers 10-11)
Dilation (large model) 0 (layers 0-14), 1 (layers 15-19), 2 (layers 20-24), 3 (layers 25-29)
Dilation heads 2 heads only

Table 12 from the paper: Hyperparameters for the best performing model for character-level language modeling.

I highly encourage you to read the paper, which I found extremely well written.

Footnotes

  1. The paper claims to use chunks of size \(w\) with an overlap of size \(w/2\), but that seems incorrect and contradicts the code. 

  2. https://github.com/allenai/…/pretrain.py#L451