Core ML框架詳細(xì)解析(十四) —— 使用Keras和Core ML開始機(jī)器學(xué)習(xí)(二)

版本記錄

版本號 時(shí)間
V1.0 2018.10.16 星期二

前言

目前世界上科技界的所有大佬一致認(rèn)為人工智能是下一代科技革命捞镰,蘋果作為科技界的巨頭性锭,當(dāng)然也會緊跟新的科技革命的步伐列牺,其中ios API 就新出了一個(gè)框架Core ML贬堵。ML是Machine Learning的縮寫带族,也就是機(jī)器學(xué)習(xí)锁荔,這正是現(xiàn)在很火的一個(gè)技術(shù),它也是人工智能最核心的內(nèi)容蝙砌。感興趣的可以看我寫的下面幾篇阳堕。
1. Core ML框架詳細(xì)解析(一) —— Core ML基本概覽
2. Core ML框架詳細(xì)解析(二) —— 獲取模型并集成到APP中
3. Core ML框架詳細(xì)解析(三) —— 利用Vision和Core ML對圖像進(jìn)行分類
4. Core ML框架詳細(xì)解析(四) —— 將訓(xùn)練模型轉(zhuǎn)化為Core ML
5. Core ML框架詳細(xì)解析(五) —— 一個(gè)Core ML簡單示例(一)
6. Core ML框架詳細(xì)解析(六) —— 一個(gè)Core ML簡單示例(二)
7. Core ML框架詳細(xì)解析(七) —— 減少Core ML應(yīng)用程序的大小(一)
8. Core ML框架詳細(xì)解析(八) —— 在用戶設(shè)備上下載和編譯模型(一)
9. Core ML框架詳細(xì)解析(九) —— 用一系列輸入進(jìn)行預(yù)測(一)
10. Core ML框架詳細(xì)解析(十) —— 集成自定義圖層(一)
11. Core ML框架詳細(xì)解析(十一) —— 創(chuàng)建自定義圖層(一)
12. Core ML框架詳細(xì)解析(十二) —— 用scikit-learn開始機(jī)器學(xué)習(xí)(一)
13. Core ML框架詳細(xì)解析(十三) —— 使用Keras和Core ML開始機(jī)器學(xué)習(xí)(一)

Train the Model - 訓(xùn)練模型

1. Define Callbacks List - 定義回調(diào)列表

callbacksfit函數(shù)的可選參數(shù)择克,因此首先定義callbacks_list恬总。

輸入以下代碼,然后運(yùn)行它肚邢。

callbacks_list = [
    keras.callbacks.ModelCheckpoint(
        filepath='best_model.{epoch:02d}-{val_loss:.2f}.h5',
        monitor='val_loss', save_best_only=True),
    keras.callbacks.EarlyStopping(monitor='acc', patience=1)
]

一個(gè)epoch是完整傳遞數(shù)據(jù)集中的所有小批量壹堰。

ModelCheckpoint回調(diào)監(jiān)視驗(yàn)證丟失值,使用文件編號和文件名中的驗(yàn)證丟失將文件中的最低值保存骡湖。

EarlyStopping回調(diào)監(jiān)控訓(xùn)練準(zhǔn)確性:如果連續(xù)兩個(gè)epochs未能改善贱纠,則訓(xùn)練提前停止。在我的實(shí)驗(yàn)中响蕴,這種情況從未發(fā)生過:如果acc在一個(gè)epoch內(nèi)逐漸消失谆焊,它總會在下一個(gè)時(shí)代恢復(fù)。

2. Compile & Fit Model - 編譯和擬合模型

除非您可以訪問GPU换途,否則我建議您使用Malireddimodel_m進(jìn)行此步驟懊渡,因?yàn)樗倪\(yùn)行速度比Chollet的model_c快得多:在我的MacBook Pro上刽射,76-106s / epoch與246-309s / epoch相比,或者大約15分鐘vs 剃执。 45分鐘誓禁。

注意:如果在第一個(gè)epoch完成后notebook中沒有出現(xiàn).h5文件,請單擊stop button以中斷內(nèi)核肾档,單擊save button摹恰,然后注銷。在終端中怒见,按Control-C停止服務(wù)器俗慈,然后重新運(yùn)行docker run命令。將URL或令牌粘貼到瀏覽器或登錄頁面遣耍,導(dǎo)航到notebook闺阱,然后單擊Not Trusted button按鈕。選擇此單元格舵变,然后從菜單中選擇Cell \ Run All Above酣溃。

輸入以下代碼,然后運(yùn)行它纪隙。這將花費(fèi)很長時(shí)間赊豌,所以在等待時(shí)閱讀Explanations部分。但是幾分鐘后檢查Finder绵咱,以確保notebook正在保存.h5文件碘饼。

注意:此單元格顯示多行函數(shù)調(diào)用的兩種縮進(jìn)類型,具體取決于您編寫第一個(gè)參數(shù)的位置悲伶。如果它甚至被一個(gè)空格輸出艾恼,那么這是一個(gè)語法錯(cuò)誤。

model_m.compile(loss='categorical_crossentropy',
                optimizer='adam', metrics=['accuracy'])

# Hyper-parameters
batch_size = 200
epochs = 10

# Enable validation to use ModelCheckpoint and EarlyStopping callbacks.
model_m.fit(
    x_train, y_train, batch_size=batch_size, epochs=epochs,
    callbacks=callbacks_list, validation_data=(x_val, y_val), verbose=1)

Convolutional Neural Network: Explanations - 卷積神經(jīng)網(wǎng)絡(luò):解釋

您可以使用幾乎任何ML方法來創(chuàng)建MNIST分類器拢切,但本教程使用卷積神經(jīng)網(wǎng)絡(luò)(CNN)蒂萎,因?yàn)檫@是TensorFlowKeras的關(guān)鍵優(yōu)勢。

卷積神經(jīng)網(wǎng)絡(luò)假設(shè)輸入是圖像淮椰,并在三個(gè)維度上排列神經(jīng)元:寬度五慈,高度,深度主穗。 CNN由卷積層組成泻拦,每個(gè)卷層檢測訓(xùn)練圖像的更高級特征:第一層可以訓(xùn)練濾波器以檢測各種角度的短線或弧線;第二層訓(xùn)練濾波器以檢測這些線的重要組合忽媒;最后一層的過濾器構(gòu)建在前面的圖層上以對圖像進(jìn)行分類争拐。

每個(gè)卷積層在輸入上傳遞一個(gè)小方塊的kernel權(quán)重 - 1×1,3×35×5 ,計(jì)算內(nèi)核下輸入單元的加權(quán)和晦雨。 這是卷積過程架曹。

每個(gè)神經(jīng)元僅連接到前一層中的1個(gè)隘冲,9個(gè)或25個(gè)神經(jīng)元,因此存在co-adapting的危險(xiǎn) - 過多地依賴于少數(shù)輸入 - 這可能導(dǎo)致過度擬合绑雄。 因此展辞,CNN包括poolingdropout層,以抵消co-adapting和過度擬合万牺。 我在下面解釋這些罗珍。

Sample Model - 樣本模型

這是Malireddi的模型:

model_m = Sequential()
model_m.add(Conv2D(32, (5, 5), input_shape=input_shape, activation='relu'))
model_m.add(MaxPooling2D(pool_size=(2, 2)))
model_m.add(Dropout(0.5))
model_m.add(Conv2D(64, (3, 3), activation='relu'))
model_m.add(MaxPooling2D(pool_size=(2, 2)))
model_m.add(Dropout(0.2))
model_m.add(Conv2D(128, (1, 1), activation='relu'))
model_m.add(MaxPooling2D(pool_size=(2, 2)))
model_m.add(Dropout(0.2))
model_m.add(Flatten())
model_m.add(Dense(128, activation='relu'))
model_m.add(Dense(num_classes, activation='softmax'))

1. Sequential

首先創(chuàng)建一個(gè)空的Sequential模型,然后添加一個(gè)線性的圖層堆棧:這些圖層按照它們添加到模型的順序運(yùn)行脚粟。 Keras文檔有幾個(gè)examples of Sequential models覆旱。

注意:Keras還具有用于定義復(fù)雜模型的函數(shù)API,例如多輸出模型核无,有向非循環(huán)圖或具有共享層的模型扣唱。 Google的InceptionMicrosoft Research AsiaResidual Networks是具有非線性連接結(jié)構(gòu)的復(fù)雜模型的示例。

第一層必須具有關(guān)于輸入形狀的信息厕宗,對于MNIST(28,28,1)画舌。 其他層從前一層的輸出形狀推斷出它們的輸入形狀堕担。 這是模型摘要的輸出形狀部分:

Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_6 (Conv2D)            (None, 24, 24, 32)        832       
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 12, 12, 32)        0         
_________________________________________________________________
dropout_6 (Dropout)          (None, 12, 12, 32)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 10, 10, 64)        18496     
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
dropout_7 (Dropout)          (None, 5, 5, 64)          0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 5, 5, 128)         8320      
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 2, 2, 128)         0         
_________________________________________________________________
dropout_8 (Dropout)          (None, 2, 2, 128)         0         
_________________________________________________________________
flatten_3 (Flatten)          (None, 512)               0         
_________________________________________________________________
dense_5 (Dense)              (None, 128)               65664     
_________________________________________________________________
dense_6 (Dense)              (None, 10)                1290      

2. Conv2D

該模型有三個(gè)Conv2D層:

Conv2D(32, (5, 5), input_shape=input_shape, activation='relu')
Conv2D(64, (3, 3), activation='relu')
Conv2D(128, (1, 1), activation='relu')
  • 第一個(gè)參數(shù) - 32,64,128 - 是您要訓(xùn)練此圖層檢測的過濾器或要素的數(shù)量已慢。 這也是輸出形狀的深度 - 最后一個(gè)維度。
  • 第二個(gè)參數(shù) - (5,5),(3,3),(1,1) - 是內(nèi)核大信骸:一個(gè)元組佑惠,指定在輸入空間上滑動(dòng)的卷積窗口的寬度和高度,計(jì)算加權(quán)和 - dot 內(nèi)核權(quán)重和輸入單位值的乘積齐疙。
  • 第三個(gè)參數(shù)activation ='relu'指定ReLU(Rectified Linear Unit)(整流線性單元)激活功能膜楷。 當(dāng)內(nèi)核以輸入單元為中心時(shí),如果加權(quán)和大于閾值贞奋,則稱該單元激活或觸發(fā):weighted_sum> threshold赌厅。 偏差值為-threshold:如果weighted_sum + bias> 0,則單位觸發(fā)轿塔。訓(xùn)練模型計(jì)算每個(gè)濾波器的內(nèi)核權(quán)重和偏差值特愿。 ReLU是深度神經(jīng)網(wǎng)絡(luò)中最受歡迎的激活函數(shù)。

3. MaxPooling2D

MaxPooling2D(pool_size=(2, 2))

pooling層在前一層上通過m列過濾器滑動(dòng)n行勾缭,將n x m值替換為其最大值揍障。pooling濾器通常是方形的:n = m。 如下所示俩由,最常用的2 x 2 pooling濾器將前一層的寬度和高度減半毒嫡,從而減少了參數(shù)的數(shù)量,從而有助于控制過度擬合幻梯。

Malireddi的模型在每個(gè)卷積層之后都有一個(gè)pooling層兜畸,這大大減少了最終的模型大小和訓(xùn)練時(shí)間努释。

Chollet的模型在pooling之前有兩個(gè)卷積層。這建議用于較大的網(wǎng)絡(luò)咬摇,因?yàn)樗试S卷積層在pooling之前開發(fā)更復(fù)雜的特征洽洁,丟棄75%的值。

Conv2DMaxPooling2D參數(shù)確定每個(gè)圖層的輸出形狀和可訓(xùn)練參數(shù)的數(shù)量:

Output Shape = (input width – kernel width + 1, input height – kernel height + 1, number of filters)

您不能將3×3內(nèi)核置于每行和每列的第一個(gè)和最后一個(gè)單元的中心菲嘴,因此輸出寬度和高度比輸入小2個(gè)像素饿自。 5×5內(nèi)核可將輸出寬度和高度減少4個(gè)像素。

  • Conv2D(32龄坪,(5,5)昭雌,input_shape =(28,28,1)):( 28-4,28-4,32)=(24,24,32)
  • MaxPooling2D將輸入寬度和高度減半:(24 / 2,24 / 2,32)=(12,12,32)
  • Conv2D(64,(3,3)):( 12-2,12-2,64)=(10,10,64)
  • MaxPooling2D將輸入寬度和高度減半:(10 / 2,10 / 2,64)=(5,5,64)
  • Conv2D(128健田,(1,1)):( 5-0,5-0,128)=(5,5,128)

Param # = number of filters x (kernel width x kernel height x input depth + 1 bias)

  • Conv2D(32烛卧,(5,5),input_shape =(28,28,1)):32 x(5x5x1 + 1)= 832
  • Conv2D(64妓局,(3,3)):64 x(3x3x32 + 1)= 18,496
  • Conv2D(128总放,(1,1)):128 x(1x1x64 + 1)= 8320

Challenge:計(jì)算Chollet架構(gòu)model_c的輸出形狀和參數(shù)編號。

Output Shape = (input width – kernel width + 1, input height – kernel height + 1, number of filters)

  • Conv2D(32, (3, 3), input_shape=(28, 28, 1)): (28-2, 28-2, 32) = (26, 26, 32)
  • Conv2D(64, (3, 3)): (26-2, 26-2, 64) = (24, 24, 64)
  • MaxPooling2D halves the input width and height: (24/2, 24/2, 64) = (12, 12, 64)

Param # = number of filters x (kernel width x kernel height x input depth + 1 bias)

  • Conv2D(32, (3, 3), input_shape=(28, 28, 1)): 32 x (3x3x1 + 1) = 320
  • Conv2D(64, (3, 3)): 64 x (3x3x32 + 1) = 18,496

4. Dropout

Dropout(0.5)
Dropout(0.2)

dropout層通常與pooling層配對好爬。 它將輸入單位的一小部分隨機(jī)設(shè)置為0局雄。這是控制過度擬合的另一種方法:神經(jīng)元不太可能受到相鄰神經(jīng)元的過多影響,因?yàn)樗鼈冎械娜魏我粋€(gè)都可能隨機(jī)掉出網(wǎng)絡(luò)存炮。 這使得網(wǎng)絡(luò)對輸入中的微小變化不太敏感炬搭,因此更有可能推廣到新輸入。

Hands-on Machine Learning with Scikit-Learn & TensorFlowAurélienGéron將其與工作場所進(jìn)行比較穆桂,在任何一天宫盔,某些人可能無法上班:每個(gè)人都必須能夠完成關(guān)鍵任務(wù), 并且必須與更多的同事合作享完。 這將使公司更具彈性灼芭,減少對任何單個(gè)工人的依賴。

5. Flatten

在將卷積層傳遞到完全連接的密集層之前般又,必須使卷積層的權(quán)重為1彼绷。

model_m.add(Dropout(0.2))
model_m.add(Flatten())
model_m.add(Dense(128, activation='relu'))

前一層的輸出形狀為(2,2,128),因此Flatten()的輸出是一個(gè)包含512個(gè)元素的數(shù)組倒源。

6. Dense

Dense(128, activation='relu')
Dense(num_classes, activation='softmax')

卷積層中的每個(gè)神經(jīng)元使用前一層中僅少數(shù)神經(jīng)元的值苛预。 完全連接層中的每個(gè)神經(jīng)元使用前一層中所有神經(jīng)元的值。 此類圖層的Keras名稱為Dense笋熬。

看看上面的模型摘要热某,Malireddi的第一個(gè)Dense層有512個(gè)神經(jīng)元,而Chollet有9216個(gè)。兩者都產(chǎn)生128個(gè)神經(jīng)元輸出層昔馋,但Chollet必須計(jì)算的參數(shù)比Malireddi的多18倍筹吐。 這是使用大部分額外訓(xùn)練時(shí)間的原因。

大多數(shù)CNN架構(gòu)以一個(gè)或多個(gè)Dense層結(jié)束秘遏,然后是輸出層丘薛。

第一個(gè)參數(shù)是圖層的輸出大小。 最終輸出層的輸出大小為10邦危,對應(yīng)于10個(gè)數(shù)字類洋侨。

softmax激活函數(shù)在10個(gè)輸出類別上產(chǎn)生概率分布。 它是sigmoid函數(shù)的推廣倦蚪,它將其輸入值縮放到[0,1]范圍內(nèi)希坚。 對于您的MNIST分類器,softmax將10個(gè)值中的每一個(gè)都縮放為[0,1]陵且,這樣它們總計(jì)為1裁僧。

您可以將sigmoid函數(shù)用于單個(gè)輸出類:例如,這是一張好狗照片的概率是多少慕购?

7. Compile

model_m.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

分類交叉熵(categorical crossentropy)損失函數(shù)測量由CNN計(jì)算的概率分布與標(biāo)簽的真實(shí)分布之間的距離聊疲。

優(yōu)化器(optimizer)是隨機(jī)梯度下降算法,它試圖通過以恰當(dāng)?shù)乃俣雀S梯度來最小化損失函數(shù)沪悲。

準(zhǔn)確度(Accuracy) - 正確分類的圖像的分?jǐn)?shù) - 是在訓(xùn)練和測試期間監(jiān)控的最常見度量获洲。

8. Fit

batch_size = 256
epochs = 10
model_m.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, callbacks=callbacks_list,
            validation_data=(x_val, y_val), verbose=1)

批量大小(Batch size)是用于小批量隨機(jī)梯度擬合的數(shù)據(jù)項(xiàng)的數(shù)量。選擇批量大小是一個(gè)試驗(yàn)和錯(cuò)誤的問題可训,一骰子昌妹。較小的值使得epoch需要更長的時(shí)間;較大的值可以更好地利用GPU并行性握截,并減少數(shù)據(jù)傳輸時(shí)間,但過大可能會導(dǎo)致內(nèi)存不足烂叔。

epoch的數(shù)量也是擲骰子谨胞。每個(gè)epoch都應(yīng)該改善損失和準(zhǔn)確度測量。更多epoch應(yīng)該產(chǎn)生更準(zhǔn)確的模型蒜鸡,但訓(xùn)練需要更長時(shí)間胯努。太多的epoch可能導(dǎo)致過度擬合。如果模型在完成所有epoch之前停止改進(jìn)逢防,則設(shè)置回調(diào)以提前停止叶沛。在notebook中,您可以重新運(yùn)行fit的單元格以繼續(xù)改進(jìn)模型忘朝。

加載數(shù)據(jù)時(shí)灰署,將10000個(gè)項(xiàng)目設(shè)置為驗(yàn)證數(shù)據(jù)。通過此參數(shù)可以在訓(xùn)練時(shí)進(jìn)行驗(yàn)證,因此您可以監(jiān)控驗(yàn)證損失和準(zhǔn)確性溉箕。如果這些值比訓(xùn)練損失和準(zhǔn)確度差晦墙,則表明該模型過度擬合。

9. Verbose

0 = silent, 1 = progress bar, 2 = one line per epoch.

Results - 結(jié)果

以下是我的一次訓(xùn)練結(jié)果:

Epoch 1/10
60000/60000 [==============================] - 106s - loss: 0.0284 - acc: 0.9909 - val_loss: 0.0216 - val_acc: 0.9940
Epoch 2/10
60000/60000 [==============================] - 100s - loss: 0.0271 - acc: 0.9911 - val_loss: 0.0199 - val_acc: 0.9942
Epoch 3/10
60000/60000 [==============================] - 102s - loss: 0.0260 - acc: 0.9914 - val_loss: 0.0228 - val_acc: 0.9931
Epoch 4/10
60000/60000 [==============================] - 101s - loss: 0.0257 - acc: 0.9913 - val_loss: 0.0211 - val_acc: 0.9935
Epoch 5/10
60000/60000 [==============================] - 101s - loss: 0.0256 - acc: 0.9916 - val_loss: 0.0222 - val_acc: 0.9928
Epoch 6/10
60000/60000 [==============================] - 100s - loss: 0.0263 - acc: 0.9913 - val_loss: 0.0178 - val_acc: 0.9950
Epoch 7/10
60000/60000 [==============================] - 87s - loss: 0.0231 - acc: 0.9920 - val_loss: 0.0212 - val_acc: 0.9932
Epoch 8/10
60000/60000 [==============================] - 76s - loss: 0.0240 - acc: 0.9922 - val_loss: 0.0212 - val_acc: 0.9935
Epoch 9/10
60000/60000 [==============================] - 76s - loss: 0.0261 - acc: 0.9916 - val_loss: 0.0220 - val_acc: 0.9934
Epoch 10/10
60000/60000 [==============================] - 76s - loss: 0.0231 - acc: 0.9925 - val_loss: 0.0203 - val_acc: 0.9935

在每個(gè)epoch肴茄,損失值應(yīng)該減少晌畅,準(zhǔn)確度值應(yīng)該增加。 ModelCheckpoint回調(diào)保存了epoch1,2和6寡痰,因?yàn)?code>epoch3,4和5中的驗(yàn)證損失值高于epoch2抗楔,并且在epoch6之后驗(yàn)證損失沒有改善。訓(xùn)練不會提前停止拦坠,因?yàn)橛?xùn)練準(zhǔn)確性從未在連續(xù)兩個(gè)epoch內(nèi)減少谓谦。

注意:實(shí)際上,這些結(jié)果來自20或30個(gè)epoch:我在不重置模型的情況下不止一次地運(yùn)行fit單元格贪婉,因此即使在第1epoch中反粥,損失和準(zhǔn)確度值也已經(jīng)非常好。但是您在測量中看到一些波動(dòng)疲迂。例如才顿,在epoch4,6和9中精度降低。

到目前為止尤蒿,您的模型已經(jīng)完成訓(xùn)練郑气,所以回到編碼!


Convert to Core ML Model - 轉(zhuǎn)換為Core ML模型

訓(xùn)練步驟完成后腰池,您應(yīng)該在notebook中保存一些模型尾组。 具有最高epoch數(shù)(和最低驗(yàn)證損失)的那個(gè)是最佳模型,因此在convert函數(shù)中使用該文件名示弓。

輸入以下代碼讳侨,然后運(yùn)行它。

output_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
# For the first argument, use the filename of the newest .h5 file in the notebook folder.
coreml_mnist = coremltools.converters.keras.convert(
    'best_model.09-0.03.h5', input_names=['image'], output_names=['output'], 
    class_labels=output_labels, image_input_names='image')

在這里奏属,您在數(shù)組中設(shè)置10個(gè)輸出標(biāo)簽跨跨,并將其作為class_labels參數(shù)傳遞。 如果訓(xùn)練具有大量輸出類的模型囱皿,請將標(biāo)簽放在文本文件中勇婴,每行一個(gè)標(biāo)簽,并將class_labels參數(shù)設(shè)置為文件名嘱腥。

在參數(shù)列表中耕渴,您提供輸入和輸出名稱,并設(shè)置image_input_names ='image'齿兔,以便Core ML模型接受圖像作為輸入橱脸,而不是多數(shù)組础米。

1. Inspect Core ML model - 檢查Core ML模型

輸入此行,然后運(yùn)行它以查看打印輸出慰技。

print(coreml_mnist)

只需檢查輸入類型是imageType椭盏,而不是多數(shù)組:

input {
  name: "image"
  shortDescription: "Digit image"
  type {
    imageType {
      width: 28
      height: 28
      colorSpace: GRAYSCALE
    }
  }
}

2. Add Metadata for Xcode - 為Xcode添加元數(shù)據(jù)

現(xiàn)在添加以下內(nèi)容,替換前兩個(gè)項(xiàng)目的自己的名稱和許可證信息吻商,然后運(yùn)行它掏颊。

coreml_mnist.author = 'raywenderlich.com'
coreml_mnist.license = 'Razeware'
coreml_mnist.short_description = 'Image based digit recognition (MNIST)'
coreml_mnist.input_description['image'] = 'Digit image'
coreml_mnist.output_description['output'] = 'Probability of each digit'
coreml_mnist.output_description['classLabel'] = 'Labels of digits'

在Xcode的項(xiàng)目導(dǎo)航器中選擇模型時(shí)會出現(xiàn)此信息。

3. Save the Core ML Model - 保存Core ML模型

最后艾帐,添加以下內(nèi)容并運(yùn)行它乌叶。

coreml_mnist.save('MNISTClassifier.mlmodel')

這會將mlmodel文件保存在notebook文件夾中。

恭喜柒爸,您現(xiàn)在擁有一個(gè)Core ML模型准浴,可以對手寫數(shù)字進(jìn)行分類! 是時(shí)候在iOS應(yīng)用程序中使用它了捎稚。


Use Model in iOS App - 在iOS App中使用Model

1. Step 1. Drag the model into the app - 步驟1.將模型拖到應(yīng)用程序中:

在Xcode中打開入門應(yīng)用程序乐横,并將Finders中的MNISTClassifier.mlmodel拖到項(xiàng)目的Project導(dǎo)航器中。 選擇它以查看您添加的元數(shù)據(jù):

如果不是Automatically generated Swift model class今野,而是建立項(xiàng)目來生成模型類葡公,請繼續(xù)執(zhí)行此操作。

2. Step 2. Import the CoreML and Vision frameworks: - 步驟2.導(dǎo)入CoreML和Vision框架:

打開ViewController.swift条霜,導(dǎo)入兩個(gè)框架催什,就在導(dǎo)入UIKit下面:

import CoreML
import Vision

3. Step 3. Create VNCoreMLModel and VNCoreMLRequest objects: - 步驟3.創(chuàng)建VNCoreMLModel和VNCoreMLRequest對象:

outlets下面添加以下代碼:

lazy var classificationRequest: VNCoreMLRequest = {
  // Load the ML model through its generated class and create a Vision request for it.
  do {
    let model = try VNCoreMLModel(for: MNISTClassifier().model)
    return VNCoreMLRequest(model: model, completionHandler: handleClassification)
  } catch {
    fatalError("Can't load Vision ML model: \(error).")
  }
}()

func handleClassification(request: VNRequest, error: Error?) {
  guard let observations = request.results as? [VNClassificationObservation]
    else { fatalError("Unexpected result type from VNCoreMLRequest.") }
  guard let best = observations.first
    else { fatalError("Can't get best result.") }

  DispatchQueue.main.async {
    self.predictLabel.text = best.identifier
    self.predictLabel.isHidden = false
  }
}

請求對象適用于步驟4中的處理程序傳遞給它的任何圖像,因此您只需將其定義一次宰睡,作為一個(gè)lazy var蒲凶。

請求對象的完成處理程序接收requesterror對象。 您檢查request.results是一個(gè)VNClassificationObservation對象的數(shù)組拆内,這是當(dāng)Core ML模型是分類器而不是預(yù)測器或圖像處理器時(shí)Vision框架返回的對象旋圆。

VNClassificationObservation對象有兩個(gè)屬性:identifier - 一個(gè)String - 和confidence - 一個(gè)介于0和1之間的數(shù)字 - 分類正確的概率。 您獲取第一個(gè)結(jié)果矛纹,該結(jié)果具有最高置信度值臂聋,并調(diào)度回主隊(duì)列以更新predictLabel。 分類工作發(fā)生在主隊(duì)列之外或南,因?yàn)樗赡芎苈?/p>

4. Step 4. Create and run a VNImageRequestHandler: - 步驟4.創(chuàng)建并運(yùn)行VNImageRequestHandler:

找到predictTapped(),并使用以下代碼替換print語句:

let ciImage = CIImage(cgImage: inputImage)
let handler = VNImageRequestHandler(ciImage: ciImage)
do {
  try handler.perform([classificationRequest])
} catch {
  print(error)
}

您可以從inputImage創(chuàng)建CIImage艾君,然后為此ciImage創(chuàng)建VNImageRequestHandler對象采够,并在VNCoreMLRequest對象數(shù)組上運(yùn)行處理程序 - 在本例中,只是您在步驟3中創(chuàng)建的一個(gè)請求對象冰垄。

建立并運(yùn)行蹬癌。 在繪圖區(qū)域的中心繪制一個(gè)數(shù)字,然后點(diǎn)擊Predict。 點(diǎn)按Clear再試一次逝薪。

較大的繪制往往效果更好隅要,但模型常常遇到'7'和'4'的問題。 毫不奇怪董济,因?yàn)?code>MNIST數(shù)據(jù)的PCA visualization顯示7s和4s聚集在9s:

注意:Malireddi表示Vision框架使用了20%的CPU步清,因此his app包含一個(gè)將UIImage對象轉(zhuǎn)換為CVPixelBuffer格式的擴(kuò)展。

如果您不使用Vision虏肾,請?jiān)趯?code>Keras模型轉(zhuǎn)換為Core ML時(shí)將image_scale = 1 / 255.0作為參數(shù):Keras模型訓(xùn)練灰度值在[0,1]范圍內(nèi)的圖像廓啊,CVPixelBuffer值為 在[0,255]范圍內(nèi)。

感謝 Sri Raghu M, Matthijs HollemansHon Weng Chong的有益討論封豪!

資源

進(jìn)一步閱讀


源碼

1. Swift

看下工程文檔結(jié)構(gòu)

接著谴轮,看一下sb內(nèi)容

1. ViewController.swift
import UIKit
import CoreML
import Vision

class ViewController: UIViewController {

  @IBOutlet weak var drawView: DrawView!
  @IBOutlet weak var predictLabel: UILabel!

  // DONE: Define lazy var classificationRequest
  lazy var classificationRequest: VNCoreMLRequest = {
    // Load the ML model through its generated class and create a Vision request for it.
    do {
      let model = try VNCoreMLModel(for: MNISTClassifier().model)
      return VNCoreMLRequest(model: model, completionHandler: self.handleClassification)
    } catch {
      fatalError("Can't load Vision ML model: \(error).")
    }
  }()

  func handleClassification(request: VNRequest, error: Error?) {
    guard let observations = request.results as? [VNClassificationObservation]
      else { fatalError("Unexpected result type from VNCoreMLRequest.") }
    guard let best = observations.first
      else { fatalError("Can't get best result.") }

    DispatchQueue.main.async {
      self.predictLabel.text = best.identifier
      self.predictLabel.isHidden = false
    }
  }

  override func viewDidLoad() {
    super.viewDidLoad()
    predictLabel.isHidden = true
  }

  @IBAction func clearTapped() {
    drawView.lines = []
    drawView.setNeedsDisplay()
    predictLabel.isHidden = true
  }

  @IBAction func predictTapped() {
    guard let context = drawView.getViewContext(),
      let inputImage = context.makeImage()
      else { fatalError("Get context or make image failed.") }
    // DONE: Perform request on model
    let ciImage = CIImage(cgImage: inputImage)
    let handler = VNImageRequestHandler(ciImage: ciImage)
    do {
      try handler.perform([classificationRequest])
    } catch {
      print(error)
    }
  }

}
2. DrawView.swift
// Code taken with inspiration from Apple's Metal-2 sample MPSCNNHelloWorld
import UIKit

/**
 This class is used to handle the drawing in the DigitView so we can get user input digit,
 This class doesn't really have an MPS or Metal going in it, it is just used to get user input
 */
class DrawView: UIView {
    
    // some parameters of how thick a line to draw 15 seems to work
    // and we have white drawings on black background just like MNIST needs its input
    var linewidth = CGFloat(15) { didSet { setNeedsDisplay() } }
    var color = UIColor.white { didSet { setNeedsDisplay() } }
    
    // we will keep touches made by user in view in these as a record so we can draw them.
    var lines: [Line] = []
    var lastPoint: CGPoint!
    
    override func touchesBegan(_ touches: Set<UITouch>, with event: UIEvent?) {
        lastPoint = touches.first!.location(in: self)
    }
    
    override func touchesMoved(_ touches: Set<UITouch>, with event: UIEvent?) {
        let newPoint = touches.first!.location(in: self)
        // keep all lines drawn by user as touch in record so we can draw them in view
        lines.append(Line(start: lastPoint, end: newPoint))
        lastPoint = newPoint
        // make a draw call
        setNeedsDisplay()
    }
    
    override func draw(_ rect: CGRect) {
        super.draw(rect)
        
        let drawPath = UIBezierPath()
        drawPath.lineCapStyle = .round
        
        for line in lines{
            drawPath.move(to: line.start)
            drawPath.addLine(to: line.end)
        }
        
        drawPath.lineWidth = linewidth
        color.set()
        drawPath.stroke()
    }
    
    
    /**
     This function gets the pixel data of the view so we can put it in MTLTexture
     
     - Returns:
     Void
     */
    func getViewContext() -> CGContext? {
        // our network takes in only grayscale images as input
        let colorSpace:CGColorSpace = CGColorSpaceCreateDeviceGray()
        
        // we have 3 channels no alpha value put in the network
        let bitmapInfo = CGImageAlphaInfo.none.rawValue
        
        // this is where our view pixel data will go in once we make the render call
        let context = CGContext(data: nil, width: 28, height: 28, bitsPerComponent: 8, bytesPerRow: 28, space: colorSpace, bitmapInfo: bitmapInfo)
        
        // scale and translate so we have the full digit and in MNIST standard size 28x28
        context!.translateBy(x: 0 , y: 28)
        context!.scaleBy(x: 28/self.frame.size.width, y: -28/self.frame.size.height)
        
        // put view pixel data in context
        self.layer.render(in: context!)
        
        return context
    }
}

/**
 2 points can give a line and this class is just for that purpose, it keeps a record of a line
 */
class Line{
    var start, end: CGPoint
    
    init(start: CGPoint, end: CGPoint) {
        self.start = start
        self.end   = end
    }
}

后記

本篇主要講述了使用Keras和Core ML開始機(jī)器學(xué)習(xí),感興趣的給個(gè)贊或者關(guān)注~~~

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末吹埠,一起剝皮案震驚了整個(gè)濱河市第步,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌缘琅,老刑警劉巖粘都,帶你破解...
    沈念sama閱讀 217,277評論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異胯杭,居然都是意外死亡驯杜,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,689評論 3 393
  • 文/潘曉璐 我一進(jìn)店門做个,熙熙樓的掌柜王于貴愁眉苦臉地迎上來鸽心,“玉大人,你說我怎么就攤上這事居暖⊥缙担” “怎么了?”我有些...
    開封第一講書人閱讀 163,624評論 0 353
  • 文/不壞的土叔 我叫張陵太闺,是天一觀的道長糯景。 經(jīng)常有香客問我,道長省骂,這世上最難降的妖魔是什么蟀淮? 我笑而不...
    開封第一講書人閱讀 58,356評論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮钞澳,結(jié)果婚禮上怠惶,老公的妹妹穿的比我還像新娘。我一直安慰自己轧粟,他們只是感情好策治,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,402評論 6 392
  • 文/花漫 我一把揭開白布脓魏。 她就那樣靜靜地躺著,像睡著了一般通惫。 火紅的嫁衣襯著肌膚如雪茂翔。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,292評論 1 301
  • 那天履腋,我揣著相機(jī)與錄音珊燎,去河邊找鬼。 笑死府树,一個(gè)胖子當(dāng)著我的面吹牛俐末,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播奄侠,決...
    沈念sama閱讀 40,135評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼卓箫,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了垄潮?” 一聲冷哼從身側(cè)響起烹卒,我...
    開封第一講書人閱讀 38,992評論 0 275
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎弯洗,沒想到半個(gè)月后旅急,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,429評論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡牡整,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,636評論 3 334
  • 正文 我和宋清朗相戀三年藐吮,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片逃贝。...
    茶點(diǎn)故事閱讀 39,785評論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡谣辞,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出沐扳,到底是詐尸還是另有隱情泥从,我是刑警寧澤,帶...
    沈念sama閱讀 35,492評論 5 345
  • 正文 年R本政府宣布沪摄,位于F島的核電站躯嫉,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏杨拐。R本人自食惡果不足惜祈餐,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,092評論 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望哄陶。 院中可真熱鬧昼弟,春花似錦、人聲如沸奕筐。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,723評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽离赫。三九已至芭逝,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間渊胸,已是汗流浹背旬盯。 一陣腳步聲響...
    開封第一講書人閱讀 32,858評論 1 269
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留翎猛,地道東北人胖翰。 一個(gè)月前我還...
    沈念sama閱讀 47,891評論 2 370
  • 正文 我出身青樓,卻偏偏與公主長得像切厘,于是被迫代替她去往敵國和親萨咳。 傳聞我的和親對象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,713評論 2 354

推薦閱讀更多精彩內(nèi)容