原文:Gradient backpropagation with torch.distributed.all_gather - 2021.02.07
本文主要是关于 torch.distributed.all_gather 的使用与如何确保梯度被正确计算.
torch.distributed.all_gather 定义如:
https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#all_gather
def all_gather(tensor_list,
tensor,
group=None,
async_op=False):
"""
Gathers tensors from the whole group in a list.
Complex tensors are supported.
Args:
tensor_list (list[Tensor]): Output list. It should contain
correctly-sized tensors to be used for output of the collective.
tensor (Tensor): Tensor to be broadcast from current process.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
Examples:
>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1
>>> # All tensors below are of torch.cfloat dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)]
>>> tensor_list
[tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1
"""
_check_tensor_list(tensor_list, "tensor_list")
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
return
tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
if group is None:
default_pg = _get_default_group()
work = default_pg.allgather([tensor_list], [tensor])
else:
work = group.allgather([tensor_list], [tensor])
if async_op:
return work
else:
work.wait()
1. all_gather 不进行梯度反向传播
首先,torch.distributed.all_gather 本身是不会进行梯度的反向传播的.
如:test.py
:
batch_size = 16
rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', '0'))
world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE', '1'))
bs_each = batch_size // world_size
device_id = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0'))
torch.cuda.set_device(device_id)
torch.distributed.init_process_group(
backend='nccl',
init_method='tcp://localhost:12345',
rank=rank,
world_size=world_size,
)
#
from torch import nn
model = nn.Linear(1, 1, bias=False)
model.weight.data[:] = 1.
model = model.cuda()
x = torch.ones((bs_each, 1), requires_grad=True).cuda()
y = model(x)
ys = [torch.zeros_like(y) for i in range(get_mpi_size())]
#
torch.distributed.all_gather(ys, y)
print(y.grad_fn)
#<MmBackward object at 0x7f2073fc3ba8>
for sub_y in ys:
print(sub_y.grad_fn)
#None
运行该代码,首先,其会打印出没采用 all_gather
的真正的梯度函数y.grad_fn
. 然后,调用 all_gather 后,ys
的输出是没有 grad_fn
的,可以理解为其是没有梯度反向传播的.
实际场景中,推荐采用 torch.no_grad()
封装 all_gather
函数,以显式地表明没有梯度进行反向传播.
2. all_gather 如何反向传播
由于 all_gather
是没有梯度反向传播的,但实际场景是否需要呢?如果需要的话,如何实现呢?
典型设置是,每个 GPU 分别计算输出,基于所有 GPUs 的输出来计算损失函数,而不是每个 GPU 分别计算损失函数.
其可以确保:
[1] - 不需要 all_gather
的梯度传播;
[2] - 损失函数可以非常易于计算.
具体实现如:
假设 i-th GPU 的输出是 $x_i$,损失函数的计算取决于所有的 $x_i$,即:
$$ f(x_1, x_2, ..., x_n) $$
传统的每张 GPU 分别计算损失函数,其形式如:
$$ f(x_1, x_2, ..., x_n) = \frac{1}{n} \sum_i g(x_i) $$
每个GPU计算 $g(x)$ 损失,然后自动梯度计算(auto grad) 会计算所有参数的梯度. 一般其是与 DistributedDataParallel
一起的,自动进行求平均. 此时,是不需要收集其他 GPU 的输出的.
但,如果损失函数不是分别梯度计算的呢?此时,对于每个 GPU,需要收集所有特征. 而,由于收集的输出是没有梯度(grad_fn
)的,可以采用如下方式:
with torch.no_grad():
all_x = [torch.zeros_like(x) for _ in range(world_size)]
torch.distributed.all_gather(all_x, x)
all_x[rank] = x
all_x
包含了所有 GPUs 输出的 x
. 所有的 x
都是没有 grad_fn 的,除了当前 GPU 输出的 x
,因为 all_x[rank] = x
.
然后,即可基于 all_x
和 f
计算损失.
而,梯度计算为:
$$ \frac{\partial f}{\partial x_i} \frac{\partial x_i}{\partial \theta} $$
其中,$\theta$ 为网络参数.
注,仅对 $x_i$ 有梯度,而其他的 $x$ 是没有梯度的. 因为 all_gather
得到的是没有梯度的.
结合 DistributedDataParallel
,每个参数的梯度计算为:
$$ \frac{1}{N} \sum_i \frac{\partial f}{\partial x_i} \frac{\partial x_i}{\partial \theta} $$
其中,$N$ 为 world_size.
但,这并不是期望的. 因为目标损失函数是 $f$,而没有约数$N$. 对此,可以将每个 GPU 的计算作为 $Nf$,而不是 $f$.
3. 总结
虽然 all_gather
是没有梯度计算的,但可以有效的计算真正的梯度,主要步骤为:
[1] - 采用 all_gather
收集所有网络的输出,然后,替换当前网络的输出,而当前网络的输出是有梯度的.
[2] - 计算损失函数,并乘以 world_size.
4. 相关材料
Multi-gpu training:
https://github.com/KevinMusgrave/pytorch-metric-learning/issues/10
[1] - all_gather
embeddings
from all replicas.
[2] - Because gathered tensors have no gradients, we overwrite the gathered embeddings
tensor from the current replica with the embeddings
tensor produced on that replica, which has gradients to the encoder.
[3] - Concatenate the list of embeddings
before computing the loss.
实现的伪代码如:
import torch
import torch.distributed as dist
# Dummy code representing the forward pass for some batch of text on one replica.
embeddings = model(batch)
# Gather the embeddings from every replica.
embeddings_list = [torch.ones_like(embeddings) for _ in range(dist.get_world_size())]
dist.all_gather(embeddings_list, embeddings)
# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.
# with the embeddings produced on this replica, which do have gradients.
embeddings_list[dist.get_rank()] = embeddings
# Finally, concatenate the list of embeddings before computing a loss.
embeddings = torch.cat(embeddings_list)
# I didn't demonstrate how to generate the labels, this will be task-dependent.
loss = some_contrastive_loss(embeddings, labels)