Advise Category: Algorithm >> 樹狀數(shù)組
Scenario
- 單點更新
- 區(qū)間求和(前綴和)
Objects
- 待處理數(shù)組兆龙,a[1...n]
- 待維護樹狀數(shù)組逼纸,c[1...n]
- 結(jié)果數(shù)組(前綴和)(待處理數(shù)組的前綴和)缀去,r[1...n]
Ideas
Q:
A1:
每次計算r[i]
都遍歷a
的前i
項干旁,一個結(jié)果r[i]
一次遍歷撒桨,n
個結(jié)果r[n]
需要n
次遍歷(不適合大數(shù)據(jù)數(shù)組的情況)董虱。
A2:(樹狀)
孽椰?如何加速遍歷
不與a[i]
為維度進行運算鹃愤,而是通過維護一個數(shù)組(數(shù)組)c[1...n]
其中每一項是通過一定規(guī)律的m
項和搓彻,再通過規(guī)律找到前綴和r[i]
的待累積c[i]
并求和如绸。
Method for A2
a[i] -> c[i]:
c[1] = c[0001] = a[1];
c[2] = c[0010] = a[1]+a[2];
c[3] = c[0011] = a[3];
c[4] = c[0100] = a[1]+a[2]+a[3]+a[4];
c[5] = c[0101] = a[5];
c[6] = c[0110] = a[5]+a[6];
c[7] = c[0111] = a[7];
c[8] = c[1000] = a[1]+a[2]+a[3]+a[4]+a[5]+a[6]+a[7]+a[8];
......
Rule:
c[i]=a[i-2^k+1]+a[i-2^k+2]+......a[i];
Note:
-
k
為i
的二進制中從最低位到高位連續(xù)零的長度, 例如i=8(1000)
時,k=3
旭贬。 - 可以理解為這是一種分類方法怔接,通過維護分類(要么有零,要么沒零稀轨,有零說明進位了需要把其下的數(shù)都考慮在內(nèi))數(shù)組
c[i]
讓前綴求和變得更快扼脐。
單點更新
?如果修改a
中的一個元素奋刽,c
中的元素如何變化
Assumption:a[3] = a[3] + 1
瓦侮,從3往后找,直到數(shù)組結(jié)束佣谐。
lowbit(3)=0001=2^0 3+lowbit(3)=04(00100) c[04] += 1
lowbit(4)=0100=2^2 4+lowbit(4)=08(01000) c[08] += 1
lowbit(8)=1000=2^3 8+lowbit(8)=16(10000) c[16] += 1
......
Note:
- 可以看出a[3]變化之后肚吏,會涉及到c[4]/c[8]/[16]...的變化,所以需要更新跟隨變化的c中的元素狭魂。
罚攀?lowbit
lowbit(x)是取出x的最低位1(從右往左數(shù)第一個1)党觅,滿足:
int lowbit(x){ return x & (-x); }
Note:
一個數(shù)的負數(shù)就等于對這個數(shù)取反+1
補碼和原碼必然相反,所以原碼有0的部位補碼全是1,補碼再+1之后由于進位那么最末尾的1和原碼最右邊的1一定是同一個位置
剛好等于
2^k
,k
為x
的二進制中從最低位到高位連續(xù)零的長度
Code:
void update(int x,int y,int n){
for(int i=x; i <= n; i += lowbit(i)) //x為更新的位置,y為更新后的數(shù),n為數(shù)組最大值
c[i] += y;
}
區(qū)間求和
e.g 求r[5]
Disappear:
c[4]=a[1]+a[2]+a[3]+a[4];
c[5]=a[5];
sum(i = 5) = c[4] + c[5];
sum(i = 101) = c[(100)] + c[(101)];
Note:
- 首次從101,減去最低位的1就是100斋泄,剛好是單點更新的你操作杯瞻。
Code:
int sum(int x){
int ans = 0;
for(int i = x; i >= 0; i -= lowbit(i))
ans += c[i];
return ans;
}
Example
leet-code:
Ans:
class Solution{
public List<Integer> countSmaller3(int[] nums){
if(nums == null) return null;
if(nums.length == 0) return new ArrayList<>();
List<Integer> result = new ArrayList<>();
Set<Integer> set = new HashSet<Integer>();
for (int i = 0; i < nums.length; i++) {
set.add(nums[i]);
}
int[] c = Arrays.stream(set.toArray()).sorted().mapToInt(e -> Integer.parseInt(e.toString())).toArray();
int[] d = new int[c.length + 2];
for (int i = nums.length - 1; i > -1; i--) {
int idx = Arrays.binarySearch(c, nums[i]) + 1;
int s = sum(d, idx - 1);
update(d, idx, 1);
result.add(s);
}
Collections.reverse(result);
return result;
}
private int lowbit(int x){
return x & (-x);
}
private void update(int[] c, int i, int delta){
while( i <= c.length - 1 ){
c[i] += delta;
i += lowbit(i);
}
}
private int sum(int[] c, int i){
int ans = 0;
while( i > 0 ){
ans += c[i];
i -= lowbit(i);
}
return ans;
}
}