一、問題背景
在使用機(jī)器學(xué)習(xí)模型預(yù)測的各類場景中孽江,對象間的序關(guān)系是否預(yù)測準(zhǔn)確,是考量模型效果的指標(biāo)之一番电。
正序數(shù)岗屏、逆序數(shù)可作為模型效果的衡量指標(biāo)。當(dāng)樣本的之間存在序關(guān)系時漱办,由樣本間兩兩組成的
这刷,若模型預(yù)測結(jié)果的序關(guān)系與
之間的序關(guān)系相同,稱為正序娩井;若模型預(yù)測結(jié)果的序關(guān)系與
之間的序關(guān)系相反暇屋,稱為逆序。當(dāng)正序數(shù)量越多洞辣、逆序數(shù)量越少時咐刨,表明模型對序關(guān)系的刻畫越準(zhǔn)確,模型效果越好扬霜。正逆序即為正序數(shù)量與逆序數(shù)量的比值定鸟。
計算正序數(shù)量、逆序數(shù)量時著瓶,一種直觀的方法是暴力構(gòu)造所有的對并一一驗證联予,在
個樣本下時間復(fù)雜度為
,當(dāng)
的量級在5萬時普通的服務(wù)器上已需要耗費(fèi)近10分鐘(Python)蟹但,在更大量級時計算時間無法忍受躯泰。
進(jìn)一步,如果在模型訓(xùn)練過程中华糖,希望在訓(xùn)練的每一輪迭代時都查看正逆序的值(據(jù)此做終止條件)麦向;或是在多組參數(shù)間使用正逆序做驗證調(diào)參的標(biāo)準(zhǔn)時,正逆序的計算速度問題則更為突出客叉。我們需要復(fù)雜度更低的算法來快速計算出正逆序诵竭。
二话告、數(shù)學(xué)表述
給定個樣本,現(xiàn)有人工打好的每個樣本的
卵慰,記為
以及模型預(yù)測出的每個樣本的分值沙郭,記為
以 表示集合
的大小,
表示邏輯與裳朋、
表示邏輯或
以下 均在
中取值病线、且均滿足
,不再特別寫出
構(gòu)造出的所有pair集合為(注意我們只對不同的樣本構(gòu)造
)
正序集合為
逆序集合為
嚴(yán)格正序集合為
嚴(yán)格逆序集合為
由以上的定義容易看出鲤嫡,
三送挑、算法思路
注:以下排序均指升序排列。
MergeSort
熟悉排序算法的同學(xué)應(yīng)已看出:這個問題像極了經(jīng)典的找逆序?qū)栴}暖眼。
給定數(shù)組
惕耕,找出滿足
的
![]()
對數(shù)量。使用
可以在
時間計算出結(jié)果诫肠。
MergeSort計算逆序的思路較為直接司澎,使用的是Divide-and-Conquer的思想:
- 將數(shù)組二分
- 左半邊數(shù)組排好序并計算出其內(nèi)部的逆序數(shù) (可通過遞歸調(diào)用實現(xiàn))
- 右半邊數(shù)組排好序并計算出其內(nèi)部的逆序數(shù) (可通過遞歸調(diào)用實現(xiàn))
- 左半邊數(shù)組與右半邊數(shù)組合并,得到整體排好序的數(shù)組栋豫、以及左半邊與右半邊形成的逆序?qū)?shù)量
- 返回整體排好序的數(shù)組挤安、總逆序數(shù)(左半邊內(nèi)部 + 右半邊內(nèi)部 + 左右半邊聯(lián)合形成)
嚴(yán)格逆序 StrictWrong
我們從計算嚴(yán)格逆序集合入手。
根據(jù)上一點關(guān)于的討論笼才,一個自然的想法出現(xiàn)了:
先按
排序得到數(shù)組
漱受,接著計算數(shù)組
中的關(guān)于
的逆序?qū)?shù)量。
但這與我們的原始需求仍有細(xì)微的差別:在按排序后骡送,
我們對于下標(biāo)只能得到
然而
因此直接計算逆序?qū)Φ脑挵合郏瑫?img class="math-inline" src="https://math.jianshu.com/math?formula=label" alt="label" mathimg="1">相等的情形也算進(jìn)來,得不到正確答案摔踱。
為了消除這一影響虐先,我們可以使用一個小trick:
在按
排序時,排序的key不僅僅使用
派敷,而是按照二元組
排序蛹批。
即:先按排序,當(dāng)
值相等時篮愉,按
排序
于是
因此在這種情形下我們不會統(tǒng)計到任何label相等時的逆序?qū)Ω帧?稍?img class="math-inline" src="https://math.jianshu.com/math?formula=O(N%20%5Clog%20N)" alt="O(N \log N)" mathimg="1">時間內(nèi)計算得到
嚴(yán)格正序 StrictRight
思路與嚴(yán)格逆序完全一致试躏,只是不等號方向變反猪勇。為了程序復(fù)用,可將正負(fù)號變反后颠蕴、直接調(diào)用嚴(yán)格逆序計算的程序得到結(jié)果
正序Right, 逆序Wrong, Pair
因為
結(jié)合前面的結(jié)果泣刹, 只需計算出 助析, 即可獲得
和
的計算較為簡單,將
排好序后椅您,依次遍歷處理即可外冀,總復(fù)雜度為
結(jié)論及實驗
根據(jù)以上討論,各個集合的大小均可在 時間計算得出掀泳。
隨機(jī)構(gòu)造的樣本在普通服務(wù)器上的實際運(yùn)行時間統(tǒng)計如下(Python)雪隧,可以看出優(yōu)化后的算法執(zhí)行時間大幅提升。
樣本數(shù)量N | 基于 MergeSort 計算時間(秒) | 基于 暴力法 計算時間(秒) |
---|---|---|
500 | 0.017 | 0.051 |
5000 | 0.164 | 5.048 |
50000 | 2.017 | 512.371 |
四开伏、總結(jié)與展望
總結(jié)
- 將原始問題轉(zhuǎn)為經(jīng)典的求解數(shù)組逆序?qū)栴}膀跌,并使用經(jīng)典的MergeSort進(jìn)行求解。在問題轉(zhuǎn)化建模的過程中使用多字段排序等trick固灵,使得逆序?qū)栴}與原始問題完全等價。
- 最終將計算正逆序的時間復(fù)雜度由
優(yōu)化至
展望
- 在模型訓(xùn)練的迭代過程中劫流,
保持不變巫玻,且每一輪參數(shù)變化帶來的預(yù)測變化較小。是否有可能依據(jù)
的變化計算逆序?qū)Φ淖兓艋恪⒓铀儆嬎悖ú槐孛看味紡念^開始)仍秤,使得多輪迭代時的逆序?qū)τ嬎阏w時間縮短?
- 是否存在分布式的解決方案可很,能夠應(yīng)對海量樣本數(shù)量的正逆序計算诗力?
注
- 本文中使用的各類名詞及其對應(yīng)英文表達(dá),均為隨手自創(chuàng)我抠,僅為上下文敘述方便(除了mergeSort等大眾熟知的術(shù)語)
附苇本、代碼片段
(代碼中的變量true即為上文中的,取"groundtruth"之意菜拓;pred即為上文中的
)
from itertools import groupby
class InversionCounter(object):
@classmethod
def merge_sort_count_sub(cls, vals):
if len(vals) <= 1:
return vals, 0
n = len(vals)
left_vals, left_cnt = cls.merge_sort_count_sub(vals[:n/2])
right_vals, right_cnt = cls.merge_sort_count_sub(vals[n/2:])
left_i = 0
right_i = 0
mid_cnt = 0
new_vals = []
while True:
if left_vals[left_i][1] <= right_vals[right_i][1]:
new_vals.append(left_vals[left_i])
left_i += 1
elif left_vals[left_i][1] > right_vals[right_i][1]:
mid_cnt += (len(left_vals) - left_i)
new_vals.append(right_vals[right_i])
right_i += 1
if left_i == len(left_vals):
new_vals.extend(right_vals[right_i:])
break
if right_i == len(right_vals):
new_vals.extend(left_vals[left_i:])
break
return new_vals, left_cnt + mid_cnt + right_cnt
@classmethod
def merge_sort_count_strict_right(cls, trues, preds):
neg_preds = (-p for p in preds)
vals = zip(trues, neg_preds)
vals.sort()
return cls.merge_sort_count_sub(vals)[1]
@classmethod
def merge_sort_count_strict_wrong(cls, trues, preds):
vals = zip(trues, preds)
vals.sort()
return cls.merge_sort_count_sub(vals)[1]
@classmethod
def merge_sort_count_right(cls, trues, preds):
return cls.merge_sort_count_pair(trues) - cls.merge_sort_count_strict_wrong(trues, preds)
@classmethod
def merge_sort_count_wrong(cls, trues, preds):
return cls.merge_sort_count_pair(trues) - cls.merge_sort_count_strict_right(trues, preds)
@classmethod
def merge_sort_count_pair(cls, trues, preds=None):
'''
preds: dummpy variable, no need inside function
'''
trues = sorted(trues)
acc_num = 0
pair = 0
for k, ks in groupby(trues):
current_num = sum(1 for _ in ks)
acc_num += current_num
pair += (len(trues) - acc_num) * current_num
return pair