蘋果在去年推出了CoreML機器學習模型,今年在XCode10中提供的CreateML framework洞豁,可以創(chuàng)建CoreML模型绩蜻。
使用CreateML創(chuàng)建CoreML模型時蹲坷,僅需編寫少量的代碼。
準備工作
1惰爬、XCode10(目前是beta版本)
2、MacOS Mojave(目前也是beta版本)
3惫企、訓練數據:在同一個目錄下撕瞧,以文件夾作為分類,各個文件夾下存放對應分類的圖片
4狞尔、測試數據:和訓練數據一樣丛版,并且文件夾分類的名稱要和訓練數據的名稱一致
說明:
1、訓練數據可以自己準備偏序,也可以從網上找一些页畦,例如:Kaggle Cats and Dogs Dataset(本文是以Pets-100目錄下的圖片進行的訓練)
2、訓練數據數量越大研儒,訓練的模型越準確豫缨,訓練的時間也就越長
創(chuàng)建圖像分類CoreML模型
1、運行XCode10端朵,創(chuàng)建一個空的playground工程好芭,清除所有代碼,然后將下面的代碼拷貝在playground中
import CreateMLUI
let builder = MLImageClassifierBuilder()
builder.showInLiveView()
2冲呢、切換顯示XCode的assistant editor舍败,再點擊運行
3、此時,XCode的assistant editor中邻薯,會顯示MLImageClassifierBuilder的live view裙戏,將訓練數據的目錄拖拽進來,XCode便開始訓練CoreML模型了
4弛说、將訓練后的模型挽懦,保存到文件
5、應用創(chuàng)建的模型進行預測:將想要預測的圖片(或目錄)拖拽到模型上木人,進行預測信柿。例如,將Pets-1000目錄拖拽到Live view上醒第,預測的準確率如下
說明:除了在Live view中進行預測外渔嚷,也可以將保存后的模型導入到app中使用。參見Classifying Images with Vision and Core ML
創(chuàng)建文本分類模型
創(chuàng)建文本分類ML模型稠曼,可以使用MLDataTable和MLTextClassifier類形病。步驟如下:
1、創(chuàng)建一個MLDataTable對象霞幅,讀取訓練數據(可以是JSON或CSV格式漠吻、或者Dictionary)
2、創(chuàng)建一個MLTextClassifier對象司恳,使用MLDataTable對象中的數據進行訓練
3途乃、通過MLTextClassifier對象的write(to:metadata:)方法,將模型保存到磁盤
csv文件格式示例:
title,author,pageCount,genre
Alice in Wonderland,Lewis Carroll,124,Fantasy
Hamlet,William Shakespeare,98,Drama
Treasure Island,Robert L. Stevenson,280,Adventure
Peter Pan,J. M. Barrie,94,Fantasy
JSON文件格式示例:
[
{
"title": "Alice in Wonderland",
"author": "Lewis Carroll",
"pageCount": 124,
"genre": "Fantasy"
},
{
"title": "Hamlet",
"author": "William Shakespeare",
"pageCount": 98,
"genre": "Drama"
}, ...
]
//Dictionary數據示例
let data: [String: MLDataValueConvertible] = [
??? "title": ["Alice in Wonderland", "Hamlet", "Treasure Island", "Peter Pan"],
??? "author": ["Lewis Carroll", "William Shakespeare", "Robert L. Stevenson", "J. M. Barrie"],
??? "pageCount": [124, 98, 280, 94],
??? "genre": ["Fantasy", "Drama", "Adventure", "Fantasy"]
]
let bookTable = try MLDataTable(dictionary: data)
示例代碼
在XCode創(chuàng)建一個空的playground工程扔傅,在資源中添加訓練使用的數據spam-sms.csv耍共,然后將下面的代碼粘貼到工程中
import Foundation
import CreateML
//獲取csv文件路徑
guard let trainingCSV = Bundle.main.url(forResource: "spam-sms", withExtension: "csv") else {
? ? fatalError()
}
//將csv文件內容加載到MLDataTable中
var spamData = try MLDataTable(contentsOf: trainingCSV)
let (trainingData, testData) = spamData.randomSplit(by: 0.8, seed: 0)
//創(chuàng)建文本分類器,進行訓練
//message和label分別對應csv文件中的短信內容列猎塞、短信標簽列
let predictor = try MLTextClassifier(trainingData: trainingData, textColumn: "message", labelColumn: "label")
//在測試數據集上驗證
let metrics = predictor.evaluation(on: testData)
說明:
使用400條中文短信內容的csv试读,訓練模型時,內存占用十分嚴重荠耽,超過Mac系統(tǒng)的物理內存钩骇,訓練卡在解析短信的步驟,未能訓練出模型铝量。
使用英文短信內容進行訓練時沐序,沒有內存問題瑞筐,可以訓練出模型。
其它
MLClassifier是一個通用的分類模型,MLRegressor是一個回歸模型恼蓬,給定訓練模型(MLDataTable)中的特征列和結果列后柠偶,就可以對這兩種模型進行訓練别伏。
缺點
模型訓練好后吐根,如果增加了數據集请梢,必須重新開始訓練,即無法在訓練好的模型上應用新的數據進行訓練力穗。
模型優(yōu)化
提高訓練數據集上的準確率(Training Accuracy)
對于MLImageClassifierBuilder毅弧,可以將訓練的迭代次數調整成20次
對于自然語言的分類器,可以嘗試不同的算法(MLTextClassifier.ModelAlgorithmType)
對于MLClassifier和MLRegressor当窗,則可以嘗試選用不同的模型進行訓練
提高驗證數據集上的準確率(Validation Accuracy)
對于擬合不足的問題够坐,可以通過增加訓練數據集來進行優(yōu)化。例如崖面,對于圖像分類器元咙,可以在訓練時勾選Augmentation(增加)選項:
對于過擬合的問題,則可以嘗試減少迭代次數進行優(yōu)化巫员。
提高測試數據集上的準確率(Evaluation Accuracy)
如果訓練數據集庶香、驗證數據集上的準確率,高于測試數據集上的準確率简识,原因通常是訓練數據和測試數據存在比較明顯的差異導致赶掖,這種情況下,可以嘗試在訓練數據集中使用更多的不同的數據七扰。