Deeplearning4j是一個java編寫的深度學(xué)習(xí)商業(yè)框架,可以通過他們提供的API快速搭建出神經(jīng)網(wǎng)絡(luò)摔握。根據(jù)官網(wǎng)的下載的demo中的MLPClassiferLinear(多層感知線性分類器),訓(xùn)練Titanic數(shù)據(jù)塞耕,求生存分析焕济。
MLPClassiferLinear
Titanic數(shù)據(jù)萌壳,數(shù)據(jù)集來自kaggle
訓(xùn)練集:
測試集:
然后依葫蘆畫瓢,照著MLP的樣子打代碼摹芙。
一開始時設(shè)置神經(jīng)網(wǎng)絡(luò)的參數(shù)灼狰,根據(jù)官網(wǎng)介紹這些參數(shù)設(shè)置和最終訓(xùn)練的結(jié)果有比較大的關(guān)聯(lián)。batchSize表示每一步抓取的數(shù)據(jù)量浮禾。
要搞清楚這里的參數(shù)是什么意思得先看神經(jīng)網(wǎng)絡(luò)的原理圖
可以看到信息經(jīng)過一層一層的處理最終變?yōu)檩敵鼋慌撸恳粚佣加幸粋€輸入(除了第一層)都有一個輸出(除了最后一層)份汗,也就是說第一層我們輸入原始的數(shù)據(jù),然后經(jīng)過N層的隱藏層得到結(jié)果蝴簇。numInput表示原始數(shù)據(jù)的維數(shù)杯活,例如在titanic中將數(shù)據(jù)處理為[sex,survived]則數(shù)據(jù)為1維,[sex,parch,survived]則數(shù)據(jù)為二維的(即我們用幾個數(shù)據(jù)預(yù)測survived這個變量军熏,f(x)=y,f(x,z)=y,可以這么理解)轩猩。然后這n維數(shù)據(jù)經(jīng)過一定的組合變成numHiddenNodes維,距離f(x,z)=y->f(x1,x2,z1,z2...)=y,最終輸出結(jié)果荡澎。numOutput表示y的取值可能性均践,在本例中suvived即為生或死,表示為0或1所以為2摩幔。
參數(shù)設(shè)置好了后可以加載數(shù)據(jù)了彤委,但是這里的數(shù)據(jù)不能是上面展示的原始數(shù)據(jù),需要對數(shù)據(jù)進行處理或衡。因為神經(jīng)網(wǎng)絡(luò)接受的輸入是向量類型的焦影,即不可以出現(xiàn)字符,要把所有的信息都轉(zhuǎn)換為數(shù)字封断,比如男斯辰,女可以表示為0,1等坡疼。deeplearning4j自己提供了data2vec工具類彬呻,在機器學(xué)習(xí)中數(shù)據(jù)的質(zhì)量對結(jié)果有著格外的影響,好的數(shù)據(jù)處理能讓預(yù)測結(jié)果更準確柄瑰。這里我處理數(shù)據(jù)的方式采用的方法不科學(xué)且粗獷不可取闸氮。為了方便我直接把所有出現(xiàn)字符的行列都刪除,最終得到的數(shù)據(jù)如下
第一列是suvived教沾,后面的依次為pclass,sibSp蒲跨,parch
kaggle提供的測試即未提供suvived列,故不能用來評估授翻,我直接把數(shù)據(jù)裁取好后在第一列中簡單粗暴的加上分類標(biāo)簽或悲,隨機的0,1。
構(gòu)造方法參數(shù)的意思堪唐,lableIndex表示在你的數(shù)據(jù)中標(biāo)簽列的索引號(本例中即survived巡语,就是要預(yù)測的那一列)numPossibleLabels表示標(biāo)簽列可能的值的個數(shù)(本例中即生或死,為2)
然后可以搭建神經(jīng)網(wǎng)絡(luò)羔杨,代碼如上每需要一層則用.layer(...)添加捌臊,layer即英文中神經(jīng)網(wǎng)絡(luò)層的意思。上圖中搭建的是一個兩層的網(wǎng)絡(luò)兜材,第一層接受numInputs的向量為參數(shù)然后輸出numHiddenNodes的向量為輸出理澎,第二層接受上一層的輸出為輸入并且輸出結(jié)果numOutputs逞力。然后創(chuàng)建模型對象并且用訓(xùn)練集訓(xùn)練模型。最終得到model對象糠爬。
這樣一個神經(jīng)網(wǎng)絡(luò)就訓(xùn)練好啦寇荧!
但是我們并不知道這個網(wǎng)絡(luò)的效果如何,這時測試數(shù)據(jù)集及登場了执隧。(原本應(yīng)該是測試數(shù)據(jù)集是一組正確的數(shù)據(jù)揩抡,用來評估模型的預(yù)測準確率的,但這里測試數(shù)據(jù)集的數(shù)據(jù)并不是正確的而是自己杜撰的)
寫好后峦嗤,點擊運行。過一會屋摔,就能預(yù)測輸出結(jié)果了烁设。
可以看到評估指標(biāo)中又正確率,召回率等數(shù)據(jù)钓试,其結(jié)果均為50%左右装黑。也就是我們這模型可能和瞎猜的結(jié)果差不多。這個應(yīng)該是數(shù)據(jù)處理的問題弓熏,用demo中的數(shù)據(jù)跑精確率可以達到99%以上恋谭。
deeplearning4j給廣大java程序源提供一個很好進一步接觸人工智能的機會。有興趣的花可以自己扒數(shù)據(jù)挽鞠,嘗試這它預(yù)測諸如股票疚颊、天氣等數(shù)據(jù),看看自己訓(xùn)練的網(wǎng)絡(luò)的準確性滞谢、實用性串稀。
ps.未來應(yīng)該是一個人工智能的時代除抛,越來越多的框架讓我們能更簡潔的接觸利用人工智能狮杨,這是很好的時代普通人一可以用這些看似高大上(其實也高達上)的東西做一些自己的想法。官網(wǎng)的demo還有許多其他例子到忽,有時間可以拿出來研究橄教。