线段树

线段树是一种高级数据结构,可以快速地进行数组的区间修改、区间查询。加上懒标记,所有操作都是 O(logn)O(logn) 的。

本文要求有基础才能看懂,不了解什么是线段树的话可以看OI Wiki 线段树基础【左程云】线段树专题1-线段树原理和代码详解

操作

线段树主要有以下几个操作:初始化、建树、向上更新、向下更新、区间修改、区间查询。

我们以下面这个题目为例:

P3372 【模板】线段树 1

如题,已知一个数列 {ai}\{a_i\},你需要进行下面两种操作:

  1. 将某区间每一个数加上 kk
  2. 求出某区间每一个数的和。

看数据范围就知道会超时,不信的话可以自行尝试,这是线段树模板题,我们使用线段树。

初始化

为了方便调用,我们使用结构体封装线段树,命名为 Segment

#define lc p<<1       // 简写左右孩子节点
#define rc p<<1|1

struct Segment {
    int N;             // 元素数量
    struct Node {      // 每个区间存的数
        int l, r, sum, flag; // l 左区间,r 右区间,sum 区间和,flag 懒标记
    } *tr;

    // 构造函数初始化,开4倍空间防止溢出
    Segment(int n) { N = n; tr = new int[n * 4]; }
    // 析构函数,释放内存,可加可不加
    ~Segment() { delete[] tr; } 

    void build() { ... } // 下面的函数都写在结构体里面
};

向上更新

将某个点的值更新为左右孩子节点的和

void pushup(int p) {
    tr[p].sum = tr[lc].sum + tr[rc].sum; 
    // lc = p << 1 = p * 2, rc = p << 1 | 1 = p * 2 + 1
}

建树

通过数组建立线段树,建立每个区间

// 通过数组v来建立线段树,p默认为1,l表示左区间,r右区间
void build(vector<int>& v, int p, int l, int r) {
    tr[p] = {l, r, v[l], 0}; // 赋值区间值,sum值会被后续pushup覆盖
    if (l == r) return;      // 如果到最底层,则返回
    int mid = (l + r) >> 1;  // 分裂
    build(v, lc, l, mid);    // 建左子树
    build(v, rc, mid + 1, r); // 右子树
    pushup(p);                // 赋值完左右子树后,向上更新,更新当前值
}

向下更新

由于懒标记,我们需要向下更新

void pushdown(int p) {
    if (tr[p].flag) { // 如果当前区间有懒标记,则向下传播
        tr[lc].sum += (tr[lc.r] - tr[lc].l + 1) * tr[p].flag; // 区间和加上区间长度乘懒标记
        tr[rc].sum += (tr[rc.r] - tr[rc].l + 1) * tr[p].flag; // 右区间同理
        tr[lc].flag += tr[p].flag;  // 懒标记一同向下传递
        tr[rc].flag += tr[p].flag;  // 右区间同理
        tr[p].flag = 0;             // 清除当前区间懒标记
    }
}

区间修改

将任意区间加上某个数

// p当前区间,l区间左边界,r区间右边界,k区间加上k
void update(int p, int l, int r, int k) {
    if (l <= tr[p].l && r >= tr[p].r) {     // 如果当前区间被需求的区间包裹
        tr[p].sum += (tr[p].r - tr[p].l + 1) * k;  // 修改当前的区间的和,即相当于给区间所有元素加上了k
        tr[p].flag += k;                    // 区间懒标记加上k
        return;     // 返回
    }
    pushdown(p);      // 向下更新一层,因为后续需要访问儿子节点
    int mid = (tr[p].l + tr[p].r) >> 1; // 分裂中间点
    if (l <= mid) update(lc, l, r, k); // 如果包含左区间,递归更新左区间
    if (r >  mid) update(rc, l, r, k); // 更新右区间
    pushup(p);          // 子节点更新完必须更新父节点
}

区间查询

查询任意区间的和

// p当前区间,l区间左边界,r区间右边界
int query(int p, int l, int r) {   
    // 如果当前区间被查询区间全包含,则直接返回当前区间节点上的数据
    if (l <= tr[p].l && r >= tr[p].r) {
        return tr[p].sum; 
    }
    pushdown(p); // 向下查询之前一定要向下传递懒标记(有的话)
    int mid = (tr[p].r + tr[p].l) >> 1; // 分裂中点
    int sum = 0;    // 存左右子树的和
    if (l <= mid) sum += query(lc, l, r); // 如果左区间包含查询区间,则递归查询左区间
    if (r >  mid) sum += query(rc, l, r); // 右区间同理
    return sum; // 返回左右区间的和
}

完整代码

struct Segment {

#define lc p<<1
#define rc p<<1|1
    int N; 
    struct node {
        int l, r, sum, add; 
    } *tr;

    Segment(int n) { N = n; tr = new node[n * 4 + 5]; }

    void pushup(int p) {
        tr[p].sum = tr[lc].sum + tr[rc].sum;
    }
    
    void _build(vi&v, int p, int l, int r) {
        tr[p] = {l, r, v[l], 0};
        if (l == r) return;
        int mid = (l + r) >> 1;
        _build(v, lc, l, mid);
        _build(v, rc, mid + 1, r);
        pushup(p); 
    }

    void build(vi&v) {
        int p = 1, l = 1, r = N;
        _build(v, p, l, r);
    }

    void pushdown(int p) {
        if(tr[p].add) {
            tr[lc].sum += (tr[lc].r - tr[lc].l + 1) * tr[p].add;
            tr[rc].sum += (tr[rc].r - tr[rc].l + 1) * tr[p].add;
            tr[lc].add += tr[p].add;
            tr[rc].add += tr[p].add;
            tr[p].add = 0;
        }
    }

    void update(int p, int l, int r, int k) {
        if (l <= tr[p].l && r >= tr[p].r) {
            tr[p].sum += (tr[p].r - tr[p].l + 1) * k;
            tr[p].add += k;
            return;
        }
        pushdown(p);
        int mid = (tr[p].l + tr[p].r) >> 1;
        if (l <= mid) update(lc, l, r, k);
        if (r >  mid) update(rc, l, r, k);
        pushup(p);
    }

    int query(int p, int l, int r) {
        if (l <= tr[p].l && r >= tr[p].r) {
            return tr[p].sum;
        }
        pushdown(p);
        int mid = (tr[p].r + tr[p].l) >> 1;
        int sum = 0;
        if (l <= mid) sum += query(lc, l, r);
        if (r > mid) sum += query(rc, l, r);
        return sum;
    }
};

其他

不难看出,只要是父节点的值可以由左右子节点得来,就可以使用线段树。

同理,不难实现:查询区间乘区间max,min区间gcd、lcm等…

可以完成以下练习:

  1. P1531 I Hate It
  2. P5057 [CQOI2006] 简单题

技巧

建空树

如题:P4588 [TJOI2018] 数学计算

如果你一眼觉得很简单,那么可以自己尝试写一写,待会别来打我

仔细想象会发现,我们只需要输出的时候取模,并不能对原数据进行取模,不然会影响后续的操作,直接相乘很容易爆 long long。这么想想会发现这题其实很棘手,我们需要使用线段树。

你可能会觉得,为什么使用线段树?怎么使用线段树?我们要使用一个技巧:建空树

我们建立一颗空树

// 建树
void _build(int p, int l, int r) {
    tr[p] = {l, r, 1, 0};
    if (l == r) return;
    int mid = (l + r) >> 1;
    _build(lc, l, mid);
    _build(rc, mid + 1, r);
    pushup(p); 
}
...
Segment seg(q);

seg.build();

我们维护区间乘modM

void pushup(int p) {
    tr[p].sum = (tr[lc].sum * tr[rc].sum) % mod;
}

单点修改,动态添加点

void update(int p, int idx, int k) {
    if (tr[p].l == tr[p].r && tr[p].l == idx) {
        tr[p].sum = k % mod;
        return;
    }
    int mid = (tr[p].l + tr[p].r) >> 1;
    if (idx <= mid) update(lc, idx, k);
    if (idx >  mid) update(rc, idx, k);
    pushup(p);
}

每次操作1,则将m添加到线段树底部,操作2,则将前边第pos位的数改为1,每次输出整个区间的乘积

void solve() {
    int q;
    cin >> q >> mod;
    Segment seg(q);
    seg.build();
    int cnt = 1;
    vi ops = {1};
    while(q --) {
        int op, x; cin >> op >> x;
        ops.push_back(cnt);
        if (op == 1) {
            seg.update(1, cnt ++, x);
        } else {
            seg.update(1, ops[x], 1);
        }
        cout << seg.tr[1].sum << "\n";
    }
}

我们动态的添加点、修改点,区间维护区间乘(取模),即可解决此题,完整AC代码:#69606192 | Cysheper’s solution for [洛谷-P4588]

总结:建立一颗空树,动态添加点,每次获取区间信息。

练习:P1637 三元上升子序列

权值线段树

线段树的节点上可以放很多东西,比如区间和、区间max、甚至一个数组(注意不要爆空间)

如题:Codeforces Round 149 (Div. 2) E. xor on Segment

对于多元素异或操作,不难想到,可以对于每一个二进制位单独操作,增加不超过31的常数,一般是20(a<106a < 10^6

我们在区间上维护一个20空间的数组,存区间元素相应二进制位的个数

struct node {
    int l, r, flag; // 左右区间,懒标记
    int bits[20];  // 存当前区间相应二进制位的个数

    node() {
        l = r = flag = 0;
        memset(bits, 0, sizeof bits); // 初始化为0
    }
}*tr;  

向上更新:数组相应位置相加

void pushup(int p) {
    // 左区间元素二进制相应位加右区间...
    for (int i = 0; i < 20; ++ i) {
        tr[p].bits[i] = tr[lc].bits[i] + tr[rc].bits[i];
    }
}

建树终点赋值的时候,就是这个元素的二进制形式

void _build(vi&v, int p, int l, int r) {
    tr[p].l = l, tr[p].r = r;
    tr[p].flag = 0;
    if (l == r) {
        // 赋值为那个元素的二进制形式
        for (int i = 0; i < 20; ++ i) {
            if(v[l] & (1 << i)) 
                tr[p].bits[i] ++; 
        }
        return;
    }
    int mid = (l + r) >> 1;
    _build(v, lc, l, mid);
    _build(v, rc, mid + 1, r);
    pushup(p); 
}

然后这道题就不难写了,完整AC代码:#69645153 | Cysheper’s solution for [CodeForces-242E]