6. 模塊,層和模型

為了使用TensorFlow進(jìn)行機(jī)器學(xué)習(xí),你需要學(xué)會(huì)如何定義巾兆,保存和重載一個(gè)模型。
一個(gè)模型抽象上說有以下內(nèi)容:

  • 用于在張量上進(jìn)行計(jì)算的函數(shù)
  • 在訓(xùn)練過程中被不斷更新的變量
    本章內(nèi)容中梢褐,你會(huì)開始了解到Keras的底層挑社,并且明白TensorFlow的模型是如何定義的,即TensorFlow如何組織變量和模型顿痪,使得模型可以被保存和重載镊辕。

模型和層在TensorFlow中的定義

大多數(shù)的模型都是有層疊加而成的。層是一種包含有可訓(xùn)練變量及運(yùn)算的可重用的數(shù)據(jù)結(jié)構(gòu)蚁袭。TensorFlow中的大多數(shù)高級(jí)實(shí)現(xiàn)(如Keras中的層和模型)都是基于tf.Module類征懈。
模塊和層都是深度學(xué)習(xí)中的術(shù)語,用于描述有內(nèi)部狀態(tài)及定義這些內(nèi)部狀態(tài)上的操作的對(duì)象揩悄。你可以自由設(shè)置變量是否可以訓(xùn)練卖哎,這樣方便調(diào)優(yōu)。
使用tf.Module作為父類删性,子類將自動(dòng)組織所有的tf.Variable和tf.Module對(duì)象亏娜。這使得你可以使用tf.Module中的方法來保存和重載變量。

class SimpleModule(tf.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.a_variable = tf.Variable(5.0, name="train_me")
        self.non_trainable_variable = tf.Variable(5.0, trainable=False, name="do_not_train_me")

    def __call__(self, x):
        return self.a_variable * x + self.non_trainable_variable


simple_module = SimpleModule(name="simple")

simple_module(tf.constant(5.0))

print("trainable variables:", simple_module.trainable_variables)
print("all variables:", simple_module.variables)

結(jié)果是:

trainable variables: (<tf.Variable 'train_me:0' shape=() dtype=float32, numpy=5.0>,)
all variables: (<tf.Variable 'train_me:0' shape=() dtype=float32, numpy=5.0>, <tf.Variable 'do_not_train_me:0' shape=() dtype=float32, numpy=5.0>)

下面是一個(gè)雙層神經(jīng)網(wǎng)絡(luò)的代碼:

class Dense(tf.Module):
    def __init__(self, in_features, out_features, name=None):
        super().__init__(name=name)
        self.w = tf.Variable(
            tf.random.normal([in_features, out_features], name="w")
        )
        self.b = tf.Variable(
            tf.zeros([out_features], name="b")
        )

    def __call__(self, x):
        y = tf.matmul(x, self.w) + self.b
        return tf.nn.relu(y)


class SequentialModule(tf.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.dense_1 = Dense(in_features=3, out_features=3)
        self.dense_2 = Dense(in_features=3, out_features=2)

    def __call__(self, x):
        x = self.dense_1(x)
        return self.dense_2(x)


my_model = SequentialModule(name="the_model")

print("Model results:", my_model(tf.constant([[2., 2., 2.], [3., 3., 3.]])))

tf.Module的實(shí)現(xiàn)類會(huì)自動(dòng)遞歸地收集鄒游的tf.Variable和tf.Module對(duì)象蹬挺。

延后創(chuàng)建變量

上述代碼中维贺,在模型初始化的時(shí)候便指定了輸入和輸出的維度,即W和b是已知維度的變量巴帮。這在大多數(shù)情況下造成了模型的局限性幸缕,很多情況下群发,在模型創(chuàng)建之前并不知道具體的維度。若能在變量第一次輸入時(shí)推斷出具體的輸入維度发乔,你就不需要人為指定輸入維度熟妓。這樣的代碼更加具有靈活性。
修改之后的代碼如下:

class FlexibleDenseModule(tf.Module):
    def __init__(self, out_features, name=None):
        super().__init__(name=name)
        self.is_built = False
        self.out_features = out_features

    def __call__(self, x):
        if not self.is_built:
            self.w = tf.Variable(tf.random.normal([x.shape[-1], self.out_features]), name="w")
            self.b = tf.Variable(tf.zeros([self.out_features]), name="b")
            self.is_built = True
        y = tf.matmul(x, self.w) + self.b
        return tf.nn.relu(y)


class FlexibleSequentialModule(tf.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.dense_1 = FlexibleDenseModule(out_features=3)
        self.dense_2 = FlexibleDenseModule(out_features=2)

    def __call__(self, x):
        x = self.dense_1(x)
        return self.dense_2(x)

變量保存

你可以將tf.Module保存為檢查點(diǎn)(checkpoint)或存儲(chǔ)模型(savedModel)栏尚。檢查點(diǎn)只保存模型及子模型的變量起愈。檢查點(diǎn)保存時(shí)會(huì)生成兩類文件:數(shù)據(jù)文件和元數(shù)據(jù)索引文件。索引文件內(nèi)保存了檢查點(diǎn)的編號(hào)译仗,并記錄了什么數(shù)據(jù)被保存了起來抬虽。數(shù)據(jù)文件記錄了變量值和查找路徑纵菌。你可以通過查看checkpoint的內(nèi)容來確認(rèn)所有變量已經(jīng)保存成功。

chkp_path = "my_checkpoint"
checkpoint = tf.train.Checkpoint(model=my_model)
checkpoint.write(chkp_path)

tf.train.list_variables(chkp_path)

模型保存

TensorFlow可以在沒有Python源碼的情況下運(yùn)行模型笛辟,這使得你可以直接從TensorFlowHub上下載已經(jīng)訓(xùn)練好的模型來使用手幢。TensorFlow需要在沒有源碼的情況下知道Python代碼中描述的計(jì)算流程围来,我們可以使用圖來達(dá)到這個(gè)目標(biāo)监透。圖記錄了所有構(gòu)成目標(biāo)函數(shù)的計(jì)算過程航唆。關(guān)于使用tf.function將python函數(shù)轉(zhuǎn)換為圖佛点,不再贅敘超营。

Keras的模型和層

你可以基于tf.Module來創(chuàng)建屬于你的高級(jí)API阅虫,而Keras也是這么做的颓帝。

Keras的層

tf.keras.lays.Layer是所有keras的層的基類,而這個(gè)基類繼承自tf.Module類虐译。
若你需要將基于tf.Module的類轉(zhuǎn)換為基于keras層的類漆诽,則你只需要將父類更換一下并將call更換為call即可(keras類的call方法有自己的用處)厢拭。

class MyDense(tf.keras.layers.Layer):
    def __init__(self, in_features, out_features, **kwargs):
        super().__init__(**kwargs)

        self.w = tf.Variable(
            tf.random.normal([in_features, out_features]), name="w"
        )
        self.b = tf.Variable(
            tf.zeros([out_features]), name="b"
        )

    def call(self, x):
        y = tf.matmul(x, self.w) + self.b
        return tf.nn.relu(y)


simple_layer = MyDense(name="simple", in_features=3, out_features=3)

print(simple_layer([[2., 2., 2.]]))

結(jié)果為:

tf.Tensor([[0.01792851 2.781715   0.        ]], shape=(1, 3), dtype=float32)

build階段

前面提到過供鸠,若是在確定了輸入維度之后再創(chuàng)建內(nèi)部變量的話楞捂,是十分靈活的泡一。
keras的層新增了一個(gè)新的生命周期(稱為build)的步驟鼻忠,讓你更加靈活的定義和使用層杈绸。build只會(huì)被調(diào)用一次瞳脓,在這個(gè)階段確定輸入的維度劫侧,多用于創(chuàng)建內(nèi)部變量烧栋。將上面的MyDense增加build階段审姓,則代碼如下:

class FlexibleDense(tf.keras.layers.Layer):
    def __init__(self, out_features, **kwargs):
        super().__init__(**kwargs)
        self.out_features = out_features

    def build(self, input_shape):
        self.w = tf.Variable(tf.random.normal([input_shape[-1], self.out_features]), name="w")
        self.b = tf.Variable(tf.zeros([self.out_features]), name="b")

    def call(self, x):
        y = tf.matmul(x, self.w) + self.b
        return tf.nn.relu(y)


flexible_dense = FlexibleDense(out_features=3)

print("Model results:", flexible_dense(tf.constant([[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])))

注意魔吐,由于build只會(huì)調(diào)用一次,因此若是輸入與build階段的維度不同時(shí)奥溺,會(huì)報(bào)錯(cuò)谚赎。

Keras模型

Keras的tf.keras.Model類提供了模型的全部特性 壶唤,它繼承自tf.keras.layers.Layer,因此keras模型可以像keras的層一樣被重用和保存闸盔。 除此之外迎吵,keras的模型還提供了額外的功能击费,用于訓(xùn)練蔫巩,求解圆仔,保存和重載坪郭,甚至于提供了分布式訓(xùn)練功能歪沃。

class FlexibleDense(tf.keras.layers.Layer):
    def __init__(self, out_features, **kwargs):
        super().__init__(**kwargs)
        self.out_features = out_features

    def build(self, input_shape):
        self.w = tf.Variable(tf.random.normal([input_shape[-1], self.out_features]), name="w")
        self.b = tf.Variable(tf.zeros([self.out_features]), name="b")

    def call(self, x):
        y = tf.matmul(x, self.w) + self.b
        return tf.nn.relu(y)


class MySequentialModel(tf.keras.Model):
    def __init__(self, name=None, **kwargs):
        super().__init__(**kwargs)

        self.dense_1 = FlexibleDense(out_features=3)
        self.dense_2 = FlexibleDense(out_features=3)

    def call(self, x):
        x = self.dense_1(x)
        return self.dense_2(x)


my_sequentail_model = MySequentialModel(name="the_model")

print("Model resutls:", my_sequentail_model(tf.constant([[2., 2., 2.]])))


print(my_sequentail_model.variables)

運(yùn)行結(jié)果為:

Model resutls: tf.Tensor([[0. 0. 0.]], shape=(1, 3), dtype=float32)
[<tf.Variable 'my_sequential_model/flexible_dense/w:0' shape=(3, 3) dtype=float32, numpy=
array([[-1.0733198 , -2.5860493 ,  0.42328298],
       [-2.0001495 ,  1.5054438 , -0.9208656 ],
       [ 1.6782132 ,  0.72947365, -0.08435281]], dtype=float32)>, <tf.Variable 'my_sequential_model/flexible_dense/b:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Variable 'my_sequential_model/flexible_dense_1/w:0' shape=(3, 3) dtype=float32, numpy=
array([[ 0.37761664, -1.0342877 , -0.8181074 ],
       [ 0.6091555 ,  0.97727245,  0.11385015],
       [ 1.2820089 , -0.39806262, -0.28293946]], dtype=float32)>, <tf.Variable 'my_sequential_model/flexible_dense_1/b:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末豆瘫,一起剝皮案震驚了整個(gè)濱河市外驱,隨后出現(xiàn)的幾起案子昵宇,更是在濱河造成了極大的恐慌瓦哎,老刑警劉巖蒋譬,帶你破解...
    沈念sama閱讀 217,509評(píng)論 6 504
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件癣漆,死亡現(xiàn)場(chǎng)離奇詭異惠爽,居然都是意外死亡瞬哼,警方通過查閱死者的電腦和手機(jī)较性,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,806評(píng)論 3 394
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來人弓,“玉大人崔赌,你說我怎么就攤上這事健芭〈嚷酰” “怎么了痒留?”我有些...
    開封第一講書人閱讀 163,875評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵匾效,是天一觀的道長(zhǎng)面哼。 經(jīng)常有香客問我魔策,道長(zhǎng),這世上最難降的妖魔是什么搁吓? 我笑而不...
    開封第一講書人閱讀 58,441評(píng)論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮摩骨,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘哭懈。我一直安慰自己遣总,他們只是感情好容达,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,488評(píng)論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著算芯,像睡著了一般。 火紅的嫁衣襯著肌膚如雪诈嘿。 梳的紋絲不亂的頭發(fā)上奖亚,一...
    開封第一講書人閱讀 51,365評(píng)論 1 302
  • 那天首繁,我揣著相機(jī)與錄音夹攒,去河邊找鬼咏尝。 笑死编检,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播谤专,決...
    沈念sama閱讀 40,190評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼蜡坊,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼秕衙!你這毒婦竟也來了鹦牛?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,062評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤晶伦,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后近忙,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,500評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡智润,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,706評(píng)論 3 335
  • 正文 我和宋清朗相戀三年及舍,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片窟绷。...
    茶點(diǎn)故事閱讀 39,834評(píng)論 1 347
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡锯玛,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出兼蜈,到底是詐尸還是另有隱情攘残,我是刑警寧澤,帶...
    沈念sama閱讀 35,559評(píng)論 5 345
  • 正文 年R本政府宣布为狸,位于F島的核電站歼郭,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏辐棒。R本人自食惡果不足惜病曾,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,167評(píng)論 3 328
  • 文/蒙蒙 一牍蜂、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧泰涂,春花似錦鲫竞、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,779評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至是牢,卻和暖如春僵井,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背妖泄。 一陣腳步聲響...
    開封第一講書人閱讀 32,912評(píng)論 1 269
  • 我被黑心中介騙來泰國(guó)打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留艘策,地道東北人蹈胡。 一個(gè)月前我還...
    沈念sama閱讀 47,958評(píng)論 2 370
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像朋蔫,于是被迫代替她去往敵國(guó)和親罚渐。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,779評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容