本文由ChardLau原創(chuàng),轉(zhuǎn)載請?zhí)砑釉逆溄?a target="_blank" rel="nofollow">https://www.chardlau.com/mean-shift/
今天的文章介紹如何利用Mean Shift
算法的基本形式對數(shù)據(jù)進行聚類操作簇秒。而有關(guān)Mean Shift
算法加入核函數(shù)計算漂移向量部分的內(nèi)容將不在本文講述范圍內(nèi)箱季。實際上除了聚類暇番,Mean Shift
算法還能用于計算機視覺等場合于个,有關(guān)該算法的理論知識請參考這篇文章镐作。
Mean Shift
算法原理
下圖展示了Mean Shift
算法計算飄逸向量的過程:
Mean Shift
算法的關(guān)鍵操作是通過感興趣區(qū)域內(nèi)的數(shù)據(jù)密度變化計算中心點的漂移向量,從而移動中心點進行下一次迭代呀酸,直到到達密度最大處(中心點不變)凉蜂。從每個數(shù)據(jù)點出發(fā)都可以進行該操作,在這個過程,統(tǒng)計出現(xiàn)在感興趣區(qū)域內(nèi)的數(shù)據(jù)的次數(shù)窿吩。該參數(shù)將在最后作為分類的依據(jù)茎杂。
與K-Means
算法不一樣的是,Mean Shift
算法可以自動決定類別的數(shù)目纫雁。與K-Means
算法一樣的是煌往,兩者都用集合內(nèi)數(shù)據(jù)點的均值進行中心點的移動。
算法步驟
下面是有關(guān)Mean Shift
聚類算法的步驟:
- 在未被標記的數(shù)據(jù)點中隨機選擇一個點作為起始中心點center轧邪;
- 找出以center為中心半徑為radius的區(qū)域中出現(xiàn)的所有數(shù)據(jù)點携冤,認為這些點同屬于一個聚類C。同時在該聚類中記錄數(shù)據(jù)點出現(xiàn)的次數(shù)加1闲勺。
- 以center為中心點,計算從center開始到集合M中每個元素的向量扣猫,將這些向量相加菜循,得到向量shift。
- center = center + shift申尤。即center沿著shift的方向移動癌幕,移動距離是||shift||。
- 重復(fù)步驟2昧穿、3勺远、4,直到shift的很惺蓖摇(就是迭代到收斂)胶逢,記住此時的center。注意饰潜,這個迭代過程中遇到的點都應(yīng)該歸類到簇C初坠。
- 如果收斂時當前簇C的center與其它已經(jīng)存在的簇C2中心的距離小于閾值,那么把C2和C合并彭雾,數(shù)據(jù)點出現(xiàn)次數(shù)也對應(yīng)合并碟刺。否則,把C作為新的聚類薯酝。
- 重復(fù)1半沽、2、3吴菠、4者填、5直到所有的點都被標記為已訪問。
- 分類:根據(jù)每個類做葵,對每個點的訪問頻率幔托,取訪問頻率最大的那個類,作為當前點集的所屬類。
算法實現(xiàn)
下面使用Python
實現(xiàn)了Mean Shift
算法的基本形式:
import numpy as np
import matplotlib.pyplot as plt
# Input data set
X = np.array([
[-4, -3.5], [-3.5, -5], [-2.7, -4.5],
[-2, -4.5], [-2.9, -2.9], [-0.4, -4.5],
[-1.4, -2.5], [-1.6, -2], [-1.5, -1.3],
[-0.5, -2.1], [-0.6, -1], [0, -1.6],
[-2.8, -1], [-2.4, -0.6], [-3.5, 0],
[-0.2, 4], [0.9, 1.8], [1, 2.2],
[1.1, 2.8], [1.1, 3.4], [1, 4.5],
[1.8, 0.3], [2.2, 1.3], [2.9, 0],
[2.7, 1.2], [3, 3], [3.4, 2.8],
[3, 5], [5.4, 1.2], [6.3, 2]
])
def mean_shift(data, radius=2.0):
clusters = []
for i in range(len(data)):
cluster_centroid = data[i]
cluster_frequency = np.zeros(len(data))
# Search points in circle
while True:
temp_data = []
for j in range(len(data)):
v = data[j]
# Handle points in the circles
if np.linalg.norm(v - cluster_centroid) <= radius:
temp_data.append(v)
cluster_frequency[i] += 1
# Update centroid
old_centroid = cluster_centroid
new_centroid = np.average(temp_data, axis=0)
cluster_centroid = new_centroid
# Find the mode
if np.array_equal(new_centroid, old_centroid):
break
# Combined 'same' clusters
has_same_cluster = False
for cluster in clusters:
if np.linalg.norm(cluster['centroid'] - cluster_centroid) <= radius:
has_same_cluster = True
cluster['frequency'] = cluster['frequency'] + cluster_frequency
break
if not has_same_cluster:
clusters.append({
'centroid': cluster_centroid,
'frequency': cluster_frequency
})
print('clusters (', len(clusters), '): ', clusters)
clustering(data, clusters)
show_clusters(clusters, radius)
# Clustering data using frequency
def clustering(data, clusters):
t = []
for cluster in clusters:
cluster['data'] = []
t.append(cluster['frequency'])
t = np.array(t)
# Clustering
for i in range(len(data)):
column_frequency = t[:, i]
cluster_index = np.where(column_frequency == np.max(column_frequency))[0][0]
clusters[cluster_index]['data'].append(data[i])
# Plot clusters
def show_clusters(clusters, radius):
colors = 10 * ['r', 'g', 'b', 'k', 'y']
plt.figure(figsize=(5, 5))
plt.xlim((-8, 8))
plt.ylim((-8, 8))
plt.scatter(X[:, 0], X[:, 1], s=20)
theta = np.linspace(0, 2 * np.pi, 800)
for i in range(len(clusters)):
cluster = clusters[i]
data = np.array(cluster['data'])
plt.scatter(data[:, 0], data[:, 1], color=colors[i], s=20)
centroid = cluster['centroid']
plt.scatter(centroid[0], centroid[1], color=colors[i], marker='x', s=30)
x, y = np.cos(theta) * radius + centroid[0], np.sin(theta) * radius + centroid[1]
plt.plot(x, y, linewidth=1, color=colors[i])
plt.show()
mean_shift(X, 2.5)
上述代碼執(zhí)行結(jié)果如下:
其他
Mean Shift
算法還有很多內(nèi)容未提及重挑。其中有“動態(tài)計算感興趣區(qū)域半徑”嗓化、“加入核函數(shù)計算漂移向量”等。本文作為入門引導(dǎo)谬哀,暫時只覆蓋這些內(nèi)容刺覆。