Pytorch 低版本环境(如,1.0) 加载高版本(如,1.10) 训练的模型,报错如下:

xxx is a zip archive (did you mean to use torch.jit.load()?)

原因:Pytorch1.6 版本后模型保存torch.save()参数变更,重新加上参数 _use_new_zipfile_serialization=False 保存,即可:

https://pytorch.org/docs/stable/generated/torch.save.html

import torch

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.load_state_dict(torch.load(save_path),map_location=device)

state_dict = torch.load('pytorch_model.pth')
state_dict = torch.load('pytorch_model.pth', map_location="cpu")
#state_dict = torch.load('pytorch_model.bin', map_location="cpu")

torch.save(state_dict, 'pytorch_model.pth', _use_new_zipfile_serialization=False)
#torch.save(state_dict, 'pytorch_model.bin', _use_new_zipfile_serialization=False)
Last modification:December 28th, 2021 at 02:05 pm