引子:线段树是什么
不知道你有没有被某些算法题搞得焦头烂额,例如给你次询问求区间和,区间最大最小,区间子段和这些神秘题目,但是现在,他们都将变成板子题!
线段树是一种二叉树结构,其实作用是可以快速修改和查询,我们先从支持查询的线段数开始
线段树思想
线段树通过将一段区间不断二分,直到分解成单个节点的元素,然后通过合并子区间的信息从而得到父区间的结果,将区间问题变成 分治合并问题
实现
一颗完整的静态线段树需要进行
- 建树
- 合并
- 查询
这些操作
维护信息
根据题目的意义,我们的节点可能维护的信息是不一样的
- 区间和(sum)
- 区间最值(max / min)
- 区间最大子段和 (
sum,lmax,rmax,tmax) 一般这些信息都被维护在一个结构体节点下
我们接下来的演示都以 区间子段和 为例子进行演示
struct Node{
i64 sum,lmax,rmax,tmax;
}这个表示我们每个节点保存的信息是这个区间的区间和,区间前缀最大值,区间后缀最大值,区间子段最大值
合并策略
在有了维护信息之后我们就需要进行确定合并策略,不同的题目有不同的合并策略
确定合并策略的思考点主要是去寻思两个区间的信息怎么通过一些手段变成他们合并后的信息
下面给出一些合并策略
Tip
- 区间和
res.sum = L.sum + R.sum;
- 区间最大/最小值
res.max = max(L.max,R.max); res.min = min(L.min,R.min);
- 区间最大子段和
res.sum = L.sum + R.sum; res.lmax = max(L.lmax,L.sum + R.lmax); res.rmax = max(R.rmax,L.rmax + R.sum); res.tmax = max({ L.tmax, R.tmax, L.rmax + R.lmax });
我们的写法则是:
Node merge(const Node &L, const Node &R)
{
Node res;
res.sum = L.sum + R.sum;
res.lmax = max(L.lmax, L.sum + R.lmax);
res.rmax = max(R.rmax, R.sum + L.rmax);
res.tmax = max({L.tmax, R.tmax, L.rmax + R.lmax});
}建树
做好了一切的准备后我们便可以开始建树
在建树的时候,我们主要用数组模拟树,同时使用位运算优化建树过程,递归的构建整棵树
vector<Node> seg;
i64 n;
// 建树需要开四倍空间
Seg(i64 n): n(n),seg(n * 4 + 4){}
void build(i64 idx, i64 l, i64 r, const vint &v)
{
if (l == r) {
i64 val = v[l];
seg[idx] = {val, val, val, val};
return;
}
i64 mid = (l + r) >> 1;
build(idx << 1, l, mid, v);
build(idx << 1 | 1, mid + 1, r, v);
seg[idx] = merge(seg[idx << 1], seg[idx << 1 | 1]);
}建树的时间复杂度为
查询
当查询到[ql,qr]区间时候:
- 若当前节点
[l,r]完全在区间[ql,qr]时,直接返回 - 若有交叉,则递归到左右子区间
- 再合并左右结果
Node query(i64 idx, i64 l, i64 r, i64 ql, i64 qr)
{
if (ql <= l && r <= qr) return seg[idx];
i64 mid = (l + r) >> 1;
if (qr <= mid) {
return query(idx << 1, l, mid, ql, qr);
}
else if (ql > mid) {
return query(idx << 1 | 1, mid + 1, r, ql, qr);
}
else {
return merge(query(idx << 1, l, mid, ql, qr), query(idx << 1 | 1, mid + 1, r, ql, qr));
}
}则整个树就建立完成了,这是最基本的静态查询线段树,线段树的区间查询是
struct Seg
{
struct Node
{
i64 sum, lmax, rmax, tmax;
};
i64 n;
vector<Node> seg;
Seg(i64 n): n(n), seg(4 * n + 4, Node()) { }
//
Node merge(const Node &L, const Node &R)
{
Node res;
res.sum = L.sum + R.sum;
res.lmax = max(L.lmax, L.sum + R.lmax);
res.rmax = max(R.rmax, R.sum + L.rmax);
res.tmax = max({L.tmax, R.tmax, L.rmax + R.lmax});
}
void build(i64 idx, i64 l, i64 r, const vint &v)
{
if (l == r) {
i64 val = v[l];
seg[idx] = {val, val, val, val};
return;
}
i64 mid = (l + r) >> 1;
build(idx << 1, l, mid, v);
build(idx << 1 | 1, mid + 1, r, v);
seg[idx] = merge(seg[idx << 1], seg[idx << 1 | 1]);
}
Node query(i64 idx, i64 l, i64 r, i64 ql, i64 qr)
{
if (ql <= l && r <= qr) return seg[idx];
i64 mid = (l + r) >> 1;
if (qr <= mid) {
return query(idx << 1, l, mid, ql, qr);
}
else if (ql > mid) {
return query(idx << 1 | 1, mid + 1, r, ql, qr);
}
else {
return merge(query(idx << 1, l, mid, ql, qr), query(idx << 1 | 1, mid + 1, r, ql, qr));
}
}
};