XBM 实现伪代码如:
在 PyTorch 训练中的构建与应用如:
https://github.com/msight-tech/research-xbm/blob/master/ret_benchmark/modeling/xbm.py
import torch
class XBM:
def __init__(self, K):
self.K = K #队列长度
self.feats = torch.zeros(self.K, 128).cuda() #特征
self.targets = torch.zeros(self.K, dtype=torch.long).cuda()#标签
self.ptr = 0 #队列指针
@property
def is_full(self):
return self.targets[-1].item() != 0
def get(self):
if self.is_full:
return self.feats, self.targets
else:
return self.feats[:self.ptr], self.targets[:self.ptr]
def enqueue_dequeue(self, feats, targets):
q_size = len(targets)
if self.ptr + q_size > self.K:
self.feats[-q_size:] = feats
self.targets[-q_size:] = targets
self.ptr = 0
else:
self.feats[self.ptr: self.ptr + q_size] = feats
self.targets[self.ptr: self.ptr + q_size] = targets
self.ptr += q_size
#
#训练代码中
print("[INFO]>>> use XBM")
#XBM 初始化
xbm = XBM(K=1000)
#
feats = ''
targets = ''
#入队列
xbm.enqueue_dequeue(feats.detach(), targets.detach())
#注:.detach() 意味着不用进行梯度计算
loss = criterion(feats, targets, feats, targets)
log_info["batch_loss"] = loss.item()
#出队列
xbm_feats, xbm_targets = xbm.get()
xbm_loss = criterion(feats, targets, xbm_feats, xbm_targets)
log_info["xbm_loss"] = xbm_loss.item()
loss = loss + XBM.WEIGHT * xbm_loss #如:WEIGHT=1.0
#
optimizer.zero_grad()
loss.backward()
optimizer.step()