摘要:局部敏感哈希
,Python
洒沦,矢量檢索
,推薦系統(tǒng)
單獨(dú)記錄一下LSH算法的原理价淌,結(jié)合代碼深入理解一下申眼,因?yàn)檫@個(gè)算法的調(diào)參對(duì)結(jié)果影響極大,不懂原理就不會(huì)調(diào)參蝉衣,導(dǎo)致最終效果不理想
LSH概述
知識(shí)準(zhǔn)備
重點(diǎn)概念
算法流程
其他策略
Python代碼實(shí)現(xiàn)
這個(gè)代碼參考了https://blog.csdn.net/sgyuanshi/article/details/108132214
我在他的基礎(chǔ)上增加了idMap的映射括尸,在模型中把向量的標(biāo)識(shí)也灌了進(jìn)去
import numpy as np
from typing import List, Union
class EuclideanLSH:
def __init__(self, num_hash_tables: int, bucket_len: int, embedding_size: int):
"""
LSH
:param num_hash_tables:
:param bucket_len:
:param embedding_size:
"""
self.num_hash_tables = num_hash_tables
self.bucket_len = bucket_len
# 同一個(gè)hash_table采用同一個(gè)hash函數(shù),k和hash函數(shù)相同病毡,R相同
self.R = np.random.random([embedding_size, num_hash_tables])
# 一個(gè)hash——table一個(gè)隨機(jī)的b
self.b = np.random.uniform(0, bucket_len, [1, num_hash_tables])
# 初始化空的hash_table
self.hash_tables = [dict() for i in range(num_hash_tables)]
# ids和vector的對(duì)應(yīng)關(guān)系
self.ids_map = {}
def _hash(self, inputs: Union[List[List], np.ndarray]):
"""
將向量映射到對(duì)應(yīng)的hash_table的索引
:param inputs: 輸入的單個(gè)或多個(gè)向量
:return: 每一行代表一個(gè)向量輸出的所有索引濒翻,每一列代表位于一個(gè)hash_table中的索引
"""
# H(V) = |V·R + b| / a,R是一個(gè)隨機(jī)向量啦膜,a是桶寬有送,b是一個(gè)在[0,a]之間均勻分布的隨機(jī)變量,這個(gè)乘10是為了hash值分布地更開
hash_val = np.floor(np.abs(np.matmul(inputs, self.R) + self.b) * 10 / self.bucket_len)
return hash_val
def insert(self, inputs, ids): # 增加id和向量的對(duì)應(yīng)關(guān)系
"""
將向量映射到對(duì)應(yīng)的hash_table的索引僧家,并插入到所有hash_table中
:param inputs:
:return:
"""
self.ids_map = dict(zip(ids, inputs))
# 將inputs轉(zhuǎn)化為二維向量
inputs = np.array(inputs)
if len(inputs.shape) == 1:
inputs = inputs.reshape([1, -1])
hash_index = self._hash(inputs)
# 一條輸入向量雀摘,一條映射到所有hash_table的索引值
for id_one, inputs_one, indexs in zip(ids, inputs, hash_index):
# i代表第i個(gè)hash_table,key則為當(dāng)前hash_table的索引位置
for i, key in enumerate(indexs):
# 第n個(gè)hash_table的第k個(gè)桶位置八拱,將這條數(shù)據(jù)灌進(jìn)去
# 還可以這樣寫, 這個(gè)地方用元組是因?yàn)楹竺嫘枰尤雜et阵赠,不可變可以hash
self.hash_tables[i].setdefault(key, []).append(id_one)
def query(self, id_one, nums=20):
"""
查詢與id_one的inputs相似的向量,并輸出相似度最高的nums個(gè)
:param id_one:
:param nums:
:return:
"""
assert id_one in self.ids_map, "元素不存在在"
id_vector = self.ids_map[id_one]
hash_val = self._hash(id_vector).ravel() # 計(jì)算新輸入的在所有hash_table的值肌稻,然后拉平為一維向量
candidates = set()
# 每一張hash table中相同桶位置的向量全部加入候選集豌注,去重
for i, key in enumerate(hash_val):
# 后面是一個(gè)集合,所以用update灯萍,將集合拆碎加入,集合內(nèi)部元素必須是不可變的轧铁,所以提前轉(zhuǎn)了元組
candidates.update(self.hash_tables[i][key])
candidates = [x for x in candidates if x != id_one]
print("LSH之后所有hash_table一個(gè)桶下的候選集總計(jì){}".format(len(candidates)))
# 根據(jù)向量距離進(jìn)行排序
# 候選集暴力求解
res = [(x, self.euclidean_dis(self.ids_map[x], id_vector)) for x in candidates]
return sorted(res, key=lambda x: x[1])[:nums]
@staticmethod
def euclidean_dis(x, y):
"""
計(jì)算歐式距離
:param x:
:param y:
:return:
"""
x = np.array(x)
y = np.array(y)
return np.sqrt(np.sum(np.power(x - y, 2)))
現(xiàn)在我們了來(lái)調(diào)用一發(fā),數(shù)據(jù)采用離線訓(xùn)練好的170萬(wàn)的實(shí)體embedding向量旦棉,向量維度為16齿风,大概長(zhǎng)這個(gè)樣子
下一步初始化模型參數(shù),向量維度16固定绑洛,設(shè)置num_hash_table和bucket_len都是8
最后輸入一個(gè)實(shí)體尋找Top10救斑,對(duì)比一下LSH近似搜索和全表暴力搜索的性能和準(zhǔn)確度
結(jié)果是LSH從170萬(wàn)實(shí)體中先召回了所有和輸入實(shí)體共同桶的7.1萬(wàn)個(gè)候選,最終在桶下暴力搜索耗時(shí)3.6秒真屯,而全表暴力搜索耗時(shí)60秒脸候,從準(zhǔn)確率來(lái)看全表掃描更準(zhǔn)備,但是LSH的Top1也是全表掃描的第三名,整體來(lái)看LSH準(zhǔn)確率表現(xiàn)也不錯(cuò)