Pytorch - The Theory behind Distributed Data Parallel

Joel Tokple
5 min readJan 26, 2021

--

With the urge to steadily improve performance of deep neural network models, the necessity emerges to scale out model training, in order to use larger training data with more computational resources. The Distributed Data Parallel module in Pytorch can be used to perform distributed training across multiple GPUs and machines. Let’s dive into the theory behind this approach.

Photo by Alina Grubnyak on Unsplash

First of all, with the implementation of Distributed Data Parallel, Pytorch enables an extremely non-intrusive way for you to use your code across multiple nodes. The following is a code snippet of the Pytorch implementation of a simple neural network, running locally solely on one machine.

Source: Modified from [1]

Line 9 initializes our neural network architecture. Line 10 defines our optimization algorithm to be used. In lines 13 to 21, we generate random input samples and labels, perform our forward pass through the neural network, compute the error, compute our gradients, and then update our model parameters.

If you wish to scale out your training to multiple machines and GPUs, all you have to do is add one line to your code.

Source: [1]

In line 11 we merely have to call the constructor for DistributedDataParallel and provide our neural network model as the argument.

The Mechanism behind Distributed Data Parallel

When the constructor for DistributedDataParallel is called, your model gets replicated across all nodes that take part in distributed training. Your training batches are then split across nodes, and then each node is performing independent forward passes, error computations and gradient computations. After each node has computed its gradients locally, these gradients are then synchronized with the other nodes by averaging them. To do this, Pytorch uses the AllReduce API which is supported by NCCL, GLOO and MPI communication libraries. Subsequently, each node uses these averaged gradients to update its local model parameters. This ensures that models stay consistent across iterations on each participating node. In this manner your neural network model can iterate through the training data until convergence.

Reducing Training Iteration Time

To keep training iteration time at a minimum, Pytorch uses three techniques: Bucketing Gradients, Overlapping Gradient Computation and Communication, and providing you with the ability to skip Gradient Synchronizations. All three concepts will be presented in the following.

Bucketing Gradients

When synchronizing gradients, Pytorch lets you choose how many gradients should be communicated to the other nodes per AllReduce operation. Bucket sizes can range anywhere from one gradient to all gradients computed during an entire backward pass. Following is an evaluation of how long total communication of 60 million torch.float32 gradients takes, depending on bucket size per AllReduce operation.

Source: [1]

The results show that for 60 million gradients, total communication time decreases in a linear manner as bucket sizes increase using NCCL (left). Using GLOO (right) the same is true until reaching a bucket size of about 200 thousand, after which total communication time is starting to converge. Keeping bucket sizes relatively small however, will allow for gradient communication and gradient computation to overlap, causing the total training iteration time to shorten.

Overlapping Gradient Computation and Communication

Pytorch uses autograd hooks for gradient bucketing. This means that as soon as a gradient is computed in the backward pass, its hook will fire. If all gradients that were assigned to a bucket had their hook fire, the AllReduce operation will be called on that bucket. It can happen however, that a bucket x on one node becomes ready for communication before bucket y on the same node, while the same buckets on another node become ready in reverse order. If the AllReduce operation would be called on buckets right after they become ready for communication, buckets could potentially mismatch when synchronizing gradients. For this reason a bucketing order is determined and all nodes have to submit buckets in the same order. The bucketing order is set as the reverse order in which model parameters appear in the forward pass, as this is roughly going to be the order in which gradients are computed in the backward pass.

Another fact to consider is also that, if only training a subgraph of the neural network, not all gradients of the neural network will be computed in the backward pass, and thus waiting for certain buckets would not terminate. For this reason Pytorch marks gradients of model parameters who do not participate in the training iteration as ready after completing the forward pass, and consequently it will only be waited for the firing of autograd hooks of participating model parameters. This scenario, as well as the former scenario in which buckets across nodes mismatch when synchronizing are visualized in the following illustration.

Source: [1]

Skipping Gradient Synchronizations

To further speed up training, gradient synchronizations can be skipped for a number of times. This way, each node conducts n partial training iterations (partial, because no parameter updates will be performed) locally before synchronizing its gradients with the other nodes. This can be achieved through the no_sync context manager provided by Pytorch. When using this context manager, the autograd hooks, which are normally firing during the backward pass are disabled, causing AllReduce not be able to be called. Instead, gradients are accumulated for each local partial training iteration and after the whole training batch has been processed, gradients will be synchronized again in the first backward pass following.

Source: [1]

The previous image illustrates the usage of the no_sync context manager in an example. The DistributedDataParallel constructor is called with the neural network model as an argument and then the context entered. In the context it is iterated through the training batch while performing forward passes, error computations and backward passes. Instead of updating model parameters, gradients are accumulated. After exiting the context, another batch of training data is used for the forward pass, error computation and backward pass. Subsequently, the gradients are synchronized.

Specifications that were presented in this article, such as the size of gradient buckets or the communication library to be used with DistributedDataParallel can be set as arguments in its constructor. For further information also refer to here.

References

[1] Li, Shen, et al. “PyTorch Distributed: Experiences on Accelerating Data Parallel Training.” Proceedings of the VLDB Endowment 13.12.

[2] Distributed Data Parallel — PyTorch 1.7.0 documentation

--

--

Joel Tokple
Joel Tokple

No responses yet