大賽簡介
為響應(yīng)國家健康中國戰(zhàn)略蝠猬,推送健康醫(yī)療和大數(shù)據(jù)的融合發(fā)展的政策,由清華大學(xué)臨床醫(yī)學(xué)院和數(shù)據(jù)科學(xué)研究院点待,天津市武清區(qū)京津高村科技創(chuàng)新園仰楚,以及多家重點醫(yī)院聯(lián)合主辦的首屆中國心電智能大賽正式啟動。自今日起至2019年3月31日24時薛躬,大賽開啟全球招募俯渤,預(yù)計大賽總獎金將高達(dá)百萬元!目前官方報名網(wǎng)站已上線型宝,歡迎高校八匠、醫(yī)院、創(chuàng)業(yè)團隊等有志于中國心電人工智能發(fā)展的人員踴躍參加趴酣。
首屆中國心電智能大賽官方報名網(wǎng)站>>http://mdi.ids.tsinghua.edu.cn
數(shù)據(jù)介紹
下載完整的訓(xùn)練集和測試集梨树,共1000例常規(guī)心電圖,其中訓(xùn)練集中包含600例岖寞,測試集中共400例抡四。該數(shù)據(jù)是從多個公開數(shù)據(jù)集中獲取。參賽團隊需要利用有正常/異常兩類標(biāo)簽的訓(xùn)練集數(shù)據(jù)設(shè)計和實現(xiàn)算法仗谆,并在沒有標(biāo)簽的測試集上做出預(yù)測指巡。
該心電數(shù)據(jù)的采樣率為500 Hz。為了方便參賽團隊用不同編程語言都能讀取數(shù)據(jù)隶垮,所有心電數(shù)據(jù)的存儲格式為MAT格式藻雪。該文件中存儲了12個導(dǎo)聯(lián)的電壓信號。訓(xùn)練數(shù)據(jù)對應(yīng)的標(biāo)簽存儲在txt文件中狸吞,其中0代表正常勉耀,1代表異常指煎。
賽題分析
簡單分析一下,初賽的數(shù)據(jù)集共有1000個樣本便斥,其中訓(xùn)練集中包含600例至壤,測試集中共400例。其中訓(xùn)練集中包含600例是具有l(wèi)abel的枢纠,可以用于我們訓(xùn)練模型像街;測試集中共400例沒有標(biāo)簽,需要我們使用訓(xùn)練好的模型進行預(yù)測京郑。
賽題就是一個二分類預(yù)測問題宅广,解題思路應(yīng)該包括以下內(nèi)容:
- 數(shù)據(jù)讀取與處理
- 網(wǎng)絡(luò)模型搭建
- 模型的訓(xùn)練
- 模型應(yīng)用與提交預(yù)測結(jié)果
實戰(zhàn)應(yīng)用
經(jīng)過對賽題的分析,我們把任務(wù)分成四個小任務(wù)些举,首先第一步是:
1.數(shù)據(jù)讀取與處理
該心電數(shù)據(jù)的采樣率為500 Hz跟狱。為了方便參賽團隊用不同編程語言都能讀取數(shù)據(jù),所有心電數(shù)據(jù)的存儲格式為MAT格式户魏。該文件中存儲了12個導(dǎo)聯(lián)的電壓信號驶臊。訓(xùn)練數(shù)據(jù)對應(yīng)的標(biāo)簽存儲在txt文件中,其中0代表正常叼丑,1代表異常关翎。
我們由上述描述可以得知,
- 我們的數(shù)據(jù)保存在MAT格式文件中(這決定了后面我們要如何讀取數(shù)據(jù))
- 采樣率為500 Hz(這個信息并沒有怎么用到鸠信,大家可以簡單了解一下纵寝,就是1秒采集500個點,由后面我們得知每個數(shù)據(jù)都是5000個點星立,也就是10秒的心電圖片)
- 12個導(dǎo)聯(lián)的電壓信號(這個是指采用12種導(dǎo)聯(lián)方式爽茴,大家可以簡單理解為用12個體溫計量體溫,從而得到更加準(zhǔn)確的信息绰垂,下圖為導(dǎo)聯(lián)方式簡單介紹室奏,大家了解下即可。要注意的是劲装,既然提供了12種導(dǎo)聯(lián)胧沫,我們應(yīng)該全部都用到,雖然我們僅使用一種導(dǎo)聯(lián)方式也可以進行訓(xùn)練與預(yù)測占业,但是經(jīng)驗告訴我們绒怨,采取多個特征會取得更優(yōu)效果)
數(shù)據(jù)處理函數(shù)定義:
import keras
from scipy.io import loadmat
import matplotlib.pyplot as plt
import glob
import numpy as np
import pandas as pd
import math
import os
from keras.layers import *
from keras.models import *
from keras.objectives import *
BASE_DIR = "preliminary/TRAIN/"
#進行歸一化
def normalize(v):
return (v - v.mean(axis=1).reshape((v.shape[0],1))) / (v.max(axis=1).reshape((v.shape[0],1)) + 2e-12)
#loadmat打開文件
def get_feature(wav_file,Lens = 12,BASE_DIR=BASE_DIR):
mat = loadmat(BASE_DIR+wav_file)
dat = mat["data"]
feature = dat[0:12]
return(normalize(feature).transpose())
#把標(biāo)簽轉(zhuǎn)成oneHot形式
def convert2oneHot(index,Lens):
hot = np.zeros((Lens,))
hot[index] = 1
return(hot)
TXT_DIR = "preliminary/reference.txt"
MANIFEST_DIR = "preliminary/reference.csv"
讀取一條數(shù)據(jù)進行顯示
if __name__ == "__main__":
dat1 = get_feature("preliminary/TRAIN/TRAIN101.mat")
print(dat1.shape)
#one data shape is (12, 5000)
plt.plot(dat1[:,0])
plt.show()
我們由上述信息可以看出每種導(dǎo)聯(lián)都是由5000個點組成的列表,12種導(dǎo)聯(lián)方式使每個樣本都是12*5000的矩陣纺酸,類似于一張分辨率為12x5000的照片窖逗。
我們需要處理的就是把每個讀取出來,歸一化一下餐蔬,送入網(wǎng)絡(luò)進行訓(xùn)練可以了碎紊。
標(biāo)簽處理方式
def create_csv(TXT_DIR=TXT_DIR):
lists = pd.read_csv(TXT_DIR,sep=r"\t",header=None)
lists = lists.sample(frac=1)
lists.to_csv(MANIFEST_DIR,index=None)
print("Finish save csv")
我這里是采用從reference.txt讀取,然后打亂保存到reference.csv中樊诺,注意一定要進行數(shù)據(jù)打亂操作仗考,不然訓(xùn)練效果很差。因為原始數(shù)據(jù)前面便簽全部是1词爬,后面全部是0
數(shù)據(jù)迭代方式
Batch_size = 20
def xs_gen(path=MANIFEST_DIR,batch_size = Batch_size,train=True):
img_list = pd.read_csv(path)
if train :
img_list = np.array(img_list)[:500]
print("Found %s train items."%len(img_list))
print("list 1 is",img_list[0])
steps = math.ceil(len(img_list) / batch_size) # 確定每輪有多少個batch
else:
img_list = np.array(img_list)[500:]
print("Found %s test items."%len(img_list))
print("list 1 is",img_list[0])
steps = math.ceil(len(img_list) / batch_size) # 確定每輪有多少個batch
while True:
for i in range(steps):
batch_list = img_list[i * batch_size : i * batch_size + batch_size]
np.random.shuffle(batch_list)
batch_x = np.array([get_feature(file) for file in batch_list[:,0]])
batch_y = np.array([convert2oneHot(label,2) for label in batch_list[:,1]])
yield batch_x, batch_y
數(shù)據(jù)讀取的方式我采用的是生成器的方式秃嗜,這樣可以按batch讀取,加快訓(xùn)練速度顿膨,大家也可以采用一下全部讀取锅锨,看個人的習(xí)慣了。關(guān)于生成器恋沃,可以參看我的這個博文必搞。
2.網(wǎng)絡(luò)模型搭建
數(shù)據(jù)我們處理好了,后面就是模型的搭建了囊咏,我使用keras搭建的恕洲,操作簡單便捷,tf梅割,pytorch霜第,sklearn大家可以按照自己喜好來。
網(wǎng)絡(luò)模型可以選擇CNN户辞,RNN泌类,Attention結(jié)構(gòu),或者多模型的融合底燎,拋磚引玉刃榨,此Baseline采用的一維CNN方式,一維CNN學(xué)習(xí)地址
模型搭建
TIME_PERIODS = 5000
num_sensors = 12
def build_model(input_shape=(TIME_PERIODS,num_sensors),num_classes=2):
model = Sequential()
model.add(Conv1D(16, 16,strides=2, activation='relu',input_shape=input_shape))
model.add(Conv1D(16, 16,strides=2, activation='relu',padding="same"))
model.add(MaxPooling1D(2))
model.add(Conv1D(64, 8,strides=2, activation='relu',padding="same"))
model.add(Conv1D(64, 8,strides=2, activation='relu',padding="same"))
model.add(MaxPooling1D(2))
model.add(Conv1D(128, 4,strides=2, activation='relu',padding="same"))
model.add(Conv1D(128, 4,strides=2, activation='relu',padding="same"))
model.add(MaxPooling1D(2))
model.add(Conv1D(256, 2,strides=1, activation='relu',padding="same"))
model.add(Conv1D(256, 2,strides=1, activation='relu',padding="same"))
model.add(MaxPooling1D(2))
model.add(GlobalAveragePooling1D())
model.add(Dropout(0.3))
model.add(Dense(num_classes, activation='softmax'))
return(model)
用model.summary()輸出的網(wǎng)絡(luò)模型為
________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
reshape_1 (Reshape) (None, 5000, 12) 0
_________________________________________________________________
conv1d_1 (Conv1D) (None, 2493, 16) 3088
_________________________________________________________________
conv1d_2 (Conv1D) (None, 1247, 16) 4112
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 623, 16) 0
_________________________________________________________________
conv1d_3 (Conv1D) (None, 312, 64) 8256
_________________________________________________________________
conv1d_4 (Conv1D) (None, 156, 64) 32832
_________________________________________________________________
max_pooling1d_2 (MaxPooling1 (None, 78, 64) 0
_________________________________________________________________
conv1d_5 (Conv1D) (None, 39, 128) 32896
_________________________________________________________________
conv1d_6 (Conv1D) (None, 20, 128) 65664
_________________________________________________________________
max_pooling1d_3 (MaxPooling1 (None, 10, 128) 0
_________________________________________________________________
conv1d_7 (Conv1D) (None, 10, 256) 65792
_________________________________________________________________
conv1d_8 (Conv1D) (None, 10, 256) 131328
_________________________________________________________________
max_pooling1d_4 (MaxPooling1 (None, 5, 256) 0
_________________________________________________________________
global_average_pooling1d_1 ( (None, 256) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 256) 0
_________________________________________________________________
dense_1 (Dense) (None, 2) 514
=================================================================
Total params: 344,482
Trainable params: 344,482
Non-trainable params: 0
_________________________________________________________________
訓(xùn)練參數(shù)比較少书蚪,大家可以根據(jù)自己想法更改喇澡。
3.網(wǎng)絡(luò)模型訓(xùn)練
模型訓(xùn)練
if __name__ == "__main__":
"""dat1 = get_feature("TRAIN101.mat")
print("one data shape is",dat1.shape)
#one data shape is (12, 5000)
plt.plot(dat1[0])
plt.show()"""
if (os.path.exists(MANIFEST_DIR)==False):
create_csv()
train_iter = xs_gen(train=True)
test_iter = xs_gen(train=False)
model = build_model()
print(model.summary())
ckpt = keras.callbacks.ModelCheckpoint(
filepath='best_model.{epoch:02d}-{val_acc:.2f}.h5',
monitor='val_acc', save_best_only=True,verbose=1)
model.compile(loss='categorical_crossentropy',
optimizer='adam', metrics=['accuracy'])
model.fit_generator(
generator=train_iter,
steps_per_epoch=500//Batch_size,
epochs=20,
initial_epoch=0,
validation_data = test_iter,
nb_val_samples = 100//Batch_size,
callbacks=[ckpt],
)
訓(xùn)練過程輸出(最優(yōu)結(jié)果:loss: 0.0565 - acc: 0.9820 - val_loss: 0.8307 - val_acc: 0.8800)
Epoch 10/20
25/25 [==============================] - 1s 37ms/step - loss: 0.2329 - acc: 0.9040 - val_loss: 0.4041 - val_acc: 0.8700
Epoch 00010: val_acc improved from 0.85000 to 0.87000, saving model to best_model.10-0.87.h5
Epoch 11/20
25/25 [==============================] - 1s 38ms/step - loss: 0.1633 - acc: 0.9380 - val_loss: 0.5277 - val_acc: 0.8300
Epoch 00011: val_acc did not improve from 0.87000
Epoch 12/20
25/25 [==============================] - 1s 40ms/step - loss: 0.1394 - acc: 0.9500 - val_loss: 0.4916 - val_acc: 0.7400
Epoch 00012: val_acc did not improve from 0.87000
Epoch 13/20
25/25 [==============================] - 1s 38ms/step - loss: 0.1746 - acc: 0.9220 - val_loss: 0.5208 - val_acc: 0.8100
Epoch 00013: val_acc did not improve from 0.87000
Epoch 14/20
25/25 [==============================] - 1s 38ms/step - loss: 0.1009 - acc: 0.9720 - val_loss: 0.5513 - val_acc: 0.8000
Epoch 00014: val_acc did not improve from 0.87000
Epoch 15/20
25/25 [==============================] - 1s 38ms/step - loss: 0.0565 - acc: 0.9820 - val_loss: 0.8307 - val_acc: 0.8800
Epoch 00015: val_acc improved from 0.87000 to 0.88000, saving model to best_model.15-0.88.h5
Epoch 16/20
25/25 [==============================] - 1s 38ms/step - loss: 0.0261 - acc: 0.9920 - val_loss: 0.6443 - val_acc: 0.8400
Epoch 00016: val_acc did not improve from 0.88000
Epoch 17/20
25/25 [==============================] - 1s 38ms/step - loss: 0.0178 - acc: 0.9960 - val_loss: 0.7773 - val_acc: 0.8700
Epoch 00017: val_acc did not improve from 0.88000
Epoch 18/20
25/25 [==============================] - 1s 38ms/step - loss: 0.0082 - acc: 0.9980 - val_loss: 0.8875 - val_acc: 0.8600
Epoch 00018: val_acc did not improve from 0.88000
Epoch 19/20
25/25 [==============================] - 1s 37ms/step - loss: 0.0045 - acc: 1.0000 - val_loss: 1.0057 - val_acc: 0.8600
Epoch 00019: val_acc did not improve from 0.88000
Epoch 20/20
25/25 [==============================] - 1s 37ms/step - loss: 0.0012 - acc: 1.0000 - val_loss: 1.1088 - val_acc: 0.8600
Epoch 00020: val_acc did not improve from 0.88000
4.模型應(yīng)用預(yù)測結(jié)果
預(yù)測數(shù)據(jù)
if __name__ == "__main__":
"""dat1 = get_feature("TRAIN101.mat")
print("one data shape is",dat1.shape)
#one data shape is (12, 5000)
plt.plot(dat1[0])
plt.show()"""
"""if (os.path.exists(MANIFEST_DIR)==False):
create_csv()
train_iter = xs_gen(train=True)
test_iter = xs_gen(train=False)
model = build_model()
print(model.summary())
ckpt = keras.callbacks.ModelCheckpoint(
filepath='best_model.{epoch:02d}-{val_acc:.2f}.h5',
monitor='val_acc', save_best_only=True,verbose=1)
model.compile(loss='categorical_crossentropy',
optimizer='adam', metrics=['accuracy'])
model.fit_generator(
generator=train_iter,
steps_per_epoch=500//Batch_size,
epochs=20,
initial_epoch=0,
validation_data = test_iter,
nb_val_samples = 100//Batch_size,
callbacks=[ckpt],
)"""
PRE_DIR = "sample_codes/answers.txt"
model = load_model("best_model.15-0.88.h5")
pre_lists = pd.read_csv(PRE_DIR,sep=r" ",header=None)
print(pre_lists.head())
pre_datas = np.array([get_feature(item,BASE_DIR="preliminary/TEST/") for item in pre_lists[0]])
pre_result = model.predict_classes(pre_datas)#0-1概率預(yù)測
print(pre_result.shape)
pre_lists[1] = pre_result
pre_lists.to_csv("sample_codes/answers1.txt",index=None,header=None)
print("predict finish")
下面是前十條預(yù)測結(jié)果:
TEST394,0
TEST313,1
TEST484,0
TEST288,0
TEST261,1
TEST310,0
TEST286,1
TEST367,1
TEST149,1
TEST160,1
大家需要注意一下,我預(yù)測的方式和官方不同殊校,需要大家自己根據(jù)賽題要求來進行預(yù)測提交晴玖。
展望
此Baseline采用最簡單的一維卷積達(dá)到了88%測試準(zhǔn)確率(可能會因為隨機初始化值上下波動),大家也可以多嘗試GRU为流,Attention呕屎,和Resnet等結(jié)果,測試準(zhǔn)確率準(zhǔn)確率會突破95+敬察。
能力有限秀睛,寫的不好的地方歡迎大家批評指正。