Javascript類型推斷(3) - 算法模型解析

Javascript類型推斷(3) - 算法模型解析

構建訓練模型

上一節(jié)我們介紹了生成訓練集舔哪,測試集,驗證集的方法槽棍,以及生成詞表的方法捉蚤。
這5個文件構成了訓練的基本素材:

files = {
    'train': { 'file': 'data/train.ctf', 'location': 0 },
    'valid': { 'file': 'data/valid.ctf', 'location': 0 },
    'test': { 'file': 'data/test.ctf', 'location': 0 },
    'source': { 'file': 'data/source_wl', 'location': 1 },
    'target': { 'file': 'data/target_wl', 'location': 1 }
}

詞表我們需要轉(zhuǎn)換一下格式抬驴,放到哈希表里:

# load dictionaries
source_wl = [line.rstrip('\n') for line in open(files['source']['file'])]
target_wl = [line.rstrip('\n') for line in open(files['target']['file'])]
source_dict = {source_wl[i]:i for i in range(len(source_wl))}
target_dict = {target_wl[i]:i for i in range(len(target_wl))}

下面是一些全局參數(shù):

# number of words in vocab, slot labels, and intent labels
vocab_size = len(source_dict)
num_labels = len(target_dict)
epoch_size = 17.955*1000*1000
minibatch_size = 5000
emb_dim = 300
hidden_dim = 650
num_epochs = 10

下面我們定義x,y,t三個值,分別與輸入詞表缆巧、輸出標簽數(shù)和隱藏層有關

# Create the containers for input feature (x) and the label (y)
x = C.sequence.input_variable(vocab_size, name="x")
y = C.sequence.input_variable(num_labels, name="y")
t = C.sequence.input_variable(hidden_dim, name="t")

好布持,我們開始看下訓練的流程:

model = create_model()
enc, dec = model(x, t)
trainer = create_trainer()
train()

訓練模型

首先是一個詞嵌入層:

def create_model():
    embed = C.layers.Embedding(emb_dim, name='embed')

然后是兩個雙向的循環(huán)神經(jīng)網(wǎng)絡(使用GRU),一個全連接網(wǎng)絡陕悬,和一個dropout:

    encoder = BiRecurrence(C.layers.GRU(hidden_dim//2), C.layers.GRU(hidden_dim//2))
    recoder = BiRecurrence(C.layers.GRU(hidden_dim//2), C.layers.GRU(hidden_()dim//2))
    project = C.layers.Dense(num_labels, name='classify')
    do = C.layers.Dropout(0.5)

然后把上面的四項組合起來:

    def recode(x, t):
        inp = embed(x)
        inp = C.layers.LayerNormalization()(inp)
        
        enc = encoder(inp)
        rec = recoder(enc + t)
        proj = project(do(rec))
        
        dec = C.ops.softmax(proj)
        return enc, dec
    return recode

其中雙向循環(huán)神經(jīng)網(wǎng)絡定義如下:

def BiRecurrence(fwd, bwd):
    F = C.layers.Recurrence(fwd)
    G = C.layers.Recurrence(bwd, go_backwards=True)
    x = C.placeholder()
    apply_x = C.splice(F(x), G(x))
    return apply_x

構建訓練過程

首先定義下?lián)p失函數(shù)题暖,由兩部分組成,一部分是loss捉超,另一部分是分類錯誤:

def criterion(model, labels):
    ce   = -C.reduce_sum(labels*C.ops.log(model))
    errs = C.classification_error(model, labels)
    return ce, errs

有了損失函數(shù)之后胧卤,我們使用帶動量的Adam算法進行梯度下降訓練:

def create_trainer():
    masked_dec = dec*C.ops.clip(C.ops.argmax(y), 0, 1)
    loss, label_error = criterion(masked_dec, y)
    loss *= C.ops.clip(C.ops.argmax(y), 0, 1)

    lr_schedule = C.learning_parameter_schedule_per_sample([1e-3]*2 + [5e-4]*2 + [1e-4], epoch_size=int(epoch_size))
    momentum_as_time_constant = C.momentum_as_time_constant_schedule(1000)
    learner = C.adam(parameters=dec.parameters,
                         lr=lr_schedule,
                         momentum=momentum_as_time_constant,
                         gradient_clipping_threshold_per_sample=15, 
                         gradient_clipping_with_truncation=True)

    progress_printer = C.logging.ProgressPrinter(tag='Training', num_epochs=num_epochs)
    trainer = C.Trainer(dec, (loss, label_error), learner, progress_printer)
    C.logging.log_number_of_parameters(dec)
    return trainer

訓練

定義好模型之后,我們就可以訓練了拼岳。
首先我們可以利用CNTK.io包的功能定義一個數(shù)據(jù)的讀取器:

def create_reader(path, is_training):
    return C.io.MinibatchSource(C.io.CTFDeserializer(path, C.io.StreamDefs(
            source      = C.io.StreamDef(field='S0', shape=vocab_size, is_sparse=True), 
            slot_labels = C.io.StreamDef(field='S1', shape=num_labels, is_sparse=True)
    )), randomize=is_training, max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1)

然后我們就可以利用這個讀取器讀取數(shù)據(jù)開始訓練了:

def train():
    train_reader = create_reader(files['train']['file'], is_training=True)
    step = 0
    pp = C.logging.ProgressPrinter(freq=10, tag='Training')
    for epoch in range(num_epochs):
        epoch_end = (epoch+1) * epoch_size
        while step < epoch_end:
            data = train_reader.next_minibatch(minibatch_size, input_map={
                x: train_reader.streams.source,
                y: train_reader.streams.slot_labels
            })
            # Enhance data
            enhance_data(data, enc)
            # Train model
            trainer.train_minibatch(data)
            pp.update_with_trainer(trainer, with_metric=True)
            step += data[y].num_samples
        pp.epoch_summary(with_metric=True)
        trainer.save_checkpoint("models/model-" + str(epoch + 1) + ".cntk")
        validate()
        evaluate()

上面的代碼中枝誊,enhance_data需要解釋一下。
我們的數(shù)據(jù)并非是完全線性的數(shù)據(jù)惜纸,還需要進行一個數(shù)據(jù)增強的處理過程:

def enhance_data(data, enc):
    guesses = enc.eval({x: data[x]})
    inputs = C.ops.argmax(x).eval({x: data[x]})
    tables = []
    for i in range(len(inputs)):
        ts = []
        table = {}
        counts = {}
        for j in range(len(inputs[i])):
            inp = int(inputs[i][j])
            if inp not in table:
                table[inp] = guesses[i][j]
                counts[inp] = 1
            else:
                table[inp] += guesses[i][j]
                counts[inp] += 1
        for inp in table:
            table[inp] /= counts[inp]
        for j in range(len(inputs[i])):
            inp = int(inputs[i][j])
            ts.append(table[inp])
        tables.append(np.array(np.float32(ts)))
    s = C.io.MinibatchSourceFromData(dict(t=(tables, C.layers.typing.Sequence[C.layers.typing.tensor])))
    mems = s.next_minibatch(minibatch_size)
    data[t] = mems[s.streams['t']]

測試和驗證

測試和驗證的過程中叶撒,也需要我們上面介紹的數(shù)據(jù)增強的過程:

def validate():
    valid_reader = create_reader(files['valid']['file'], is_training=False)
    while True:
        data = valid_reader.next_minibatch(minibatch_size, input_map={
                x: valid_reader.streams.source,
                y: valid_reader.streams.slot_labels
        })
        if not data:
            break
        enhance_data(data, enc)
        trainer.test_minibatch(data)
    trainer.summarize_test_progress()

evaluate與validate邏輯完全一樣,只是讀取的文件不同:

def evaluate():
    test_reader = create_reader(files['test']['file'], is_training=False)
    while True:
        data = test_reader.next_minibatch(minibatch_size, input_map={
            x: test_reader.streams.source,
            y: test_reader.streams.slot_labels
        })
        if not data:
            break
        # Enhance data
        enhance_data(data, enc)
        # Test model
        trainer.test_minibatch(data)
    trainer.summarize_test_progress()
?著作權歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末耐版,一起剝皮案震驚了整個濱河市痊乾,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌椭更,老刑警劉巖,帶你破解...
    沈念sama閱讀 211,265評論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件蛾魄,死亡現(xiàn)場離奇詭異虑瀑,居然都是意外死亡,警方通過查閱死者的電腦和手機滴须,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,078評論 2 385
  • 文/潘曉璐 我一進店門舌狗,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人扔水,你說我怎么就攤上這事痛侍。” “怎么了魔市?”我有些...
    開封第一講書人閱讀 156,852評論 0 347
  • 文/不壞的土叔 我叫張陵主届,是天一觀的道長。 經(jīng)常有香客問我待德,道長君丁,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 56,408評論 1 283
  • 正文 為了忘掉前任将宪,我火速辦了婚禮绘闷,結果婚禮上橡庞,老公的妹妹穿的比我還像新娘。我一直安慰自己印蔗,他們只是感情好扒最,可當我...
    茶點故事閱讀 65,445評論 5 384
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著华嘹,像睡著了一般吧趣。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上除呵,一...
    開封第一講書人閱讀 49,772評論 1 290
  • 那天再菊,我揣著相機與錄音,去河邊找鬼颜曾。 笑死纠拔,一個胖子當著我的面吹牛,可吹牛的內(nèi)容都是我干的泛豪。 我是一名探鬼主播稠诲,決...
    沈念sama閱讀 38,921評論 3 406
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼诡曙!你這毒婦竟也來了臀叙?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 37,688評論 0 266
  • 序言:老撾萬榮一對情侶失蹤价卤,失蹤者是張志新(化名)和其女友劉穎劝萤,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體慎璧,經(jīng)...
    沈念sama閱讀 44,130評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡床嫌,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,467評論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了胸私。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片厌处。...
    茶點故事閱讀 38,617評論 1 340
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖岁疼,靈堂內(nèi)的尸體忽然破棺而出阔涉,到底是詐尸還是另有隱情,我是刑警寧澤捷绒,帶...
    沈念sama閱讀 34,276評論 4 329
  • 正文 年R本政府宣布瑰排,位于F島的核電站,受9級特大地震影響疙驾,放射性物質(zhì)發(fā)生泄漏凶伙。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 39,882評論 3 312
  • 文/蒙蒙 一它碎、第九天 我趴在偏房一處隱蔽的房頂上張望函荣。 院中可真熱鬧显押,春花似錦、人聲如沸傻挂。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,740評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽金拒。三九已至兽肤,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間绪抛,已是汗流浹背资铡。 一陣腳步聲響...
    開封第一講書人閱讀 31,967評論 1 265
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留幢码,地道東北人笤休。 一個月前我還...
    沈念sama閱讀 46,315評論 2 360
  • 正文 我出身青樓,卻偏偏與公主長得像症副,于是被迫代替她去往敵國和親店雅。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 43,486評論 2 348

推薦閱讀更多精彩內(nèi)容