深度学习模型存储格式

前言

本文记录出行在深度学习中的各类模型存储格式,不定期持续更新。

存储格式

PyTorch .pth/.pt/.pkl/.pth.tar

二进制文件,用于保存模型参数,后缀不同在保存上并没有区别。

模型保存

1
2
3
model_saver = {'model':model.state_dict()} # 也可以添加其他项目,如epoch,optimizer, scheduler
# model_saver.update({'epoch':10})
torch.save(model_saver,'storage_name.pth')

模型读取

1
2
3
model_saver = torch.load('storage_name.pth')
model = NNModel(*args, **kwargs) # 初始化模型
model.load_state_dict(model_saver['model'])

也可以直接保存model.state_dict()而不是字典。这样在读取时,也不需要用键在字典中查询。

也可以直接保存整个模型,这样不用重新初始化模型,但会占用更多存储,增加读取时间。

模型保存

1
torch.save(model,'storage_name.pt')

模型读取

1
model = torch.load('storage_name.pt')

ONNX

ONNX(Open Neural Network Exchange)开放式神经网络交换格式,用于统一多种训练框架导出的模型,如PyTorch,TensorFlow,Scikit-learn。