Tensorflow estimator 訓(xùn)練和遷移學(xué)習(xí)(一)

以mnist數(shù)據(jù)集做訓(xùn)練

學(xué)習(xí)tensorflow和它的高級API estimator

由于Hnd手寫字母訓(xùn)練集數(shù)量較少,直接訓(xùn)練誤差可能較大芽偏,因此采用訓(xùn)練+遷移+微調(diào)的方式提升準(zhǔn)確率。這是第一部分,在mnist數(shù)據(jù)集上訓(xùn)練。

編寫model_fn跛锌,在mnist數(shù)據(jù)集上訓(xùn)練

import numpy as np
import tensorflow as tf
import os


def cnn_model_no_top(features, mode, trainable):
    """
    :param features: 輸入
    :param mode: estimator模式
    :param trainable: 該層的變量是否可訓(xùn)練
    :return: 不含最上層全連接層的模型
    """
    input_layer = tf.reshape(features, [-1, 28, 28, 1])
    conv1 = tf.layers.conv2d(inputs=input_layer, filters=32, kernel_size=[5, 5], padding="same", activation=tf.nn.relu, trainable=trainable)
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
    conv2 = tf.layers.conv2d(inputs=pool1, filters=64, kernel_size=[5, 5], padding="same", activation=tf.nn.relu, trainable=trainable)
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
    pool2_flat = tf.reshape(pool2, shape=[-1, 7 * 7 * 64])
    dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu, trainable=trainable)
    dropout = tf.layers.dropout(inputs=dense, rate=0.4, training=(mode == tf.estimator.ModeKeys.TRAIN))
    return dropout

def cnn_model_fn(features, labels, mode, params):
    """
    用于構(gòu)造estimator的model_fn
    :param features: 輸入
    :param labels: 標(biāo)簽
    :param mode: 模式
    :param params: 用于訓(xùn)練,遷移學(xué)習(xí)和微調(diào)的dict類型參數(shù)
        nb_classes 輸入的類別數(shù)
    :return: EstimatorSpec
    """
    logits_name = "predictions"
    labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=params["nb_classes"])
    model_no_top = cnn_model_no_top(features["x"], mode, trainable=True)  # mnist是完整的訓(xùn)練
    logits = tf.layers.dense(inputs=model_no_top, units=params["nb_classes"], name=logits_name)
    predictions = {
        "classes": tf.argmax(input=logits, axis=1),
        "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
    }
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()
        optimizer = tf.train.AdamOptimizer(learning_rate=0.0001)
        train_op = optimizer.minimize(loss, global_step)
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

    eval_metric_ops = {
        'accuracy': tf.metrics.accuracy(labels=tf.argmax(labels, 1),
                                        predictions=predictions['classes'],
                                        name='accuracy')
    }
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        eval_metric_ops=eval_metric_ops
    )

開始訓(xùn)練

首先準(zhǔn)備訓(xùn)練數(shù)據(jù)和驗(yàn)證數(shù)據(jù)

mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images  # Returns np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images  # Returns np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)

構(gòu)造estimator

mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir="./mnist_model", params={
    "nb_classes": 10
})

開始訓(xùn)練

train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": train_data},
    y=train_labels,
    batch_size=100,
    num_epochs=None,
    shuffle=True
)
mnist_classifier.train(input_fn=train_input_fn, steps=2000)

訓(xùn)練結(jié)束后届惋,驗(yàn)證

eval_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": eval_data},
    y=eval_labels,
    num_epochs=1,
    shuffle=False)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print(eval_results)

結(jié)果挺不錯的

{'accuracy': 0.9855, 'loss': 0.043955494, 'global_step': 2000}
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末髓帽,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子脑豹,更是在濱河造成了極大的恐慌郑藏,老刑警劉巖,帶你破解...
    沈念sama閱讀 210,914評論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件瘩欺,死亡現(xiàn)場離奇詭異必盖,居然都是意外死亡拌牲,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 89,935評論 2 383
  • 文/潘曉璐 我一進(jìn)店門歌粥,熙熙樓的掌柜王于貴愁眉苦臉地迎上來塌忽,“玉大人,你說我怎么就攤上這事失驶⊥辆樱” “怎么了?”我有些...
    開封第一講書人閱讀 156,531評論 0 345
  • 文/不壞的土叔 我叫張陵突勇,是天一觀的道長。 經(jīng)常有香客問我坷虑,道長甲馋,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 56,309評論 1 282
  • 正文 為了忘掉前任迄损,我火速辦了婚禮定躏,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘芹敌。我一直安慰自己痊远,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,381評論 5 384
  • 文/花漫 我一把揭開白布氏捞。 她就那樣靜靜地躺著碧聪,像睡著了一般。 火紅的嫁衣襯著肌膚如雪液茎。 梳的紋絲不亂的頭發(fā)上逞姿,一...
    開封第一講書人閱讀 49,730評論 1 289
  • 那天,我揣著相機(jī)與錄音捆等,去河邊找鬼滞造。 笑死,一個胖子當(dāng)著我的面吹牛栋烤,可吹牛的內(nèi)容都是我干的谒养。 我是一名探鬼主播,決...
    沈念sama閱讀 38,882評論 3 404
  • 文/蒼蘭香墨 我猛地睜開眼明郭,長吁一口氣:“原來是場噩夢啊……” “哼买窟!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起薯定,我...
    開封第一講書人閱讀 37,643評論 0 266
  • 序言:老撾萬榮一對情侶失蹤蔑祟,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后沉唠,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體疆虚,經(jīng)...
    沈念sama閱讀 44,095評論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,448評論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了径簿。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片罢屈。...
    茶點(diǎn)故事閱讀 38,566評論 1 339
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖篇亭,靈堂內(nèi)的尸體忽然破棺而出缠捌,到底是詐尸還是另有隱情,我是刑警寧澤译蒂,帶...
    沈念sama閱讀 34,253評論 4 328
  • 正文 年R本政府宣布曼月,位于F島的核電站,受9級特大地震影響柔昼,放射性物質(zhì)發(fā)生泄漏哑芹。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,829評論 3 312
  • 文/蒙蒙 一捕透、第九天 我趴在偏房一處隱蔽的房頂上張望聪姿。 院中可真熱鬧,春花似錦乙嘀、人聲如沸末购。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,715評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽盟榴。三九已至,卻和暖如春婴噩,著一層夾襖步出監(jiān)牢的瞬間曹货,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,945評論 1 264
  • 我被黑心中介騙來泰國打工讳推, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留顶籽,地道東北人。 一個月前我還...
    沈念sama閱讀 46,248評論 2 360
  • 正文 我出身青樓银觅,卻偏偏與公主長得像礼饱,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子究驴,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,440評論 2 348

推薦閱讀更多精彩內(nèi)容