創(chuàng)建簡單序列分類網(wǎng)絡
本示例說明了如何創(chuàng)建簡單的長期短期記憶(LSTM)分類網(wǎng)絡宋欺。
要訓??練深度神經(jīng)網(wǎng)絡對序列數(shù)據(jù)進行分類殷勘,可以使用LSTM網(wǎng)絡。 LSTM網(wǎng)絡是一種遞歸神經(jīng)網(wǎng)絡(RNN)竭贩,用于學習序列數(shù)據(jù)的時間步長之間的長期依賴關系印屁。
該示例演示如何:?加載序列數(shù)據(jù)。
?定義網(wǎng)絡體系結構稼锅。
?指定訓練選項吼具。
?訓練網(wǎng)絡。
?預測新數(shù)據(jù)的標簽并計算分類準確性矩距。
1)導入數(shù)據(jù)
如[1]和[2]中所述加載日語元音數(shù)據(jù)集拗盒。 預測變量是一個單元陣列,其中包含長度可變的特征維為12的序列锥债。標簽是標簽1,2陡蝇,...,9的分類向量赞弥。
[XTrain,YTrain] = japaneseVowelsTrainData; [XValidation,YValidation] = japaneseVowelsTestData;
查看前幾個訓練序列的大小。 序列是具有12行(每個要素一行)和不同列數(shù)(每個時間步長一列)的矩陣趣兄。
XTrain(1:5)
- 定義網(wǎng)絡架構
定義LSTM網(wǎng)絡體系結構绽左。 指定輸入層中要素的數(shù)量,以及完全連接層中的類的數(shù)量艇潭。
numFeatures = 12;
numHiddenUnits = 100; numClasses = 9;
layers = [ ...
sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];
3)訓練網(wǎng)絡
指定訓練選項并訓練網(wǎng)絡拼窥。
由于小批量生產(chǎn)的序列短,因此CPU更適合訓練蹋凝。 將“ ExecutionEnvironment”設置為“ cpu”鲁纠。 要在GPU上進行訓練(如果有),請將'ExecutionEnvironment'設置為'auto'(默認值)鳍寂。
miniBatchSize = 27;
options = trainingOptions('adam', ...
'ExecutionEnvironment','cpu', ...
'MaxEpochs',100, ...
'MiniBatchSize',miniBatchSize, ...
'ValidationData',{XValidation,YValidation}, ...
'GradientThreshold',2, ...
'Shuffle','every-epoch', ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(XTrain,YTrain,layers,options);
有關指定訓練選項的更多信息改含,請參見“設置參數(shù)和訓練卷積神經(jīng)網(wǎng)絡”。
- 測試網(wǎng)絡
對測試數(shù)據(jù)進行分類并計算分類精度迄汛。 指定用于訓練的相同小批量大小捍壤。
YPred = classify(net,XValidation,'MiniBatchSize',miniBatchSize);
acc = mean(YPred == YValidation)
acc=0.9351
對于下一步骤视,您可以嘗試使用雙向LSTM(BiLSTM)層或創(chuàng)建更深的網(wǎng)絡來提高準確性。 有關更多信息鹃觉,請參見“長短期存儲網(wǎng)絡”专酗。
有關顯示如何使用卷積網(wǎng)絡對序列數(shù)據(jù)進行分類的示例,請參見“使用深度學習的語音命令識別”盗扇。
References
1 M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.
2 UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
See Also
lstmLayer | trainNetwork | trainingOptions
More About
? “Long Short-Term Memory Networks”
? “Try Deep Learning in 10 Lines of MATLAB Code” on page 1-12
? “Classify Image Using Pretrained Network” on page 1-15
? “Get Started with Transfer Learning” on page 1-18
? “Transfer Learning with Deep Network Designer”
? “Create Simple Image Classification Network” on page 1-22