前言
本文记录出行在深度学习中的各类模型存储格式,不定期持续更新。
存储格式
PyTorch .pth/.pt/.pkl/.pth.tar
二进制文件,用于保存模型参数,后缀不同在保存上并没有区别。
模型保存
1 | model_saver = {'model':model.state_dict()} # 也可以添加其他项目,如epoch,optimizer, scheduler |
模型读取
1 | model_saver = torch.load('storage_name.pth') |
也可以直接保存model.state_dict()而不是字典。这样在读取时,也不需要用键在字典中查询。
也可以直接保存整个模型,这样不用重新初始化模型,但会占用更多存储,增加读取时间。
模型保存
1 | torch.save(model,'storage_name.pt') |
模型读取
1 | model = torch.load('storage_name.pt') |
ONNX
ONNX(Open Neural Network Exchange)开放式神经网络交换格式,用于统一多种训练框架导出的模型,如PyTorch,TensorFlow,Scikit-learn。