论文: Multi-Context Attention for Human Pose Estimation
类似于 Torch实践 - Stacked Hourglass Networks for Human Pose Estimation ,基于Docker-Torch,估计人体关节点.
这里只简单进行测试估计结果,由于显存有限,未能加入所有的 scale_search.
<h2>1. 图片人体姿态估计 - demo.lua</h2>
# 输入参数由两个, 第二个参数默认为 'mean'
th demo.lua imglist.txt 'max'
# or
th demo.lua imglist.txt
require 'paths'
paths.dofile('util.lua')
paths.dofile('img.lua')
--------------------------------------------------------------------------------
-- Initialization
--------------------------------------------------------------------------------
a = loadImageNames(arg[1]) -- 批量读取文件名列表
m = torch.load( '../checkpoints/mpii/crf_parts/model.t7') -- Load pre-trained model
m:cuda()
m:evaluate()
-- Parameters
local isflip = true
local minusmean = tru
local scale_search = {1.0, 1.1} -- 根据显存情况来选择
-- local scale_search = {0.7,0.8,0.9,1.0,1.1,1.2} -- used in paper with NVIDIA Titan X (12 GB memory).
-- Displays a convenient progress bar
idxs = torch.range(1, a.nsamples)
nsamples = idxs:nElement()
xlua.progress(0,nsamples)
preds = torch.Tensor(nsamples,16,3)
imgs = torch.Tensor(nsamples,3,256,256)
local imgpath = '../data/image/'
--------------------------------------------------------------------------------
-- Main loop
--------------------------------------------------------------------------------
for idx = 1,nsamples do
-- Set up input image
local imgname = paths.concat(imgpath, a'images'])
print(imgname)
local im = image.load(imgname)
local original_scale = 256/200 -- 假设预先已经将图像中人体进行裁剪,并resize到256
local center = {128.0, 128.0}
local fuseInp = torch.zeros(#scale_search, 3, 256, 256)
local hmpyra = torch.zeros(#scale_search, 16, im:size(2), im:size(3))
local batch = torch.Tensor(#scale_search, 3, 256, 256)
local flipbatch = torch.Tensor(#scale_search, 3, 256, 256)
for is, factor in ipairs(scale_search) do
local scale = original_scale*factor
local inp = crop(im, center, scale, 0, 256)
batch[{is, {}, {}, {}}]:copy(inp)
imgs[idx]:copy(inp)
end
-- minus mean
if minusmean then
batch:add(-0.5)
end
-- Get network output
local out = m:forward(batch:cuda())
-- Get flipped output
if isflip then
out = applyFn(function (x) return x:clone() end, out)
local flippedOut = m:forward(flip(batch):cuda())
flippedOut = applyFn(function (x) return flip(shuffleLR(x)) end, flippedOut)
out = applyFn(function (x,y) return x:add(y):div(2) end, out, flippedOut)
end
cutorch.synchronize()
local hm = out[#out]:float()
hm[hm:lt(0)] = 0
-- Get heatmaps (original image size)
for is, scale in pairs(scale_search) do
local hm_img = getHeatmaps(im, center, original_scale*scale, 0, 256, hm[is])
hmpyra[{is, {}, {}, {}}]:copy(hm_img:sub(1, 16))
end
-- fuse heatmap
if arg[2] == 'max' then
fuseHm = hmpyra:max(1)
else
fuseHm = hmpyra:mean(1)
end
fuseHm = fuseHm[1]
fuseHm[fuseHm:lt(0)] = 0
-- get predictions
for p = 1,16 do
local maxy, iy = fuseHm[p]:max(2)
local maxv, ix = maxy:max(1)
ix = torch.squeeze(ix)
predsidx[2] = ix
predsidx[1] = iy[ix]
predsidx[3] = maxy[ix]
end
xlua.progress(idx, nsamples)
collectgarbage()
end
-- Save predictions
local predFile = hdf5.open('../preds/preds.h5', 'w')
predFile:write('preds', preds)
predFile:write('imgs', imgs)
predFile:close()
<h2>2. 人体姿态估计可视化 - show.py</h2>
#!/usr/bin/env python
import h5py
import scipy.misc as scm
import matplotlib.pyplot as plt
JointsIndex = {'r_ankle': 0, 'r_knee': 1, 'r_hip': 2,
'l_hip': 3, 'l_knee': 4, 'l_ankle': 5,
'pelvis': 6, 'thorax': 7, 'neck': 8, 'head': 9,
'r_wrist': 10, 'r_elbow': 11, 'r_shoulder': 12,
'l_shoulder': 13, 'l_elbow': 14, 'l_wrist': 15}
JointPairs = [['head', 'neck'], ['neck', 'thorax'],
['thorax', 'r_shoulder'], ['thorax', 'l_shoulder'], \
['r_shoulder', 'r_elbow'], ['r_elbow', 'r_wrist'],
['l_shoulder', 'l_elbow'], ['l_elbow', 'l_wrist'], \
['pelvis', 'r_hip'], ['pelvis', 'l_hip'], ['r_hip', 'r_knee'],
['r_knee', 'r_ankle'], \
['l_hip', 'l_knee'], ['l_knee', 'l_ankle'],
['thorax', 'pelvis']]
StickType = ['r-', 'r-', 'g-', 'b-', 'g-', 'g-', 'b-', 'b-', 'c-', 'm-',
'c-', 'c-', 'm-', 'm-', 'r-']
imgs = open('../test/imglist.txt','r').readlines()
images_path = '../data/image/'
f = h5py.File('preds.h5','r')
f_keys = f.keys()
#imgs = f'imgs'
preds = f'preds'
f.close()
assert len(imgs) == len(preds)
for i in range(len(imgs)):
filename = images_path + imgsi
img = scm.imread(filename)
pose = preds[i]
# img = imgs[i].transpose(1,2,0)
plt.axis('off')
plt.imshow(img)
# for i in range(16):
# if posei > 0 and posei > 0:
# plt.scatter(posei, posei, marker='o', color='r', s=15)
# plt.show()
for i in range(len(JointPairs)):
idx1 = JointsIndexJointPairs[i]
idx2 = JointsIndexJointPairs[i]
if poseidx1 > 0 and poseidx1 > 0 and \
poseidx2 > 0 and poseidx2 > 0:
joints_x = pose[idx1, poseidx2]
joints_y = pose[idx1, poseidx2]
plt.plot(joints_x, joints_y, StickType[i], linewidth=3)
plt.show()
print 'Done.'
<h3>3. Results</h3>
- <p>理想的结果 </p>
- <p>不理想的结果(可能因为scales不足造成) </p>