在DLRM中有對(duì)訓(xùn)練集做處理的函數(shù),我們對(duì)訓(xùn)練序列做了研究,
def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
# WARNING: notice that we are processing the batch at once. We implicitly
# assume that the data is laid out such that:
# 1. each embedding is indexed with a group of sparse indices,
# corresponding to a single lookup
# 2. for each embedding the lookups are further organized into a batch
# 3. for a list of embedding tables there is a list of batched lookups
ly = []
for k, sparse_index_group_batch in enumerate(lS_i):
sparse_offset_group_batch = lS_o[k]
# embedding lookup
# We are using EmbeddingBag, which implicitly uses sum operator.
# The embeddings are represented as tall matrices, with sum
# happening vertically across 0 axis, resulting in a row vector
# E = emb_l[k]
if v_W_l[k] is not None:
per_sample_weights = v_W_l[k].gather(0, sparse_index_group_batch)
else:
per_sample_weights = None
if:
....
else:
E = emb_l[k]
V = E(
sparse_index_group_batch,
sparse_offset_group_batch,
per_sample_weights=per_sample_weights,
)
ly.append(V)
重點(diǎn)是這個(gè)地方川慌,其中E是所有打包好的Embedding:
其中第一維為這個(gè)Embedding table中包括的vector的數(shù)量驹饺,第二維64為vector的維度(有64個(gè)float)。
sparse_index_group_batch
以及sparse_offset_group_batch
為訓(xùn)練時(shí)需要的index以及offset最仑,Embedding table會(huì)根據(jù)index找具體的vector藐俺。
offset需要注意,offset = torch.LongTensor([0,1,4]).to(0)
代表三個(gè)樣本泥彤,第一個(gè)樣本是0 ~ 1欲芹,第二個(gè)是1 ~ 4,第三個(gè)是4(網(wǎng)上解釋的都不夠清楚吟吝,所以我這里通過代碼實(shí)際跑了一下測(cè)出來(lái)是這個(gè)結(jié)果) 菱父。且左閉右開[0,1)這種形式取整數(shù)(已經(jīng)根據(jù)代碼進(jìn)行過驗(yàn)證)。
詳細(xì)解釋一下流程:
首先在apply_emb
函數(shù)中每次循環(huán)會(huì)取出當(dāng)前第k個(gè)Emb table:E = emb_l[k]
,其中k是當(dāng)前所在輪數(shù)浙宜。
對(duì)于index數(shù)組與offset數(shù)組:
我們能看到官辽,第一個(gè)tensor是index,有五個(gè)元素粟瞬,代表我要取的當(dāng)前table中的vector的編號(hào)(共5個(gè))同仆。
而后面的offset就代表我取出來(lái)的這5個(gè)數(shù)組哪些要進(jìn)行reduce操作(加和等)。
例如我如果取offset為[0,3]裙品,則代表0俗批,1,2相加進(jìn)行reduce市怎,3岁忘,4進(jìn)行reduce。所以最終出來(lái)的數(shù)字個(gè)數(shù)就是offset的size区匠。
IS_I以及IS_O生成的位置
在dlrm_data_pytorch.py中的collate_wrapper_criteo_offset()
函數(shù)里:
def collate_wrapper_criteo_offset(list_of_tuples):
# where each tuple is (X_int, X_cat, y)
transposed_data = list(zip(*list_of_tuples))
X_int = torch.log(torch.tensor(transposed_data[0], dtype=torch.float) + 1)
X_cat = torch.tensor(transposed_data[1], dtype=torch.long)
T = torch.tensor(transposed_data[2], dtype=torch.float32).view(-1, 1)
batchSize = X_cat.shape[0]
featureCnt = X_cat.shape[1]
lS_i = [X_cat[:, i] for i in range(featureCnt)]
lS_o = [torch.tensor(range(batchSize)) for _ in range(featureCnt)]
return X_int, torch.stack(lS_o), torch.stack(lS_i), T
在這里生成訪問序列干像,首先將傳入的數(shù)據(jù)解析為X_cat,當(dāng)bs=2時(shí)辱志,X_cat為:
tensor([[ 0, 17, 36684, 11838, 1, 0, 145, 9, 0, 1176,
24, 34569, 24, 5, 24, 15109, 0, 19, 14, 3,
32351, 0, 1, 4159, 32, 5050],
[ 3, 12, 33818, 19987, 0, 5, 1426, 1, 0, 8616,
729, 31879, 658, 1, 50, 26833, 1, 12, 89, 0,
29850, 0, 1, 1637, 3, 1246]])
其中每一個(gè)tensor有26個(gè)數(shù)字蝠筑,代表26個(gè)Embedding table。每一個(gè)數(shù)字代表其中每個(gè)table需要訪問的vector揩懒。(比如0代表訪問第一個(gè)table的0號(hào)vector)
下面將訪問序列打包什乙,IS_i為:
[tensor([0, 3]), tensor([17, 12]), tensor([36684, 33818]), tensor([11838, 19987]), tensor([1, 0]), tensor([0, 5]), tensor([ 145, 1426]), tensor([9, 1]), tensor([0, 0]), tensor([1176, 8616]), tensor([ 24, 729]), tensor([34569, 31879]), tensor([ 24, 658]), tensor([5, 1]), tensor([24, 50]), tensor([15109, 26833]), tensor([0, 1]), tensor([19, 12]), tensor([14, 89]), tensor([3, 0]), tensor([32351, 29850]), tensor([0, 0]), tensor([1, 1]), tensor([4159, 1637]), tensor([32, 3]), tensor([5050, 1246])]
這里bs為2,所以[tensor([0, 3])
代表訪問第一個(gè)table的0已球,3個(gè)vactor臣镣。
這里我們要再次理解一下數(shù)據(jù)集的含義,這里每一個(gè)table都是用戶的一個(gè)特征(所在城市智亮、年齡等)忆某,所以每一個(gè)用戶也就是每個(gè)table擁有一個(gè)數(shù)值,所以當(dāng)bs=2時(shí)阔蛉,這里的tensor[0弃舒,3]代表對(duì)兩個(gè)用戶進(jìn)行訓(xùn)練,其中第一個(gè)用戶的第一個(gè)table取值是0號(hào)vector状原,第二個(gè)用戶第一個(gè)table取值是3號(hào)vector聋呢。