Github - bnu-wangxun/Deep_Metric

Deep Metric Learning in PyTorch

Learn deep metric for image retrieval or other information retrieval.

代码学习

项目主要是关于深度度量学习.

整理思维导图如:

Xmind 文件:Deep_Metric_xmind.doc 文件名后缀去除.doc,修改为 .xmind.

1. 训练

1.1. 数据读取

以 In-shop-clothes 数据集的读取为例:

def default_loader(path):
    return Image.open(path).convert('RGB')
  
  
class MyData(data.Dataset):
    def __init__(self, root=None, label_txt=None,
                 transform=None, loader=default_loader):

        if root is None:
            root = "/home/xunwang"
            label_txt = os.path.join(root, 'train.txt')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        if transform is None:
            transform = transforms.Compose([
                # transforms.CovertBGR(),
                transforms.Resize(256),
                transforms.RandomResizedCrop(scale=(0.16, 1), size=224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])

        #从 txt 文件中读取图片路径和标签
        file = open(label_txt)
        images_anon = file.readlines()

        images = []
        labels = []
        for img_anon in images_anon:
            img_anon = img_anon.replace(' ', '\t')
            [img, label] = (img_anon.split('\t'))[:2]
            images.append(img)
            labels.append(int(label))
                
        #
        classes = list(set(labels))

        #为每一类建立索引字典Index Dictionary
        Index = defaultdict(list)
        for i, label in enumerate(labels):
            Index[label].append(i)

        #
        self.root = root
        self.images = images
        self.labels = labels
        self.classes = classes
        self.transform = transform
        self.Index = Index
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.images[index], self.labels[index]
        img = self.loader(os.path.join(self.root, fn))
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.images)
      
#
class InShopClothes:
    def __init__(self, root=None, crop=False, origin_width=256, width=224, ratio=0.16):
        # Data loading
        transform_Dict = Generate_transform_Dict(
          origin_width=origin_width, 
          width=width, ratio=ratio)

        if root is None:
            root = 'data/In_shop_clothes'

        train_txt = os.path.join(root, 'train.txt')
        self.train = MyData(root, label_txt=train_txt, 
                            transform=transform_Dict['rand-crop'])

1.2. 数据采样

FastRandomIdentitySampler 采样函数:

from torch.utils.data.sampler import (
  Sampler, 
  SequentialSampler, 
  RandomSampler, 
  SubsetRandomSampler,
  WeightedRandomSampler)

class FastRandomIdentitySampler(Sampler):
    def __init__(self, data_source, num_instances=1):
        self.data_source = data_source
        self.num_instances = num_instances
        self.index_dic = defaultdict(list)

        # for index, (_, pid) in enumerate(data_source):
        #     self.index_dic[pid].append(index)

        self.index_dic = data_source.Index

        self.pids = list(self.index_dic.keys())
        self.num_samples = len(self.pids)

    def __len__(self):
        return self.num_samples * self.num_instances

    def __iter__(self):
        indices = torch.randperm(self.num_samples)
        ret = []
        for i in indices:
            pid = self.pids[i]
            t = self.index_dic[pid]
            if len(t) >= self.num_instances:
                t = np.random.choice(t, size=self.num_instances, replace=False)
            else:
                t = np.random.choice(t, size=self.num_instances, replace=True)
            ret.extend(t)
        # print('Done data sampling')
        return iter(ret)

1.3. 网络输出层微调

#基于预训练的分类模型
# Fine-tune the model: the learning rate for pre-trained parameter is 1/10
new_param_ids = set(map(id, model.module.classifier.parameters()))
new_params = [p for p in model.module.parameters() if
              id(p) in new_param_ids]
base_params = [p for p in model.module.parameters() if
               id(p) not in new_param_ids]
param_groups = [{'params': base_params, 'lr_mult': 0.0},
                {'params': new_params, 'lr_mult': 1.0}]

1.4. 模型训练

一个 epoch 的训练函数 train 如,trainer.py:

# coding=utf-8
from __future__ import print_function, absolute_import
import time
from utils import AverageMeter, orth_reg
import  torch
from torch.autograd import Variable
from torch.backends import cudnn

cudnn.benchmark = True


def train(epoch, model, criterion, optimizer, train_loader, args):
    losses = AverageMeter()
    batch_time = AverageMeter()
    accuracy = AverageMeter()
    pos_sims = AverageMeter()
    neg_sims = AverageMeter()

    end = time.time()

    freq = min(args.print_freq, len(train_loader))
    for i, data_ in enumerate(train_loader, 0):
        inputs, labels = data_

        #变量封装为Variable
        inputs = Variable(inputs).cuda()
        labels = Variable(labels).cuda()

        optimizer.zero_grad()
        embed_feat = model(inputs) #特征
        #损失函数计算
        loss, inter_, dist_ap, dist_an = criterion(embed_feat, labels)

        if args.orth_reg != 0:
            loss = orth_reg(net=model, loss=loss, cof=args.orth_reg)
                #BP 计算
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        losses.update(loss.item())
        accuracy.update(inter_)
        pos_sims.update(dist_ap)
        neg_sims.update(dist_an)

        if (i + 1) % freq == 0 or (i+1) == len(train_loader):
            print('Epoch: [{0:03d}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f} \t'
                  'Accuracy {accuracy.avg:.4f} \t'
                  'Pos {pos.avg:.4f}\t'
                  'Neg {neg.avg:.4f} \t'.format
                  (epoch + 1, i + 1, len(train_loader), batch_time=batch_time,
                   loss=losses, accuracy=accuracy, pos=pos_sims, neg=neg_sims))

        if epoch == 0 and i == 0:
            print('-- HA-HA-HA-HA-AH-AH-AH-AH --')

2. 测试

2.1. 特征提取

cnn.py - extract_cnn_feature:

def to_torch(ndarray):
    if type(ndarray).__module__ == 'numpy':
        return torch.from_numpy(ndarray)
    elif not torch.is_tensor(ndarray):
        raise ValueError("Cannot convert {} to torch tensor"
                         .format(type(ndarray)))
    return ndarray

def extract_cnn_feature(model, inputs, pool_feature=False):
    model.eval()
    with torch.no_grad():
        inputs = to_torch(inputs)
        inputs = Variable(inputs).cuda()
        if pool_feature is False:
            outputs = model(inputs)
            return outputs
        else:
            # Register forward hook for each module
            outputs = {}

        def func(m, i, o): 
            outputs['pool_feature'] = o.data.view(n, -1)
        #
        hook = model.module._modules.get('features').register_forward_hook(func)
        model(inputs)
        hook.remove()
        # print(outputs['pool_feature'].shape)
        return outputs['pool_feature']

extract_feature.py

def extract_features(model, data_loader, print_freq=1, metric=None, pool_feature=False):
    # model.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    feature_cpu = torch.FloatTensor()
    feature_gpu = torch.FloatTensor().cuda()

    trans_inter = 1e4
    labels = list()
    end = time.time()

    for i, (imgs, pids) in enumerate(data_loader):
        imgs = imgs
        outputs = extract_cnn_feature(model, imgs, pool_feature=pool_feature)
        feature_gpu = torch.cat((feature_gpu, outputs.data), 0)
        labels.extend(pids)
        count = feature_gpu.size(0)
        if count > trans_inter or i == len(data_loader)-1:
            # print(feature_gpu.size())
            data_time.update(time.time() - end)
            end = time.time()
            feature_cpu = torch.cat((feature_cpu, feature_gpu.cpu()), 0)
            feature_gpu = torch.FloatTensor().cuda()
            batch_time.update(time.time() - end)
            print('Extract Features: [{}/{}]\t'
                  'Time {:.3f} ({:.3f})\t'
                  'Data {:.3f} ({:.3f})\t'
                  .format(i + 1, len(data_loader),
                          batch_time.val, batch_time.avg,
                          data_time.val, data_time.avg))

            end = time.time()
        del outputs

    return feature_cpu, labels

2.2. 相似性计算

sim_mat = pairwise_similarity(query_feature, gallery_feature)

pairwise_similarity 函数:

def normalize(x):
    norm = x.norm(dim=1, p=2, keepdim=True)
    x = x.div(norm.expand_as(x))
    return x
  
def pairwise_similarity(x, y=None):
    if y is None:
        y = x 
    # normalization
    y = normalize(y)
    x = normalize(x)
    # similarity
    similarity = torch.mm(x, y.t())
    return similarity
  
#
def pairwise_distance(features, metric=None):
    n = features.size(0)
    # normalize feature
    x = normalize(features)
    if metric is not None:
        x = metric.transform(x)
    dist = torch.pow(x, 2).sum(dim=1, keepdim=True)
    # print(dist.size())
    dist = dist.expand(n, n)
    dist = dist + dist.t()
    dist = dist - 2 * torch.mm(x, x.t()) 
    dist = torch.sqrt(dist)
    return dist

2.3. Recall 计算

Recall_at_ks 函数:

def to_numpy(tensor):
    if torch.is_tensor(tensor):
        return tensor.cpu().numpy()
    elif type(tensor).__module__ != 'numpy':
        raise ValueError("Cannot convert {} to numpy array"
                         .format(type(tensor)))
    return tensor
  
def Recall_at_ks(sim_mat, data='product', query_ids=None, gallery_ids=None):
    """
    :param sim_mat:
    :param query_ids
    :param gallery_ids
    :param data

    Compute  [R@1, R@10, R@100, R@1000]
    """

    ks_dict = dict()
    ks_dict['product'] = [1, 10, 100, 1000]
    ks_dict['shop'] = [1, 10, 20, 30, 40, 50]

    if data is None:
        data = 'product'
    k_s = ks_dict[data]

    sim_mat = to_numpy(sim_mat)
    m, n = sim_mat.shape
    gallery_ids = np.asarray(gallery_ids)
    if query_ids is None:
        query_ids = gallery_ids
    else:
        query_ids = np.asarray(query_ids)

    num_max = int(1e6)
    if m > num_max:
        samples = list(range(m))
        random.shuffle(samples)
        samples = samples[:num_max]
        sim_mat = sim_mat[samples, :]
        query_ids = [query_ids[k] for k in samples]
        m = num_max

    # Hope to be much faster  yes!!
    num_valid = np.zeros(len(k_s))
    neg_nums = np.zeros(m)
    for i in range(m):
        x = sim_mat[i]
        pos_max = np.max(x[gallery_ids == query_ids[i]])
        neg_num = np.sum(x > pos_max)
        neg_nums[i] = neg_num

    for i, k in enumerate(k_s):
        if i == 0:
            temp = np.sum(neg_nums < k)
            num_valid[i:] += temp
        else:
            temp = np.sum(neg_nums < k)
            num_valid[i:] += temp - num_valid[i-1]

    return num_valid / float(m)

2.4. topK 计算

Compute_top_k 函数:

import heapq

def Compute_top_k(sim_mat, k=10):
    """
    :param sim_mat:

    Compute
    top-k in gallery for each query
    """
    sim_mat = to_numpy(sim_mat)
    m, n = sim_mat.shape
    print('query number is %d' % m)
    print('gallery number is %d' % n)

    top_k = np.zeros([m, k])
    for i in range(m):
        sim_i = sim_mat[i]
        idx = heapq.nlargest(k, range(len(sim_i)), sim_i.take)
        top_k[i] = idx
    return top_k
Last modification:April 13th, 2021 at 03:23 pm