說起歸并排序(Merge Sort)驹饺,其在排序界的地位可不低,畢竟O(nlogn)比較排序的三大排序方法捞高,就是Quick Sort, Merge Sort和Heap Sort砸琅。歸并排序是典型的分而治之方法,先來看看其最簡單的遞歸實現(xiàn):
def merge_sort(lst):
"""Sortsthe input list using the merge sort algorithm.
# >>> lst = [4, 5, 1, 6, 3]
# >>> merge_sort(lst)
[1, 3, 4, 5, 6]
"""
if len(lst) <= 1:
return lst
mid = len(lst) // 2
left = merge_sort(lst[:mid])
right = merge_sort(lst[mid:])
return merge(left, right)
def merge(left, right):
"""Takestwo sorted lists and returns a single sorted list by comparing the
elements one at a time.
# >>> left = [1, 5, 6]
# >>> right = [2, 3, 4]
# >>> merge(left, right)
[1, 2, 3, 4, 5, 6]
"""
if not left:
return right
if not right:
return left
if left[0] < right[0]:
return [left[0]] + merge(left[1:], right)
return [right[0]] + merge(left, right[1:])
很明顯劝枣,歸并排序是典型的分而治之(Divide and Conquer,D&C)算法汤踏,思想就是先把兩半數(shù)據(jù)分別排序,然后再歸并到一起舔腾。這樣T(n) = 2T(n/2) + O(n)溪胶,由Master Theorem可以得到其時間復雜度是O(nlogn)。
再看具體的實現(xiàn)稳诚。排序主體函數(shù)用的是遞歸哗脖,歸并算法一般都是這樣;而merge部分其實也可以用迭代來完成:
def merge_2(left, right):
p1 = p2 = 0
temp = []
while p1 < len(left) and p2 < len(right):
if left[p1] <= right[p2]:
temp.append(left[p1])
p1 += 1
else:
temp.append(right[p2])
p2 += 1
while p1 < len(left):
temp.append(left[p1])
p1 += 1
while p2 < len(right):
temp.append(right[p2])
p2 += 1
return temp
單純就此情景而言扳还,迭代的merge顯得冗長而且效率沒有提升才避。但是其好處就是適用性廣,因為有很多merge sort的變形氨距,不太方便遞歸調(diào)用merge函數(shù)桑逝。
變形:merge sort有很多tweak的應用,大部分是需要考慮數(shù)組前后關(guān)系俏让。
例1. 給出一個數(shù)組nums楞遏,需要對每個數(shù)其index之后比它大的數(shù)的個數(shù)求和。例如給出[7, 4, 5, 2, 8, 9, 0, 1]首昔,返回11寡喝,因為7后面有8,9兩個比它大的,4有3個沙廉,5有2個拘荡,2有2個,8有1個撬陵,0有1一個珊皿,總共2+3+2+2+1+1=11。
【解】
Method 1:一個naive的方法就是對于每一個數(shù)巨税,遍歷搜索其后面所有比其大的數(shù)蟋定,顯然時間復雜度是O(n^2)。
Method 2:還有一個方法就是考慮Segment Tree草添,先構(gòu)建從min到max的線段樹驶兜,O(max(n) - min(n)),初始count都是0远寸。然后從反方向考慮抄淑,考慮前面比其小的有多少個。也就是說對于某個n[i]驰后,考慮[min(n), n[i]-1]這個區(qū)間里面有多少count肆资。完了再把n[i]的count++。就這個例子而言灶芝,先放7郑原,然后4進來的時候搜索[0,3]區(qū)間唉韭,因為是要比4小,再把4的count設置為1犯犁;這樣5進來的時候属愤,就能搜索到4的存在。
這個算法后面的步驟是O(nlogn)酸役,但是需要構(gòu)造一個線段樹住诸,假如max很大很大,就不太合適簇捍。當然也可以argue說我就干脆構(gòu)造一個囊括最小到最大的32位int的線段樹只壳,這還是O(1)呢XD俏拱。
Method 3:這個解法考慮使用merge sort的tweak暑塑。因為要求每個n[i]之后比它大的數(shù),可以用分而治之的思想锅必,即考慮前一半有多少個事格,后一半有多少個,然后再考慮之間有多少個搞隐。
就這個例子而言驹愚,考慮最后一次merge之前的樣子:[2,4,5,7][0,1,8,9],此時兩半里面都已經(jīng)計算完畢劣纲,只需要計算merge時候產(chǎn)生的結(jié)果逢捺。很明顯結(jié)果是8,因為前一半的4個數(shù)都比8和9要小癞季。但是如何計算呢劫瞳?
考慮到后一半已經(jīng)排好序,假如對后一半使用binary search绷柒,自然可以得到第一個大于前一半某一個數(shù)的index志于,從而獲得所有大于這個數(shù)的個數(shù)。也就是說這個merge是O(nlogn)废睦。那么總體就是T(n) = T(n/2) + O(nlogn)伺绽,由Master Theorem可知復雜度是O(n(logn)^2)。
Method 4:但是嗜湃,上面這個merge方法沒有利用前一半也排好序的條件奈应,因此可以做到更好。
考慮兩個指針p1和p2购披,分別指向前一半和后一半杖挣。p1初始是在2,p2在0今瀑,因為此時n[p2] < n[p1]因此p2增加程梦,直至指向8点把。那么因為第二半是遞增的,p2后面的數(shù)肯定也滿足屿附,因此這時候就可以獲得大于n[p1]的個數(shù):第二半的長度-p2郎逃。然后呢?p1遞增指向4挺份,假如p2重新回到0然后掃描褒翰,這個復雜度就是O(n^2),比上面的二分查找還要差匀泊。
因此做一些調(diào)整优训,不遞增p1,而是遞增p2各聘。也就是說換一個思路揣非,不是從第二半里面找比第一半大的,而是從第一半里面找比第二半小的:剛開始還是p1指向2躲因,p2指向0早敬,然后因為n[p1] >= n[p2],因為第一半遞增大脉,后面的肯定也比n[p2]要大搞监,因此沒必要往后看,可以直接計算個數(shù):p1-s镰矿,s是遞歸使用的開始的index琐驴,這里是0,也就是說對于n[p2]沒有比其更小的秤标。
然后遞增p2绝淡,但需要注意的是p1不用復位,這是很關(guān)鍵的一點抛杨。為什么够委?因為p1停止的條件,要么就是已經(jīng)掃完整個一半了怖现,要么就是現(xiàn)在的n[p1]比n[p2-1]要大茁帽,也就是說現(xiàn)在的p1之前的都比n[p2-1]要小,而n[p2]>n[p2-1]屈嗤,因此前面那些根本就不需要比較就能知道結(jié)論潘拨,可以直接沿用之前的p1的位置。
在這個例子里面饶号,比較明顯的就是8和9.對于8铁追,p1將會遞增至第一半的長度,也就是說整個第一半都比8要小茫船,那么對于9而言琅束,比8大扭屁,因此整個第一半也都比9小,無需再從頭比較涩禀。
再舉一個一般性一點的例子:[2,4,5,7][0,1,6,8]料滥,對于6,p1將會停在7上面艾船,計數(shù)是3葵腹;p2遞增后,對于8屿岂,可以知道p1前面都是比6小的践宴,那么肯定也就比8小,因此直接從p1在7上面開始爷怀,最后計數(shù)是4阻肩。
這樣一來,merge函數(shù)兩個指針就不需要走回頭路霉撵,效率O(n)磺浙,整體效率是O(nlogn),空間復雜度O(n)徒坡。當然,具體實現(xiàn)的時候瘤缩,還是要把兩半真正的merge排好序喇完,因為上面的計算都是在兩邊都排好序的情況下進行的。當只有一個元素的時候可以直接返回0剥啤。代碼如下:
# count number that larger than it and after it
def dc2(self, n, s, e):
if s >= e:
return 0
m = (s + e) // 2
ans = self.dc2(n, s, m) + self.dc2(n, m + 1, e)
p1 = s
for q in range(m + 1, e + 1):
while p1 <= m and n[q] > n[p1]:
p1 += 1
ans += p1 - s
# merge
temp = []
p1, p2 = s, m + 1
while p1 <= m and p2 <= e:
if n[p1] <= n[p2]:
temp.append(n[p1])
p1 += 1
else:
temp.append(n[p2])
p2 += 1
while p1 <= m:
temp.append(n[p1])
p1 += 1
while p2 <= m:
temp.append(n[p2])
p2 += 1
for i in range(len(temp)):
n[i + s] = temp[i]
return ans
例2. 給出一個數(shù)組n和一個范圍[a, b]锦溪,求n有多少個子區(qū)間的和在[a,b]之內(nèi)。假設數(shù)組n的元素和a,b都是整數(shù)府怯。例如給出[2,3,4,1]刻诊,和范圍[3,5],那么子區(qū)間[3][4][2,3][4,1]都滿足條件牺丙,返回4则涯。
【解】
Method 1:naive方法就是找出所有的子區(qū)間,然后看有多少個滿足條件冲簿。復雜度非常高粟判。
Method 2:看到子區(qū)間之和,當然想到prefix sum峦剔。也就是說可以造一個數(shù)組s档礁,每一個元素s[i] = n[0]+...+n[i]。那么所有的子區(qū)間除了n[0]都可以用s的后一個元素減去前一個元素獲得吝沫。
也就是說呻澜,問題轉(zhuǎn)換成為:給出一個數(shù)組s递礼,計算有多少對ij,使得s[i] - s[j] in [a,b]而且i < j羹幸?
假如s是升序的宰衙,那么好說;但s是無序的睹欲。Naive方法就是對每一個s[i]供炼,都掃一遍后面元素看看能不能滿足在區(qū)間a+s[i],b+s[i]里面,假如滿足那么減去s[i]就在要求的區(qū)間里面窘疮。當然最后還需要比較一下單個的s元素袋哼。這個做法復雜度O(n^2)。
Method 3:在子區(qū)間prefix sum的基礎上闸衫,考慮merge sort的tweak涛贯。假設n=[7, 4, 5, 2, 8, 9, 0, 1], a=0, b=7。
考慮最后一次merge之前的情況:[2,4,5,7][0,1,8,9].從上一題得到啟發(fā)蔚出,假如對第一半里的每一個數(shù)s[i]弟翘,在第二半里面二分查找第一個大于等于s[i]+a的index1,假如index1不存在那就不需要再找了骄酗,沒有符合條件的稀余;和第一個大于s[i]+b的index2,假如不存在那么index2=e也就是end的index趋翻。那么自然就可以得到個數(shù)index2 - index1睛琳。這樣merge的復雜度是O(nlogn),總體O(n(logn)^2)踏烙。
Method 4:
在Method 3的基礎上改進师骗。類似于例1,Method 3的問題還是在于沒有利用第一半排好序的條件讨惩。
考慮三個指針辟癌,p1p2和q,p1p2都指向第一半荐捻,p2指向第二半黍少。 因為要利用第一半排序的條件,因此還是固定遞增q靴患。對于s[q]仍侥,需要s[q] - s[i] 在區(qū)間[a,b]當中。也就是說s[q] - s[i] >= a, s[i] <= s[q] - a; s[q] - s[i] <= b, s[i] >= s[q] - b鸳君。
因此兩個指針p1p2农渊,p1不斷遞增直至不滿足s[p1] <= s[q] - a,p2不斷遞增直至不滿足s[p2] < s[q] - b。那么砸紊,p1之前的都是滿足s[q] - s[i] >= a的,p2之后的都是滿足s[q] - s[i] <= b传于,p1p2之間的就是滿足條件的,即count+=p1-p2.
然后遞增q醉顽,因為s[q] >= s[q-1]沼溜,因此之前p1p2的位置可以延續(xù),即s[q] - a >= s[q-1] - a >= s[i]游添,也就是說p1之前p2之前的元素還是滿足那些條件系草。因此,這個merge函數(shù)的復雜度是O(n)唆涝,總體時間復雜度O(nlogn)找都,空間復雜度O(n)。注意單個區(qū)間的情況已經(jīng)被涵蓋了廊酣。代碼如下:
class Solution:
# count numbers of subarray sum in range of [a,b]
def countSubarraySum(self, nums, a, b):
if not nums:
return 0
n = [0] * len(nums)
for i in range(len(nums)):
if i != 0:
n[i] = nums[i] + n[i - 1]
else:
n[i] = nums[i]
return self.dc(n, a, b, 0, len(n) - 1)
# count number of prefix sum that x[i] - x[j] in [a, b] and i > j plus itself in [a, b]
def dc(self, n, a, b, s, e):
if s > e:
return 0
if s == e:
return a <= n[s] <= b
m = (s + e) // 2
ans = self.dc(n, a, b, s, m) + self.dc(n, a, b, m + 1, e)
p1 = p2 = s
for q in range(m + 1, e + 1):
while p1 <= m and n[q] - n[p1] >= a:
p1 += 1
while p2 <= m and n[q] - n[p2] > b:
p2 += 1
if p2 <= p1:
ans += p1 - p2
# merge
temp = []
p1, p2 = s, m + 1
while p1 <= m and p2 <= e:
if n[p1] <= n[p2]:
temp.append(n[p1])
p1 += 1
else:
temp.append(n[p2])
p2 += 1
while p1 <= m:
temp.append(n[p1])
p1 += 1
while p2 <= m:
temp.append(n[p2])
p2 += 1
for i in range(len(temp)):
n[i + s] = temp[i]
return ans