為大家介紹一下谷歌去年年中推出的前端js深度學(xué)習(xí)框架:deeplearnjs饿这,可以在瀏覽器上進(jìn)行神經(jīng)網(wǎng)絡(luò)訓(xùn)練與預(yù)測(cè)菩颖,AI框架都有獲取數(shù)據(jù)蔬胯、構(gòu)建網(wǎng)絡(luò)、訓(xùn)練位他、預(yù)測(cè)等氛濒。本文以官方的手寫識(shí)別例子看看deeplearnjs是如何實(shí)現(xiàn)這些的,還有deeplearnjs的性能情況鹅髓。
一.安裝
git clone https://github.com/PAIR-code/deeplearnjs
brew install yarn 【已經(jīng)安裝yarn這里可以忽略】
cd deeplearnjs
yarn prep
安裝vs code:
https://marketplace.visualstudio.com/items?itemName=eg2.tslint
sudo npm install -g clang-format
看手寫識(shí)別demo:
./scripts/watch-demo demos/mnist_eager
二.手寫識(shí)別demo目錄結(jié)構(gòu)
目錄路徑:deeplearnjs-master/demos/mnist_eager
目錄結(jié)構(gòu):(以typescript編寫所以文件名后綴是ts)
入口:mnist_eager.ts
數(shù)據(jù):data.ts
界面顯示:ui.ts
網(wǎng)絡(luò)模型:model.ts
三.手寫識(shí)別demo加載訓(xùn)練數(shù)據(jù)
文件是data.ts舞竿,識(shí)別圖片格式為28x28的灰度圖,labels為對(duì)應(yīng)的結(jié)果:
'data': [
{
'name': 'images',
'path': 'https://storage.googleapis.com/learnjs-data/model-builder/' +
'mnist_images.png',
'dataType': 'png',
'shape': [28, 28, 1]
},
{
'name': 'labels',
'path': 'https://storage.googleapis.com/learnjs-data/model-builder/' +
'mnist_labels_uint8',
'dataType': 'uint8',
'shape': [10]
}
],
this.dataset = new dl.XhrDataset(mnistConfig);
在網(wǎng)絡(luò)獲取的數(shù)據(jù)圖:
https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png
在數(shù)據(jù)集中獲取訓(xùn)練數(shù)據(jù)與測(cè)試數(shù)據(jù):
const [images, labels] =
this.dataset.getData() as [dl.NDArray[], dl.NDArray[]];
四.手寫識(shí)別demo構(gòu)造網(wǎng)絡(luò)
文件是model.ts
初始化網(wǎng)絡(luò)權(quán)重:
const weights = dl.variable(dl.Array2D.randNormal(
[IMAGE_SIZE, LABELS_SIZE], 0, 1 / Math.sqrt(IMAGE_SIZE), 'float32'));
構(gòu)造網(wǎng)絡(luò)模型:
const model = (xs: dl.Array2D<'float32'>): dl.Array2D<'float32'> => {
return math.matMul(xs, weights) as dl.Array2D<'float32'>;
};
構(gòu)造損失函數(shù):
const loss = (labels: dl.Array2D<'float32'>,
ys: dl.Array2D<'float32'>): dl.Scalar => {
return math.mean(math.softmaxCrossEntropyWithLogits(labels, ys)) as dl.Scalar;
};
五.手寫識(shí)別demo訓(xùn)練網(wǎng)絡(luò)
放入訓(xùn)練數(shù)據(jù)訓(xùn)練網(wǎng)絡(luò)窿冯,i為學(xué)習(xí)步長(zhǎng):
export async function train(data: MnistData, log: (message: string) => void) {
const returnCost = true;
for (let i = 0; i < TRAIN_STEPS; i++) {
const cost = optimizer.minimize(() => {
const batch = data.nextTrainBatch(BATCH_SIZE);
return loss(batch.labels, model(batch.xs));
}, returnCost);
log(`loss[${i}]: ${cost.dataSync()}`);
await dl.util.nextFrame();
}
}
六.手寫識(shí)別demo預(yù)測(cè)網(wǎng)絡(luò)
export async function test(data: MnistData) {}
// Predict the digit number from a batch of input images.
export function predict(x: dl.Array2D<'float32'>): number[] {
const pred = math.scope(() => {
const axis = 1;
return math.argMax(model(x), axis);
});
return Array.from(pred.dataSync());
}
七.手寫識(shí)別demo運(yùn)行結(jié)果與gpu性能
運(yùn)行結(jié)果如下:
訓(xùn)練網(wǎng)絡(luò)時(shí)運(yùn)行的js函數(shù):
訓(xùn)練網(wǎng)絡(luò)時(shí)運(yùn)行的js函數(shù)對(duì)應(yīng)的gpu消耗情況:
總結(jié):
deeplearnjs可以支持es5骗奖,并且可以支持瀏覽器的WebGL2.0、WebGL1.0以及CPU醒串,若瀏覽器支持WebGL2.0框架則優(yōu)先調(diào)用WebGL2.0执桌。市面上的深度學(xué)習(xí)框架大多數(shù)只支持N卡,用deeplearnjs就可以通過(guò)瀏覽器調(diào)用A卡芜赌,缺點(diǎn)是暫時(shí)沒(méi)有支持分布式gpu計(jì)算仰挣。
參考資料
官網(wǎng)地址:
https://deeplearnjs.org/
github地址:
https://github.com/PAIR-code/deeplearnjs
deeplearn.js:瀏覽器端機(jī)器智能框架 @徐進(jìn)
http://www.infoq.com/cn/news/2017/08/deeplearn-js-Browser-machine-int