DataLoader 和 Dataset 核心逻辑伪代码,如:
import torch
class Dataset(object):
def __init__(self):
pass
def __len__(self):
raise NotImplementedError
def __getitem__(self,index):
raise NotImplementedError
class DataLoader(object):
def __init__(self,dataset,batch_size,collate_fn,shuffle = True,drop_last = False):
self.dataset = dataset
self.sampler =torch.utils.data.RandomSampler if shuffle else \
torch.utils.data.SequentialSampler
self.batch_sampler = torch.utils.data.BatchSampler
self.sample_iter = self.batch_sampler(
self.sampler(range(len(dataset))),
batch_size = batch_size,drop_last = drop_last)
def __next__(self):
indices = next(self.sample_iter)
batch = self.collate_fn([self.dataset[i] for i in indices])
return batch
1. DataLoader
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
torch.utils.data.DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor=2,
persistent_workers=False)
其中,
- num_workers: CPU 使用进程. 一般建议机器总共 CPU 核的数量
- pin_memory: 是否先把数据加载到缓存再加载到GPU. 建议开启. 是否设置为锁业内存. 默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快.
- drop_last: 一般建议开启,最后一个 batch 如果小于 batch_size, 会扔掉, 训练会更稳定.
选择 worker 数量一般经验是将其设置为可用 GPU 数量的四倍,大于或小于这个数都会降低训练速度.
注意,增加 num_workers 将增加 CPU 内存消耗.
使用大 batch 可能导致解决方案的泛化能力比使用小 batch 的差.
若 pin_memory=True
,则建议数据放入 GPU 的时,开启 non_blocking=True
,则数据只是放入 GPU,而不从 GPU 拿出来再做计算,速度会快很多;而且,就算把数据从 GPU 中再拿出来,比如,用了 .cpu()
,最差的结果也是与 non_blocking=False
相当. 如:
#to gpu
image = image.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True).float()
# predict
prediction = net(image)
1.1. CPU/GPU 间数据传输
训练过程尽可能的减少CPU和GPU之间的频繁数据传输.
因为,当频繁使用 .cpu()
或 .cuda()
将 tensor 在 GPU 和 CPU之间转换时,代价是比较大.
可以使用 .as_tensor()
而不是 .tensor()
,因为 torch.tensor() 总是会复制数据. 如果要转换一个 numpy 数组,使用 torch.as_tensor()
或 torch.from_numpy()
来避免复制数据.
2. 采用 LMDB 加速数据读取
模型训练时,如果 batchsize 比较大, 或者一个 batch 需要比较多的预处理,Pytorch 的 DataLoader 获取一个 batch 数据的时间会比较久,进而会出现 GPU 空闲等待 CPU 的情况,导致训练效率下降.
LMDB是一种数据库,可以实现多进程访问,访问简单,而且不需要把全部文件读入内存,速度很快.
2.1. 创建 LMDB
首先,需要将原始数据集转换为LMDB的数据格式,以便于后面训练时读取:
Github - xunge/pytorch_lmdb_imagenet
import lmdb
import pickle
import numpy as np
def dumps_data(obj):
"""
Serialize an object.
Returns:
Implementation-dependent bytes-like object
"""
return pickle.dumps(obj)
def convert_to_lmdb(datas, db_path='data_lmdb', write_frequency=50):
"""
db_path: the path you want to save the lmdb file
write_frequency: Write once every ? rounds
"""
if not os.path.exists(db_path):
os.makedirs(db_path)
lmdb_path = os.path.join(db_path, "%s.lmdb" % name)
isdir = os.path.isdir(lmdb_path)
print("Generate LMDB to %s" % lmdb_path)
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True)
txn = db.begin(write=True)
for idx, data in enumerate(datas):
# get data from dataloader
pic = data
# put data to lmdb dataset
# {idx, (in_LDRs, in_HDRs, ref_HDRs)}
txn.put(u'{}'.format(idx).encode('ascii'), dumps_data((pic.numpy())))
if idx % write_frequency == 0:
print("[%d/%d]" % (idx, len(dataloader)))
txn.commit()
txn = db.begin(write=True)
# finish iterating through dataset
txn.commit()
keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps_data(keys))
txn.put(b'__len__', dumps_data(len(keys)))
print("Flushing database ...")
db.sync()
db.close()
2.2. DataLoader 读取 LMDB
如:
import six
import lmdb
import pickle
import numpy as np
def loads_data(buf):
"""
Args:
buf: the output of `dumps`.
"""
return pickle.loads(buf)
class ImageFolderLMDB(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None):
self.db_path = db_path
self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
self.length = loads_data(txn.get(b'__len__'))
self.keys = loads_data(txn.get(b'__keys__'))
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
env = self.env
with env.begin(write=False) as txn:
byteflow = txn.get(self.keys[index])
unpacked = loads_data(byteflow)
# load img
imgbuf = unpacked[0]
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
img = Image.open(buf).convert('RGB')
# load label
target = unpacked[1]
if self.transform is not None:
img = self.transform(img)
im2arr = np.array(img)
if self.target_transform is not None:
target = self.target_transform(target)
# return img, target
return im2arr, target
def __len__(self):
return self.length
def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'
3. 采用 MongoDB 加速数据读取
大量图片文件存储在硬盘上,时间久了偶尔会出现读取时损坏问题;
基于数据库的方案对于加速数据读取、数据有效性,值得一试.
import json
from io import BytesIO
from PIL import Image
from pymongo import MongoClient
class MongoDataSet(Dataset):
"""
Mongo Dataset read image path from img_source
img_source: list of img_path and label
"""
def __init__(self, img_source, transforms=None, is_train=False, mode="RGB"):
self.mode = mode
self.transforms = transforms
assert os.path.exists(img_source), f"{img_source} NOT found."
self.img_source = img_source
with open(img_source) as fp:
datas_dict = json.load(fp)
self.label_list = datas_dict['label_list']
self.path_list = datas_dict['path_list']
client = MongoClient('192.168.100.100:2100', connect=False)
self.coll = client['db']['test_coll']
def __len__(self):
return len(self.label_list)
def __repr__(self):
return self.__str__()
def __str__(self):
return f"| Dataset Info |datasize: {self.__len__()}|num_labels: {len(set(self.label_list))}|"
def __getitem__(self, index):
img_name = self.path_list[index]
label = self.label_list[index]
doc = self.coll.find_one({'img_name': img_name})
img_bytes = doc['bson_img']
#判断图片是否损坏
if not img_bytes.endswith(b'\xff\xd9'):
img_bytes = img_bytes + b'\xff\xd9'
img = Image.open(BytesIO(img_bytes))
if self.transforms is not None:
img = self.transforms(img)
return img, label, index