MeanShift
該算法也叫做均值漂移,在目標(biāo)追蹤中應(yīng)用廣泛欧引。本身其實是一種基于密度的聚類算法垮媒。
主要思路是:計算某一點(diǎn)A與其周圍半徑R內(nèi)的向量距離的平均值M,計算出該點(diǎn)下一步漂移(移動)的方向(A=M+A)猫胁。當(dāng)該點(diǎn)不再移動時箱亿,其與周圍點(diǎn)形成一個類簇,計算這個類簇與歷史類簇的距離弃秆,滿足小于閾值D即合并為同一個類簇届惋,不滿足則自身形成一個類簇髓帽。直到所有的數(shù)據(jù)點(diǎn)選取完畢。
一般形式
對于給定的 n 維空間 中的 m 個樣本點(diǎn)
脑豹,i=1...m郑藏,對于其中一個樣本X,他的均值漂移向量為:
瘩欺,其中
指的是一個半徑為h的球狀領(lǐng)域必盖,定義為
,如下圖所示
藍(lán)色圈內(nèi)表示半徑h的區(qū)域
首先,我們再看一下上圖和公式:藍(lán)色圈區(qū)域內(nèi)棉圈,每一個與
核函數(shù)形式
設(shè)是輸入空間捆等,是實數(shù)空間的一個子集滞造。設(shè)
為希爾伯特空間(完備的空間,抽象意義上對有限維歐式空間的擴(kuò)展)栋烤,設(shè)存在一個映射:
谒养,此時有函數(shù)
,其中
买窟。關(guān)于希爾伯特空間和核函數(shù)的概念丰泊,本人了解的也不深,歡迎探討始绍。
高斯核函數(shù)是一種應(yīng)用廣泛的核函數(shù):
其中h為bandwidth 帶寬瞳购,不同帶寬的核函數(shù)形式也不一樣
由上圖可以看到,橫坐標(biāo)指的是兩變量之間的距離亏推。距離越近(接近于0)則函數(shù)值越大学赛,否則越小。h越大吞杭,相同距離的情況下 函數(shù)值會越小盏浇。因此我們可以選取適當(dāng)?shù)膆值,得到滿足上述要求的那種權(quán)重(兩變量距離越近芽狗,得到權(quán)重越大)缠捌,故經(jīng)過核函數(shù)改進(jìn)后的均值漂移為:
其中
看到其他的文章說,經(jīng)過核函數(shù)改進(jìn)后的均值漂移译蒂,經(jīng)過證明(求導(dǎo)),會朝著概率密度上升的區(qū)域移動谊却。
上代碼及實驗結(jié)果:
Python代碼
class MeanShift(object):
"""
均值漂移聚類-基于密度
"""
def __init__(self,radius = 0.5,distance_between_groups = 2.5,bandwidth = 1,use_gk = True):
self._radius = radius
self._groups = []
self._bandwidth = bandwidth
self._distance_between_groups = distance_between_groups
self._use_gk = use_gk #是否啟用高斯核函數(shù)
def _find_nearst_indexes(self,xi,XX):
if XX.shape[0] == 0:
return []
distances= eculide(xi,XX)
nearst_indexes = np.where(distances <= self._distance_between_groups)[0].tolist()
return nearst_indexes
def _compute_mean_vector(self,xi,datas):
distances = datas-xi
if self._use_gk:
sum1 = self.gaussian_kernel(distances)
sum2 = sum1*(distances)
mean_vector = np.sum(sum2,axis=0)/np.sum(sum1,axis=0)
else:
mean_vector = np.sum(datas - xi, axis=0) / datas.shape[0]
return mean_vector
def fit(self,X):
XX = X
while(XX.shape[0]!=0):
# 1.從原始數(shù)據(jù)選取一個中心點(diǎn)及其半徑周邊的點(diǎn) 進(jìn)行漂移運(yùn)算
index = np.random.randint(0,XX.shape[0],1).squeeze()
group = Group()
xi = XX[index]
XX = np.delete(XX,index,axis=0) # 刪除XX中的一行并重新賦值
nearest_indexes = self._find_nearst_indexes(xi, XX)
nearest_datas = None
mean_vector = None
if len(nearest_indexes) != 0:
nearest_datas = None
# 2.不斷進(jìn)行漂移柔昼,中心點(diǎn)達(dá)到穩(wěn)定值
epos = 1.0
while (True):
nearest_datas = XX[nearest_indexes]
mean_vector = self._compute_mean_vector(xi,nearest_datas)
xi = mean_vector + xi
nearest_indexes = self._find_nearst_indexes(xi, XX)
epos = np.abs(np.sum(mean_vector))
if epos < 0.00001 : break
if len(nearest_indexes) == 0 : break
# 有些博客說在一次漂移過程中 每個漂移點(diǎn)周邊的點(diǎn)都需要納入該類簇中,我覺得不妥炎辨,此處不是這樣實現(xiàn)的捕透,
# 只把穩(wěn)定點(diǎn)周邊的數(shù)據(jù)納入該類簇中
group.members = nearest_datas.tolist()
group.center = xi
XX = np.delete(XX, nearest_indexes, axis=0)
else:
group.center = xi
# 3.與歷史類簇進(jìn)行距離計算,若小于閾值則加入歷史類簇碴萧,并更新類簇中心及成員
for i in range(len(self._groups)):
h_group = self._groups[i]
distance = eculide(h_group.center,group.center)
if distance <= self._distance_between_groups:
h_group.members = group.members
h_group.center = (h_group.center+group.center)/2
else:
group.name = len(self._groups) + 1
self._groups.append(group)
break
if len(self._groups) == 0:
group.name = len(self._groups) + 1
self._groups.append(group)
# 4.從余下的點(diǎn)中重復(fù)1-3的計算乙嘀,直到所有數(shù)據(jù)完成選取
def plot_example(self):
figure = plt.figure()
ax = figure.add_subplot(111)
ax.set_title("MeanShift Iris Example")
plt.xlabel("first dim")
plt.ylabel("third dim")
legends = []
cxs = []
cys = []
for i in range(len(self._groups)):
group = self._groups[i]
members = group.members
x = [member[0] for member in members]
y = [member[2] for member in members]
cx = group.center[0]
cy = group.center[2]
cxs.append(cx)
cys.append(cy)
ax.scatter(x, y, marker='o')
#ax.scatter(cx,cy,marker='+',c='r')
legends.append(group.name)
plt.scatter(cxs,cys,marker='+',c='k')
plt.legend(legends, loc="best")
plt.show()
def gaussian_kernel(self,distances):
"""
高斯核函數(shù)
:param distances:
:param h:
:return:
"""
left = 1/(self._bandwidth*np.sqrt(2*np.pi))
right = np.exp(-np.power(distances,2)/(2*np.power(self._bandwidth,2)))
return left*right
def test_meanshift(use_gk = False):
data,t,tn=load_data()
ms = MeanShift(radius=0.66,distance_between_groups=1.4,use_gk=use_gk)
ms.fit(data)
ms.plot_example()
test_meanshift(use_gk = True)
上述定義的Group類及一些import導(dǎo)入包,參見K均值聚類及代碼實現(xiàn)
實驗結(jié)果還是利用了iris數(shù)據(jù)集破喻,結(jié)果如下虎谢,第一幅圖是一般形式,第二幅圖是高斯核函數(shù)曹质。黑色“+”代表的是聚類中心
與KMeans相比較而言婴噩,meashift可以不用指定類簇的個數(shù),自動發(fā)現(xiàn)類簇結(jié)構(gòu)羽德。
但是Kmeans也類似几莽,發(fā)現(xiàn)的類簇多為球狀類簇,不能發(fā)現(xiàn)一些混合度較高宅静,非球狀類簇章蚣。
下面是經(jīng)過調(diào)參得到的分為3個類圖像。此時
MeanShift(radius=1.5,distance_between_groups=2.3,use_gk=use_gk)
此處實現(xiàn)的與sklearn中的MeanShift不同姨夹,后續(xù)會研究一下sklearn的實現(xiàn)方法纤垂。
參考文獻(xiàn)
1.簡單易學(xué)的機(jī)器學(xué)習(xí)算法——Mean Shift聚類算法
2.python機(jī)器學(xué)習(xí)算法-趙志勇
1中的文章也是2作者寫的