本文來源于MATLAB例程:Classify Time Series Using Wavelet Analysis and Deep Learning
關(guān)鍵詞:ECG 遷移學(xué)習(xí) GoogLeNet
實(shí)驗(yàn)環(huán)境
- 軟件部分
本人使用的是MATLAB 2018b。官方文檔提到GoogLeNet,AlexNet分別是在MATLAB 2017b歼捏、MATLAB 2017a版本引入的。
此實(shí)驗(yàn)必備的Toolbox:
1.Wavelet Toolbox
2.Image Processing Toolbox
3.Deep Learning Toolbox
4.Deep Learning Toolbox Model for GoogLeNet Network support package
5.Deep Learning Toolbox Model for AlexNet Network
其中栅贴,兩個(gè)support package可能需要登錄MATLAB賬號才能下載,注冊一個(gè)即可熏迹。
- 硬件部分
這里我用的是顯卡是NVIDIA Geforce RTX 2080檐薯。實(shí)驗(yàn)可以用CPU跑,但最好還是使用GPU注暗,快的不是一丁半點(diǎn)坛缕。
數(shù)據(jù)集
共計(jì)162條數(shù)據(jù),下載地址捆昏,請點(diǎn)擊鏈接
讀取數(shù)據(jù)
代碼如下:
dir = 'E:\Transfer_learning\ECG\ECG_Matlab\ECGdata';
load(fullfile(dir,'ECGData.mat'));
parentDir = dir;
dataDir = 'data';
helperCreateECGDirectories(ECGData,parentDir,dataDir)
讀入數(shù)據(jù)后赚楚,工作空間生成了名為ECGdata的結(jié)構(gòu)數(shù)組,如下圖骗卜。Data為162?65536維宠页,即162條數(shù)據(jù)左胞,每條數(shù)據(jù)時(shí)長為512s,采樣率為128Hz举户。Labels存儲了每條數(shù)據(jù)的標(biāo)簽罩句。
顯示原始數(shù)據(jù)
借助helperPlotReps()函數(shù),繪制原始數(shù)據(jù)敛摘,如圖。
特征提取
這里主要使用cwtfilterbank函數(shù)乳愉,將原始的一維心電信號通過連續(xù)小波變換(CWT)轉(zhuǎn)換為時(shí)頻域表達(dá)兄淫,即scalograms。代碼和示例圖如下:
%時(shí)頻域表達(dá)蔓姚,CWT連續(xù)小波變換
Fs = 128;
fb = cwtfilterbank('SignalLength',1000,...
'SamplingFrequency',Fs,...
'VoicesPerOctave',12);
sig = ECGData.Data(1,1:1000);
[cfs,frq] = wt(fb,sig);
t = (0:999)/Fs;figure;pcolor(t,frq,abs(cfs))
set(gca,'yscale','log');shading interp;axis tight;
title('Scalogram');xlabel('Time (s)');ylabel('Frequency (Hz)')
%生成各個(gè)病種的RGB圖像捕虽,尺寸為224?224?3
helperCreateRGBfromTF(ECGData,parentDir,dataDir)
數(shù)據(jù)集的劃分
%劃分訓(xùn)練與測試數(shù)據(jù)集
allImages = imageDatastore(fullfile(parentDir,dataDir),...
'IncludeSubfolders',true,...
'LabelSource','foldernames');
% 80%作為訓(xùn)練,其余作為測試坡脐,隨機(jī)種子設(shè)為默認(rèn)泄私,以便可重復(fù)。
rng default
[imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,'randomized');
disp(['Number of training images: ',num2str(numel(imgsTrain.Files))]);
disp(['Number of validation images: ',num2str(numel(imgsValidation.Files))]);
加載GoogLeNet并進(jìn)行訓(xùn)練
GooLeNet
net = googlenet;
lgraph = layerGraph(net);
numberOfLayers = numel(lgraph.Layers);
figure('Units','normalized','Position',[0.1 0.1 0.8 0.8]);
plot(lgraph)%繪制結(jié)構(gòu)圖
title(['GoogLeNet Layer Graph: ',num2str(numberOfLayers),' Layers']);
GoogLeNet模型的參數(shù)修改
GoogleNet是使用ImageNet訓(xùn)練的對于1000分類的深層CNN網(wǎng)絡(luò)备闲,這里最后四層修改為針對三分類問題的輸出晌端。
lgraph = removeLayers(lgraph,{'pool5-drop_7x7_s1','loss3-classifier','prob','output'});
numClasses = numel(categories(imgsTrain.Labels));
newLayers = [
dropoutLayer(0.6,'Name','newDropout') % dropout概率60%
fullyConnectedLayer(numClasses,'Name','fc','WeightLearnRateFactor',5,'BiasLearnRateFactor',5) %全連接層,這里numClasses為3
softmaxLayer('Name','softmax')
classificationLayer('Name','classoutput')];
lgraph = addLayers(lgraph,newLayers);
lgraph = connectLayers(lgraph,'pool5-7x7_s1','newDropout');
inputSize = net.Layers(1).InputSize;
模型超參數(shù)設(shè)置及訓(xùn)練
參數(shù)設(shè)置和訓(xùn)練結(jié)果如下恬砂,最終的正確率為90.625%咧纠。ps.例程中做的訓(xùn)練過程動態(tài)圖真的太贊了!
options = trainingOptions('sgdm',...
'MiniBatchSize',15,...
'MaxEpochs',20,...
'InitialLearnRate',1e-4,...
'ValidationData',imgsValidation,...
'ValidationFrequency',10,...
'ValidationPatience',Inf,...
'Verbose',1,...
'ExecutionEnvironment','gpu',...
'Plots','training-progress');
例程中后面的內(nèi)容不再贅述泻骤,主要是探討網(wǎng)絡(luò)內(nèi)部的結(jié)構(gòu)漆羔。和下圖所示的人臉識別的很相似。
底層網(wǎng)絡(luò):各種邊緣結(jié)構(gòu)
中層網(wǎng)絡(luò):眼睛狱掂,鼻子演痒,嘴巴等局部特征
高層網(wǎng)絡(luò):將局部特征組合,得到各種人臉特征
后記
本文實(shí)驗(yàn)是將ECG轉(zhuǎn)換為二維的時(shí)頻域圖作為網(wǎng)絡(luò)輸入趋惨,在arXiv上瀏覽文獻(xiàn)發(fā)現(xiàn)有一篇文章做的工作很相似鸟顺,貼在這里,是基于DenseNet做的遷移器虾。
ECG Arrhythmia Classification Using Transfer Learning from 2-Dimensional Deep CNN Features