PyTorch 在模型训练时,使用 DataLoader 来生成训练数据. 其支持并行执行,以加快模型训练速度,其主要是基于 python 的 multiprocessing.
每个进程分别生成 batch 数据,并互斥同步地提供给主进程. 假设有 N 个workers,则需要 N 倍的内存. 其计算,比如:
[1] - 假设 batchsize=64, RGB 图像大小为 512x512x3,在 CPU 上进行图像标准化,此时,最终的图像 tensor:512x512x3xsizeof(float32) =512x512x3x4=3,145,728 bytes. 再乘以 batchsize,得到 201326592 bytes,约200M.
[2] - 除了RGB图像,还有 GT 标签,分类问题中较少,不过在分割问题中是不可忽视的.
[3] - 最终,所需的总内存量为 200M x N. 如,N=16,则需要 200Mx16=3200M.