使用Pytorch實現(xiàn)Kmeans聚類

Kmeans是一種簡單易用的聚類算法翠霍,是少有的會出現(xiàn)在深度學習項目中的傳統(tǒng)算法以蕴,比如人臉搜索項目蝴乔、物體檢測項目(yolov3中用到了Kmeans進行anchors聚類)等歇攻。

一般使用Kmeans會直接調(diào)sklearn译隘,如果任務(wù)比較復(fù)雜,可以通過numpy進行自定義洛心,這里介紹使用Pytorch實現(xiàn)的方式固耘,經(jīng)測試,通過Pytorch調(diào)用GPU之后词身,能夠提高多特征聚類的速度厅目。


import torch
import time
from tqdm import tqdm

class KMEANS:
    def __init__(self, n_clusters=20, max_iter=None, verbose=True,device = torch.device("cpu")):

        self.n_cluster = n_clusters
        self.n_clusters = n_clusters
        self.labels = None
        self.dists = None  # shape: [x.shape[0],n_cluster]
        self.centers = None
        self.variation = torch.Tensor([float("Inf")]).to(device)
        self.verbose = verbose
        self.started = False
        self.representative_samples = None
        self.max_iter = max_iter
        self.count = 0
        self.device = device

    def fit(self, x):
        # 隨機選擇初始中心點,想更快的收斂速度可以借鑒sklearn中的kmeans++初始化方法
        init_row = torch.randint(0, x.shape[0], (self.n_clusters,)).to(self.device)
        init_points = x[init_row]
        self.centers = init_points
        while True:
            # 聚類標記
            self.nearest_center(x)
            # 更新中心點
            self.update_center(x)
            if self.verbose:
                print(self.variation, torch.argmin(self.dists, (0)))
            if torch.abs(self.variation) < 1e-3 and self.max_iter is None:
                break
            elif self.max_iter is not None and self.count == self.max_iter:
                break

            self.count += 1

        self.representative_sample()

    def nearest_center(self, x):
        labels = torch.empty((x.shape[0],)).long().to(self.device)
        dists = torch.empty((0, self.n_clusters)).to(self.device)
        for i, sample in enumerate(x):
            dist = torch.sum(torch.mul(sample - self.centers, sample - self.centers), (1))
            labels[i] = torch.argmin(dist)
            dists = torch.cat([dists, dist.unsqueeze(0)], (0))
        self.labels = labels
        if self.started:
            self.variation = torch.sum(self.dists - dists)
        self.dists = dists
        self.started = True

    def update_center(self, x):
        centers = torch.empty((0, x.shape[1])).to(self.device)
        for i in range(self.n_clusters):
            mask = self.labels == i
            cluster_samples = x[mask]
            centers = torch.cat([centers, torch.mean(cluster_samples, (0)).unsqueeze(0)], (0))
        self.centers = centers

    def representative_sample(self):
        # 查找距離中心點最近的樣本法严,作為聚類的代表樣本损敷,更加直觀
        self.representative_samples = torch.argmin(self.dists, (0))


def time_clock(matrix,device):
    a = time.time()
    k = KMEANS(max_iter=10,verbose=False,device=device)
    k.fit(matrix)
    b = time.time()
    return (b-a)/k.count

def choose_device(cuda=False):
    if cuda:
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    return device

if __name__ == "__main__":
    import matplotlib.pyplot as plt

    plt.figure()

    device = choose_device(False)

    cpu_speeds = []
    for i in tqdm([20,100,500,2000,8000,20000]):
        matrix = torch.rand((10000,i)).to(device)
        speed = time_clock(matrix,device)
        cpu_speeds.append(speed)
    l1, = plt.plot([20,100,500,2000,8000,20000],cpu_speeds,color = 'r',label = 'CPU')

    device = choose_device(True)

    gpu_speeds = []
    for i in tqdm([20, 100, 500, 2000, 8000, 20000]):
        matrix = torch.rand((10000, i)).to(device)
        speed = time_clock(matrix,device)
        gpu_speeds.append(speed)
    l2, = plt.plot([20, 100, 500, 2000, 8000, 20000], gpu_speeds, color='g',label = "GPU")



    plt.xlabel("num_features")
    plt.ylabel("speed(s/iter)")
    plt.title("Speed with cuda")
    plt.legend(handles = [l1,l2],labels = ['CPU','GPU'],loc='best')
    plt.savefig("../result/speed.jpg")

cpu和gpu運行的結(jié)果對比如下:

speed.jpg

可以看到,在特征數(shù)<3000的情況下深啤,cpu運行速度更快拗馒,但是特征數(shù)量超過3000之后,gpu的優(yōu)勢越來越明顯溯街。

因為pytorch的矩陣運算接口基本是照著numpy寫的瘟忱,所以numpy的實現(xiàn)方式大概只需要將代碼中的torch替換成numpy就可以了。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末苫幢,一起剝皮案震驚了整個濱河市访诱,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌韩肝,老刑警劉巖触菜,帶你破解...
    沈念sama閱讀 217,734評論 6 505
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異哀峻,居然都是意外死亡涡相,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,931評論 3 394
  • 文/潘曉璐 我一進店門剩蟀,熙熙樓的掌柜王于貴愁眉苦臉地迎上來催蝗,“玉大人,你說我怎么就攤上這事育特”牛” “怎么了?”我有些...
    開封第一講書人閱讀 164,133評論 0 354
  • 文/不壞的土叔 我叫張陵缰冤,是天一觀的道長犬缨。 經(jīng)常有香客問我,道長棉浸,這世上最難降的妖魔是什么怀薛? 我笑而不...
    開封第一講書人閱讀 58,532評論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮迷郑,結(jié)果婚禮上枝恋,老公的妹妹穿的比我還像新娘创倔。我一直安慰自己,他們只是感情好焚碌,可當我...
    茶點故事閱讀 67,585評論 6 392
  • 文/花漫 我一把揭開白布畦攘。 她就那樣靜靜地躺著,像睡著了一般呐能。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上抑堡,一...
    開封第一講書人閱讀 51,462評論 1 302
  • 那天摆出,我揣著相機與錄音,去河邊找鬼首妖。 笑死偎漫,一個胖子當著我的面吹牛,可吹牛的內(nèi)容都是我干的有缆。 我是一名探鬼主播象踊,決...
    沈念sama閱讀 40,262評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼棚壁!你這毒婦竟也來了杯矩?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,153評論 0 276
  • 序言:老撾萬榮一對情侶失蹤袖外,失蹤者是張志新(化名)和其女友劉穎史隆,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體曼验,經(jīng)...
    沈念sama閱讀 45,587評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡泌射,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,792評論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了鬓照。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片奏甫。...
    茶點故事閱讀 39,919評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡咳榜,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情暖璧,我是刑警寧澤,帶...
    沈念sama閱讀 35,635評論 5 345
  • 正文 年R本政府宣布七嫌,位于F島的核電站董虱,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏获讳。R本人自食惡果不足惜阴颖,卻給世界環(huán)境...
    茶點故事閱讀 41,237評論 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望丐膝。 院中可真熱鬧量愧,春花似錦钾菊、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,855評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至累颂,卻和暖如春滞详,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背紊馏。 一陣腳步聲響...
    開封第一講書人閱讀 32,983評論 1 269
  • 我被黑心中介騙來泰國打工料饥, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人朱监。 一個月前我還...
    沈念sama閱讀 48,048評論 3 370
  • 正文 我出身青樓岸啡,卻偏偏與公主長得像,于是被迫代替她去往敵國和親赫编。 傳聞我的和親對象是個殘疾皇子巡蘸,可洞房花燭夜當晚...
    茶點故事閱讀 44,864評論 2 354

推薦閱讀更多精彩內(nèi)容