Skip to content

树状数组

介绍

树状数组 (Binary Indexed Tree(B.I.T), Fenwick Tree) 是一个查询和修改复杂度都为 log(n) 的数据结构。

主要用于查询任意两位之间的所有元素之和,但是每次只能修改一个元素的值;

经过简单修改可以在 log(n) 的复杂度下进行范围修改,但是这时只能查询其中一个元素的值(如果加入多个辅助数组则可以实现区间修改与区间查询)。

这种数据结构(算法)并没有 C++ 和 Java 的库支持,需要自己手动实现。在 Competitive Programming 的竞赛中被广泛的使用。

树状数组和线段树很像,但能用树状数组解决的问题,基本上都能用线段树解决,而线段树能解决的树状数组不一定能解决。相比较而言,树状数组效率要高很多。

功能:区间和查询、单点修改

树状图概念

假设数组 A[1..n] ,那么查询 A[1]+...+A[n] 的时间是 log 级别的,而且是一个在线的数据结构,支持随时修改某个元素的值,复杂度也为 log 级别。

来观察这个图:

树状数组

令这棵树的结点编号为 C[1],C[2],...,C[n] 。令每个结点的值为这棵树的值的总和,那么容易发现:

C[1] = A[1]
C[2] = A[1] + A[2]
C[3] = A[3]
C[4] = A[1] + A[2] + A[3] + A[4]
C[5] = A[5]
C[6] = A[5] + A[6]
C[7] = A[7]
C[8] = A[1] + A[2] + A[3] + A[4] + A[5] + A[6] + A[7] + A[8]

基本算法

若区间结尾为 R ,则区间长度就等于 R 的“二进制分解”下最小的 2 的次幂,我们设为 lowbit(R)。即 R 在二进制表示下最小的一位。

lowbit(n) 表示取出非负整数 n 在二进制表示下最低位的 \(1\) 以及它后边的 \(0\) 构成的数值。

如:

lowbit(1) = 1 // 1
lowbit(2) = 2 // 10
lowbit(3) = 1 // 11
lowbit(4) = 4 // 100
lowbit(5) = 1 // 101
lowbit(6) = 2 // 110
lowbit(7) = 1 // 111
lowbit(8) = 8 // 1000

对于给定的序列 A 我们建立一个数组 C ,其中 C[x] 保存序列 A 的区间 \([x-lowbit(x)+1,x]\) 中所有数的和。

该结构满足下列性质:

  1. 每个内部节点 C[x] 保存以它为根的子树中所有叶节点的和。
  2. 每个内部节点 C[x] 的子节点个数等于 lowbit(x) 的大小。
  3. 除树根外,每个内部节点 C[x] 的父节点是 C[x+lowbit(x)]
  4. 树的深度为 log(n)

对于 lowbit(n) ,我们可以通过以下方式计算:

int lowbit(int n)
{
    return n&-n;
}

关于 x&-x 的原理

前置知识

关于原码,反码,补码

类型 正数 负数
二进制数值 +1101001 -1101001
原码 01101001 11101001
反码 01101001 10010110
补码 01101001 10010111

注意:对于正数,原码、补码、反码都相等,且符号位为 0

详解

注意,-x 为 x 的补码

如 1&(-1)的二进制位运算为(二个二进位都为 1):

  0000 0001
& 1111 1111
------------
  0000 0001

6&(-6)的二进制位运算为(二个二进位都为 1):

  0000 0110
& 1111 1010
------------
  0000 0010

x&-x == x&(~x+1)

最终结论

如果这个数是偶数,则结果为能被这个数整除的最大 2 的次幂

如果这个数是奇数,则结果必为 1

代码展示

#include<iostream>
using namespace std;
int n,m,i,num[100001],t[200001],l,r;//num:原数组;t:树状数组 
int lowbit(int x)
{
    return x&(-x);
}
void change(int x,int p)//将第x个数加p 
{
    while(x<=n)
    {
        t[x]+=p;
        x+=lowbit(x);
    }
    return;
}
int sum(int k)//前k个数的和 
{
    int ans=0;
    while(k>0)
    {
        ans+=t[k];
        k-=lowbit(k);
    }
    return ans;
}
int ask(int l,int r)//求l-r区间和 
{
    return sum(r)-sum(l-1); 
}
int main()
{
    cin>>n>>m;
    for(i=1;i<=n;i++)
    {
        cin>>num[i];
        change(i,num[i]);
    }
    for(i=1;i<=m;i++)
    {
        cin>>l>>r;
        cout<<ask(l,r)<<endl;
    }
    return 0;
}

注:部分资料来源于百度百科


Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License.