原文:Gradient backpropagation with torch.distributed.all_gather - 2021.02.07

作者:Jianfeng Wang

本文主要是关于 torch.distributed.all_gather 的使用与如何确保梯度被正确计算.

torch.distributed.all_gather 定义如:

https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#all_gather

def all_gather(tensor_list,
               tensor,
               group=None,
               async_op=False):
    """
    Gathers tensors from the whole group in a list.
    Complex tensors are supported.

    Args:
        tensor_list (list[Tensor]): Output list. It should contain
            correctly-sized tensors to be used for output of the collective.
        tensor (Tensor): Tensor to be broadcast from current process.
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
        async_op (bool, optional): Whether this op should be an async op

    Returns:
        Async work handle, if async_op is set to True.
        None, if not async_op or if not part of the group

    Examples:
        >>> # All tensors below are of torch.int64 dtype.
        >>> # We have 2 process groups, 2 ranks.
        >>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
        >>> tensor_list
        [tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
        >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
        >>> tensor
        tensor([1, 2]) # Rank 0
        tensor([3, 4]) # Rank 1
        >>> dist.all_gather(tensor_list, tensor)
        >>> tensor_list
        [tensor([1, 2]), tensor([3, 4])] # Rank 0
        [tensor([1, 2]), tensor([3, 4])] # Rank 1

        >>> # All tensors below are of torch.cfloat dtype.
        >>> # We have 2 process groups, 2 ranks.
        >>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)]
        >>> tensor_list
        [tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1
        >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
        >>> tensor
        tensor([1.+1.j, 2.+2.j]) # Rank 0
        tensor([3.+3.j, 4.+4.j]) # Rank 1
        >>> dist.all_gather(tensor_list, tensor)
        >>> tensor_list
        [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0
        [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1

    """
    _check_tensor_list(tensor_list, "tensor_list")
    _check_single_tensor(tensor, "tensor")
    if _rank_not_in_group(group):
        return

    tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]
    tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)

    if group is None:
        default_pg = _get_default_group()
        work = default_pg.allgather([tensor_list], [tensor])
    else:
        work = group.allgather([tensor_list], [tensor])

    if async_op:
        return work
    else:
        work.wait()

1. all_gather 不进行梯度反向传播

首先,torch.distributed.all_gather 本身是不会进行梯度的反向传播的.

如:test.py:

batch_size = 16
rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', '0'))
world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE', '1'))
bs_each = batch_size // world_size
device_id = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0'))
torch.cuda.set_device(device_id)
torch.distributed.init_process_group(
    backend='nccl',
    init_method='tcp://localhost:12345',
    rank=rank,
    world_size=world_size,
)
#
from torch import nn

model = nn.Linear(1, 1, bias=False)
model.weight.data[:] = 1.
model = model.cuda()
x = torch.ones((bs_each, 1), requires_grad=True).cuda()
y = model(x)
ys = [torch.zeros_like(y) for i in range(get_mpi_size())]
#
torch.distributed.all_gather(ys, y)
print(y.grad_fn)
#<MmBackward object at 0x7f2073fc3ba8>
for sub_y in ys:
     print(sub_y.grad_fn)
     #None

运行该代码,首先,其会打印出没采用 all_gather 的真正的梯度函数y.grad_fn. 然后,调用 all_gather 后,ys 的输出是没有 grad_fn 的,可以理解为其是没有梯度反向传播的.

实际场景中,推荐采用 torch.no_grad() 封装 all_gather 函数,以显式地表明没有梯度进行反向传播.

2. all_gather 如何反向传播

由于 all_gather 是没有梯度反向传播的,但实际场景是否需要呢?如果需要的话,如何实现呢?

典型设置是,每个 GPU 分别计算输出,基于所有 GPUs 的输出来计算损失函数,而不是每个 GPU 分别计算损失函数.

其可以确保:

[1] - 不需要 all_gather 的梯度传播;

[2] - 损失函数可以非常易于计算.

具体实现如:

假设 i-th GPU 的输出是 $x_i$,损失函数的计算取决于所有的 $x_i$,即:

$$ f(x_1, x_2, ..., x_n) $$

传统的每张 GPU 分别计算损失函数,其形式如:

$$ f(x_1, x_2, ..., x_n) = \frac{1}{n} \sum_i g(x_i) $$

每个GPU计算 $g(x)$ 损失,然后自动梯度计算(auto grad) 会计算所有参数的梯度. 一般其是与 DistributedDataParallel 一起的,自动进行求平均. 此时,是不需要收集其他 GPU 的输出的.

但,如果损失函数不是分别梯度计算的呢?此时,对于每个 GPU,需要收集所有特征. 而,由于收集的输出是没有梯度(grad_fn)的,可以采用如下方式:

with torch.no_grad():
    all_x = [torch.zeros_like(x) for _ in range(world_size)]
    torch.distributed.all_gather(all_x, x)
all_x[rank] = x

all_x 包含了所有 GPUs 输出的 x. 所有的 x 都是没有 grad_fn 的,除了当前 GPU 输出的 x,因为 all_x[rank] = x.

然后,即可基于 all_xf 计算损失.

而,梯度计算为:

$$ \frac{\partial f}{\partial x_i} \frac{\partial x_i}{\partial \theta} $$

其中,$\theta$ 为网络参数.

注,仅对 $x_i$ 有梯度,而其他的 $x$ 是没有梯度的. 因为 all_gather 得到的是没有梯度的.

结合 DistributedDataParallel,每个参数的梯度计算为:

$$ \frac{1}{N} \sum_i \frac{\partial f}{\partial x_i} \frac{\partial x_i}{\partial \theta} $$

其中,$N$ 为 world_size.

但,这并不是期望的. 因为目标损失函数是 $f$,而没有约数$N$. 对此,可以将每个 GPU 的计算作为 $Nf$,而不是 $f$.

3. 总结

虽然 all_gather 是没有梯度计算的,但可以有效的计算真正的梯度,主要步骤为:

[1] - 采用 all_gather 收集所有网络的输出,然后,替换当前网络的输出,而当前网络的输出是有梯度的.

[2] - 计算损失函数,并乘以 world_size.

4. 相关材料

Multi-gpu training:

https://github.com/KevinMusgrave/pytorch-metric-learning/issues/10

[1] - all_gather embeddings from all replicas.

[2] - Because gathered tensors have no gradients, we overwrite the gathered embeddings tensor from the current replica with the embeddings tensor produced on that replica, which has gradients to the encoder.

[3] - Concatenate the list of embeddings before computing the loss.

实现的伪代码如:

import torch
import torch.distributed as dist

# Dummy code representing the forward pass for some batch of text on one replica.
embeddings = model(batch)

# Gather the embeddings from every replica.
embeddings_list = [torch.ones_like(embeddings) for _ in range(dist.get_world_size())]
dist.all_gather(embeddings_list, embeddings)

# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.
# with the embeddings produced on this replica, which do have gradients.
embeddings_list[dist.get_rank()] = embeddings

# Finally, concatenate the list of embeddings before computing a loss.
embeddings = torch.cat(embeddings_list)

# I didn't demonstrate how to generate the labels, this will be task-dependent.
loss = some_contrastive_loss(embeddings, labels)
Last modification:May 16th, 2021 at 01:16 pm