現(xiàn)在模型已經(jīng)定義好了嵌施,數(shù)據(jù)也下載并進行了處理腿箩,一切準備就緒準備開始訓(xùn)練挣磨。
async function trainModel(model, inputs, labels) {
// 準備要訓(xùn)練的模型
model.compile({
optimizer: tf.train.adam(),
loss: tf.losses.meanSquaredError,
metrics: ['mse'],
});
const batchSize = 32;
const epochs = 50;
return await model.fit(inputs, labels, {
batchSize,
epochs,
shuffle: true,
callbacks: tfvis.show.fitCallbacks(
{ name: 'Training Performance' },
['loss', 'mse'],
{ height: 200, callbacks: ['onEpochEnd'] }
)
});
}
訓(xùn)練前的一些準備
model.compile({
optimizer: tf.train.adam(),
loss: tf.losses.meanSquaredError,
metrics: ['mse'],
});
在訓(xùn)練模型之前雇逞,需要 "編譯 "該模型荤懂,那么具體應(yīng)該如何做呢? 我們需要一個優(yōu)化和一個損失函數(shù),損失函數(shù)也可以理解目標函數(shù)喝峦,主要是指定訓(xùn)練势誊,讓我們訓(xùn)練一個目標,優(yōu)化器這是給出一個策略如何在訓(xùn)練過程更新參數(shù)谣蠢。
- 優(yōu)化器粟耻。這是一種算法,是更新參數(shù)的算法眉踱。在 TensorFlow.js 中有許多優(yōu)化器可用挤忙。這里選擇了 adam 優(yōu)化器,也可以嘗試用其他優(yōu)化器
- 損失函數(shù):其實就是一個函數(shù)谈喳,告訴模型在學(xué)習(xí)過程中册烈,在每個批次(數(shù)據(jù)子集)時的表現(xiàn)如何。這里選擇 meanSquaredError 來比較模型的預(yù)測和真實值
const batchSize = 32;
const epochs = 50;
設(shè)置超參數(shù) batchSize 和一個 epochs 的數(shù)量婿禽。
batchSize 指的是模型在每次迭代訓(xùn)練中看到的數(shù)據(jù)子集的大小赏僧。常見的批次大小往往在 32-512 之間取值。批次大小對于訓(xùn)練速度是有所影響的
epochs 完成整個數(shù)據(jù)集進行訓(xùn)練的次數(shù)
開始訓(xùn)練
return await model.fit(inputs, labels, {
batchSize,
epochs,
callbacks: tfvis.show.fitCallbacks(
{ name: 'Training Performance' },
['loss', 'mse'],
{ height: 200, callbacks: ['onEpochEnd'] }
)
});
model.fit 是來啟動訓(xùn)練的函數(shù)扭倾。這是一個異步函數(shù)淀零,所以返回會是一個 promise。
為了監(jiān)控訓(xùn)練進度膛壹,回調(diào)傳函數(shù)作為 model.fit 來獲取訓(xùn)練過程中信息驾中。然后回調(diào)函數(shù)使用 tfvis.show.fitCallbacks 來定義,然后可以繪制損失值對于迭代的圖標
const tensorData = convertToTensor(data);
const {inputs, labels} = tensorData;
// Train the model
await trainModel(model, inputs, labels);
console.log('Done Training');
這的注意的這部分代碼要寫在 run 函數(shù)中模聋,具體如下
async function run() {
// 加載數(shù)據(jù)
const data = await getData();
// 處理原始數(shù)據(jù)肩民,將數(shù)據(jù) horsepower 映射為 x 而 mpg 則映射為 y
const values = data.map(d => ({
x: d.horsepower,
y: d.mpg,
}));
// 將數(shù)據(jù)以散點圖形式顯示在開發(fā)者調(diào)試工具
tfvis.render.scatterplot(
{name: 'Horsepower v MPG'},
{values},
{
xLabel: 'Horsepower',
yLabel: 'MPG',
height: 300
}
);
const model = createModel();
const tensorData = convertToTensor(data);
const {inputs, labels} = tensorData;
// Train the model
await trainModel(model, inputs, labels);
console.log('Done Training');
}