Caffe 在图像分类模型的训练时, 效率起见, 未直接从图片列表读取图片, 训练数据往往是采用 LMDB 或 HDF5 格式.
LMDB格式的优点:
- 基于文件映射IO(memory-mapped),数据速率更好
- 对大规模数据集更有效.
HDF5的特点:
- 易于读取
- 类似于mat数据,但数据压缩性能更强
- 需要全部读进内存里,故HDF5文件大小不能超过内存,可以分成多个HDF5文件,将HDF5子文件路径写入txt中.
- I/O速率不如LMDB.
1. LMDB创建
import lmdb
import caffe
lmdb_file = '/path/to/data_lmdb'
N = 1000
# 准备 data 和 labels
X = np.zeros((N, 3, 224, 224), dtype=np.uint8) # data
y = np.zeros(N, dtype=np.int64) # labels
env = lmdb.open(lmdb_file, map_size=int(1e12))
txn = env.begin(write=True)
for i in range(N):
datum = caffe.proto.caffe_pb2.Datum()
datum.channels = X.shape[1]
datum.height = X.shape[2]
datum.width = X.shape[3]
datum.data = X[i].tobytes() # or .tostring() if numpy < 1.9
datum.label = int(y[i])
# 以上五行也可以直接: datum = caffe.io.array_to_datum(data, label)
str_id = '{:08}'.format(i)
txn.put(str_id, datum.SerializeToString())
# in Python3
# txn.put(str_id.encode('ascii'), datum.SerializeToString())
2. LMDB读取
import numpy as np
import lmdb
import caffe
env = lmdb.open('data_lmdb', readonly=True)
txn = env.begin()
lmdb_cursor = txn.cursor()
datum = caffe.proto.caffe_pb2.Datum()
for key, value in lmdb_cursor:
print '{},{}'.format(key, value)
datum.ParseFromString(value)
flat_data = np.fromstring(datum.data, dtype=np.uint8)
data = flat_data.reshape(datum.channels, datum.height, datum.width)
# 或 data = caffe.io.datum_to_array(datum)
labels = datum.label
3. HDF5创建和读取
W1:
import h5py
import numpy as np
# 创建HDF5文件
imgsData = np.zeros((10,3,224,224)) # Images
labels = range(10) # Labels
f = h5py.File('HDF5_FILE.h5','w') # 创建一个h5文件
f['datas'] = imgsData # 写入Images数据
f['labels'] = labels # 写入Labels数据
f.close() #
# 读取HDF5文件
f = h5py.File('HDF5_FILE.h5','r') # 打开h5文件
f_keys = f.keys()
imgsData = f['datas'][:]
labels = f['labels'][:]
f.close()
W2:
import h5py
datas = np.random.rand(100, 1000, 1000).astype('float32')
labels = np.random.rand(1, 1000, 1000).astype('float32')
# Create a new file
f = h5py.File('data.h5', 'w')
f.create_dataset('datas', data=datas)
f.create_dataset('labels', data=labels)
f.close()
# Load hdf5 dataset
f = h5py.File('data.h5', 'r')
X = f['datas']
Y = f['labels']
f.close()
4. LMDB 数据集创建
"""
a modified version of CRNN torch repository
https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py
"""
import fire
import os
import lmdb
import cv2
import numpy as np
def checkImageIsValid(imageBin):
if imageBin is None:
return False
imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True
def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.items():
txn.put(k, v)
def createDataset(inputPath, gtFile, outputPath, checkValid=True):
"""
Create LMDB dataset for training and evaluation.
ARGS:
inputPath : input folder path where starts imagePath
outputPath : LMDB output path
gtFile : list of image path and label
checkValid : if true, check the validity of every image
"""
os.makedirs(outputPath, exist_ok=True)
env = lmdb.open(outputPath, map_size=1099511627776)
cache = {}
cnt = 1
with open(gtFile, 'r', encoding='utf-8') as data:
datalist = data.readlines()
nSamples = len(datalist)
for i in range(nSamples):
imagePath, label = datalist[i].strip('\n').split('\t')
imagePath = os.path.join(inputPath, imagePath)
# # only use alphanumeric data
# if re.search('[^a-zA-Z0-9]', label):
# continue
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'rb') as f:
imageBin = f.read()
if checkValid:
try:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
continue
except:
print('error occured', i)
with open(outputPath + '/error_image_log.txt', 'a') as log:
log.write('%s-th image data occured error\n' % str(i))
continue
imageKey = 'image-%09d'.encode() % cnt
labelKey = 'label-%09d'.encode() % cnt
cache[imageKey] = imageBin
cache[labelKey] = label.encode()
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
nSamples = cnt-1
cache['num-samples'.encode()] = str(nSamples).encode()
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)
if __name__ == '__main__':
fire.Fire(createDataset)
5. LMDB 数据集读取
import lmdb
import numpy as np
import cv2
lmdb_file = "/path/to/lmdb"
lmdb_env = lmdb.open(lmdb_file)
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()
for key, value in lmdb_cursor:
print('[INFO]', key)
img = cv2.imdecode(np.fromstring(value, np.uint8), 3);
cv2.imshow("demo", img)
cv2.waitKey(0)
One comment
[...]后端前端人工智能DevOps移动端测试程序人生 Search 人工智能Python主要数据读写方式 2019年10月20日 Leave a Commentpython常用的训练数据的格式读写汇总HDF5HDF5的特点:易于读取类似于mat数据,但数据压缩性能更强需要全部读进内存里,故HDF5文件大小不能超过内存,可以分成多个HDF5文件,将HDF5子文件路径写入txt中.I/O速率不如L[...]