Pytroch 涉及到 Variable
,Tensor
和 Numpy
间的转换比较多,还会涉及到 cuda
到 cpu
的转换.
1. Variable 转 Numpy
import torch
from torch.autograd import Variable
var = Variable(torch.FloatTensor(2,3))
# var = tensor(1.00000e-03 *
# [[ 1.1476, 0.0000, 0.0000],
# [ 0.0000, 0.0000, 0.0000]])
var_numpy = var.data.numpy()
# array([[1.1476139e-03, 4.5816855e-41, 3.7984453e-37],
# [0.0000000e+00, 4.4841551e-44, 0.0000000e+00]], dtype=float32)
1.2 Numpy 转 Variable
import torch
from torch.autograd import Variable
import numpy as np
var_numpy = np.random.randn(2, 3)
# array([[-0.27443182, 1.18369008, -0.24645608],
# [-0.99800364, 0.58202014, -0.84904032]])
var = Variable(torch.from_numpy(var_numpy))
# tensor([[-0.2744, 1.1837, -0.2465],
# [-0.9980, 0.5820, -0.8490]], dtype=torch.float64)
1.3 Tensor 转 Numpy
import torch
var_tensor = torch.FloatTensor(2,3)
# tensor(1.00000e-03 *
# [[ 1.1476, 0.0000, 1.1476],
# [ 0.0000, 0.0000, 0.0000]])
var_numpy = var_tensor.numpy()
# array([[1.1476139e-03, 4.5816855e-41, 1.1476139e-03],
# [4.5816855e-41, 4.4841551e-44, 0.0000000e+00]], dtype=float32)
1.4 Numpy 转 Tensor
import torch
import numpy as np
var_numpy = np.ones()
var_tensor = torch.from_numpy(var_numpy)
1.5 .cuda()
Pytorch 可以将内存中的模型和数据复制到 GPU 显存中,进行 GPU 计算.
import torch
torch.cuda.device_count() # 计算可用 GPU 数量
var_tensor = torch.FloatTensor(2,3)
if torch.cuda.is_available(): # 判断 GPU 是否可用
var_tensor = var_tensor.cuda() # .cuda(device_id) 指定 GPU 上
# tensor(1.00000e-03 *
# [[ 1.1476, 0.0000, 1.1476],
# [ 0.0000, 0.0000, 0.0000]], device='cuda:0')
# model = model.cuda()
1.6 .cpu()
Pytorch 中,如果直接从 cuda
中取数据,如 var_tensor.cuda().data.numpy()
,
import torch
var_tensor = torch.FloatTensor(2,3)
if torch.cuda.is_available(): # 判断 GPU 是否可用
var_tensor.cuda().data.numpy()
则会出现如下类似错误:
TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
其应该 var_tensor.cuda().data.cpu().numpy()
.
import torch
var_tensor = torch.FloatTensor(2,3)
if torch.cuda.is_available(): # 判断 GPU 是否可用
print(var_tensor.cuda().data.cpu().numpy())
# array([[1.1476139e-03, 4.5816855e-41, 1.1476139e-03],
# [4.5816855e-41, 4.4841551e-44, 0.0000000e+00]], dtype=float32)