论文: Stacked Hourglass Networks for Human Pose Estimation Demo Code

论文阅读 - Stacked Hourglass Networks for Human Pose Estimation

pose-hg-demo主要包含文件及文件夹内容:

这里基于Docker、python和pose-hg-demo.

<h2>1. 拉取Torch7镜像</h2>

sudo nvidia-docker pull registry.cn-hangzhou.aliyuncs.com/docker_learning_aliyun/torch:v1

<h2>2. 运行 Demo on MPII Human Pose dataset</h2>

下载MPII Human Pose dataset,并将图片放在 images 文件夹.

sudo nvidia-docker run -it --rm -v /path/to/pose-hg-demo-master:/media registry.cn-hangzhou.aliyuncs.com/docker_learning_aliyun/torch:v1 # 进入Torch镜像 root@8f1548fc3b34:~/torch# cd /media # 即主机中的 pose-hg-demo-master th main.lua predict-test # 得到人体姿态估计结果,并保存在'preds/test.h5'中

利用下面的python脚本可视化人体姿态结果:

#!/usr/bin/env python import h5py import scipy.misc as scm import matplotlib.pyplot as plt test_images = open('../annot/test_images.txt','r').readlines() images_path = './images/' f = h5py.File('./preds/test.h5','r') preds = f'preds' f.close() assert len(test_images) == len(preds) for i in range(len(test_images)): filename = images_path + test_imagesi im = scm.imread(filename) pose = preds[i] plt.axis('off') plt.imshow(im) for i in range(16): if posei > 0 and posei > 0: plt.scatter(posei, posei, marker='o', color='r', s=15) plt.show() print 'Done.'

<h2>3. 自定义图片的人体姿态估计</h2>

由于MPII Human Pose Dataset提供了图片中人体scale和center的标注信息,因此可以直接采用pose-hg-demo提供方式处理:

inputImg = crop(img, center, scale, rot, res)

不过,对于一张或多张图片,未知图片中人体scal和center信息时,需要单独处理,这里,处理思路是: 首先检测人体框(这里未给出实现过程),再采用Python对图片与处理,作为网络输入.

  • Python预处理图片的程序
#!/usr/bin/env python import os import numpy as np import cv2 import matplotlib.pyplot as plt import scipy if name == '__main__': orig_img_path = '/orig/images/path/' new_img_path = '/new/images/path_256/' boxsize = 256 files = os.listdir(orig_img_path) for file in files: if file[-4:] == '.jpg': orig_img_name = orig_img_path + file if(os.path.isfile(orig_img_name)): img = cv2.imread(orig_img_name) height,width = float(img.shape[0]), float(img.shape[1]) scale = min(boxsize/height, boxsize/width) img_resize = cv2.resize(img, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_LANCZOS4) #plt.imshow(img_resize); plt.show() h, w = img_resize.shape[0], img_resize.shape[1] pad_up = abs(int((boxsize - h) / 2)) # up pad_down = abs(boxsize - h - pad_up) # down pad_left = abs(int((boxsize - w) / 2)) # left pad_right = abs(boxsize - w - pad_left) # right pad_img = np.lib.pad(img_resize, ((pad_up, pad_down), (pad_left, pad_right), (0, 0)), 'constant', constant_values=0) new_img_name = new_img_path + file cv2.imwrite(new_img_name, pad_img) print 'Done.'

<h3>3.1 估计单张图片中人体姿态 - demo.lua</h3>

require 'paths' paths.dofile('util.lua') paths.dofile('img.lua') -- Load pre-trained model m = torch.load('umich-stacked-hourglass.t7') -- Set up input image local im = image.load('image/' .. arg[1]) -- Get network output local out = m:forward(im:view(1,3,256,256):cuda()) cutorch.synchronize() local hms = out#out:float() hms[hms:lt(0)] = 0 --print(hms:size()) -- Get predictions (hm and img refer to the coordinate space) if hms:size():size() == 3 then hms = hms:view(1, hms:size(1), hms:size(2), hms:size(3)) end ---- Get locations of maximum activations local max, idx = torch.max(hms:view(hms:size(1), hms:size(2), hms:size(3) * hms:size(4)), 3) local preds = torch.repeatTensor(idx, 1, 1, 2):float() preds[{{}, {}, 1}]:apply(function(x) return (x - 1) % hms:size(4) + 1 end) preds[{{}, {}, 2}]:add(-1):div(hms:size(3)):floor():add(.5) collectgarbage() -- Save predictions local predFile = hdf5.open('preds/pred.h5', 'w') predFile:write('preds', preds) predFile:write('img', im) predFile:close()

<h3>3.2 批量图片中人体姿态估计 - demo_multi.lua</h3>

这个需要在 util.lua 中新增函数loadImageNames

function loadImageNames(fileName) a = {} -- Load in image file names a.images = {} local namesFile = io.open(fileName) local idxs = 1 for line in namesFile:lines() do print(line) a.images[idxs] = line idxs = idxs + 1 end namesFile:close() a.nsamples = idxs-1 return a end

demo_multi.lua:

require 'paths' paths.dofile('util.lua') paths.dofile('img.lua') -------------------------------------------------------------------------------- -- Initialization -------------------------------------------------------------------------------- a = loadImageNames(arg[1]) m = torch.load('umich-stacked-hourglass.t7') -- Load pre-trained model -- Displays a convenient progress bar idxs = torch.range(1, a.nsamples) nsamples = idxs:nElement() xlua.progress(0,nsamples) preds = torch.Tensor(nsamples,16,2) imgs = torch.Tensor(nsamples,3,256,256) -------------------------------------------------------------------------------- -- Main loop -------------------------------------------------------------------------------- for i = 1,nsamples do -- Set up input image --print(a'images']) local im = image.load('image/' .. a'images']) -- Get network output local out = m:forward(im:view(1,3,256,256):cuda()) cutorch.synchronize() local hms = out#out:float() hms[hms:lt(0)] = 0 -- Get predictions (hm and img refer to the coordinate space) if hms:size():size() == 3 then hms = hms:view(1, hms:size(1), hms:size(2), hms:size(3)) end ---- Get locations of maximum activations local max, idx = torch.max(hms:view(hms:size(1), hms:size(2), hms:size(3) * hms:size(4)), 3) local preds_img = torch.repeatTensor(idx, 1, 1, 2):float() preds_img[{{}, {}, 1}]:apply(function(x) return (x - 1) % hms:size(4) + 1 end) preds_img[{{}, {}, 2}]:add(-1):div(hms:size(3)):floor():add(.5) preds[i]:copy(preds_img) imgs[i]:copy(im) xlua.progress(i,nsamples) collectgarbage() end -- Save predictions local predFile = hdf5.open('preds/preds.h5', 'w') predFile:write('preds', preds) predFile:write('imgs', imgs) predFile:close()

<h3>3.3 利用Python可视化结果:</h3>

#!/usr/bin/env python import h5py import scipy.misc as scm import matplotlib.pyplot as plt f = h5py.File('./preds/preds.h5','r') imgs = f'imgs' preds = f'preds' f.close() assert len(imgs) == len(preds) for i in range(len(imgs)): pose = preds[i]*4 # 输入图片是 256×256,输出是64×64,4倍处理 img = imgs[i].transpose(1,2,0) plt.axis('off') plt.imshow(img) for i in range(16): if posei > 0 and posei > 0: plt.scatter(posei, posei, marker='o', color='r', s=15) plt.show() print 'Done.'

结果如下:



Last modification:October 9th, 2018 at 09:31 am