树链剖分

树链剖分

树链剖分,是将树的结构分解成几条链的线性结构的算法。本以为是很复杂精细的算法,实际上却非常暴力而且码量不小。好处是化树形为线形后可以使用线段树或者树状数组等方便的数据结构,而且算法本身并不复杂而且容易理解,总的来说除了码量大还是优点多多。

例题:P3384 【模板】轻重链剖分/树链剖分

题目描述

如题,已知一棵包含 N 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

  • 1 x y z,表示将树从 xy 结点最短路径上所有节点的值都加上 z
  • 2 x y,表示求树从 xy 结点最短路径上所有节点的值之和。
  • 3 x z,表示将以 x 为根节点的子树内所有节点值都加上 z
  • 4 x 表示求以 x 为根节点的子树内所有节点值之和

输入格式

第一行包含 4 个正整数 N,M,R,P 分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。

接下来一行包含 N 个非负整数,分别依次表示各个节点上初始的数值。

接下来 N-1 行每行包含两个整数 x,y,表示点 x 和点 y 之间连有一条边(保证无环且连通)。

接下来 M 行每行包含若干个正整数,每行表示一个操作。

输出格式

输出包含若干行,分别依次表示每个操作 2 或操作 4 所得的结果(对 P 取模)。

输入输出样例

输入 #1

1
2
3
4
5
6
7
8
9
10
11
5 5 2 24
7 3 7 8 0
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3

输出 #1

1
2
2
21

说明/提示

【数据规模】

对于 30% 的数据: 1≤N≤10,1≤M≤10;

对于 70% 的数据: 1≤N≤103,1≤M≤103

对于 100% 的数据: 1≤N≤105,1≤M≤105,1≤RN,1≤P≤231−1。

思路

首先感谢题解,写的非常详细,简单易懂,质量很高。

因为配图很麻烦,所以我也就不使用图片,浅谈一下我的理解。

首先树链剖分是想将树形转化为线形,这样的操作其实使用dfs序也可以做到。但是dfs序只能表示子树,不能表示两点之间路径。但是树链剖分通过不同的dfs方法建立的线性结构可以表示两点之间路径的属性,这就是优势。

这样的dfs顺序,其实就是通过先重儿子再轻儿子的遍历顺序对点进行编号,来将树变成许多条链,而链与链之间通过父子关系进行连接。

要进行这样的编号,我们要进行两个dfs

dfs1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int a[N],dep[N],fa[N],siz[N],son[N];//a点权,dep深度,fa父节点编号,siz子树大小(算上自己),son重儿子编号
void dfs1(int u,int f,int d) {
dep[u]=d;
fa[u]=f;
siz[u]=1;
int mx=-1;
for(int i=0; i<tr[u].size(); i++) {
int v=tr[u][i];
if(v==f)continue;
dfs1(v,u,d+1);
siz[u]+=siz[v];
if(siz[v]>mx)son[u]=v,mx=siz[v];
}
}

dfs2

1
2
3
4
5
6
7
8
9
10
11
12
13
int cnt,id[N],w[N],top[N];//cnt新编号指针,id编号,w新编号下点权,top所在链的链首
void dfs2(int u,int topf) {
id[u]=++cnt;
w[cnt]=a[u];
top[u]=topf;
if(!son[u])return ;
dfs2(son[u],topf);
for(int i=0; i<tr[u].size(); i++) {
int v=tr[u][i];
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}

处理问题

两点间路径操作

对于在同一条重链上的节点,编号是连续的。因此如果两点不在同一条链上,就将深度较深的点到所在链的链首的区间,更新或求和都可。然后跳到链首的父节点,重复以上步骤,直到两点处于同一条链,再对两点之间的区间操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
int qrange(int x,int y) {
int ans=0;
while(top[x]!=top[y]) {
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans=(ans+query(1,1,n,id[top[x]],id[x]))%mod;
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans=(ans+query(1,1,n,id[x],id[y]))%mod;
return ans;
}
void updrange(int x,int y,int c) {
while(top[x]!=top[y]) {
if(dep[top[x]]<dep[top[y]])swap(x,y);
add(1,1,n,id[top[x]],id[x],c);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
add(1,1,n,id[x],id[y],c);
}

子树操作

对于在同一棵子树上的节点,编号也是连续的。值得一提的是,这里的连续指的是,子树以某种dfs遍历的序号是连续的。可以直接对从子树根节点开始,子树大小长度的区间进行操作。

1
2
3
4
5
6
int qson(int x) {
return query(1,1,n,id[x],id[x]+siz[x]-1);
}
void updson(int x,int c) {
add(1,1,n,id[x],id[x]+siz[x]-1,c);
}

上面的add和query都是常规的线段树操作,具体可以看我的这篇博客

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <bits/stdc++.h>
using namespace std;
#define IOS ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define ls 2*x
#define rs 2*x+1
#define mid ((l+r)/2)
const int N=1e5+10;
vector<int>tr[N];
int a[N],dep[N],fa[N],siz[N],son[N];
void dfs1(int u,int f,int d) {
dep[u]=d;
fa[u]=f;
siz[u]=1;
int mx=-1;
for(int i=0; i<tr[u].size(); i++) {
int v=tr[u][i];
if(v==f)continue;
dfs1(v,u,d+1);
siz[u]+=siz[v];
if(siz[v]>mx)son[u]=v,mx=siz[v];
}
}
int cnt,id[N],w[N],top[N],sum[N<<2],t[N<<2];
int n,m,r,mod;
void dfs2(int u,int topf) {
id[u]=++cnt;
w[cnt]=a[u];
top[u]=topf;
if(!son[u])return ;
dfs2(son[u],topf);
for(int i=0; i<tr[u].size(); i++) {
int v=tr[u][i];
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
void pushup(int x) {
t[x]=(t[ls]+t[rs])%mod;
}
void pushdown(int x,int l,int r) {
t[ls]=(t[ls]+sum[x]*(mid-l+1))%mod;
t[rs]=(t[rs]+sum[x]*(r-mid))%mod;
sum[ls]=(sum[ls]+sum[x])%mod;
sum[rs]=(sum[rs]+sum[x])%mod;
sum[x]=0;
}
void build(int x,int l,int r) {
if(l==r) {
t[x]=w[l]%mod;
return;
}
build(ls,l,mid);
build(rs,mid+1,r);
pushup(x);
}
void add(int x,int l,int r,int L,int R,int c) {
if(l>=L&&r<=R) {
sum[x]=(sum[x]+c)%mod;
t[x]=(t[x]+c*(r-l+1))%mod;
return ;
}
pushdown(x,l,r);
if(L<=mid)add(ls,l,mid,L,R,c);
if(mid<R)add(rs,mid+1,r,L,R,c);
pushup(x);
}
int query(int x,int l,int r,int L,int R) {
if(l>=L&&r<=R) {
return t[x];
}
pushdown(x,l,r);
int ret=0;
if(L<=mid)ret+=query(ls,l,mid,L,R);
if(mid<R)ret+=query(rs,mid+1,r,L,R);
return ret%mod;
}

int qrange(int x,int y) {
int ans=0;
while(top[x]!=top[y]) {
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans=(ans+query(1,1,n,id[top[x]],id[x]))%mod;
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans=(ans+query(1,1,n,id[x],id[y]))%mod;
return ans;
}
void updrange(int x,int y,int c) {
while(top[x]!=top[y]) {
if(dep[top[x]]<dep[top[y]])swap(x,y);
add(1,1,n,id[top[x]],id[x],c);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
add(1,1,n,id[x],id[y],c);
}
int qson(int x) {
return query(1,1,n,id[x],id[x]+siz[x]-1);
}
void updson(int x,int c) {
add(1,1,n,id[x],id[x]+siz[x]-1,c);
}
int main() {
IOS
cin>>n>>m>>r>>mod;
for(int i=1; i<=n; i++) {
cin>>a[i];
}
for(int i=1; i<n; i++) {
int u,v;
cin>>u>>v;
tr[u].push_back(v);
tr[v].push_back(u);
}
dfs1(r,r,1);
dfs2(r,r);
build(1,1,n);
while(m--) {
int op,a,b,c;
cin>>op;
if(op==1) {
cin>>a>>b>>c;
updrange(a,b,c);
} else if(op==2) {
cin>>a>>b;
cout<<qrange(a,b)<<'\n';
} else if(op==3) {
cin>>a>>c;
updson(a,c);
} else {
cin>>a;
cout<<qson(a)<<'\n';
}
}
return 0;
}


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!