機器學習筆記 - 20. EM算法實踐(講師:鄒博)

主要內(nèi)容

2018-12-12 20_25_51-【鄒博_chinahadoop】機器學習升級版VII(七).png

多維高斯混合分布聚類

EM算法的聚類效果或許比K均值聚類好一些。


2018-12-12 20_27_30-【鄒博_chinahadoop】機器學習升級版VII(七).png

如圖罕袋,對于二維數(shù)據(jù)形成概率密度曲線改淑,或者說等值線:


2018-12-12 20_28_09-【鄒博_chinahadoop】機器學習升級版VII(七).png

這個圖也說明,身高一定符合高斯分布浴讯,不一定對朵夏。
下圖表明礼搁,男性符合幾個混合高斯分布惫谤,女性符合幾個混合高斯分布


2018-12-25 19_27_33-【鄒博_chinahadoop】機器學習升級版VII(七).png

問答

問:歸一化的幾種優(yōu)劣之處澎语?
答:比如做min-max宽气,Scalar或標準版,如果數(shù)據(jù)服從均勻分布痛悯,可能做min-max好一些痊银,但是如果數(shù)據(jù)服從高斯分布循头,可能標準化更好一些唠摹。
問:為什么鳶尾花沒有隱變量爆捞?鳶尾花也可能有某個未知的特征決定它的分類奉瘤,任何分布不是都可能有隱變量么勾拉?
答:是的。也許花萼長寬與花瓣長寬并不是鳶尾花最重要特征盗温。只是沒有提取
問:針對Kmeans多特征的情景藕赞,是不是用PCA處理以后,變成3維或2維卖局,然后再用聚類的方式處理斧蜕?
答:如果不需要做算法解釋的話,這么做是合理的砚偶;但是如果需要做算法解釋批销,建議不要用PCA,否則特征無法解釋染坯。
比如如果數(shù)據(jù)是200維的均芽,那么就針對這200維做K-Means聚類。

2018-12-25 19_29_38-【鄒博_chinahadoop】機器學習升級版VII(七).png

高斯分布的公式:
f(x) = 1/((2*π)0.5*σ) * e-(x-μ)2/(2*σ2)
其中:μ是均值单鹿,σ2是方差
如果f(x)是多元的掀宋,則得到多元高斯分布的概率密度函數(shù):f(x) = (2*π)-n/2*(Σ-1)n*e-(1/2)*(x-μ)TT*(x-μ)
此處Σ為協(xié)方差矩陣,首先它是一個nxn的對稱方陣仲锄。
這個Σ矩陣劲妙,在做混合高斯模型的時候,就出問題了儒喊。
比如現(xiàn)在做一個二聚類:
GMM:
N(μ1,Σ1)以及N(μ2,Σ2)
μ1與μ2都是n元的镣奋,而Σ1與Σ2都是矩陣:

  1. 如果矩陣是單位矩陣,如:
    1 0 0 0
    0 1 0 0
    0 0 1 0
    0 0 0 1怀愧,
    則σ 乘以單位矩陣得到:σ · I
    則圖形為球面的侨颈,即圖中的Spherical柱狀圖富雅。
    參數(shù)有1個

  2. 如果矩陣是對角矩陣,如:
    σ12 0 0 0
    0 σ22 0 0
    0 0 σ32 0
    0 0 0 σ42
    則得到diag柱狀圖
    參數(shù)有n個

  3. 如果Σ1 = Σ2肛搬,則形成tied没佑,即相互關(guān)聯(lián)的
    即圖中的tied柱狀圖
    參數(shù)有nxn個,準確的說是nx(n+1)個參數(shù)

  4. Σ1 與 Σ2沒有任何關(guān)聯(lián)温赔,我們求正常的EM算法
    理論上有2倍的nx(n+1)個參數(shù)
    即圖中的full柱狀圖

只要是做混合高斯模型蛤奢,基本都會涉及這四個參數(shù)

問答

問:協(xié)方差矩陣為什么是個對稱陣?
答:因為這是定義陶贼。協(xié)方差矩陣是對稱陣啤贩。
問:怎么看出來是球形?
答:如果隨機變量是三元的拜秧,協(xié)方差矩陣如果三個方差都相等痹屹,即主軸,副軸與短軸都相等枉氮,得到球形
問:這四種情況怎么來的志衍?
答:就是參數(shù)的設(shè)置。不管是做EM, sklearn的設(shè)置聊替,還有隱馬爾科夫模型楼肪,如果說隱馬爾科夫模型符合高斯分布,那么就是高斯隱馬爾科夫模型惹悄,那個模型里面春叫,不同的隱變量,如果是方差也有方差是否相等泣港,方差是不是對角陣等情況
問:參數(shù)設(shè)置時暂殖,這幾種情況怎么選?
答:如果不知道怎么選当纱,我們選full呛每,即參數(shù)有2倍的nx(n+1)個,如果知道是對角陣惫东,選diagonal莉给,當然都試一下也無妨。
問:EM算法是無監(jiān)督學習么廉沮?
答:EM算法可以看成無監(jiān)督學習颓遏,雖然它是一種算法,是描述how而不是what滞时。比如EM叁幢,MLE(最大似然估計),SGD(隨機梯度下降)坪稽,L-BFGS(擬牛頓)都是講的how曼玩,即解決what的具體的方法鳞骤。

模型選擇的準則

2018-12-25 20_12_36-【鄒博_chinahadoop】機器學習升級版VII(七).png

AIC解釋:
負的對數(shù)似然,就可以作為目標函數(shù)黍判;
但是我們不希望過擬合豫尽,所以需要在損失函數(shù)的前提下,加一個模型的復(fù)雜程度顷帖,比如模型的維度作為一個復(fù)雜標準美旧。
哪個模型的這個值小,哪個模型最優(yōu)贬墩。
即2k就成為了正則項

BIC解釋:
樣本多可能帶來模型復(fù)雜度變化榴嗅,如果兩個模型,一個樣本多陶舞,一個樣本少嗽测,在結(jié)果相同的情況下,樣本少的模型肿孵,看起來要好一些唠粥。
所以乘以與樣本個數(shù)有關(guān)的項,是有道理的颁井。
即(lnn)k
BIC也可以認為貝葉斯信息準則
BIC看相對大小才有意義厅贪,絕對大小沒有意義

2018-12-25 20_27_03-【鄒博_chinahadoop】機器學習升級版VII(七).png

很顯然,當參數(shù)選擇full的時候雅宾,錯誤率幾乎就是0,并且BIC最小葵硕。即選擇full參數(shù)的時候眉抬,模型是最優(yōu)的。

問題:為什么上圖右下角有一小塊紅色懈凹?
答:因為紅色方差大蜀变。

問答

問:上述的例子說明什么?
答:對于Σ1與Σ的選擇介评,引入更多的參數(shù)是否值得
問:平時計算EM的Σ不都是full類型么库北?
答:是的
問:樣本的個數(shù)n為什么越大,BIC就大呢们陆?為什么和樣本個數(shù)n有關(guān)系寒瓦?
答:能達到相同效果的時候,如果樣本比別人多坪仇,那么模型就沒有別人好杂腰。

2018-12-25 20_37_50-【鄒博_chinahadoop】機器學習升級版VII(七).png
2018-12-25 20_38_39-【鄒博_chinahadoop】機器學習升級版VII(七).png

如圖,上圖中三分類效果遠遠比二分類要差椅文,所以可以加入一些先驗知識喂很。
如圖惜颇,如果模型中的參數(shù)θ是未知定值,則可以通過最大似然估計(MLE)以及期望最大化(EM)去求少辣。
如果θ也是變化的凌摄,且符合概率分布,即P(θ|α)漓帅,這個是先驗分布望伦,
對于樣本y,只要給出x就能算出y的分布煎殷,且是對于θ的概率分布屯伞,這個是似然分布
P(θ|x, y),則是屬于后驗分布


2019-01-03 19_41_14-【鄒博_chinahadoop】機器學習升級版VII(七).png

2019-01-03 19_48_16-【鄒博_chinahadoop】機器學習升級版VII(七).png

接著進行計算豪直,如果θ有無窮多個劣摇,那么哪一個θ是最大的,就是我們想要求的:


2019-01-03 19_51_04-【鄒博_chinahadoop】機器學習升級版VII(七).png

θ這個值如何去求弓乙?
如圖末融,后驗分布,可以認為與似然分布 * 先驗分布成正比暇韧。

2019-01-03 19_53_45-【鄒博_chinahadoop】機器學習升級版VII(七).png

如果θ是Dirichlet(狄利克雷)分布勾习,可以演化為:

Dirichlet分布(參數(shù)為α+x) = 多項分布 * Dirichlet分布(參數(shù)為α)

如果α采樣的值,拍腦袋選擇1懈玻,10巧婶,或100,
假定α=1涂乌,得到θ1艺栈,θ2,θ3湾盒,湿右。。罚勾。毅人,θ100
從而分別得到(x1, y1),(x2, y2), (x3, y3), 尖殃。丈莺。。, (x100, y100)
這些是我們看到的樣本數(shù)據(jù)分衫。
其中每一個θ都是根據(jù)α采樣得到的场刑,即每一個θ都是一個隨機變量,構(gòu)成了一個隨機過程,或者說構(gòu)成了Dirichlet過程


2019-01-03 20_07_03-【鄒博_chinahadoop】機器學習升級版VII(七).png

α取1的時候是最特殊的牵现,即此時為均勻分布

再回頭來看這張圖铐懊,我們使用了符合高斯混合分布的模型,但是我們希望對參數(shù)做一個影響瞎疼,那么使用Dirichlet過程+高斯混合分布模型科乎,就得到DPGMM的模型,此時分類是合理的贼急。
如圖左邊是正常的高斯混合模型茅茂,分類分錯了;
右邊是使用了貝葉斯的高斯混合模型太抓,是特定的Dirichlet過程+高斯混合分布模型得到的結(jié)果空闲,即使分類選擇3,但是得到的結(jié)果也是正確的

在sk-learn中走敌, DPGMM的相關(guān)類為:BayesianGaussianMixture


2018-12-25 20_38_39-【鄒博_chinahadoop】機器學習升級版VII(七).png

相關(guān)的代碼如下:

# !/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
import scipy as sp
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse


def expand(a, b, rate=0.05):
    d = (b - a) * rate
    return a-d, b+d


matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False


if __name__ == '__main__':
    np.random.seed(0)
    cov1 = np.diag((1, 2))
    N1 = 500
    N2 = 300
    N = N1 + N2
    x1 = np.random.multivariate_normal(mean=(3, 2), cov=cov1, size=N1)
    m = np.array(((1, 1), (1, 3)))
    x1 = x1.dot(m)
    x2 = np.random.multivariate_normal(mean=(-1, 10), cov=cov1, size=N2)
    x = np.vstack((x1, x2))
    y = np.array([0]*N1 + [1]*N2)
    n_components = 3

    # 繪圖使用
    colors = '#A0FFA0', '#2090E0', '#FF8080'
    cm = mpl.colors.ListedColormap(colors)
    x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
    x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)
    x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
    grid_test = np.stack((x1.flat, x2.flat), axis=1)

    plt.figure(figsize=(6, 6), facecolor='w')
    plt.suptitle('GMM/DPGMM比較', fontsize=15)

    ax = plt.subplot(211)
    gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=0)
    gmm.fit(x)
    centers = gmm.means_
    covs = gmm.covariances_
    print('GMM均值 = \n', centers)
    print('GMM方差 = \n', covs)
    y_hat = gmm.predict(x)

    grid_hat = gmm.predict(grid_test)
    grid_hat = grid_hat.reshape(x1.shape)
    plt.pcolormesh(x1, x2, grid_hat, cmap=cm)
    plt.scatter(x[:, 0], x[:, 1], s=20, c=y, cmap=cm, marker='o', edgecolors='#202020')

    clrs = list('rgbmy')
    for i, (center, cov) in enumerate(zip(centers, covs)):
        value, vector = sp.linalg.eigh(cov)
        width, height = value[0], value[1]
        v = vector[0] / sp.linalg.norm(vector[0])
        angle = 180* np.arctan(v[1] / v[0]) / np.pi
        e = Ellipse(xy=center, width=width, height=height,
                    angle=angle, color=clrs[i], alpha=0.5, clip_box = ax.bbox)
        ax.add_artist(e)

    ax1_min, ax1_max, ax2_min, ax2_max = plt.axis()
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.title('GMM', fontsize=15)
    plt.grid(b=True, ls=':', color='#606060')

    # DPGMM
    dpgmm = BayesianGaussianMixture(n_components=n_components, covariance_type='full', max_iter=1000, n_init=5,
                                    weight_concentration_prior_type='dirichlet_process', weight_concentration_prior=0.1)
    dpgmm.fit(x)
    centers = dpgmm.means_
    covs = dpgmm.covariances_
    print('DPGMM均值 = \n', centers)
    print('DPGMM方差 = \n', covs)
    y_hat = dpgmm.predict(x)
    print(y_hat)

    ax = plt.subplot(212)
    grid_hat = dpgmm.predict(grid_test)
    grid_hat = grid_hat.reshape(x1.shape)
    plt.pcolormesh(x1, x2, grid_hat, cmap=cm)
    plt.scatter(x[:, 0], x[:, 1], s=20, c=y, cmap=cm, marker='o', edgecolors='#202020')

    for i, cc in enumerate(zip(centers, covs)):
        if i not in y_hat:
            continue
        center, cov = cc
        value, vector = sp.linalg.eigh(cov)
        width, height = value[0], value[1]
        v = vector[0] / sp.linalg.norm(vector[0])
        angle = 180* np.arctan(v[1] / v[0]) / np.pi
        e = Ellipse(xy=center, width=width, height=height,
                    angle=angle, color='m', alpha=0.5, clip_box = ax.bbox)
        ax.add_artist(e)
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.title('DPGMM', fontsize=15)
    plt.grid(b=True, ls=':', color='#606060')
    plt.tight_layout(2, rect=(0, 0, 1, 0.95))
    plt.show()

得到的結(jié)果如圖:


2019-01-03 20_15_36_20.EM_20.png

問:看來要講貝葉斯啊
答:是的碴倾,要為下次LDA做鋪墊
問:DPGMM其實就是一個主題和樣本的分布作為權(quán)重,乘以主題和樣本的高斯混合分布掉丽?
答:一定程度可以這樣解釋跌榔。
問:P(,;|)這三個符號一般是指什么呢?
答:P(x,y) <=> P(y, x)捶障,這個代表x與y的聯(lián)合分布
P(y; x)與P(y|x)是等價的僧须,即x是條件,y是x的因變量
但是如果代入θ项炼,就不一樣了:
但是P(y;θ)担平,屬于頻率學派;其中θ為參數(shù)芥挣,θ是未知的定值驱闷;
而P(y|θ),屬于貝葉斯學派空免,則樣本θ是未知的隨機變量
問:都是同樣阿爾法,怎么能采樣出多個θ呢盆耽?
答:高斯分布中蹋砚,如果均值為170,方差為10摄杂,取樣可以為226,175坝咐,168多個值;同理析恢,同樣的阿爾法墨坚,也可以采樣出多個θ
問:模型和里面的實現(xiàn)可以組合么?比如這種混合高斯模型映挂,能用隨機或批量梯度下降達到目的么泽篮?
答:其實是達不到的盗尸。因為混合高斯分布,其目標函數(shù)是有隱變量的存在帽撑,所以沒辦法對其直接求取梯度泼各,只能固定隱變量求梯度;固定梯度求隱變量亏拉,二者不斷迭代扣蜻,最后才得到EM
問:均值為0和不為0有什么區(qū)別?效果會變化么及塘?
答:均值是否為0只是一個解釋莽使,因為不為0的時候,我們總是會將其調(diào)整為0附近的笙僚。比如事先減均值芳肌。

2019-02-04 15_07_45-【鄒博_chinahadoop】機器學習升級版VII(七).png

求導(dǎo)的過程很簡單:
?h(p)/?p = n*pn-1*(1-p)(N-n) - pn*(N-n)*(1-p)(N-n-1)
假定導(dǎo)數(shù)為0,即:
?h(p)/?p = n*pn-1*(1-p)(N-n) - pn*(N-n)*(1-p)(N-n-1) = 0
則等式兩邊除以pn*(1-p)(N-n),得到:
?h(p)/?p = n*p-1 - (N-n)*(1-p)-1 = n/p - (N-n)/(1-p) = 0味咳,
可得:p = n/N
下面看二項分布與先驗舉例:
2019-02-04 15_47_50-【鄒博_chinahadoop】機器學習升級版VII(七).png

可以觀察到庇勃,修正公式的分子各加了一個5,而這個5是Dirichlet(狄利克雷)分布的超參數(shù)α槽驶。

2019-02-04 15_52_03-【鄒博_chinahadoop】機器學習升級版VII(七).png

問答
問:是不是這個課程代碼都敲會责嚷,加上一個項目經(jīng)驗就OK?
答:這個看需要。比如現(xiàn)在的情況掂铐,機器學習是研究整個這套方式的一個根基罕拂。不用基礎(chǔ)這個詞,是因為以為其很簡單全陨,所以用根基這個詞爆班。有了這個根基之后,大家再去往上做其他應(yīng)用辱姨,不會感覺困難柿菩。比如用卷積網(wǎng)絡(luò),最后一層我們使用SoftMax的全連接雨涛,還是用SVM枢舶,本質(zhì)上是換損失函數(shù)。然后我們解釋模型是否有效替久,都是可以用上的凉泄。
基礎(chǔ)是夠的,但是如果大家沒有深度學習的應(yīng)用實踐蚯根,或者只有一個項目經(jīng)驗后众,還是不夠的。所以需要實際項目進行反復(fù)驗證與活學活用,或者參加競賽蒂誉,比賽也可以教藻。
問:強化學習和機器學習的關(guān)聯(lián)大么?強化學習未來的應(yīng)用前景如何拗盒?
答:強化學習可能是近兩三年的爆發(fā)點怖竭,可能是大公司玩的。需要的算力陡蝇,比數(shù)據(jù)要強痊臭。比如飛翔的小鳥,或者行走的人登夫,需要對當前的動作進行反饋广匙,然后根據(jù)反饋的結(jié)果,去更正動作恼策,并不斷學習鸦致。算力要求非常高,前景要求是有的涣楷,但目前只能進行簡單游戲分唾、博弈、對抗這種內(nèi)容狮斗,沒法成為最主力的算法應(yīng)用绽乔。也許不對,但最主流的還是在有監(jiān)督應(yīng)用碳褒。
問:算力是什么折砸?可以簡單理解為硬件速度么?
答:這個理解沒問題的沙峻。

EM算法代碼

下面代碼有自己實現(xiàn)的高斯混合模型睦授,以及通過sk-learn庫的高斯混合模型類直接實現(xiàn)的兩種方式。
自己實現(xiàn)的高斯混合模型摔寨,其實就是實現(xiàn)期望最大化去枷,即EM算法,公式如下:


2019-02-04 17_19_35-【鄒博_chinahadoop】機器學習升級版VII(七).png
import numpy as np
from scipy.stats import multivariate_normal
from sklearn.mixture import GaussianMixture
from mpl_toolkits.mplot3d import Axes3D
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import pairwise_distances_argmin


mpl.rcParams['font.sans-serif'] = ['SimHei']
mpl.rcParams['axes.unicode_minus'] = False


if __name__ == '__main__':
    #style = 'sklearn'
    style = 'myself'
    np.random.seed(0)
    mu1_fact = (0, 0, 0)
    cov1_fact = np.diag((1, 2, 3))
    # 根據(jù)實際情況生成一個多元正態(tài)分布矩陣是复,np.random.multivariate_normal
    # 參數(shù)就是高斯分布所需的均值與方差
    # 第一個參數(shù): mean:mean是多維分布的均值維度為1沉填;
    # 第二個參數(shù):cov:協(xié)方差矩陣,注意:協(xié)方差矩陣必須是對稱的且需為半正定矩陣佑笋;
    # 第三個參數(shù):size:指定生成的正態(tài)分布矩陣的維度
    data1 = np.random.multivariate_normal(mu1_fact, cov1_fact, 400)
    print('data1 shape: {0}'.format(data1.shape))
    mu2_fact = (2, 2, 1)
    # 方差對稱且正定(positive-semidefinite): (4, 1, 3), (1, 2, 1), (3, 1, 4)
    cov2_fact = np.array(((4, 1, 3), (1, 2, 1), (3, 1, 4)))
    data2 = np.random.multivariate_normal(mu2_fact, cov2_fact, 100)
    print('data2 shape: {0}'.format(data2.shape))

    data = np.vstack((data1, data2))
    print('data shape: {0}'.format(data.shape))
    y = np.array([True] * 400 + [False] * 100)

    if style == 'sklearn':
        g = GaussianMixture(n_components=2, covariance_type='full', tol=1e-6, max_iter=1000)
        g.fit(data)
        print('類別概率:\t', g.weights_[0])
        print('均值:\n', g.means_, '\n')
        print('方差:\n', g.covariances_, '\n')
        mu1, mu2 = g.means_
        sigma1, sigma2 = g.covariances_
    else:
        num_iter = 100
        n, d = data.shape
        # 隨機指定
        # mu1 = np.random.standard_normal(d)
        # print mu1
        # mu2 = np.random.standard_normal(d)
        # print mu2
        mu1 = data.min(axis=0)
        mu2 = data.max(axis=0)
        # 創(chuàng)建d行d列的單位矩陣(對角線為1,其余為0)
        sigma1 = np.identity(d)
        sigma2 = np.identity(d)
        pi = 0.5
        # EM
        for i in range(num_iter):
            # E Step
            # 通過初始化的均值與方差斑鼻,做多元的正態(tài)分布
            norm1 = multivariate_normal(mu1, sigma1)
            norm2 = multivariate_normal(mu2, sigma2)
            # 概率密度 * pi
            tau1 = pi * norm1.pdf(data)
            tau2 = (1 - pi) * norm2.pdf(data)
            gamma = tau1 / (tau1 + tau2)

            # M Step
            mu1 = np.dot(gamma, data) / np.sum(gamma)
            mu2 = np.dot((1 - gamma), data) / np.sum((1 - gamma))
            sigma1 = np.dot(gamma * (data - mu1).T, data - mu1) / np.sum(gamma)
            sigma2 = np.dot((1 - gamma) * (data - mu2).T, data - mu2) / np.sum(1 - gamma)
            pi = np.sum(gamma) / n
            print(i, ":\t", mu1, mu2)
        print('類別概率:\t', pi)
        print('均值:\t', mu1, mu2)
        print('方差:\n', sigma1, '\n\n', sigma2, '\n')

    # 預(yù)測分類
    # multivariate_normal獲得多元正態(tài)分布
    norm1 = multivariate_normal(mu1, sigma1)
    norm2 = multivariate_normal(mu2, sigma2)
    # pdf: Probability density function蒋纬,連續(xù)性概率分布函數(shù)
    tau1 = norm1.pdf(data)
    tau2 = norm2.pdf(data)

    fig = plt.figure(figsize=(10, 5), facecolor='w')
    ax = fig.add_subplot(121, projection='3d')
    ax.scatter(data[:, 0], data[:, 1], data[:, 2], c='b', s=30, marker='o', edgecolors='k', depthshade=True)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('原始數(shù)據(jù)', fontsize=15)
    ax = fig.add_subplot(122, projection='3d')
    # 求取點距離
    order = pairwise_distances_argmin([mu1_fact, mu2_fact], [mu1, mu2], metric='euclidean')
    # order = pairwise_distances_argmin([mu1_fact, mu2_fact], [mu1, mu2], metric='cosine')

    # 通過歐式距離,將點分為兩類
    print(order)
    if order[0] == 0:
        c1 = tau1 > tau2
    else:
        c1 = tau1 < tau2
    c2 = ~c1
    # 機器學習計算準確率的常用做法
    # 原理:真實值是y,預(yù)測值是c1蜀备,相等則為True关摇,否則為False。True為1碾阁,F(xiàn)alse為0
    # 求均值則為:預(yù)測準確的數(shù)目/總數(shù)目输虱,這不就是準確率么
    acc = np.mean(y == c1)
    print('準確率:%.2f%%' % (100*acc))
    ax.scatter(data[c1, 0], data[c1, 1], data[c1, 2], c='r', s=30, marker='o', edgecolors='k', depthshade=True)
    ax.scatter(data[c2, 0], data[c2, 1], data[c2, 2], c='g', s=30, marker='^', edgecolors='k', depthshade=True)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('EM算法分類', fontsize=15)
    plt.suptitle('EM算法的實現(xiàn)', fontsize=18)
    plt.subplots_adjust(top=0.90)
    # plt.tight_layout()
    plt.show()

得到的圖形界面:

emdraw.png

如果將data2,那100個數(shù)據(jù)脂凶,均值設(shè)置為(5, 5, 5)宪睹,則分類效果更明顯,如圖:
emdraw.png

當均值為5, 5, 5時蚕钦,且自己實現(xiàn)高斯分布亭病,輸出如下。
可以發(fā)現(xiàn)嘶居,迭代24次之后罪帖,均值就不再變化了,可以稱為模型收斂了邮屁。
真實值為[0, 0, 0]與[5, 5, 5]整袁,計算的均值為:
[-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ],有差別佑吝,但是靠譜坐昙。
當然,也可以增加樣本迹蛤,使得結(jié)果更接近真實值的情況民珍。比如將400, 100替換為4000, 1000

data1 shape: (400, 3)
data2 shape: (100, 3)
data shape: (500, 3)

0 :  [-0.02992749  0.09146815  0.03351835] [5.43632719 5.10518101 5.44044355]
1 :  [-0.0577343   0.02743837  0.00975419] [5.23010882 5.07679997 5.22085292]
2 :  [-0.064707    0.00613346 -0.0011657 ] [5.13849255 5.04871819 5.14698533]
3 :  [-0.0673045  -0.00120726 -0.00558504] [5.0995056  5.03026551 5.1158224 ]
4 :  [-0.068465   -0.00420981 -0.00743412] [5.08222193 5.02089901 5.10147628]
5 :  [-0.06899479 -0.00554239 -0.00824704] [5.07428887 5.01640324 5.09475058]
6 :  [-0.06923923 -0.00615431 -0.00861565] [5.07059253 5.01427675 5.09158405]
7 :  [-0.06935293 -0.00643926 -0.00878579] [5.06886018 5.01327497 5.09009261]
8 :  [-0.06940609 -0.00657271 -0.00886508] [5.06804645 5.01280356 5.0893904 ]
9 :  [-0.06943103 -0.00663537 -0.00890221] [5.06766389 5.01258178 5.0890599 ]
10 :     [-0.06944274 -0.00666482 -0.00891964] [5.06748396 5.01247744 5.08890438]
11 :     [-0.06944825 -0.00667867 -0.00892783] [5.06739933 5.01242836 5.08883121]
12 :     [-0.06945084 -0.00668519 -0.00893168] [5.06735951 5.01240527 5.08879678]
13 :     [-0.06945206 -0.00668825 -0.00893349] [5.06734079 5.01239441 5.08878059]
14 :     [-0.06945263 -0.00668969 -0.00893434] [5.06733197 5.0123893  5.08877297]
15 :     [-0.0694529  -0.00669037 -0.00893474] [5.06732783 5.0123869  5.08876938]
16 :     [-0.06945302 -0.00669069 -0.00893493] [5.06732588 5.01238577 5.0887677 ]
17 :     [-0.06945308 -0.00669084 -0.00893502] [5.06732496 5.01238523 5.0887669 ]
18 :     [-0.06945311 -0.00669091 -0.00893506] [5.06732453 5.01238498 5.08876653]
19 :     [-0.06945313 -0.00669095 -0.00893508] [5.06732433 5.01238487 5.08876635]
20 :     [-0.06945313 -0.00669096 -0.00893509] [5.06732423 5.01238481 5.08876627]
21 :     [-0.06945314 -0.00669097 -0.00893509] [5.06732419 5.01238478 5.08876623]
22 :     [-0.06945314 -0.00669097 -0.00893509] [5.06732417 5.01238477 5.08876621]
23 :     [-0.06945314 -0.00669097 -0.0089351 ] [5.06732416 5.01238477 5.08876621]
24 :     [-0.06945314 -0.00669097 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
25 :     [-0.06945314 -0.00669097 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
26 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
27 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
28 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
29 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
30 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
31 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
32 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
33 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
34 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
35 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
36 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
37 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
38 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
39 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
40 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
41 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
42 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
43 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
44 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
45 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
46 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
47 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
48 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
49 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
50 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
51 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
52 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
53 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
54 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
55 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
56 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
57 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
58 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
59 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
60 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
61 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
62 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
63 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
64 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
65 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
66 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
67 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
68 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
69 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
70 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
71 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
72 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
73 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
74 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
75 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
76 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
77 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
78 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
79 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
80 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
81 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
82 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
83 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
84 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
85 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
86 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
87 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
88 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
89 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
90 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
91 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
92 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
93 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
94 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
95 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
96 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
97 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
98 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
99 :     [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
類別概率:    0.7987220297951044
均值:  [-0.06945314 -0.00669098 -0.0089351 ] [5.06732415 5.01238476 5.0887662 ]
方差:
 [[ 0.87148101 -0.05642494  0.03198856]
 [-0.05642494  2.09700921 -0.12547629]
 [ 0.03198856 -0.12547629  2.745459  ]] 

 [[4.08142083 0.79087313 3.107469  ]
 [0.79087313 1.79995257 0.75954681]
 [3.107469   0.75954681 4.04331614]] 

[0 1]
準確率:98.60%

問答
問:協(xié)方差一定是對稱的么?
答:是的盗飒,協(xié)方差一定是對稱的
問:np.identity是什么意思嚷量?
答:創(chuàng)建單位矩陣,即對角線為1逆趣,其余值為0

GMM代碼實現(xiàn)

對應(yīng)業(yè)務(wù)為性別-身高-體重數(shù)據(jù)
通過高斯混合模型蝶溶,預(yù)測身高與體重所屬的性別

# !/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import train_test_split
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt

mpl.rcParams['font.sans-serif'] = ['SimHei']
mpl.rcParams['axes.unicode_minus'] = False
# from matplotlib.font_manager import FontProperties
# font_set = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=15)
# fontproperties=font_set


def expand(a, b):
    d = (b - a) * 0.05
    return a-d, b+d


if __name__ == '__main__':
    data = np.loadtxt('./HeightWeight.csv', dtype=np.float, delimiter=',', skiprows=1)
    print(data.shape)
    y, x = np.split(data, [1, ], axis=1)
    x, x_test, y, y_test = train_test_split(x, y, train_size=0.6, random_state=0)
    gmm = GaussianMixture(n_components=2, covariance_type='full', random_state=0)
    x_min = np.min(x, axis=0)
    x_max = np.max(x, axis=0)
    gmm.fit(x)
    print('均值 = \n', gmm.means_)
    print('方差 = \n', gmm.covariances_)
    y_hat = gmm.predict(x)
    y_test_hat = gmm.predict(x_test)
    change = (gmm.means_[0][0] > gmm.means_[1][0])
    if change:
        z = y_hat == 0
        y_hat[z] = 1
        y_hat[~z] = 0
        z = y_test_hat == 0
        y_test_hat[z] = 1
        y_test_hat[~z] = 0
    acc = np.mean(y_hat.ravel() == y.ravel())
    acc_test = np.mean(y_test_hat.ravel() == y_test.ravel())
    acc_str = '訓練集準確率:%.2f%%' % (acc * 100)
    acc_test_str = '測試集準確率:%.2f%%' % (acc_test * 100)
    print(acc_str)
    print(acc_test_str)

    cm_light = mpl.colors.ListedColormap(['#FF8080', '#77E0A0'])
    cm_dark = mpl.colors.ListedColormap(['r', 'g'])
    x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
    x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)
    x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
    grid_test = np.stack((x1.flat, x2.flat), axis=1)
    grid_hat = gmm.predict(grid_test)
    grid_hat = grid_hat.reshape(x1.shape)
    if change:
        z = grid_hat == 0
        grid_hat[z] = 1
        grid_hat[~z] = 0
    plt.figure(figsize=(7, 6), facecolor='w')
    plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)
    plt.scatter(x[:, 0], x[:, 1], s=50, c=y.ravel(), marker='o', cmap=cm_dark, edgecolors='k')
    plt.scatter(x_test[:, 0], x_test[:, 1], s=60, c=y_test.ravel(), marker='^', cmap=cm_dark, edgecolors='k')

    p = gmm.predict_proba(grid_test)
    print(p)
    p = p[:, 0].reshape(x1.shape)
    CS = plt.contour(x1, x2, p, levels=(0.1, 0.5, 0.8), colors=list('rgb'), linewidths=2)
    plt.clabel(CS, fontsize=12, fmt='%.1f', inline=True)
    ax1_min, ax1_max, ax2_min, ax2_max = plt.axis()
    xx = 0.95*ax1_min + 0.05*ax1_max
    yy = 0.05*ax2_min + 0.95*ax2_max
    plt.text(xx, yy, acc_str, fontsize=12)
    yy = 0.1*ax2_min + 0.9*ax2_max
    plt.text(xx, yy, acc_test_str, fontsize=12)
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.xlabel('身高(cm)', fontsize=13)
    plt.ylabel('體重(kg)', fontsize=13)
    plt.title('EM算法估算GMM的參數(shù)', fontsize=15)
    plt.grid(b=True, ls=':', color='#606060')
    plt.tight_layout(2)
    plt.show()

其中HeightWeight.csv的數(shù)據(jù)如下,直接將其拷貝到文本文件宣渗,然后保存為文件名為HeightWeight.csv的文件即可

Sex,Height(cm),Weight(kg)
0,156,50
0,160,60
0,162,54
0,162,55
0,160.5,56
0,160,53
0,158,55
0,164,60
0,165,50
0,166,55
0,158,47.5
0,161,49
0,169,55
0,161,46
0,160,45
0,167,44
0,155,49
0,154,57
0,172,52
0,155,56
0,157,55
0,165,65
0,156,52
0,155,50
0,156,56
0,160,55
0,158,55
0,162,70
0,162,65
0,155,57
0,163,70
0,160,60
0,162,55
0,165,65
0,159,60
0,147,47
0,163,53
0,157,54
0,160,55
0,162,48
0,158,60
0,155,48
0,165,60
0,161,58
0,159,45
0,163,50
0,158,49
0,155,50
0,162,55
0,157,63
0,159,49
0,152,47
0,156,51
0,165,49
0,154,47
0,156,52
0,162,48
1,162,60
1,164,62
1,168,86
1,187,75
1,167,75
1,174,64
1,175,62
1,170,65
1,176,73
1,169,58
1,178,54
1,165,66
1,183,68
1,171,61
1,179,64
1,172,60
1,173,59
1,172,58
1,175,62
1,160,60
1,160,58
1,160,60
1,175,75
1,163,60
1,181,77
1,172,80
1,175,73
1,175,60
1,167,65
1,172,60
1,169,75
1,172,65
1,175,72
1,172,60
1,170,65
1,158,59
1,167,63
1,164,61
1,176,65
1,182,95
1,173,75
1,176,67
1,163,58
1,166,67
1,162,59
1,169,56
1,163,59
1,163,56
1,176,62
1,169,57
1,173,61
1,163,59
1,167,57
1,176,63
1,168,61
1,167,60
1,170,69

圖形示例如下:


2019-02-04 17_41_49-Start.png

問:from sklearn.metrics.pairwise import pairwise_distances_argmin抖所,這個是干嘛的?
答:是用于計算任意的兩個值里面痕囱,誰和誰是最小的田轧。比如:order = pairwise_distances_argmin([mu1_fact, mu2_fact], [mu1, mu2], metric='euclidean'),返回的值是[0,1]鞍恢,表明mu1與mu1_fact最近傻粘,mu2與mu2_fact最近每窖。換句話說,我們做的順序是做對了弦悉。

通過GMM實現(xiàn)鳶尾花分類

通過高斯混合模型窒典,對鳶尾花數(shù)據(jù)做分類

# !/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
import pandas as pd
from sklearn.mixture import GaussianMixture
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import pairwise_distances_argmin

mpl.rcParams['font.sans-serif'] = ['SimHei']
mpl.rcParams['axes.unicode_minus'] = False

iris_feature = '花萼長度', '花萼寬度', '花瓣長度', '花瓣寬度'


def expand(a, b, rate=0.05):
    d = (b - a) * rate
    return a-d, b+d


if __name__ == '__main__':
    path = '..\9.Regression\iris.data'
    data = pd.read_csv(path, header=None)
    x_prime = data[np.arange(4)]
    y = pd.Categorical(data[4]).codes

    n_components = 3
    feature_pairs = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
    plt.figure(figsize=(8, 6), facecolor='w')
    for k, pair in enumerate(feature_pairs, start=1):
        x = x_prime[pair]
        m = np.array([np.mean(x[y == i], axis=0) for i in range(3)])  # 均值的實際值
        print('實際均值 = \n', m)

        gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=0)
        gmm.fit(x)
        print('預(yù)測均值 = \n', gmm.means_)
        print('預(yù)測方差 = \n', gmm.covariances_)
        y_hat = gmm.predict(x)
        print(y_hat)
        order = pairwise_distances_argmin(m, gmm.means_, axis=1, metric='euclidean')
        print(order)
        print('順序:\t', order)

        n_sample = y.size
        n_types = 3
        change = np.empty((n_types, n_sample), dtype=np.bool)
        for i in range(n_types):
            change[i] = y_hat == order[i]
        for i in range(n_types):
            y_hat[change[i]] = i
        acc = '準確率:%.2f%%' % (100*np.mean(y_hat == y))
        print(acc)

        cm_light = mpl.colors.ListedColormap(['#FF8080', '#77E0A0', '#A0A0FF'])
        cm_dark = mpl.colors.ListedColormap(['r', 'g', '#6060FF'])
        x1_min, x2_min = x.min()
        x1_max, x2_max = x.max()
        x1_min, x1_max = expand(x1_min, x1_max)
        x2_min, x2_max = expand(x2_min, x2_max)
        x1, x2 = np.mgrid[x1_min:x1_max:200j, x2_min:x2_max:200j]
        grid_test = np.stack((x1.flat, x2.flat), axis=1)
        grid_hat = gmm.predict(grid_test)

        change = np.empty((n_types, grid_hat.size), dtype=np.bool)
        for i in range(n_types):
            change[i] = grid_hat == order[i]
        for i in range(n_types):
            grid_hat[change[i]] = i

        grid_hat = grid_hat.reshape(x1.shape)
        plt.subplot(2, 3, k)
        plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)
        plt.scatter(x[pair[0]], x[pair[1]], s=20, c=y, marker='o', cmap=cm_dark, edgecolors='k')
        xx = 0.95 * x1_min + 0.05 * x1_max
        yy = 0.1 * x2_min + 0.9 * x2_max
        plt.text(xx, yy, acc, fontsize=10)
        plt.xlim((x1_min, x1_max))
        plt.ylim((x2_min, x2_max))
        plt.xlabel(iris_feature[pair[0]], fontsize=11)
        plt.ylabel(iris_feature[pair[1]], fontsize=11)
        plt.grid(b=True, ls=':', color='#606060')
    plt.suptitle('EM算法無監(jiān)督分類鳶尾花數(shù)據(jù)', fontsize=14)
    plt.tight_layout(1, rect=(0, 0, 1, 0.95))
    plt.show()

圖例如下:


2019-02-04 18_17_05-Figure 1.png

繪制高斯混合模型的等值線

# !/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
from sklearn.mixture import GaussianMixture
import scipy as sp
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import warnings


def expand(a, b, rate=0.05):
    d = (b - a) * rate
    return a-d, b+d


if __name__ == '__main__':
    warnings.filterwarnings(action='ignore', category=RuntimeWarning)
    np.random.seed(0)
    cov1 = np.diag((1, 2))
    N1 = 500
    N2 = 300
    N = N1 + N2
    x1 = np.random.multivariate_normal(mean=(3, 2), cov=cov1, size=N1)
    m = np.array(((1, 1), (1, 3)))
    x1 = x1.dot(m)
    x2 = np.random.multivariate_normal(mean=(-1, 10), cov=cov1, size=N2)
    x = np.vstack((x1, x2))
    y = np.array([0]*N1 + [1]*N2)

    gmm = GaussianMixture(n_components=2, covariance_type='full', random_state=0)
    gmm.fit(x)
    centers = gmm.means_
    covs = gmm.covariances_
    print('GMM均值 = \n', centers)
    print('GMM方差 = \n', covs)
    y_hat = gmm.predict(x)

    colors = '#A0FFA0', '#E080A0',
    levels = 10
    cm = mpl.colors.ListedColormap(colors)
    x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
    x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)
    x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
    grid_test = np.stack((x1.flat, x2.flat), axis=1)
    print(gmm.score_samples(grid_test))
    grid_hat = -gmm.score_samples(grid_test)
    grid_hat = grid_hat.reshape(x1.shape)
    plt.figure(figsize=(7, 6), facecolor='w')
    ax = plt.subplot(111)
    cmesh = plt.pcolormesh(x1, x2, grid_hat, cmap=plt.cm.Spectral)
    plt.colorbar(cmesh, shrink=0.9)
    CS = plt.contour(x1, x2, grid_hat, levels=np.logspace(0, 2, num=levels, base=10), colors='w', linewidths=1)
    plt.clabel(CS, fontsize=9, inline=True, fmt='%.1f')
    plt.scatter(x[:, 0], x[:, 1], s=30, c=y, cmap=cm, marker='o', edgecolors='#202020')

    for i, cc in enumerate(zip(centers, covs)):
        center, cov = cc
        value, vector = sp.linalg.eigh(cov)
        width, height = value[0], value[1]
        v = vector[0] / sp.linalg.norm(vector[0])
        angle = 180* np.arctan(v[1] / v[0]) / np.pi
        e = Ellipse(xy=center, width=width, height=height,
                    angle=angle, color='m', alpha=0.5, clip_box = ax.bbox)
        ax.add_artist(e)

    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    mpl.rcParams['font.sans-serif'] = ['SimHei']
    mpl.rcParams['axes.unicode_minus'] = False
    plt.title('GMM似然函數(shù)值', fontsize=15)
    plt.grid(b=True, ls=':', color='#606060')
    plt.tight_layout(2)
    plt.show()

圖例如下:


2019-02-04 18_29_19-Start.png

問答
問:DPGMM選的k是不是要盡量小稽莉?
答:不一定瀑志,與k值選擇是否小沒關(guān)系。
問:矩陣運算不是不能用交換律么污秆?怎么直接交換了劈猪?
答:是對這個代碼:sigma1 = np.dot(gamma * (data - mu1).T, data - mu1) / np.sum(gamma),矩陣不能交換混狠,但是標量值就可以進行交換岸霹。

EM算法內(nèi)容完結(jié)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市将饺,隨后出現(xiàn)的幾起案子贡避,更是在濱河造成了極大的恐慌,老刑警劉巖予弧,帶你破解...
    沈念sama閱讀 221,635評論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件刮吧,死亡現(xiàn)場離奇詭異,居然都是意外死亡掖蛤,警方通過查閱死者的電腦和手機杀捻,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,543評論 3 399
  • 文/潘曉璐 我一進店門舍咖,熙熙樓的掌柜王于貴愁眉苦臉地迎上來麻削,“玉大人,你說我怎么就攤上這事顿锰∑髟蓿” “怎么了垢袱?”我有些...
    開封第一講書人閱讀 168,083評論 0 360
  • 文/不壞的土叔 我叫張陵,是天一觀的道長港柜。 經(jīng)常有香客問我请契,道長,這世上最難降的妖魔是什么夏醉? 我笑而不...
    開封第一講書人閱讀 59,640評論 1 296
  • 正文 為了忘掉前任爽锥,我火速辦了婚禮,結(jié)果婚禮上畔柔,老公的妹妹穿的比我還像新娘氯夷。我一直安慰自己,他們只是感情好靶擦,可當我...
    茶點故事閱讀 68,640評論 6 397
  • 文/花漫 我一把揭開白布肠槽。 她就那樣靜靜地躺著擎淤,像睡著了一般。 火紅的嫁衣襯著肌膚如雪秸仙。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 52,262評論 1 308
  • 那天桩盲,我揣著相機與錄音寂纪,去河邊找鬼。 笑死赌结,一個胖子當著我的面吹牛捞蛋,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播柬姚,決...
    沈念sama閱讀 40,833評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼拟杉,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了量承?” 一聲冷哼從身側(cè)響起搬设,我...
    開封第一講書人閱讀 39,736評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎撕捍,沒想到半個月后拿穴,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,280評論 1 319
  • 正文 獨居荒郊野嶺守林人離奇死亡忧风,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,369評論 3 340
  • 正文 我和宋清朗相戀三年默色,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片狮腿。...
    茶點故事閱讀 40,503評論 1 352
  • 序言:一個原本活蹦亂跳的男人離奇死亡腿宰,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出缘厢,到底是詐尸還是另有隱情吃度,我是刑警寧澤,帶...
    沈念sama閱讀 36,185評論 5 350
  • 正文 年R本政府宣布昧绣,位于F島的核電站规肴,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏夜畴。R本人自食惡果不足惜拖刃,卻給世界環(huán)境...
    茶點故事閱讀 41,870評論 3 333
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望贪绘。 院中可真熱鬧兑牡,春花似錦、人聲如沸税灌。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,340評論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至苞也,卻和暖如春洛勉,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背如迟。 一陣腳步聲響...
    開封第一講書人閱讀 33,460評論 1 272
  • 我被黑心中介騙來泰國打工收毫, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人殷勘。 一個月前我還...
    沈念sama閱讀 48,909評論 3 376
  • 正文 我出身青樓此再,卻偏偏與公主長得像,于是被迫代替她去往敵國和親玲销。 傳聞我的和親對象是個殘疾皇子输拇,可洞房花燭夜當晚...
    茶點故事閱讀 45,512評論 2 359

推薦閱讀更多精彩內(nèi)容