题意
你梦见了一棵树,这是一棵很茂密的树,因此它有很多的分支。
你注意到这颗树的有 nn
个果实,每一棵果实都有自己的编号,且标号为 11
的果实在最上面,像是一个根节点,树上的一个果实 uu 到另一个果实
vv 的距离,都恰好是一个整数 cc ,因为已经固定好了 11
号果实为根节点,所以这棵树的形状已经确定了,你想知道摘下一颗果实,会连带着把它的子树的果实也给摘下来。
而这个摘下来所得到的贡献为(数字出现的次数*数字)的平方
比如2出现了5次,那么贡献即为(2*5)^2(2∗5)2
数字为两个果实之间的距离即树的边权值,边权值的范围为 c*c*
。
所以你有m组询问,想知道当前询问的果实连带着它的子树果实被摘下来时的贡献是多少。
输入
第11行,三个整数n,m,cn,m,c分别表示树的大小,询问的个数,边权的范围。(1
n,m,c )(1≤n,m,c≤100000)
第2-n2−n行,每行三个整数u,v,viu,v,vi表示从uu到vv有一条vivi边权的边。
接下来mm行,每行一个整数表示询问的节点。
输出
输出mm行,每行一个整数代表子树的权值大小。(保证不会超过long
long)
样例输入
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| 11 6 10 1 2 9 2 3 1 3 4 6 2 5 7 4 6 5 5 7 7 7 8 8 7 9 3 7 10 6 3 11 3 5 7 10 6 1 5
|
样例输出
思路
比赛的时候基本啥算法没想,感觉建一个树,暴力记一下每个子树的贡献时间复杂度也可以,结果爆内存了。。。
然后看题解
首先是dfs序把树变成线性,用区间表示一个子树,具体可以参考这个博客https://blog.csdn.net/qq_37275680/article/details/82793691我尝试自己再总结一遍
然后是对区间求贡献,可以用线段树,复杂度是O(nlogn)
,也可以用莫队(又是我没学过的算法),复杂度O(nsqrt(n))可以参考https://blog.csdn.net/ThinFatty/article/details/72581276?spm=1001.2101.3001.6661.1&utm_medium=distribute.pc_relevant_t0.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-1.pc_relevant_default&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-1.pc_relevant_default&utm_relevant_index=1,我自己也会尝试总结一遍
然后贴上代码,因为没找到比赛补题的地方,不知道能不能过,就只过了样例意思意思吧
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| #include <bits/stdc++.h> using namespace std; #define int long long const int N=1e5+10; int block; struct Sec { int l,r,id; bool operator <(const Sec&b)const { if((l/block)!=(b.l/block))return l<b.l; return r<b.r; } } sec[N],op[N]; vector<int>t[N]; int tot,cnt[N],val[N],d[N],id[N],ret[N]; void dfs(int u) { sec[u].l=++tot; for(int i=0; i<t[u].size(); i++) { dfs(t[u][i]); } sec[u].r=tot; } int add(int x){ cnt[x]++; return (cnt[x]*x)*(cnt[x]*x)-(cnt[x]-1)*x*(cnt[x]-1)*x; } int remove(int x){ cnt[x]--; return (cnt[x]*x)*(cnt[x]*x)-(cnt[x]+1)*x*(cnt[x]+1)*x; } signed main() { int n,m,c; scanf("%lld%lld%lld",&n,&m,&c);block=sqrt(n); for(int i=1; i<n; i++) { int u,v,w; scanf("%lld%lld%lld",&u,&v,&w); t[u].push_back(v); val[v]=w; } dfs(1); for(int i=1;i<=n;i++){ d[sec[i].l]=val[i]; } for(int i=1;i<=m;i++){ int u; cin>>u; op[i].l=sec[u].l+1; op[i].r=sec[u].r; op[i].id=i; } sort(op+1,op+1+m); int ans=0,l=1,r=1; for(int i=1;i<=m;i++){ while(l<op[i].l){ans+=remove(d[l]);l++;} while(l>op[i].l){l--;ans+=add(d[l]);} while(r<op[i].r){r++;ans+=add(d[r]);} while(r>op[i].r){ans+=remove(d[r]);r--;} ret[op[i].id]=ans; } for(int i=1;i<=m;i++)cout<<ret[i]<<'\n'; return 0; }
|