DataLoader 和 Dataset 核心逻辑伪代码,如:

import torch 

class Dataset(object):
    def __init__(self):
        pass

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self,index):
        raise NotImplementedError


class DataLoader(object):
    def __init__(self,dataset,batch_size,collate_fn,shuffle = True,drop_last = False):
        self.dataset = dataset
        self.sampler =torch.utils.data.RandomSampler if shuffle else \
           torch.utils.data.SequentialSampler
        
        self.batch_sampler = torch.utils.data.BatchSampler
        self.sample_iter = self.batch_sampler(
            self.sampler(range(len(dataset))),
            batch_size = batch_size,drop_last = drop_last)

    def __next__(self):
        indices = next(self.sample_iter)
        batch = self.collate_fn([self.dataset[i] for i in indices])
        return batch

1. DataLoader

https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

torch.utils.data.DataLoader(dataset, 
                            batch_size=1, 
                            shuffle=False, 
                            sampler=None, 
                            batch_sampler=None, 
                            num_workers=0, 
                            collate_fn=None, 
                            pin_memory=False, 
                            drop_last=False, 
                            timeout=0, 
                            worker_init_fn=None, 
                            multiprocessing_context=None, 
                            generator=None, 
                            *, 
                            prefetch_factor=2, 
                            persistent_workers=False)

其中,

  • num_workers: CPU 使用进程. 一般建议机器总共 CPU 核的数量
  • pin_memory: 是否先把数据加载到缓存再加载到GPU. 建议开启. 是否设置为锁业内存. 默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快.
  • drop_last: 一般建议开启,最后一个 batch 如果小于 batch_size, 会扔掉, 训练会更稳定.

选择 worker 数量一般经验是将其设置为可用 GPU 数量的四倍,大于或小于这个数都会降低训练速度.

注意,增加 num_workers 将增加 CPU 内存消耗.

使用大 batch 可能导致解决方案的泛化能力比使用小 batch 的差.

pin_memory=True,则建议数据放入 GPU 的时,开启 non_blocking=True,则数据只是放入 GPU,而不从 GPU 拿出来再做计算,速度会快很多;而且,就算把数据从 GPU 中再拿出来,比如,用了 .cpu() ,最差的结果也是与 non_blocking=False 相当. 如:

#to gpu
image = image.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True).float()

# predict
prediction = net(image)

1.1. CPU/GPU 间数据传输

训练过程尽可能的减少CPU和GPU之间的频繁数据传输.

因为,当频繁使用 .cpu() .cuda()将 tensor 在 GPU 和 CPU之间转换时,代价是比较大.

可以使用 .as_tensor() 而不是 .tensor(),因为 torch.tensor() 总是会复制数据. 如果要转换一个 numpy 数组,使用 torch.as_tensor()torch.from_numpy() 来避免复制数据.

2. 采用 LMDB 加速数据读取

模型训练时,如果 batchsize 比较大, 或者一个 batch 需要比较多的预处理,Pytorch 的 DataLoader 获取一个 batch 数据的时间会比较久,进而会出现 GPU 空闲等待 CPU 的情况,导致训练效率下降.

LMDB是一种数据库,可以实现多进程访问,访问简单,而且不需要把全部文件读入内存,速度很快.

2.1. 创建 LMDB

Caffe - 基于 Python 创建LMDB/HDF5格式数据 - AIUAI

首先,需要将原始数据集转换为LMDB的数据格式,以便于后面训练时读取:

Github - xunge/pytorch_lmdb_imagenet

import lmdb
import pickle
import numpy as np 


def dumps_data(obj):
    """
    Serialize an object.
    Returns:
        Implementation-dependent bytes-like object
    """
    return pickle.dumps(obj)


def convert_to_lmdb(datas, db_path='data_lmdb', write_frequency=50):
    """
    db_path: the path you want to save the lmdb file
    write_frequency: Write once every ? rounds
    """
    if not os.path.exists(db_path):
        os.makedirs(db_path)
    lmdb_path = os.path.join(db_path, "%s.lmdb" % name)
    isdir = os.path.isdir(lmdb_path)

    print("Generate LMDB to %s" % lmdb_path)
    db = lmdb.open(lmdb_path, subdir=isdir,
                   map_size=1099511627776 * 2, readonly=False,
                   meminit=False, map_async=True)

    txn = db.begin(write=True)
    for idx, data in enumerate(datas):
        # get data from dataloader
        pic = data

        # put data to lmdb dataset
        # {idx, (in_LDRs, in_HDRs, ref_HDRs)}
        txn.put(u'{}'.format(idx).encode('ascii'), dumps_data((pic.numpy())))
        if idx % write_frequency == 0:
            print("[%d/%d]" % (idx, len(dataloader)))
            txn.commit()
            txn = db.begin(write=True)

    # finish iterating through dataset
    txn.commit()
    keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
    with db.begin(write=True) as txn:
        txn.put(b'__keys__', dumps_data(keys))
        txn.put(b'__len__', dumps_data(len(keys)))

    print("Flushing database ...")
    db.sync()
    db.close()

2.2. DataLoader 读取 LMDB

如:

import six
import lmdb
import pickle
import numpy as np 


def loads_data(buf):
    """
    Args:
        buf: the output of `dumps`.
    """
    return pickle.loads(buf)


class ImageFolderLMDB(data.Dataset):
    def __init__(self, db_path, transform=None, target_transform=None):
        self.db_path = db_path
        self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
                             readonly=True, lock=False,
                             readahead=False, meminit=False)
        
        with self.env.begin(write=False) as txn:
            self.length = loads_data(txn.get(b'__len__'))
            self.keys = loads_data(txn.get(b'__keys__'))

        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        env = self.env
        with env.begin(write=False) as txn:
            byteflow = txn.get(self.keys[index])

        unpacked = loads_data(byteflow)

        # load img
        imgbuf = unpacked[0]
        buf = six.BytesIO()
        buf.write(imgbuf)
        buf.seek(0)
        img = Image.open(buf).convert('RGB')

        # load label
        target = unpacked[1]

        if self.transform is not None:
            img = self.transform(img)

        im2arr = np.array(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        # return img, target
        return im2arr, target

    def __len__(self):
        return self.length

    def __repr__(self):
        return self.__class__.__name__ + ' (' + self.db_path + ')'

3. 采用 MongoDB 加速数据读取

大量图片文件存储在硬盘上,时间久了偶尔会出现读取时损坏问题;

基于数据库的方案对于加速数据读取、数据有效性,值得一试.

import json 
from io import BytesIO
from PIL import Image

from pymongo import MongoClient


class MongoDataSet(Dataset):
    """
    Mongo Dataset read image path from img_source
    img_source: list of img_path and label
    """

    def __init__(self, img_source, transforms=None, is_train=False, mode="RGB"):
        self.mode = mode
        self.transforms = transforms
        assert os.path.exists(img_source), f"{img_source} NOT found."
        self.img_source = img_source
        with open(img_source) as fp:
            datas_dict = json.load(fp) 
    
        self.label_list = datas_dict['label_list']
        self.path_list = datas_dict['path_list']

        client = MongoClient('192.168.100.100:2100', connect=False)
        self.coll = client['db']['test_coll']

    def __len__(self):
        return len(self.label_list)

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        return f"| Dataset Info |datasize: {self.__len__()}|num_labels: {len(set(self.label_list))}|"

    def __getitem__(self, index):
        img_name = self.path_list[index]
        label = self.label_list[index]
        
        doc = self.coll.find_one({'img_name': img_name})
        img_bytes = doc['bson_img']
        
        #判断图片是否损坏
        if not img_bytes.endswith(b'\xff\xd9'):
            img_bytes = img_bytes + b'\xff\xd9'
        img = Image.open(BytesIO(img_bytes))
        
        if self.transforms is not None:
            img = self.transforms(img)
            
        return img, label, index
Last modification:June 13th, 2022 at 04:20 pm