Scaling Transformer-XL to 128 GPUs


One of the difficulties of researching language models is that you often don’t know if your ideas work until you try them on a real-world datasets. However, training on such datasets on one machine can take weeks.

Technical details


We primarily experimented with Wikitext-103, a standard language modeling dataset composed of a pre-tokenized subset of Wikipedia. The training set contains approximately 100M tokens and 267K unique words in 515MB. It uses a modified Moses tokenization, with 0.4% of words replaced with <UNK> symbols to limit vocab size. The Transformer-XL codebase loads the entire dataset into GPU memory at the start of training.


We’ve automated many common tasks that come up with prototyping a distributed training system and wrapped them into a tool, ncluster. The philosophy was that any frequent task should be a single command-line invocation. Some examples:


We tried various network configurations to find settings that would optimize throughput. Running iperf3 tests, we could see 93 Gbps throughput between machines, while NCCL all-reduce performance dropped to about 50 Gbps. One feature of AWS network setup is that each network flow is limited to about 10 Gbps so we used multiple NCCL rings to enable multiple flows. Because of a bug in recent version of NCCL we had to revert to older version which enabled multiple flows, but didn’t use the bandwidth efficiently. In particular, the amount of bytes transferred grew 4x as we increased the number of GPUs, which is not expected from a bandwidth-constant algorithm like ring-allreduce. However, as we got LAMB working and our local batch size saturated 32GB of RAM, our network requirements dropped enough that even with excessive per-machine communication load, we did not hit the bandwidth limit.

Mixed precision

Relying on 16-bit precision enabled us to use less memory and improve processing speed by utilizing Volta TensorCores. We tested two forms of fp16 training — pure fp16 and mixed-precision training. In pure fp16 mode, all operations are done in reduced precision. We could not match the accuracy numbers in pure fp16.


The LAMB optimizer uses a simple tweak that allows the learning rate to scale linearly with global batch size across workers.

Byte Pair Encoding vs Adaptive Embedding

Wikitext-103 has a vocab size of 267K, which takes a lot of model space to represent. We tried switching to byte pair encoding (BPE), which has shown performance benefits in machine translation. We used OpenAI’s BPT implementation since it can handle out of vocabulary sequences.

Learning rate finder

After learning about this idea in part 2, Ben was excited to see it work. With Adam it worked well:

Batch scaling

In many of our experiments, we found that when we increased the global batch size by using more GPUs, convergence speed decreased in the short term. To experiment with this effect, we tried scheduling the batch size on a single machine. We found that although it helped in the short term, the benefits vanished as we trained longer.

All of Wikipedia

The graphs below show training loss vs validation loss for the large model configuration on Wikitext-103. You can see in the chart below that in just 800M tokens (1.5h), our validation loss stopped decreasing, suggesting we overfit the dataset.

  • It’s 12GB of text, so we can’t load it all into GPU memory
  • It’s stored as JSON, so it first needs to be flattened with article boundaries preserved
  • It contains arbitrary unique tokens, so some tricks are needed to limit vocab size


the work was a joint effort with contributions by



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store