keras實(shí)現(xiàn)權(quán)重共享網(wǎng)絡(luò)

介紹

CNN中有權(quán)重共享的概念,是一張圖片中的多個(gè)區(qū)域共享同一個(gè)卷積核的參數(shù)乎澄。與前者不同突硝,這里討論的是權(quán)重共享網(wǎng)絡(luò),即有多個(gè)不同的輸入置济,經(jīng)過具有相同參數(shù)的一個(gè)網(wǎng)絡(luò)解恰,得到兩個(gè)不同的結(jié)果。
有一類網(wǎng)絡(luò)叫孿生網(wǎng)絡(luò)(如MatchNet)就具有權(quán)重共享的特性浙于。這種網(wǎng)絡(luò)可以用來衡量?jī)蓚€(gè)樣本的相似性护盈,用于判定兩個(gè)樣本是不是同一個(gè)類。在人臉識(shí)別羞酗、簽名識(shí)別腐宋、語(yǔ)言識(shí)別任務(wù)中有一些應(yīng)用。
keras實(shí)現(xiàn)孿生網(wǎng)絡(luò)的方式很簡(jiǎn)單檀轨,只要將需要共享的地方組成一個(gè)Model胸竞。在后續(xù)網(wǎng)絡(luò)調(diào)用Model即可。最后整個(gè)網(wǎng)絡(luò)再組成一個(gè)Model裤园。相當(dāng)于整體Model包含共享的Model撤师。。拧揽。Model套Model

keras官方文檔示例

下面ClassFilerNet1是共享權(quán)重的網(wǎng)絡(luò)剃盾,最后summary的參數(shù)量是4萬(wàn)多腺占;ClassFilerNet2沒有共享權(quán)重,summary的參數(shù)量是9萬(wàn)多痒谴。
共享的部分是conv2d-conv2d-maxpooling2d-flatten
第一個(gè)網(wǎng)絡(luò)將共享的部分組成了vision_model衰伯,后續(xù)像Layer一樣函數(shù)式調(diào)用即可。
第二個(gè)網(wǎng)絡(luò)重新創(chuàng)建了不同的Layer积蔚,沒有組成一個(gè)Model意鲸,所以參數(shù)會(huì)增加。

import keras
from keras.layers import Conv2D, MaxPooling2D, Input, Dense, Flatten
from keras.models import Model
def ClassiFilerNet1():
    # First, define the vision modules
    digit_input = Input(shape=(27, 27, 1))
    x = Conv2D(64, (3, 3))(digit_input)
    x = Conv2D(64, (3, 3))(x)
    x = MaxPooling2D((2, 2))(x)
    out = Flatten()(x)

    vision_model = Model(digit_input, out)

    # Then define the tell-digits-apart model
    digit_a = Input(shape=(27, 27, 1))
    digit_b = Input(shape=(27, 27, 1))

    # The vision model will be shared, weights and all
    out_a = vision_model(digit_a)
    out_b = vision_model(digit_b)

    concatenated = keras.layers.concatenate([out_a, out_b])
    out = Dense(1, activation='sigmoid')(concatenated)

    classification_model = Model([digit_a, digit_b], out)
    return classification_model
def ClassiFilerNet2():
    digit_a = Input(shape=(27, 27, 1))
    x = Conv2D(64, (3, 3))(digit_a)
    x = Conv2D(64, (3, 3))(x)
    x = MaxPooling2D((2, 2))(x)
    out1 = Flatten()(x)

    digit_b = Input(shape=(27, 27, 1))
    x = Conv2D(64, (3, 3))(digit_b)
    x = Conv2D(64, (3, 3))(x)
    x = MaxPooling2D((2, 2))(x)
    out2 = Flatten()(x)

    concatenated = keras.layers.concatenate([out1, out2])
    out = Dense(1, activation='sigmoid')(concatenated)

    classification_model = Model([digit_a, digit_b], out)
    return classification_model

MatchNet 出自這里

代碼是直接粘貼過來的尽爆,實(shí)際上核心用法看keras官方文檔最簡(jiǎn)單也最直觀怎顾。粘貼過來只是做個(gè)記錄。
MatchNet包含兩塊:FeatureExtract漱贱,Classification槐雾。
FeatureExtract部分,對(duì)于多個(gè)輸入應(yīng)該采用相同的權(quán)重進(jìn)行處理幅狮,需要用到共享權(quán)重募强。
Classification對(duì)提取的特征進(jìn)行分類,判別是不是同一類崇摄。
下面第一段代碼是非共享權(quán)重擎值,參數(shù)量為4.8M;第二段是共享權(quán)重逐抑,參數(shù)量為3.4M鸠儿。
可以放到tensorboard查看網(wǎng)絡(luò)結(jié)構(gòu),有一些差異厕氨。

from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model

# ---------------------函數(shù)功能區(qū)-------------------------
def FeatureNetwork():
    """生成特征提取網(wǎng)絡(luò)"""
    """這是根據(jù)捆交,MNIST數(shù)據(jù)調(diào)整的網(wǎng)絡(luò)結(jié)構(gòu),下面注釋掉的部分是腐巢,原始的Matchnet網(wǎng)絡(luò)中feature network結(jié)構(gòu)"""
    inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
    models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
    models = Activation('relu')(models)
    models = MaxPool2D(pool_size=(3, 3))(models)

    models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)
    models = Flatten()(models)
    models = Dense(512)(models)
    models = Activation('relu')(models)
    model = Model(inputs=inp, outputs=models)
    return model

def ClassiFilerNet():  # add classifier Net
    """生成度量網(wǎng)絡(luò)和決策網(wǎng)絡(luò),其實(shí)maychnet是兩個(gè)網(wǎng)絡(luò)結(jié)構(gòu)玄括,一個(gè)是特征提取層(孿生)冯丙,一個(gè)度量層+匹配層(統(tǒng)稱為決策層)"""
    input1 = FeatureNetwork()                     # 孿生網(wǎng)絡(luò)中的一個(gè)特征提取
    input2 = FeatureNetwork()                     # 孿生網(wǎng)絡(luò)中的另一個(gè)特征提取
    for layer in input2.layers:                   # 這個(gè)for循環(huán)一定要加,否則網(wǎng)絡(luò)重名會(huì)出錯(cuò)遭京。
        layer.name = layer.name + str("_2")
    inp1 = input1.input
    inp2 = input2.input
    merge_layers = concatenate([input1.output, input2.output])        # 進(jìn)行融合胃惜,使用的是默認(rèn)的sum,即簡(jiǎn)單的相加
    fc1 = Dense(1024, activation='relu')(merge_layers)
    fc2 = Dense(1024, activation='relu')(fc1)
    fc3 = Dense(2, activation='softmax')(fc2)

    class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
    print("1111")
    return class_models
from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model

# ----------------函數(shù)功能區(qū)-----------------------
def FeatureNetwork():
    """生成特征提取網(wǎng)絡(luò)"""
    """這是根據(jù)哪雕,MNIST數(shù)據(jù)調(diào)整的網(wǎng)絡(luò)結(jié)構(gòu)船殉,下面注釋掉的部分是,原始的Matchnet網(wǎng)絡(luò)中feature network結(jié)構(gòu)"""
    inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
    models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
    models = Activation('relu')(models)
    models = MaxPool2D(pool_size=(3, 3))(models)

    models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    # models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
    # models = Activation('relu')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Flatten()(models)
    models = Dense(512)(models)
    models = Activation('relu')(models)
    model = Model(inputs=inp, outputs=models)
    return model

def ClassiFilerNet(reuse=True):  # add classifier Net
    """生成度量網(wǎng)絡(luò)和決策網(wǎng)絡(luò)斯嚎,其實(shí)maychnet是兩個(gè)網(wǎng)絡(luò)結(jié)構(gòu)利虫,一個(gè)是特征提取層(孿生)挨厚,一個(gè)度量層+匹配層(統(tǒng)稱為決策層)"""

    if reuse:
        inp = Input(shape=(28, 28, 1), name='FeatureNet_ImageInput')
        models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
        models = Activation('relu')(models)
        models = MaxPool2D(pool_size=(3, 3))(models)

        models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
        # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
        models = Activation('relu')(models)

        models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
        models = Activation('relu')(models)

        models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
        models = Activation('relu')(models)

        # models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
        # models = Activation('relu')(models)
        # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
        models = Flatten()(models)
        models = Dense(512)(models)
        models = Activation('relu')(models)
        model = Model(inputs=inp, outputs=models)

        inp1 = Input(shape=(28, 28, 1))  # 創(chuàng)建輸入
        inp2 = Input(shape=(28, 28, 1))  # 創(chuàng)建輸入2
        model_1 = model(inp1)  # 孿生網(wǎng)絡(luò)中的一個(gè)特征提取分支
        model_2 = model(inp2)  # 孿生網(wǎng)絡(luò)中的另一個(gè)特征提取分支
        merge_layers = concatenate([model_1, model_2])  # 進(jìn)行融合,使用的是默認(rèn)的sum糠惫,即簡(jiǎn)單的相加

    else:
        input1 = FeatureNetwork()                     # 孿生網(wǎng)絡(luò)中的一個(gè)特征提取
        input2 = FeatureNetwork()                     # 孿生網(wǎng)絡(luò)中的另一個(gè)特征提取
        for layer in input2.layers:                   # 這個(gè)for循環(huán)一定要加疫剃,否則網(wǎng)絡(luò)重名會(huì)出錯(cuò)。
            layer.name = layer.name + str("_2")
        inp1 = input1.input
        inp2 = input2.input
        merge_layers = concatenate([input1.output, input2.output])        # 進(jìn)行融合硼讽,使用的是默認(rèn)的sum巢价,即簡(jiǎn)單的相加
    fc1 = Dense(1024, activation='relu')(merge_layers)
    fc2 = Dense(1024, activation='relu')(fc1)
    fc3 = Dense(2, activation='softmax')(fc2)

    class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
    print("22222")
    return class_models
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市固阁,隨后出現(xiàn)的幾起案子壤躲,更是在濱河造成了極大的恐慌,老刑警劉巖备燃,帶你破解...
    沈念sama閱讀 212,816評(píng)論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件碉克,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡赚爵,警方通過查閱死者的電腦和手機(jī)棉胀,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,729評(píng)論 3 385
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來冀膝,“玉大人唁奢,你說我怎么就攤上這事∥哑剩” “怎么了麻掸?”我有些...
    開封第一講書人閱讀 158,300評(píng)論 0 348
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)赐纱。 經(jīng)常有香客問我脊奋,道長(zhǎng),這世上最難降的妖魔是什么疙描? 我笑而不...
    開封第一講書人閱讀 56,780評(píng)論 1 285
  • 正文 為了忘掉前任诚隙,我火速辦了婚禮,結(jié)果婚禮上起胰,老公的妹妹穿的比我還像新娘久又。我一直安慰自己,他們只是感情好效五,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,890評(píng)論 6 385
  • 文/花漫 我一把揭開白布地消。 她就那樣靜靜地躺著,像睡著了一般畏妖。 火紅的嫁衣襯著肌膚如雪脉执。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 50,084評(píng)論 1 291
  • 那天戒劫,我揣著相機(jī)與錄音半夷,去河邊找鬼婆廊。 笑死,一個(gè)胖子當(dāng)著我的面吹牛玻熙,可吹牛的內(nèi)容都是我干的否彩。 我是一名探鬼主播,決...
    沈念sama閱讀 39,151評(píng)論 3 410
  • 文/蒼蘭香墨 我猛地睜開眼嗦随,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼列荔!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起枚尼,我...
    開封第一講書人閱讀 37,912評(píng)論 0 268
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤贴浙,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后署恍,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體崎溃,經(jīng)...
    沈念sama閱讀 44,355評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,666評(píng)論 2 327
  • 正文 我和宋清朗相戀三年盯质,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了袁串。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 38,809評(píng)論 1 341
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡呼巷,死狀恐怖囱修,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情王悍,我是刑警寧澤破镰,帶...
    沈念sama閱讀 34,504評(píng)論 4 334
  • 正文 年R本政府宣布,位于F島的核電站压储,受9級(jí)特大地震影響鲜漩,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜集惋,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 40,150評(píng)論 3 317
  • 文/蒙蒙 一孕似、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧刮刑,春花似錦鳞青、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,882評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)厚脉。三九已至习寸,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間傻工,已是汗流浹背霞溪。 一陣腳步聲響...
    開封第一講書人閱讀 32,121評(píng)論 1 267
  • 我被黑心中介騙來泰國(guó)打工孵滞, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人鸯匹。 一個(gè)月前我還...
    沈念sama閱讀 46,628評(píng)論 2 362
  • 正文 我出身青樓坊饶,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親殴蓬。 傳聞我的和親對(duì)象是個(gè)殘疾皇子匿级,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,724評(píng)論 2 351

推薦閱讀更多精彩內(nèi)容