1, 模型結(jié)構(gòu)保存:
-
1器瘪, 包含自定義層使用了Lambda層翠储, model.to_json, Pickle會報錯
解決方式:- 用Layer的層替換
- 提供一個model.to_pickle() 方法-快速解決方案(dill)
可以參考:
https://github.com/keras-team/keras/issues/2582
2,保存模型結(jié)構(gòu)時 跑出 ('Not JSON Serializable:', Dimension(None))
解決方式:
https://blog.csdn.net/Funkdub/article/details/100069905
https://github.com/keras-team/keras/issues/93423橡疼,自定義層的初始化參數(shù)援所,要保存在模型結(jié)構(gòu)中, 需要定義:
def get_config(self):
config = super(DilatedGatedConv1D, self).get_config()
config.update(
{
'o_dim': self.o_dim,
'k_size': self.k_size,
'rate': self.rate,
'skip_connect': self.skip_connect,
'drop_gate': self.drop_gate
}
)
return config
如果初始化參數(shù)也是一個Layer網(wǎng)絡(luò)層欣除, Layer對象本身不能序列化住拭, 這就要求重新實現(xiàn)get_config()和from_config()兩個方法,實現(xiàn)包含層Layer的序列化反序列化历帚, 參考Bidirectional滔岳,Wrapper的實現(xiàn)
def get_config(self):
"""
參數(shù)的序列化操作
:return:
"""
config = super(OurBidirectional, self).get_config()
config.update(
{
'layer': { # 參照Wrapper 不能直接保留類對象
'class_name': self.layer.__class__.__name__,
'config': self.layer.get_config()
}
}
)
return config
@classmethod
def from_config(cls, config, custom_objects=None):
"""
自定義從字典config恢復(fù)實例參數(shù)
:param config:
:param custom_objects:
:return:
"""
layer = deserialize_layer(config.pop('layer'),
custom_objects=custom_objects)
return cls(layer, **config)
- 4, Model結(jié)構(gòu)載入時, 需要用到的自定義層或者第三方類對象傳給custom_objects, 否則會提示找不到類對象
def get_custom_objects(self):
"""
自定義的層或者函數(shù)
:return:
"""
custom_objects = self.embedding.get_custom_objects()
custom_objects['OurMasking'] = OurMasking
custom_objects['CRF'] = CRF
return custom_objects
keras.models.model_from_json(
model_json_str,
custom_objects=model.get_custom_objects()
)
- 5挽牢, 模型結(jié)構(gòu)存儲時谱煤,需要包括:
- model.to_json() :Dict/Str 模型結(jié)構(gòu)參數(shù)
- config:Dict 模型參數(shù)
- class_name: Str 定義的模型類對象 self.__class__name
- module: Str self.module
模型載入時, 可以順序的先動態(tài)import module, 然后反射類對象, 接著帶著config參數(shù)來載入模型, 最后更新每層參數(shù):
import importlib model_module = importlib.import_module((model_info['module'])) modle_class = getattr(model_module, model_info['class_name']) model.model = keras.models.model_from_json( json.dumps(model_info['model']), custom_objects=get_custom_objects() ) model.model.load_weights(os.path.join(model_path, 'best_model_weight.h5')) # 非必須禽拔, 只是如果model類本身有Model繼承類的參數(shù)輸入時刘离,需要更新Model參數(shù)的訓(xùn)練權(quán)重 for l in model.model.layers: print(l.name) return model