from keras import backend as K
from keras.engine.topology import Layer
import numpy as np
class MyLayer(Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyLayer, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[1], self.output_dim),
initializer='uniform',
trainable=True)
super(MyLayer, self).build(input_shape) # Be sure to call this somewhere!
def call(self, x):
return K.dot(x, self.kernel)
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
-
build(input_shape)
:這是定義權(quán)重的方法,可訓(xùn)練的權(quán)應(yīng)該在這里被加入列表self.trainable_weights
中涛浙。其他的屬性還包括self.non_trainabe_weights
(列表)和self.updates
(需要更新的形如(tensor, new_tensor)的tuple的列表)康辑。你可以參考BatchNormalization
層的實(shí)現(xiàn)來學(xué)習(xí)如何使用上面兩個(gè)屬性。這個(gè)方法必須設(shè)置self.built = True
蝗拿,可通過調(diào)用super([layer],self).build()
實(shí)現(xiàn) -
call(x)
:這是定義層功能的方法晾捏,除非你希望你寫的層支持masking,否則你只需要關(guān)心call
的第一個(gè)參數(shù):輸入張量 -
compute_output_shape(input_shape)
:如果你的層修改了輸入數(shù)據(jù)的shape哀托,你應(yīng)該在這里指定shape變化的方法惦辛,這個(gè)函數(shù)使得Keras可以做自動(dòng)shape推斷
參考keras文檔