前言
在深度學(xué)習(xí)中教馆,模型的保存和加載很重要速兔,當(dāng)我們辛辛苦苦訓(xùn)練好的一個網(wǎng)絡(luò)模型,自然需要將訓(xùn)練好的模型保存為文件活玲。在測試使用時候涣狗,又需要將保存在磁盤的模型文件加載調(diào)用。
在pytorch中網(wǎng)絡(luò)模型定義為torch.nn.Module
的子類的對象舒憾。因此模型的保存與加載涉及到2個重要概念——對象的序列化和反序列化镀钓。
目的
- 理解并掌握對象的序列化,反序列化
- 實現(xiàn)pytorch模型的保存與加載
開發(fā)/測試環(huán)境
- Ubuntu 18.04
- pycharm
- Anaconda3
- pytorch
- IntellJIDEA ,JDK10
對象的序列化與反序列化
序列化和反序列化聽起來感覺高大上镀迂,其實是很常見的操作丁溅,下面舉一個JAVA對象序列化和反序列化的例子,幫助理解探遵。
序列化: 把對象轉(zhuǎn)換為字節(jié)序列的過程稱為對象的序列化窟赏。
序列化的目的:
在很多應(yīng)用中妓柜,需要對某些對象進(jìn)行序列化,讓它們離開內(nèi)存空間涯穷,入住物理硬盤棍掐,以便長期保存。比如最常見的是Web服務(wù)器中的Session對象拷况,當(dāng)有 10萬用戶并發(fā)訪問作煌,就有可能出現(xiàn)10萬個Session對象,內(nèi)存可能吃不消赚瘦,于是Web容器就會把一些seesion先序列化到硬盤中粟誓,等要用了,再把保存在硬盤中的對象還原到內(nèi)存中起意。
反序列化: 把字節(jié)序列恢復(fù)為對象的過程稱為對象的反序列化鹰服。
當(dāng)兩個進(jìn)程在進(jìn)行遠(yuǎn)程通信時,彼此可以發(fā)送各種類型的數(shù)據(jù)揽咕。無論是何種類型的數(shù)據(jù)获诈,都會以二進(jìn)制序列的形式在網(wǎng)絡(luò)上傳送。發(fā)送方需要把這個Java對象轉(zhuǎn)換為字節(jié)序列心褐,才能在網(wǎng)絡(luò)上傳送舔涎;接收方則需要把字節(jié)序列再恢復(fù)為Java對象《旱 當(dāng)兩個進(jìn)程在進(jìn)行遠(yuǎn)程通信時亡嫌,彼此可以發(fā)送各種類型的數(shù)據(jù)。無論是何種類型的數(shù)據(jù)掘而,都會以二進(jìn)制序列的形式在網(wǎng)絡(luò)上傳送挟冠。發(fā)送方需要把這個Java對象轉(zhuǎn)換為字節(jié)序列,才能在網(wǎng)絡(luò)上傳送袍睡;接收方則需要把字節(jié)序列再恢復(fù)為Java對象知染。
首先,定義一個Person類斑胜,實現(xiàn)Serializable
接口
package com.sty;
import java.io.Serializable;
/*
Java對象的序列化
實現(xiàn)Serializable接口
*/
public class Person implements Serializable {
private static final long serialVersionUID = -5809782578272943999L;
private int age;
private String name;
private String sex;
public int getAge() {
return age;
}
public String getName() {
return name;
}
public String getSex() {
return sex;
}
public void setAge(int age) {
this.age = age;
}
public void setSex(String sex) {
this.sex = sex;
}
public void setName(String name) {
this.name = name;
}
}
- 序列化
- 反序列化
package com.sty;
import java.io.*;
//http://www.cnblogs.com/xdp-gacl/p/3777987.html
public class Main {
public static void main(String[] args) throws IOException, ClassNotFoundException {
serializePerson();
Person person = deserializePerson();
System.out.println(person);
}
/*
對象的序列化
*/
private static void serializePerson() throws IOException {
Person person = new Person();
person.setAge(25);
person.setName("LiMing");
person.setSex("male");
/*
ObjectOutputStream 對象輸出流
*/
ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(new File("/home/weipenghui/Person.txt")));
objectOutputStream.writeObject(person);
System.out.println("對象序列化成功");
objectOutputStream.close();
}
/*
對象的反序列化
*/
private static Person deserializePerson() throws IOException, ClassNotFoundException {
ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream("/home/weipenghui/Person.txt"));
Person person = (Person) objectInputStream.readObject();
System.out.println("Person對象序列化成功");
return person;
}
}
通過實現(xiàn)Serializable
接口控淡, 調(diào)用ObjectOutputStream
實現(xiàn)了對象的序列化。Java對象序列化的結(jié)果:
使用python序列化止潘、反序列化對象
python中提供了pickle
包進(jìn)行對象的序列化和反序列化掺炭。
簡單例子,首先定義一個簡單的類Student
, 分別進(jìn)行序列化和反序列化凭戴。
-
pickle.dump()
對象序列化 -
pickle.load()
對象反序列化
import pickle
class Student:
def __init__(self):
self.name = 'aa'
self.age = 10
self.gender = 'male'
def set_name(self, name):
self.name = name
def set_age(self, age):
self.age = age
def set_gender(self, gender):
self.gender = gender
def __str__(self):
return 'Student: name:{}, age:{}, gender:{}'.format(self.name, self.age, self.gender)
stu1 = Student()
stu1.set_age(22)
stu1.set_name('xiaotiantian')
stu1.set_gender('female')
print(stu1)
# 使用pickle序列化對象
# pickle.dump()
pickle_file = open('./data/student1.pkl', 'wb')
pickle.dump(stu1, pickle_file)
pickle_file.close()
# pickle反序列化對象
# pickle.load()
file_stu1 = open('./data/student1.pkl', 'rb')
stu11 = pickle.load(file_stu1)
print(stu11)
直接用文本打開序列化的文件涧狮,發(fā)現(xiàn)是亂碼的,沒事,代碼解析又不是人去解析者冤。
反序列化的結(jié)果肤视,從文件恢復(fù)出一個對象。
pytroch模型的保存與加載
有了上面序列化涉枫, 反序列化的基礎(chǔ)邢滑,很容易理解模型的保存就是序列化過程, 模型加載則是反序列化過程拜银。
When it comes to saving and loading models, there are three core functions to be familiar with:
- torch.save: Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.
- torch.load: Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see Saving & Loading Model Across Devices).
- torch.nn.Module.load_state_dict: Loads a model’s parameter dictionary using a deserialized state_dict. For more information on state_dict, see What is a state_dict?.
模型保存與加載
pytorch中分為2種方法:
- 保存整個模型(包括網(wǎng)絡(luò)結(jié)構(gòu))
- 只保存網(wǎng)絡(luò)的訓(xùn)練參數(shù)
state_dict
與之對應(yīng)殊鞭,模型加載也是2中方法遭垛。
保存,加載整個模型
保存
torch.save(model, PATH)
加載
Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
只保存網(wǎng)絡(luò)的訓(xùn)練參數(shù)
save
torch.save(model.state_dict(), PATH)
laod
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
End
參考:
https://pytorch.org/tutorials/beginner/saving_loading_models.html