Getting started with Fully Sharded Data Parallel

What is FSDP?

Fully Sharded Data Parallel (FSDP) is one of the primary ways of doing distributed training with PyTorch. Given a training setup with several devices (multiple GPUs), which can be spread across multiple nodes, FSDP provides a framework for defining tensor operations across all of the nodes and devices. Common operations including replicating and sharding tensors with respect to the available devices.

In a previous post, we saw how we can implement the ZeRO-1 setup in PyTorch. Building on that, we will emulate ZeRO-2 with FSDP. We will see the various ways in which the FSDP API differs and the advantages it provides with respect to the previous approach.

For this post, we will fix our training setup to a single node with d devices. We will replicate the model weights across all the devices but shard the optimizer states and the gradients across the devices. This is the simplest setup for showcasing the FSDP primitives. We will be able to build on this and work towards more advanced sharding strategies, both intra-node and inter-node.

Initialisation

The first thing we have to do is set up the topology of our devices. In PyTorch, the DeviceMesh class is used for this purpose. The FSDP module will use the device mesh to determine how the computations will be carried out.

In our case, we have a single node with d devices, and we want to replicate the model parameters across all the devices. So, the device mesh will be a 1-D array. If instead, we wanted to shard across some devices and replicate across the others, we would need to define the mesh as a 2-D array. Once defined, we can specify the operations (shard or replicate) across each dimension.

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh

dist.init_process_group(backend='nccl')
world_size = dist.get_world_size()
 
dp_mesh = init_device_mesh("cuda", (world_size,))

The code above imports the necessary modules. It then creates dp_mesh, which is a 1-D array of size equal to the number of devices present. We also mention cuda as the device type, as we are working with GPUs.

Mixed Precision operations

In DistributedDataParallel, mixed precision training is done with torch.amp. When using FSDP, we can instead use the MixedPrecisionPolicy class, as shown below:

from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32
)

Note the two arguments to the function:

  • param_dtype defines the data type for the forward and backward computations. We set this to the lower precision type, in which we want to do the computations. Here it is bfloat16.

  • reduce_dtype defines how the gradients will be averaged and accumulated (such as when doing gradient accumulation with microbatches). This can be set to full precision, to preserve more information. The default behaviour of MixedPrecisionPolicy is to set reduce_dtype to the same dtype as param_dtype, which is bfloat16 in this case. In contrast, when using torch.amp, the backward pass is called outside the context, implying that gradient reductions and optimizer steps are carried out in the dtype of the original tensor.

Training with FSDP

Once we have a device mesh and a mixed precision training policy, we are ready to transform our nn.module to a FSDPModule. The PyTorch docs provides guidance on how this should be done: “Partitioning the model into multiple groups (“layer by layer”) allows for peak memory savings and communication/computation overlap”. So, once we have our model, we can iterate over the layers and call fsdp.fully_shard on each of them.

When applying fully_shard to each layer, we will also define the sharding behaviour. Previously, we have created:

  • the device mesh, which defines the topology of all the devices
  • the mixed precision policy, which determines the data type for the tensor operations, the activations and the gradients.

The remaining setting that we have to provide is for the weight tensors, that is how they should be allocated among the devices. In the current setting, we are not going to shard the weights, rather we are going to replicate the weight tensors on all devices. This is done by passing the argument reshard_after_forward=False to fully_shard 1. This instructs the FSDP runtime to keep all weights together on each device.

fsdp_kwargs = dict(mesh=dp_mesh, mp_policy=mp_policy, reshard_after_forward=False)
for layer in model.layers:
    fully_shard(layer, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)

At this point, we are ready to train. We can create our optimizer, and perform the forward pass, the backward pass and optimization step, just like we do with any other nn.module.

In case, we are using microbatches to accumulate gradients, we can disable gradient communication until the optimization step. This prevents unnecessary synchronizations among the devices until the optimization step will be done, when the gradients from all the devices are needed.

if do_optimization_step:
    model.set_requires_gradient_sync(True, recurse=True)
else:
    model.set_requires_gradient_sync(False, recurse=True)

Checkpointing

Saving and loading checkpoints work a bit differently when working with FSDP modules. Since each pytorch process is working on a particular device, it only has a view to the tensor shards allocated to it. In our setup, even though the model weights are replicated, the optimizer states are sharded across the devices. So, we need an efficient mechanism to save all the different tensors involved. The torch.distributed.checkpoint (dcp) module provides functionality to save and load tensors in a device mesh agnostic manner 2, while still being efficient.

In my previous post on ZeRO, I mentioned that saving the state dict requires a consolidation step on a single rank, and this can be time consuming since it’s a blocking operation. dcp avoids this by having all ranks save its local view, along with metadata information that will be used later on to reconstruct the full state dicts during the load stage. This post on the PyTorch discuss forum goes into details about checkpointing and the challenges involved.

import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict

def save_checkpoint(model, optimizer, ckpt_dir):
    if dist.is_initialized():
        dist.barrier()

    model_sd, optim_sd = get_state_dict(model, optimizer)

    dcp.save(
        state_dict={"model": model_sd, "optimizer": optim_sd},
        checkpoint_id=str(ckpt_dir),
    )

The code snippet above shows the steps involved in saving the model and optimizer state dicts. We first use the get_state_dict function to obtain the actual tensors, and then use dcp.save to save them to a directory.

The loading of state dicts is a similar process. We first get the state dicts, and then dcp loads the actual tensor values from the provided location.

def load_checkpoint(model, optimizer, ckpt_dir):

    model_sd, optim_sd = get_state_dict(model, optimizer)
    dcp.load(
        state_dict={"model": model_sd, "optimizer": optim_sd},
        checkpoint_id=str(ckpt_dir),
    )

    set_state_dict(
        model,
        optimizer,
        model_state_dict=model_sd,
        optim_state_dict=optim_sd,
    )

    if dist.is_initialized():
        dist.barrier()

If you inspect your save location, you will see that there are several files, which look like:

 Oct 14 08:51 __0_0.distcp
 Oct 14 08:51 __1_0.distcp
 Oct 14 08:51 __2_0.distcp
 ...

Each rank saved its own view, hence there are several files. For inference, you might want to load everything on a single rank without FSDP. For this purpose, there is a utility function to convert the distributed checkpoint to a standard torch save formatted file: dcp_to_torch_save. Running this will give you a single checkpoint file that you can use directly with torch.load.

Conclusion

This post summarizes the end to end flow of setting up FSDP training for a model with mixed precision training and sharded optimizer states and gradients. We also saw the process to save and load distributed checkpoints in an efficient manner. This is one of the simpler setups and hopefully provides a gateway to more advanced techniques such as tensor parallelism and context parallelism.

I will admit that in this post, the purpose of device mesh was not made completely clear. However, since it’s crucial for understanding other parallelism strategies and for composing different strategies together, I recommend reading the device mesh recipe from PyTorch and the jax automatic parallelization docs for a better understanding. After that, the tutorial on tensor parallelism is a good next step. I will eventually write a follow up post to this one, explaining multi node training with different parallelism strategies composed together.


  1. See the torchtitan docs, which provide the FSDP2 settings corresponding to the different ZeRO stages. 

  2. This means you can save the weights from a particular mesh configuration and reload later in a different mesh configuration 




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • ZeRO stage-1