你喜歡什么樣的音樂?目前仪壮,很多公司實(shí)現(xiàn)了對(duì)音樂的分類绽榛,要么是為了向客戶提供推薦(如Spotify湿酸、SoundCloud),要么只是作為一種產(chǎn)品(如Shazam)灭美。對(duì)音樂進(jìn)行分類推溃,首先要確定音樂類型。事實(shí)證明届腐,用機(jī)器學(xué)習(xí)技術(shù)從大量數(shù)據(jù)中找出音樂的各種潮流和類型是非常成功的铁坎。音樂分析亦然。
本文我們將學(xué)習(xí)如何用Python進(jìn)行音頻/音樂信號(hào)分析以及之后如何用該技能對(duì)不同類型的音樂片段進(jìn)行分類犁苏。
在這里給大家推薦一個(gè)python系統(tǒng)學(xué)習(xí)q群:250933691有免費(fèi)開發(fā)工具以及初學(xué)資料硬萍,(數(shù)據(jù)分析,爬蟲围详,機(jī)器學(xué)習(xí)朴乖,神經(jīng)網(wǎng)絡(luò))每天有老師給大家免費(fèi)授課,歡迎一起交流學(xué)習(xí)
用Python處理音頻
聲音以音頻信號(hào)的形式表示助赞,音頻信號(hào)具有頻率买羞、帶寬、分貝等參數(shù)雹食,音頻信號(hào)一般可表示為振幅和時(shí)間的函數(shù)哩都。
這些聲音有多種格式,因此計(jì)算機(jī)可以對(duì)其進(jìn)行讀取和分析婉徘。例如:
??mp3 格式
??WMA (Windows Media Audio) 格式
??wav (Waveform Audio File) 格式
音頻庫
Python有一些很好用的音頻處理庫漠嵌,比如Librosa和PyAudio咐汞。另外,還有一些基本的音頻功能的內(nèi)置模塊儒鹿。
我們將主要使用兩個(gè)音頻庫進(jìn)行音頻采集和回放化撕,如下:
1.Librosa
Librosa是一個(gè)Python模塊,通常用于分析音頻信號(hào)约炎,但更適合音樂信號(hào)分析植阴。它包括構(gòu)建一個(gè)音樂信息檢索(MIR)系統(tǒng)的具體細(xì)節(jié),目前圾浅,Librosa已充分實(shí)現(xiàn)文檔化掠手,并具有許多相關(guān)的示例和教程。
安裝
pip install librosa
or
conda install -c conda-forge librosa
可以安裝附帶很多音頻解碼器的ffmpeg(一個(gè)開源免費(fèi)跨平臺(tái)的視頻和音頻流方案)以提高音頻解碼功率狸捕。
2.IPython.display.Audio
IPython.display.Audio可以讓用戶直接在Jupyter notebook中播放音頻喷鸽。
音頻包加載
import librosa
audio_path = '../T08-violin.wav'
x , sr = librosa.load(audio_path)
print(type(x), type(sr))
<class 'numpy.ndarray'> <class 'int'>
print(x.shape, sr)
(396688,) 22050
以上步驟的返回值為一段音頻的時(shí)間序列,其默認(rèn)采樣頻率(sr)為22KHZ mono灸拍。我們可將其改為:
librosa.load(audio_path, sr=44100)
可重新采樣為44.1KHZ做祝,
librosa.load(audio_path, sr=None)
或者不重新采樣。
采樣頻率指音頻每秒鐘的采樣樣本數(shù)鸡岗,以Hz或kHz表示混槐。
音頻播放
用Ipython.display.Audio 播放音頻。
import IPython.display as ipd
ipd.Audio(audio_path)
以上步驟的返回值為Jupyter notebook的一個(gè)音頻插件轩性。如下:
這里的插件不起作用声登,不過放到你的notebooks上就可以了。
以下音頻也可用mp3格式或WMA格式聽揣苏。
可視化音頻(Visualizing Audio)
波形音頻 (Waveform)
我們可以用librosa.display.waveplot來繪制音頻悯嗓。
%matplotlib inline
import matplotlib.pyplot as plt
import librosa.display
plt.figure(figsize=(14, 5))
librosa.display.waveplot(x, sr=sr)
上圖顯示了該段波形音頻的振幅包絡(luò)線(amplitude envelope)。
聲譜圖(spectrogram)
聲譜圖(spectrogram)是聲音或其他信號(hào)的頻率隨時(shí)間變化時(shí)的頻譜(spectrum)的一種直觀表示舒岸。聲譜圖有時(shí)也稱sonographs,voiceprints,或者voicegrams。當(dāng)數(shù)據(jù)以三維圖形表示時(shí)芦圾,可稱其為瀑布圖(waterfalls)蛾派。在二維數(shù)組中,第一個(gè)軸是頻率个少,第二個(gè)軸是時(shí)間洪乍。
我們可以用librosa.display.specshow 來展示聲譜圖。
X = librosa.stft(x)
Xdb = librosa.amplitude_to_db(abs(X))
plt.figure(figsize=(14, 5))
librosa.display.specshow(Xdb, sr=sr, x_axis='time', y_axis='hz')
plt.colorbar()
縱軸顯示的是頻率(從0到10kHz)夜焦,橫軸顯示的是音頻的時(shí)間壳澳。因?yàn)樗锌梢姷牟▌?dòng)都發(fā)生在頻譜的底部,故這里將頻率軸轉(zhuǎn)換成對(duì)數(shù)軸茫经。
librosa.display.specshow(Xdb, sr=sr, x_axis='time', y_axis='log')
plt.colorbar()
音頻編寫
用librosa.output.write_wav 將NumPy數(shù)組保存到WAV文件中巷波。
librosa.output.write_wav('example.wav', x, sr)
創(chuàng)建音頻信號(hào)
現(xiàn)在讓我們創(chuàng)建一個(gè)220HZ的音頻信號(hào)萎津。由于音頻信號(hào)是一個(gè)numpy數(shù)組,所以創(chuàng)建后需將其轉(zhuǎn)換為音頻函數(shù)抹镊。
import numpy as np
sr = 22050 # sample rate
T = 5.0? ? # seconds
t = np.linspace(0, T, int(T*sr), endpoint=False) # time variable
x = 0.5*np.sin(2*np.pi*220*t)# pure sine wave at 220 Hz
Playing the audio
ipd.Audio(x, rate=sr) # load a NumPy array
Saving the audio
librosa.output.write_wav('tone_220.wav', x, sr)
然后锉屈,這就是你創(chuàng)建的第一個(gè)音頻信號(hào)。
特征提取
每一個(gè)音頻信號(hào)都有很多特征垮耳。然而颈渊,我們必須提取出與我們?cè)噲D解決的問題相關(guān)的特征。提取特征以用于分析的過程稱為特征提取终佛。接下來我們將詳細(xì)研究其中幾個(gè)特征俊嗽。
過零率(Zero Crossing Rate)
過零率(zero crossing rate)是一個(gè)信號(hào)符號(hào)變化的比率挑童,即犁河,在每幀中届巩,語音信號(hào)從正變?yōu)樨?fù)或從負(fù)變?yōu)檎拇螖?shù)营密。 這個(gè)特征已在語音識(shí)別和音樂信息檢索領(lǐng)域得到廣泛使用玛荞,通常對(duì)類似金屬契讲、搖滾等高沖擊性的聲音的具有更高的價(jià)值带膀。
現(xiàn)在我們來計(jì)算示例音頻片段的過零率:
# Load the signal
x, sr = librosa.load('../T08-violin.wav')
#Plot the signal:
plt.figure(figsize=(14, 5))
librosa.display.waveplot(x, sr=sr)
# Zooming in
n0 = 9000
n1 = 9100
plt.figure(figsize=(14, 5))
plt.plot(x[n0:n1])
plt.grid()
上圖似乎有6個(gè)過零點(diǎn)撩银,用librosa來驗(yàn)證下該結(jié)果鹃共。
zero_crossings = librosa.zero_crossings(x[n0:n1], pad=False)
print(sum(zero_crossings))
頻譜質(zhì)心(Spectral Centroid)
頻譜質(zhì)心指示聲音的“質(zhì)心”位于何處鬼佣,并按照聲音的頻率的加權(quán)平均值來加以計(jì)算。 假設(shè)現(xiàn)有兩首歌曲霜浴,一首是藍(lán)調(diào)歌曲晶衷,另一首是金屬歌曲。現(xiàn)在阴孟,與同等長度的藍(lán)調(diào)歌曲相比晌纫,金屬歌曲在接近尾聲位置的頻率更高。所以藍(lán)調(diào)歌曲的頻譜質(zhì)心會(huì)在頻譜偏中間的位置永丝,而金屬歌曲的頻譜質(zhì)心則靠近頻譜末端锹漱。
用librosa.feature.spectral_centroid 計(jì)算出每一幀音頻信號(hào)的頻譜質(zhì)心。
spectral_centroids = librosa.feature.spectral_centroid(x, sr=sr)[0]
spectral_centroids.shape
(775,)
# Computing the time variable for visualization
frames = range(len(spectral_centroids))
t = librosa.frames_to_time(frames)
# Normalising the spectral centroid for visualisation
def normalize(x, axis=0):
? ? return sklearn.preprocessing.minmax_scale(x, axis=axis)
#Plotting the Spectral Centroid along the waveform
librosa.display.waveplot(x, sr=sr, alpha=0.4)
plt.plot(t, normalize(spectral_centroids), color='r')
頻譜質(zhì)心在接近末端處有上升慕嚷。
譜滾降(Spectral Rolloff)
譜滾降(Spectral Rolloff)是對(duì)信號(hào)形狀的測(cè)量哥牍,表示的是在譜能量的特定百分比(如85%)時(shí)的頻率。
librosa.feature.spectral_rolloff 計(jì)算出每一幀信號(hào)的滾降頻率喝检。
spectral_rolloff = librosa.feature.spectral_rolloff(x+0.01, sr=sr)[0]
librosa.display.waveplot(x, sr=sr, alpha=0.4)
plt.plot(t, normalize(spectral_rolloff), color='r')
梅爾頻率倒譜系數(shù)(Mel-Frequency Cepstral Coefficients)
信號(hào)的梅爾頻率倒譜系數(shù)(MFCC)是一個(gè)通常由10-20個(gè)特征構(gòu)成的集合嗅辣,可簡明地描述頻譜包絡(luò)的總體形狀,對(duì)語音特征進(jìn)行建模挠说。
這次我們使用一個(gè)簡單的循環(huán)波澡谭。
x, fs = librosa.load('../simple_loop.wav')
librosa.display.waveplot(x, sr=sr)
用librosa.feature.mfcc 計(jì)算出音頻信號(hào)的梅爾頻率倒譜系數(shù):
mfccs = librosa.feature.mfcc(x, sr=fs)
print mfccs.shape
(20, 97)
#Displaying? the MFCCs:
librosa.display.specshow(mfccs, sr=sr, x_axis='time')
計(jì)算出該段超過97幀的音頻的梅爾頻率倒譜系數(shù)為20。
我們也可以給特征標(biāo)上刻度损俭,使其每個(gè)系數(shù)有相應(yīng)的零均值和單位方差蛙奖。
import sklearn
mfccs = sklearn.preprocessing.scale(mfccs, axis=1)
print(mfccs.mean(axis=1))
print(mfccs.var(axis=1))
librosa.display.specshow(mfccs, sr=sr, x_axis='time')
Chroma Frequencies
色度特征是對(duì)音樂音頻的一種有趣生動(dòng)的表示潘酗,可將整個(gè)頻譜投射到代表“八度”(在音樂中,相鄰的音組中相同音名的兩個(gè)音外永,包括變化音級(jí)崎脉,稱之為八度。)上12個(gè)不同的半音(或色度)的12進(jìn)制上伯顶。色度向量(chroma vector )(維基百科)(FMP囚灼,p.123)是一個(gè)通常包含12個(gè)元素特征的向量,表示信號(hào)中每個(gè)音級(jí){C, C#, D, D#, E, …, B}中的能量祭衩。
用librosa.feature.chroma_stft 計(jì)算Chroma Frequencies灶体。
# Loadign the file
x, sr = librosa.load('../simple_piano.wav')
hop_length = 512
chromagram = librosa.feature.chroma_stft(x, sr=sr, hop_length=hop_length)
plt.figure(figsize=(15, 5))
librosa.display.specshow(chromagram, x_axis='time', y_axis='chroma', hop_length=hop_length, cmap='coolwarm')
案例分析:對(duì)歌曲類型進(jìn)行分類
以上我們對(duì)聲學(xué)(聽覺)信號(hào)及其特征和特征提取過程進(jìn)行了概述,現(xiàn)在讓我們用剛習(xí)得的技能來解決機(jī)器學(xué)習(xí)問題掐暮。
目標(biāo)
本節(jié)蝎抽,我們將嘗試創(chuàng)建一個(gè)分類器將歌曲歸為不同的類型。假設(shè)這樣一個(gè)場(chǎng)景:出于某種原因路克,我們?cè)谟脖P上找到一堆隨機(jī)命名的MP3文件樟结,且文件里有音樂。我們的任務(wù)是根據(jù)音樂類型將它們分到不同的文件夾中精算,如爵士瓢宦、古典音樂、鄉(xiāng)村音樂灰羽、流行音樂驮履、搖滾樂和金屬樂。
數(shù)據(jù)集
我們將用最常用的的GITZAN數(shù)據(jù)集進(jìn)行案例研究廉嚼。G. Tzanetakis和P. Cook在2002年IEEETransactions on audio and Speech Processing中發(fā)表的著名論文: Musical genre classification of audio signals (音頻信號(hào)的音樂類型分類)中曾用到該數(shù)據(jù)集玫镐。
該數(shù)據(jù)集每30秒包含1000條音軌,共包含10個(gè)音樂類型怠噪,即布魯斯恐似、古典、鄉(xiāng)村傍念、迪斯科矫夷、嘻哈、爵士捂寿、雷鬼口四、搖滾孵运、金屬和流行音樂秦陋。每種類型包含100段聲頻。
數(shù)據(jù)處理
在訓(xùn)練分類模型之前治笨,我們須將原始數(shù)據(jù)從音頻樣本轉(zhuǎn)換成更有意義的表示形式驳概。需將音頻片段從.au格式轉(zhuǎn)換為能與python的 wave模塊兼容的.wav格式赤嚼,以讀取音頻文件。不過我常用的是開源SoX模塊顺又。
sox input.au output.wav
分類
特征提雀洹(Feature Extraction)
我們接下來需要從音頻文件中提取出有意義的特征。為了對(duì)音頻片段進(jìn)行分類稚照,這里將選擇5個(gè)特征蹂空,即梅爾頻率倒譜系數(shù)(Mel-Frequency Cepstral Coefficients),頻譜質(zhì)心 (Spectral Centroid)果录,過零率(Zero Crossing Rate)上枕, Chroma Frequencies,譜滾降(Spectral Roll-off)弱恒。然后將所有特征附加到.csv文件中辨萍,以便使用分類算法。
分類(Classification)
提取出特征后返弹,用現(xiàn)有的分類算法將歌曲分為不同的類型锈玉。你可以直接用聲譜圖進(jìn)行分類,也可以在提取特征后使用分類模型义起。
無論采用哪種方式拉背,都要在模型上進(jìn)行大量的實(shí)驗(yàn)。你可以進(jìn)行試驗(yàn)和改進(jìn)結(jié)果并扇。建議試試CNN模型去团,它(在聲譜圖上)的精確度更高。
音樂類型分類的筆記
導(dǎo)入庫
In [0]:
# feature extractoring and preprocessing data
import librosa
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import os
from PIL import Image
import pathlib
import csv
# Preprocessing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
#Keras
import keras
import warnings
warnings.filterwarnings('ignore')
音樂和特征提取
數(shù)據(jù)集
我們用GTZAN genre collection 數(shù)據(jù)集進(jìn)行分類穷蛹。
數(shù)據(jù)集包含10中音樂類型土陪。如下:
??布魯斯
??古典
??鄉(xiāng)村
??迪斯科
??嘻哈
??爵士
??金屬樂
??流行音樂
??雷鬼
??搖滾
每一種音樂類型包含100首歌曲。共計(jì)1000首歌曲肴熏。
音頻的聲譜圖提取
In [0]:
cmap = plt.get_cmap('inferno')
plt.figure(figsize=(10,10))
genres = 'blues classical country disco hiphop jazz metal pop reggae rock'.split()
for g in genres:
? ? pathlib.Path(f'img_data/{g}').mkdir(parents=True, exist_ok=True)? ? ?
? ? for filename in os.listdir(f'./MIR/genres/{g}'):
? ? ? ? songname = f'./MIR/genres/{g}/{filename}'
? ? ? ? y, sr = librosa.load(songname, mono=True, duration=5)
? ? ? ? plt.specgram(y, NFFT=2048, Fs=2, Fc=0, noverlap=128, cmap=cmap, sides='default', mode='default', scale='dB');
? ? ? ? plt.axis('off');
? ? ? ? plt.savefig(f'img_data/{g}/{filename[:-3].replace(".", "")}.png')
? ? ? ? plt.clf()
將所有的音頻文件轉(zhuǎn)換成相應(yīng)的聲譜圖鬼雀,以方便提取特征。
聲譜圖特征提取
我們將提取以下特征:
??梅爾頻率倒譜系數(shù)(Mel-frequency cepstral coefficients (MFCC))(20 個(gè))
??頻譜質(zhì)心(Spectral Centroid)
??過零率(Zero Crossing Rate)
??Chroma Frequencies
??譜滾降(Spectral Roll-off)
In [0]:
header = 'filename chroma_stft rmse spectral_centroid spectral_bandwidth rolloff zero_crossing_rate'
for i in range(1, 21):
? ? header += f' mfcc{i}'
header += ' label'
header = header.split()
將數(shù)據(jù)寫入csv 文件
In [0]:
file = open('data.csv', 'w', newline='')
with file:
? ? writer = csv.writer(file)
? ? writer.writerow(header)
genres = 'blues classical country disco hiphop jazz metal pop reggae rock'.split()
for g in genres:
? ? for filename in os.listdir(f'./MIR/genres/{g}'):
? ? ? ? songname = f'./MIR/genres/{g}/{filename}'
? ? ? ? y, sr = librosa.load(songname, mono=True, duration=30)
? ? ? ? chroma_stft = librosa.feature.chroma_stft(y=y, sr=sr)
? ? ? ? spec_cent = librosa.feature.spectral_centroid(y=y, sr=sr)
? ? ? ? spec_bw = librosa.feature.spectral_bandwidth(y=y, sr=sr)
? ? ? ? rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)
? ? ? ? zcr = librosa.feature.zero_crossing_rate(y)
? ? ? ? mfcc = librosa.feature.mfcc(y=y, sr=sr)
? ? ? ? to_append = f'{filename} {np.mean(chroma_stft)} {np.mean(rmse)} {np.mean(spec_cent)} {np.mean(spec_bw)} {np.mean(rolloff)} {np.mean(zcr)}'? ??
? ? ? ? for e in mfcc:
? ? ? ? ? ? to_append += f' {np.mean(e)}'
? ? ? ? to_append += f' {g}'
? ? ? ? file = open('data.csv', 'a', newline='')
? ? ? ? with file:
? ? ? ? ? ? writer = csv.writer(file)
? ? ? ? ? ? writer.writerow(to_append.split())
以上數(shù)據(jù)已被提取并寫入data.csv文件蛙吏。
用Pandas進(jìn)行數(shù)據(jù)分析
In [6]:
data = pd.read_csv('data.csv')
data.head()
Out[6]:
5行× 28列
In [7]:
data.shape
Out[7]:
(1000, 28)
In [0]:
# Dropping unneccesary columns
data = data.drop(['filename'],axis=1)
對(duì)標(biāo)簽進(jìn)行編碼
In [0]:
genre_list = data.iloc[:, -1]
encoder = LabelEncoder()
y = encoder.fit_transform(genre_list
給特征欄標(biāo)上刻度
In [0]:
scaler = StandardScaler()
X = scaler.fit_transform(np.array(data.iloc[:, :-1], dtype = float))
將數(shù)據(jù)分為訓(xùn)練集和測(cè)試集
In [0]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
In [12]:
len(y_train)
Out[12]:
800
In [13]:
len(y_test)
Out[13]:
200
In [14]:
X_train[10]
Out[14]:
array([-0.9149113 ,? 0.18294103, -1.10587131, -1.3875197 , -1.14640873,
? ? ? ?-0.97232926, -0.29174214,? 1.20078936, -0.68458101, -0.55849017,
? ? ? ?-1.27056582, -0.88176926, -0.74844069, -0.40970382,? 0.49685952,
? ? ? ?-1.12666045,? 0.59501437, -0.39783853,? 0.29327275, -0.72916871,
? ? ? ? 0.63015786, -0.91149976,? 0.7743942 , -0.64790051,? 0.42229852,
? ? ? ?-1.01449461])
使用Keras進(jìn)行分類?
創(chuàng)建自己的網(wǎng)絡(luò)
In [0]:
from keras import models
from keras import layers
model = models.Sequential()
model.add(layers.Dense(256, activation='relu', input_shape=(X_train.shape[1],)))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
In [0]:
model.compile(optimizer='adam',
? ? ? ? ? ? ? loss='sparse_categorical_crossentropy',
? ? ? ? ? ? ? metrics=['accuracy'])
In [19]:
history = model.fit(X_train,
? ? ? ? ? ? ? ? ? ? y_train,
? ? ? ? ? ? ? ? ? ? epochs=20,
? ? ? ? ? ? ? ? ? ? batch_size=128)
Epoch 1/20
800/800 [==============================] - 1s 811us/step - loss: 2.1289 - acc: 0.2400
Epoch 2/20
800/800 [==============================] - 0s 39us/step - loss: 1.7940 - acc: 0.4088
Epoch 3/20
800/800 [==============================] - 0s 37us/step - loss: 1.5437 - acc: 0.4450
Epoch 4/20
800/800 [==============================] - 0s 38us/step - loss: 1.3584 - acc: 0.5413
Epoch 5/20
800/800 [==============================] - 0s 38us/step - loss: 1.2220 - acc: 0.5750
Epoch 6/20
800/800 [==============================] - 0s 41us/step - loss: 1.1187 - acc: 0.6288
Epoch 7/20
800/800 [==============================] - 0s 37us/step - loss: 1.0326 - acc: 0.6550
Epoch 8/20
800/800 [==============================] - 0s 44us/step - loss: 0.9631 - acc: 0.6713
Epoch 9/20
800/800 [==============================] - 0s 47us/step - loss: 0.9143 - acc: 0.6913
Epoch 10/20
800/800 [==============================] - 0s 37us/step - loss: 0.8630 - acc: 0.7125
Epoch 11/20
800/800 [==============================] - 0s 36us/step - loss: 0.8095 - acc: 0.7263
Epoch 12/20
800/800 [==============================] - 0s 37us/step - loss: 0.7728 - acc: 0.7700
Epoch 13/20
800/800 [==============================] - 0s 36us/step - loss: 0.7433 - acc: 0.7563
Epoch 14/20
800/800 [==============================] - 0s 45us/step - loss: 0.7066 - acc: 0.7825
Epoch 15/20
800/800 [==============================] - 0s 43us/step - loss: 0.6718 - acc: 0.7787
Epoch 16/20
800/800 [==============================] - 0s 36us/step - loss: 0.6601 - acc: 0.7913
Epoch 17/20
800/800 [==============================] - 0s 36us/step - loss: 0.6242 - acc: 0.7963
Epoch 18/20
800/800 [==============================] - 0s 44us/step - loss: 0.5994 - acc: 0.8038
Epoch 19/20
800/800 [==============================] - 0s 42us/step - loss: 0.5715 - acc: 0.8125
Epoch 20/20
800/800 [==============================] - 0s 39us/step - loss: 0.5437 - acc: 0.8250
In [20]:
test_loss, test_acc = model.evaluate(X_test,y_test)
200/200 [==============================] - 0s 244us/step
In [21]:
print('test_acc: ',test_acc)
test_acc:? 0.68
以上數(shù)據(jù)的精確度不如訓(xùn)練數(shù)據(jù)的精確度高源哩,這說明可能存在“過度擬合”(Overfitting)。
對(duì)所用方法進(jìn)行驗(yàn)證
我們要從訓(xùn)練數(shù)據(jù)中留出200個(gè)樣本作為測(cè)驗(yàn)集:
In [0]:
x_val = X_train[:200]
partial_x_train = X_train[200:]
y_val = y_train[:200]
partial_y_train = y_train[200:]
Now let's train our network for 20 epochs:
In [37]:
model = models.Sequential()
model.add(layers.Dense(512, activation='relu', input_shape=(X_train.shape[1],)))
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam',
? ? ? ? ? ? ? loss='sparse_categorical_crossentropy',
? ? ? ? ? ? ? metrics=['accuracy'])
model.fit(partial_x_train,
? ? ? ? ? partial_y_train,
? ? ? ? ? epochs=30,
? ? ? ? ? batch_size=512,
? ? ? ? ? validation_data=(x_val, y_val))
results = model.evaluate(X_test, y_test)
Train on 600 samples, validate on 200 samples
Epoch 1/30
600/600 [==============================] - 1s 1ms/step - loss: 2.3074 - acc: 0.0950 - val_loss: 2.1857 - val_acc: 0.2850
Epoch 2/30
600/600 [==============================] - 0s 65us/step - loss: 2.1126 - acc: 0.3783 - val_loss: 2.0936 - val_acc: 0.2400
Epoch 3/30
600/600 [==============================] - 0s 59us/step - loss: 1.9535 - acc: 0.3633 - val_loss: 1.9966 - val_acc: 0.2600
Epoch 4/30
600/600 [==============================] - 0s 58us/step - loss: 1.8082 - acc: 0.3833 - val_loss: 1.8713 - val_acc: 0.3250
Epoch 5/30
600/600 [==============================] - 0s 59us/step - loss: 1.6663 - acc: 0.4083 - val_loss: 1.7302 - val_acc: 0.3450
Epoch 6/30
600/600 [==============================] - 0s 52us/step - loss: 1.5329 - acc: 0.4550 - val_loss: 1.6233 - val_acc: 0.3700
Epoch 7/30
600/600 [==============================] - 0s 62us/step - loss: 1.4236 - acc: 0.4850 - val_loss: 1.5402 - val_acc: 0.3950
Epoch 8/30
600/600 [==============================] - 0s 57us/step - loss: 1.3250 - acc: 0.5117 - val_loss: 1.4655 - val_acc: 0.3800
Epoch 9/30
600/600 [==============================] - 0s 52us/step - loss: 1.2338 - acc: 0.5633 - val_loss: 1.3927 - val_acc: 0.4650
Epoch 10/30
600/600 [==============================] - 0s 61us/step - loss: 1.1577 - acc: 0.5983 - val_loss: 1.3338 - val_acc: 0.5500
Epoch 11/30
600/600 [==============================] - 0s 64us/step - loss: 1.0981 - acc: 0.6317 - val_loss: 1.3111 - val_acc: 0.5550
Epoch 12/30
600/600 [==============================] - 0s 52us/step - loss: 1.0529 - acc: 0.6517 - val_loss: 1.2696 - val_acc: 0.5400
Epoch 13/30
600/600 [==============================] - 0s 52us/step - loss: 0.9994 - acc: 0.6567 - val_loss: 1.2480 - val_acc: 0.5400
Epoch 14/30
600/600 [==============================] - 0s 65us/step - loss: 0.9673 - acc: 0.6633 - val_loss: 1.2384 - val_acc: 0.5700
Epoch 15/30
600/600 [==============================] - 0s 58us/step - loss: 0.9286 - acc: 0.6633 - val_loss: 1.1953 - val_acc: 0.5800
Epoch 16/30
600/600 [==============================] - 0s 59us/step - loss: 0.8849 - acc: 0.6783 - val_loss: 1.2000 - val_acc: 0.5550
Epoch 17/30
600/600 [==============================] - 0s 61us/step - loss: 0.8621 - acc: 0.6850 - val_loss: 1.1743 - val_acc: 0.5850
Epoch 18/30
600/600 [==============================] - 0s 61us/step - loss: 0.8195 - acc: 0.7150 - val_loss: 1.1609 - val_acc: 0.5750
Epoch 19/30
600/600 [==============================] - 0s 62us/step - loss: 0.7976 - acc: 0.7283 - val_loss: 1.1238 - val_acc: 0.6150
Epoch 20/30
600/600 [==============================] - 0s 63us/step - loss: 0.7660 - acc: 0.7650 - val_loss: 1.1604 - val_acc: 0.5850
Epoch 21/30
600/600 [==============================] - 0s 65us/step - loss: 0.7465 - acc: 0.7650 - val_loss: 1.1888 - val_acc: 0.5700
Epoch 22/30
600/600 [==============================] - 0s 65us/step - loss: 0.7099 - acc: 0.7517 - val_loss: 1.1563 - val_acc: 0.6050
Epoch 23/30
600/600 [==============================] - 0s 68us/step - loss: 0.6857 - acc: 0.7683 - val_loss: 1.0900 - val_acc: 0.6200
Epoch 24/30
600/600 [==============================] - 0s 67us/step - loss: 0.6597 - acc: 0.7850 - val_loss: 1.0872 - val_acc: 0.6300
Epoch 25/30
600/600 [==============================] - 0s 67us/step - loss: 0.6377 - acc: 0.7967 - val_loss: 1.1148 - val_acc: 0.6200
Epoch 26/30
600/600 [==============================] - 0s 64us/step - loss: 0.6070 - acc: 0.8200 - val_loss: 1.1397 - val_acc: 0.6150
Epoch 27/30
600/600 [==============================] - 0s 66us/step - loss: 0.5991 - acc: 0.8167 - val_loss: 1.1255 - val_acc: 0.6300
Epoch 28/30
600/600 [==============================] - 0s 62us/step - loss: 0.5656 - acc: 0.8333 - val_loss: 1.0955 - val_acc: 0.6350
Epoch 29/30
600/600 [==============================] - 0s 66us/step - loss: 0.5513 - acc: 0.8300 - val_loss: 1.1030 - val_acc: 0.6050
Epoch 30/30
600/600 [==============================] - 0s 56us/step - loss: 0.5498 - acc: 0.8233 - val_loss: 1.0869 - val_acc: 0.6250
200/200 [==============================] - 0s 65us/step
In [38]:
results
Out[38]:
[1.2261371064186095, 0.65]
對(duì)測(cè)試集進(jìn)行預(yù)測(cè)
In [0]:
predictions = model.predict(X_test)
In [26]:
predictions[0].shape
Out[26]:
(10,)
In [27]:
np.sum(predictions[0])
Out[27]:
1.0
In [28]:
np.argmax(predictions[0])
Out[28]:
8
In [0]:
下一步
音樂類型分類是音樂信息檢索的眾多分支之一鸦做。你還可以對(duì)音樂數(shù)據(jù)執(zhí)行其他任務(wù)励烦,如節(jié)拍跟蹤(beat tracking)、音樂生成泼诱、推薦系統(tǒng)坛掠、音軌分離(track separation)、樂器識(shí)別等。音樂分析是一個(gè)既多元化又有趣的領(lǐng)域屉栓。音樂在某種程度上代表了用戶的一個(gè)時(shí)刻舷蒲。在數(shù)據(jù)科學(xué)領(lǐng)域,發(fā)現(xiàn)并描述這些時(shí)刻將會(huì)是一個(gè)有趣的挑戰(zhàn)友多。