學(xué)號:20021110074? ? ?電院? ? 姓名:梁雪玲
轉(zhuǎn)載自:https://blog.csdn.net/heivy/article/details/100512264?utm_medium=distribute.pc_relevant.none-task-blog-title-2&spm=1001.2101.3001.4242
【嵌牛導(dǎo)讀】:人工智能這么火熱烛芬,看了那么多理論后該如何實踐飘千?本文帶你走進(jìn)實踐的殿堂
【嵌牛鼻子】:鳶尾花多分類? ?TensorFlow? ?模型構(gòu)建
【嵌牛提問】:針對鳶尾花多分類的神經(jīng)網(wǎng)絡(luò)如何構(gòu)建模型?模型的詳解?如何擼代碼?
【嵌牛正文】:
人工智能領(lǐng)域分化為兩個陣營:其一是規(guī)則式(rule-based)方法,在人工智能早期占主峰;其二是神經(jīng)網(wǎng)絡(luò)(neural network)方法,后起之秀隐砸。隨著硬件水平的提高,算力的指數(shù)式增長蝙眶;人工智能的重心已經(jīng)從規(guī)則式的專家時代轉(zhuǎn)移到神經(jīng)網(wǎng)絡(luò)的數(shù)據(jù)時代季希。
“神經(jīng)網(wǎng)絡(luò)”不選擇把人腦熟稔的邏輯規(guī)則傳授給計算機(jī),而是直接在機(jī)器上重建人腦(類似人腦神經(jīng)元網(wǎng)絡(luò))幽纷。即模仿人腦結(jié)構(gòu)式塌,構(gòu)建類似生物神經(jīng)元網(wǎng)絡(luò)結(jié)構(gòu)來進(jìn)行收發(fā)信息。不同于規(guī)則式方法友浸,人工神經(jīng)元網(wǎng)絡(luò)的建造者一般不會給網(wǎng)絡(luò)設(shè)定決策規(guī)則(即給定網(wǎng)絡(luò)系數(shù))峰尝,而只是把某一現(xiàn)象(圖片、人聲收恢、文本等)的大量實際例子輸入人工神經(jīng)元網(wǎng)絡(luò)武学,并給定這一現(xiàn)象的結(jié)果(此圖片有貓祭往,此人聲是某某文本,此文本屬于正向積極情感等等)讓網(wǎng)絡(luò)從這些數(shù)據(jù)中學(xué)習(xí)(有監(jiān)督火窒、半監(jiān)督)硼补、識別規(guī)律已骇。換言之褪储,神經(jīng)網(wǎng)絡(luò)的原則是來自人的干預(yù)越少越好乱豆。 神經(jīng)網(wǎng)絡(luò)通過把數(shù)百萬張標(biāo)示了“有貓”或“沒有貓”的樣本圖片“喂”給計算機(jī)系統(tǒng)瑟啃,讓它自行從這數(shù)百萬張圖片中去辨察哪些特征和“貓”的標(biāo)簽最密切相關(guān)蛹屿。
根據(jù)神經(jīng)網(wǎng)絡(luò)的特點错负,大批量(百萬級)數(shù)據(jù)”喂“入網(wǎng)絡(luò)犹撒,讓網(wǎng)絡(luò)進(jìn)行特征即規(guī)則的學(xué)習(xí)和提取诚镰,可知數(shù)據(jù)是人工智能時代的核心之一清笨,人工智能的另一個核心則是神經(jīng)網(wǎng)絡(luò)的模型構(gòu)建抠艾。下面我們就鳶尾花多分類問題來詳解簡單的神經(jīng)網(wǎng)絡(luò)模型構(gòu)建及相應(yīng)代碼實現(xiàn)腌歉。
鳶尾花多分類問題是tensorflow 官方文檔里面的一個tensorflow入門教程;選取的是比較典型特點的三種鳶尾花:山鳶尾Iris setosa(0)泥彤、變色鳶尾Iris versicolor (1)、維吉尼亞鳶尾Iris virginica (2) 如圖一所示:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 圖一:鳶尾花(從左到右依次山鳶尾剑逃、維吉尼亞鳶尾、變色鳶尾)
從圖一可以看出三種鳶尾花區(qū)別很明顯萤捆,主要體現(xiàn)在花瓣和花萼上俗或;tensorFlow提供的數(shù)據(jù)集中,每個樣本包含四個特征和一個標(biāo)簽昆雀。這四個特征確定了單株鳶尾花的植物學(xué)特征鳶尾花花瓣(petals)的長度和寬度狞膘、花萼(sepals)的長度和寬度,單位CM辅愿;而標(biāo)簽則確定了此鳶尾花所屬品種:山鳶尾 (0)点待、變色鳶尾 (1)状原、維吉尼亞鳶尾 (2)颠区。數(shù)據(jù)格式如圖二所示,所有數(shù)據(jù)直接用逗號隔開(csv數(shù)據(jù)常用格式)朋截。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 圖二:鳶尾花分類問題訓(xùn)練數(shù)據(jù)集格式
數(shù)據(jù)集包括訓(xùn)練數(shù)據(jù)及測試數(shù)據(jù)稚字,數(shù)據(jù)格式統(tǒng)一為csv格式如圖三所示,,下載地址:
訓(xùn)練數(shù)據(jù)集(iris_training.csv):http://download.tensorflow.org/data/iris_training.csv
測試數(shù)據(jù)集(iris_test.csv):http://download.tensorflow.org/data/iris_test.csv
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?圖三:鳶尾花訓(xùn)練及測試數(shù)據(jù)集
數(shù)據(jù)拿到后,我們首先要進(jìn)行數(shù)據(jù)的清洗(這里數(shù)據(jù)是干凈的,此步可省略)醋闭,數(shù)據(jù)特征和數(shù)據(jù)標(biāo)簽的分割证逻,即把鳶尾花的花瓣長度丈咐、寬度,花萼長度辆影、寬度提取另存,標(biāo)簽提取另存,以便于訓(xùn)練喂入網(wǎng)絡(luò)经备。進(jìn)行csv數(shù)據(jù)的處理,代碼如下,處理結(jié)果如圖五所示:
def parse_csv(line):
? ? # 設(shè)置特征和標(biāo)簽的數(shù)據(jù)接收格式
? ? featlab_types = [[0.], [0.], [0.], [0.], [0]]
? ? # 解析csv數(shù)據(jù)犁功,以featlab_types的格式接收
? ? parsed_line = tf.io.decode_csv(line, featlab_types)
? ? # 提取出特征數(shù)據(jù)案糙,并轉(zhuǎn)化成張量
? ? features = tf.reshape(parsed_line[:-1], shape=(4,))
? ? # 提取出標(biāo)簽數(shù)據(jù)怒医,并轉(zhuǎn)化成張量
? ? label = tf.reshape(parsed_line[-1], shape=())
? ? return features, label
def getFeaturesLables(dataPath):
? ? # 使用TextLineDataset 讀取文件內(nèi)容
? ? FeatLabs = tf.data.TextLineDataset(trainPath)
? ? # 跳過第一行,因為第一行是所有數(shù)據(jù)內(nèi)容的總結(jié),不能用于訓(xùn)練或測試
? ? FeatLabs = FeatLabs.skip(1)
? ? # 把每一行數(shù)據(jù)按照parse_csv的格式報錯
? ? FeatLabs = FeatLabs.map(parse_csv)
? ? # 打亂數(shù)據(jù)原來的存放位置凡桥,
? ? FeatLabs = FeatLabs.shuffle(buffer_size=1000)
? ? # 以float32的格式保存數(shù)據(jù)
? ? FeatLabs = FeatLabs.batch(32)
? ? return FeatLabs
讀取測試數(shù)據(jù)(原數(shù)據(jù)的第一行被丟掉蠢络,因為它不屬于正式測試數(shù)據(jù))
處理后的數(shù)據(jù)張量如圖所示:
標(biāo)簽列表、標(biāo)簽和ID對照字典處理函數(shù):
def readCategory():
? ? """
? ? Args:
? ? ? ? None
? ? Returns:
? ? ? ? categories: a list of labels
? ? ? ? cat_to_id: a dict of label to id
? ? """
? ? categories = ['山鳶尾setosa', '變色鳶尾versicolor', '維吉尼亞鳶尾virginica']
? ? cat_to_id = dict(zip(categories, range(len(categories))))
? ? return categories, cat_to_id
二、網(wǎng)絡(luò)搭建
神經(jīng)網(wǎng)絡(luò)可以建有很多層纵潦,每層用什么網(wǎng)絡(luò)根據(jù)需求邀层;其中最簡單的網(wǎng)絡(luò)結(jié)構(gòu)是三層結(jié)構(gòu):輸入層只磷,隱藏層以及輸出層阿迈。
在這里炭晒,已知輸入是4個特征數(shù)據(jù)(花瓣(petals)的長度和寬度嗤无、花萼(sepals)的長度和寬度)当犯,輸出是3種類別(‘山鳶尾setosa’, ‘變色鳶尾versicolor’, ‘維吉尼亞鳶尾virginica’)設(shè)置的是隱藏層為3層垢村,節(jié)點分布分別是:10、20嚎卫、10嘉栓。
整體網(wǎng)絡(luò)示意圖如圖七所示:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?圖七:神經(jīng)網(wǎng)絡(luò)示意圖
網(wǎng)絡(luò)搭建代碼如圖八所示:共分為網(wǎng)絡(luò)參數(shù)設(shè)置類和網(wǎng)絡(luò)搭建類:
class NNetConfig():
? ? num_classes = 3? # 多分類的種類
? ? num_epochs = 161? # 訓(xùn)練總批次
? ? print_per_epoch = 20? # 每訓(xùn)練多少批次時打印訓(xùn)練損失函數(shù)值和預(yù)測準(zhǔn)確率值
? ? layersls = [4, 10, 20, 10, 3]? # 【輸入,隱藏各層節(jié)點數(shù)驰凛,輸出】
? ? learning_rate = 0.01? # 網(wǎng)絡(luò)學(xué)習(xí)率
? ? train_filename = './data/iris_training.csv'? # 訓(xùn)練數(shù)據(jù)
? ? test_filename = './data/iris_test.csv'? # 測試數(shù)據(jù)
? ? best_model_savepath = "./dnn/best_validation"? # 最好模型的存放文件夾
三胸懈、網(wǎng)絡(luò)訓(xùn)練及驗證
3.1訓(xùn)練代碼:
def iris_train():
? ? # 調(diào)用NNet網(wǎng)絡(luò),搭建自己的神經(jīng)網(wǎng)絡(luò)
? ? nnet = NNet(config)
? ? model = nnet.NNet()
? ? # 獲取訓(xùn)練數(shù)據(jù)
? ? featsLabs = getFeaturesLables(trainPath)
? ? # 定義網(wǎng)絡(luò)優(yōu)化器:梯度下降恰响,以learning_rate 的速率進(jìn)行網(wǎng)絡(luò)的訓(xùn)練優(yōu)化
? ? optimizer = tf.compat.v1.train.GradientDescentOptimizer(config.learning_rate)
? ? # 防止網(wǎng)絡(luò)過擬合的趣钱,當(dāng)準(zhǔn)確率大幅度下降時 停止訓(xùn)練,這里沒有用到
? ? flag_train = False
? ? # 損失函數(shù)胚宦,使用的是交叉熵
? ? def loss(model, x, y):
? ? ? ? y_ = model(x)
? ? ? ? return tf.compat.v1.losses.sparse_softmax_cross_entropy(labels=y, logits=y_)
? ? # 當(dāng)前網(wǎng)絡(luò)梯度
? ? def grad(model, inputs, targets):
? ? ? ? with tfe.GradientTape() as tape:
? ? ? ? ? ? loss_value = loss(model, inputs, targets)
? ? ? ? return tape.gradient(loss_value, model.variables)
? ? best_epoch_accuracy = 0
? ? last_improved = 0
? ? improved_str = ''
? ? for epoch in range(config.num_epochs):
? ? ? ? epoch_loss_avg = tfe.metrics.Mean()
? ? ? ? epoch_accuracy = tfe.metrics.Accuracy()
? ? ? ? # 輪回訓(xùn)練網(wǎng)絡(luò)
? ? ? ? for x, y in tfe.Iterator(featsLabs):
? ? ? ? ? ? # 優(yōu)化網(wǎng)絡(luò)
? ? ? ? ? ? grads = grad(model, x, y)
? ? ? ? ? ? optimizer.apply_gradients(zip(grads, model.variables),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? global_step=tf.compat.v1.train.get_or_create_global_step())
? ? ? ? ? ? # 當(dāng)前批次訓(xùn)練的損失函數(shù)均值
? ? ? ? ? ? epoch_loss_avg(loss(model, x, y))? #
? ? ? ? ? ? # 預(yù)測的標(biāo)簽值和實際標(biāo)簽值進(jìn)行對比首有,得到當(dāng)前的預(yù)測準(zhǔn)確率
? ? ? ? ? ? epoch_accuracy(tf.argmax(model(x), axis=1, output_type=tf.int32), y)
? ? ? ? ? ? # 本批次訓(xùn)練結(jié)束,保存本批次的損失函數(shù)結(jié)果和準(zhǔn)確率結(jié)果
? ? ? ? ? ? train_loss_results.append(epoch_loss_avg.result())
? ? ? ? ? ? train_accuracy_results.append(epoch_accuracy.result())
? ? ? ? ? ? # 每隔print_per_epoch次 輸出損失函數(shù)值枢劝、準(zhǔn)確率值等信息井联,方便監(jiān)控網(wǎng)絡(luò)的訓(xùn)練
? ? ? ? ? ? if epoch % config.print_per_epoch == 0:
? ? ? ? ? ? ? ? if not (epoch_accuracy.result()) > best_epoch_accuracy:
? ? ? ? ? ? ? ? ? ? # flag_train = True? #防止網(wǎng)絡(luò)被過度訓(xùn)練
? ? ? ? ? ? ? ? ? ? # break
? ? ? ? ? ? ? ? ? ? improved_str = ''
? ? ? ? ? ? ? ? ? ? pass
? ? ? ? ? ? ? ? else:
? ? ? ? ? ? ? ? ? ? best_epoch_accuracy = epoch_accuracy.result()
? ? ? ? ? ? ? ? ? ? print("當(dāng)前最高準(zhǔn)確率:%.3f:" % best_epoch_accuracy)
? ? ? ? ? ? ? ? ? ? # 最好模型的保存
? ? ? ? ? ? ? ? ? ? model.save(os.path.join(config.best_model_savepath,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? "model_best.h5"))
? ? ? ? ? ? ? ? ? ? last_improved = epoch
? ? ? ? ? ? ? ? ? ? improved_str = '*'
? ? ? ? ? ? ? ? print("Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}, isBest:{}".format(epoch,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? epoch_loss_avg.result(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? epoch_accuracy.result(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? improved_str))
? ? ? ? if flag_train:
? ? ? ? ? ? break
訓(xùn)練過程結(jié)果如圖九所示:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 圖九:訓(xùn)練過程數(shù)據(jù)輸出
從輸出第一行可以看到有個警告您旁,大意是:當(dāng)前電腦是支持AVX2 計算的烙常,但當(dāng)前使用的Tensor Flow是不支持AVX2 計算。其他沒有太大影響鹤盒,如果你感覺很扎眼蚕脏,那喔下次告訴你怎么把它藏起來哈~~
訓(xùn)練過程中的損失函數(shù)和準(zhǔn)確率變化如圖十所示,從圖中可以看出侦锯,損失函數(shù)的值在穩(wěn)定下降驼鞭,沒有太大的震蕩,從準(zhǔn)確率變化尺碰,可以看出挣棕,其實在訓(xùn)練到500Eoph時就可以終止訓(xùn)練了译隘,訓(xùn)練結(jié)果已經(jīng)達(dá)到最優(yōu)了。后面如果再加大訓(xùn)練洛心,可以會引起網(wǎng)絡(luò)的過度訓(xùn)練固耘,出現(xiàn)過擬合現(xiàn)象。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?圖十:訓(xùn)練過程中的損失函數(shù)值及準(zhǔn)確率值隨訓(xùn)練批次的變化
3.2測試代碼:
def iris_test():
? ? # 加載已經(jīng)訓(xùn)練好的最優(yōu)模型(包括網(wǎng)絡(luò)結(jié)構(gòu)及網(wǎng)絡(luò)權(quán)值矩陣)
? ? model = tf.keras.models.load_model(
? ? ? ? os.path.join(config.best_model_savepath, "model_best.h5"),
? ? ? ? compile=False)
? ? # 數(shù)據(jù)批量處理皂甘,把測試數(shù)據(jù)進(jìn)行清洗玻驻、結(jié)構(gòu)化、張量化
? ? testFeatsLabs = getFeaturesLables(testPath)
? ? # 計算測試數(shù)據(jù)在此模型下的準(zhǔn)確率
? ? test_accuracy = tfe.metrics.Accuracy()
? ? for x, y in tfe.Iterator(testFeatsLabs):
? ? ? ? # 模型的預(yù)測結(jié)果
? ? ? ? prediction = tf.argmax(model(x), axis=1, output_type=tf.int32)
? ? ? ? test_accuracy(prediction, y)
? ? print("測試數(shù)據(jù)的測試結(jié)果為: {:.3%}".format(test_accuracy.result()))
? ? return test_accuracy.result()
測試數(shù)據(jù)的測試結(jié)果如下圖所示:達(dá)到了97.5%偿枕,比單層網(wǎng)絡(luò)的91%優(yōu)化了很多
3.3預(yù)測代碼:
def iris_prediction(features=[]):
? ? # 預(yù)測函數(shù)
? ? # 加載已經(jīng)訓(xùn)練好的最優(yōu)模型(包括網(wǎng)絡(luò)結(jié)構(gòu)及網(wǎng)絡(luò)權(quán)值矩陣)
? ? # compile=False 表示對此模型璧瞬,我只用不再次訓(xùn)練
? ? model = tf.keras.models.load_model(
? ? ? ? os.path.join(config.best_model_savepath, "model_best.h5"),
? ? ? ? compile=False)
? ? # 當(dāng)預(yù)測特征為空時,使用下面給出的默認(rèn)值進(jìn)行預(yù)測
? ? if (len(features)==0):
? ? ? ? predFeats = tf.convert_to_tensor([
? ? ? ? [5.9, 3.0, 4.2, 1.5],
? ? ? ? [5.1, 3.3, 1.7, 0.5],
? ? ? ? [6.9, 3.1, 5.4, 2.1]
? ? ? ? ])
? ? else:
? ? ? ? # 預(yù)測特征數(shù)據(jù)轉(zhuǎn)換成張量
? ? ? ? predFeats = tf.convert_to_tensor(features)
? ? # 預(yù)測結(jié)果存放列表
? ? cat_probs = []
? ? # 使用已訓(xùn)練模型預(yù)測的預(yù)測結(jié)果渐夸,是張量
? ? y_probs = model(predFeats)
? ? # 取出每條預(yù)測結(jié)果進(jìn)行處理嗤锉,取出其中最大值,即最可能的結(jié)果墓塌,
? ? # 根據(jù)最大值所在下標(biāo)瘟忱,取到cat可讀文本
? ? for prob in y_probs:
? ? ? ? top1 = tf.argmax(prob).numpy()
? ? ? ? cat_probs.append(cat[top1])
? ? return cat_probs
最后,是給了3個鳶尾花的數(shù)據(jù)苫幢,進(jìn)行鳶尾花類別的預(yù)測访诱,
最后的最后,我猜測大家最關(guān)心的是喔的訓(xùn)練韩肝、測試触菜、預(yù)測代碼在哪里吖~
傳送門:鏈接:https://pan.baidu.com/s/1m4TPja9JzkbLplsISC8fxg
提取碼:ngtn