Scaling Transformer-XL to 128 GPUs

Yaroslav Bulatov
11 min readMay 6, 2019


Yaroslav Bulatov, Ben Mann, Darius Lam

TLDR; we made Transformer-XL train efficiently on 128 GPUs on Amazon cloud. The code is available at


One of the difficulties of researching language models is that you often don’t know if your ideas work until you try them on a real-world datasets. However, training on such datasets on one machine can take weeks.

Fortunately there’s a straightforward recipe to speed up this process:

Step 1. Get a good single machine model

Step 2. Run N copies of the model on N machines in parallel, synchronizing at each step

Step 3. Solve all remaining technical challenges

We used this recipe to reduce ImageNet training time from 2 weeks to 18 minutes. You could also apply the same optimization to train a model in 2 weeks that would originally require 4 years, so you can choose to scale up your research in scope instead of iteration time.

With minutes to train, image model training is “mostly solved” while language model training time remains an obstacle to innovation. Our goal was to see how much we could increase training throughput by distributing our training in the cloud.

For the single machine model, we settled on Transformer-XL’s official implementation. This architecture achieved several state of the art results in language modeling. The authors made it easy to reproduce their results by releasing code and instructions.

Once we reproduced the accuracy in their README, we iterated on performance to get about 2.6x improvement in throughput without hurting quality.

We then extended this model to train on multiple machines, using all-reduce to synchronize the gradients. Distributed language models send at least 10x more data per step than typical image models, so we used AWS p3dn instances for their 100 Gbps network connectivity. Those instances had 32GB of GPU memory so we were able to increase training efficiency another 25% by increasing the per-GPU batch size. The net result: a 64-GPU version of small Transformer-XL model trains about 44x faster than the original “slow” 4-GPU implementation. Scaling up an optimized medium-sized Transformer-XL model, 128-GPU version trains 13.2x faster than 8-GPUs version.

The training procedure required changes to prevent numerical divergence at larger batch sizes, so we followed the recipe provided by Google in their 1024 TPU scaling paper. We’ve open sourced this optimizer in a standalone repo.

Technical details


We primarily experimented with Wikitext-103, a standard language modeling dataset composed of a pre-tokenized subset of Wikipedia. The training set contains approximately 100M tokens and 267K unique words in 515MB. It uses a modified Moses tokenization, with 0.4% of words replaced with <UNK> symbols to limit vocab size. The Transformer-XL codebase loads the entire dataset into GPU memory at the start of training.

To speed things up at training time, we pre-encoded our datasets and baked them into our AWS disk images.


We’ve automated many common tasks that come up with prototyping a distributed training system and wrapped them into a tool, ncluster. The philosophy was that any frequent task should be a single command-line invocation. Some examples:

1. Bring up a number of instances for a test run at spot prices

2. Fix a bug and relaunch a multi-machine experiment in under 30 seconds.

3. Log into the most recent machine to see what it’s doing

4. Keep a folder on the remote machine synchronized with a local folder

5. Log into another user’s machine to troubleshoot an issue


We tried various network configurations to find settings that would optimize throughput. Running iperf3 tests, we could see 93 Gbps throughput between machines, while NCCL all-reduce performance dropped to about 50 Gbps. One feature of AWS network setup is that each network flow is limited to about 10 Gbps so we used multiple NCCL rings to enable multiple flows. Because of a bug in recent version of NCCL we had to revert to older version which enabled multiple flows, but didn’t use the bandwidth efficiently. In particular, the amount of bytes transferred grew 4x as we increased the number of GPUs, which is not expected from a bandwidth-constant algorithm like ring-allreduce. However, as we got LAMB working and our local batch size saturated 32GB of RAM, our network requirements dropped enough that even with excessive per-machine communication load, we did not hit the bandwidth limit.

At some point, we noticed that our most bandwidth efficient runs were underperforming, with performance dropping as much as 30% (below left chart). Some investigation revealed that running too close to the memory limit was causing the CUDA caching allocator to repeatedly run out of memory and cause extra synchronizations during garbage collection (below right chart). The solution was to reduce the memory usage slightly by dropping the batch size, which made the problem disappear.

Left: backwards time in ms, right: torch.cuda.memory_cached()

We found some throughput variability between runs. For instance an unlucky 64 GPU run would take 2.5 minutes to go through one epoch of wt103, while a “lucky” 128 GPU run could do it in 1 minute. We believe this is due to AWS cluster placement group being approximate: AWS tries to place all instances in a placement group onto a locally connected configuration, or “brick,” but when this is impossible, some instances will spill over to neighboring bricks.

Mixed precision

Relying on 16-bit precision enabled us to use less memory and improve processing speed by utilizing Volta TensorCores. We tested two forms of fp16 training — pure fp16 and mixed-precision training. In pure fp16 mode, all operations are done in reduced precision. We could not match the accuracy numbers in pure fp16.

In mixed precision training, we keep a master copy of the weights in fp32 and make updates to that master copy at each optimizer step, then cast weights to fp16 for the forward pass. This allowed us to maintain precision while saving space. Dynamic loss scaling increased the amount of gradient information propagated while maintaining numerical stability. To implement these, we used the code from Apex and Megatron packages developed by NVidia.

In synthetic experiments, we discovered that tensorcores were not fully activated unless dimensions were multiples of 8.

Figure 1: Time to multiply n x n matrices. Orange is fp32, blue is fp16. The Y axis shows log seconds averaged over three runs. Using tensorcores, fp16 matrix multiplication is almost 10x faster for large n. Note that for small n, the performance gains from using half-precision are negligible.

Figure 2: However, when n is shifted by one, so that n is not a multiple of 8, fp16 performs equally with fp32, ie there is no performance increase at all and we simply save memory.

By modifying our model sizes we are able to get significant improvements in speed, from initial 1.2x to 2. by converting to mixed-precision and increasing the batch size.


The LAMB optimizer uses a simple tweak that allows the learning rate to scale linearly with global batch size across workers.

Recall that LAMB multiplies the learning rate for each layer weight by the ratio (r) of the norm of the weights (r1) to the norm of the Adam step (r2).

How to interpret the left charts below: the Y axis shows the timestep in number of tokens processed with the first at the top. The X axis is histogram buckets, eg N samples had a value between 0 and 1, 1 and 2, etc. The Z axis is histogram frequency for the values in each parameter in the network, eg X% of the layer weights fell in the 0 to 1 bucket. You can think of this as one histogram of values at each timestep, stacked in time, so we can see how the histogram changes over time. The right charts show the same data, but as an area chart. Darker colors means more of the histogram bins had values in the shaded area.

Most of the parameter weights (r1, upper right) are near 0 the entire time, but a few grow and spread, some getting quite large (6.5 max bucket)

The Adam step values above are very large in the beginning, but quickly stabilize to between 0 and .5.

The above two effects causes the trust ratio (r, upper left) to encourage some parameters to continue growing for a long time, maxing out at the clip value of 10.

In our tests, we found that LAMB made a difference even on a single p3dn.24xlarge machine with global batch size 768, while Adam diverged immediately. We also found that we could entirely eliminate warmup in our learning rate schedule since at the beginning of training, gradients are large, which causes the LAMB trust ratio to be small and reduce the learning rate for the layers that are most unstable.

Byte Pair Encoding vs Adaptive Embedding

Wikitext-103 has a vocab size of 267K, which takes a lot of model space to represent. We tried switching to byte pair encoding (BPE), which has shown performance benefits in machine translation. We used OpenAI’s BPT implementation since it can handle out of vocabulary sequences.

Transformer-XL has a very large vocab size (267K), so most of the params are in the token embeddings. We were able to reduce the model size using BPE with a vocab size of 50K, which reduced our small model configuration from 186M to 75M parameters with the same representational capacity. The original Transformer-XL code uses adaptive input and softmax to reduce the embedding size.

One slightly tricky detail is that BPE produced tokenizations ~15% longer than Wikitext-103. This is expected because Wikitext-103 is already tokenized (eg splitting “don’t” into “do n’t”. If it weren’t, BPE might handle “don’t” as a single token, while Wikitext-103’s tokenizer would split it. In BPE, out of vocabulary words get broken up into sub-word pieces that are in the vocabulary, even unicode, by splitting apart characters into individual bytes. This behavior is the source of the 15% increase in tokenization length.

In our ablation studies, we found that performance, as measured in training loss per token, was nearly identical after accounting for tokenization length. Tokens/second increased by ~10% due to smaller model size and GPU utilization dropped significantly for the same batch size from ~90% to ~55%. This meant that when we scaled to larger models, we could keep a higher batch size per GPU.

For a more fair fight, we also compared BPE to adaptive embeddings on our large model configuration:

To convert for validation perplexity is math.exp(math.log(26.25) * 1950 / 1699) = 42.5, so it’s impressively equivalent! This is despite the fact that the BPE encoding we used was trained on a general sample of the web (OpenAI’s Webtext dataset) rather than specifically Wikitext-103.

These two techniques are compatible, so it might be interesting to combine them. That said, since BPE is so flexible and the code is simpler, I’m inclined to use only it in future experiments.

Learning rate finder

After learning about this idea in part 2, Ben was excited to see it work. With Adam it worked well:

The minimum of the graph above happens to be exactly the learning rate used in the paper (0.00025)! And we only had to run it for a few minutes to discover this, which is far cheaper than doing a traditional complete hyperparameter search in which you’d run for at least an epoch with the single LR schedule you intend to use later.

When applied to LAMB, the finder significantly overestimated the learning rate.

The learning rate at the gray curve’s minimum loss was 0.0263; orange (using weight decay 1e-4) was 0.0435. But when we ran it even with a learning rate of 0.01, it still diverged. We had to go all the way down to 0.005 to get stable convergence. We think this is because LAMB causes the actual per-layer learning rate to be quite different at different times in training, causing the learning rate finder to be out of distribution. We could probably get around this by running the learning rate finder for longer, or on a checkpoint further into training. Instead, we chose to do a more traditional hyperparameter search for the learning rate.

Batch scaling

In many of our experiments, we found that when we increased the global batch size by using more GPUs, convergence speed decreased in the short term. To experiment with this effect, we tried scheduling the batch size on a single machine. We found that although it helped in the short term, the benefits vanished as we trained longer.

One explanation for increased initial convergence speed is that we had used constant learning rate across different batch sizes, which resulted in higher learning rate per example. The gap between the blue and gray lines at the end is about 100M tokens (1h), so it did help by about 10% even though we weren’t using the whole GPU capacity.

However, if we keep the base learning rate constant and apply linear learning rate scaling to adjust the learning rate across different batch size, training diverges.

All of Wikipedia

The graphs below show training loss vs validation loss for the large model configuration on Wikitext-103. You can see in the chart below that in just 800M tokens (1.5h), our validation loss stopped decreasing, suggesting we overfit the dataset.

For this reason we were excited to try training on all of NVIDIA’s Wikipedia dump, which is ~25X larger than Wikitext-103.

  • It’s 12GB of text, so we can’t load it all into GPU memory
  • It’s stored as JSON, so it first needs to be flattened with article boundaries preserved
  • It contains arbitrary unique tokens, so some tricks are needed to limit vocab size

Luckily Transformer-XL already contained a multi-file data loader for the 1 Billion Words corpus, so it only took a little massaging to get it working. Note that because the tokenization is different, the loss numbers shouldn’t be compared. Below, we show two runs using 64 GPUs on all of Wikipedia at ~600M tokens per hour. We found that the validation loss is very sensitive to the learning rate schedule.

We hypothesize that if we hadn’t cycled back up on the blue run, validation loss might have continued to decrease.


the work was a joint effort with contributions by

Ben — data pipelines, LAMB, parameter searching and tuning

Darius — 2x speedup on single machine models using mixed precision training

Yaroslav — experiment infrastructure, distributed training

Additional thanks to AWS for providing compute resources for this work.