题目描述 给出以1号点为根的一棵有根树,问每个点的子树中与它距离小于等于l的点有多少个。 输入格式 Line 1: 2 integers, N and L (1 = N = 200,000, 1 = L = 10^18) Lines 2..N: The ith line contains two i
题目描述
给出以1号点为根的一棵有根树,问每个点的子树中与它距离小于等于l的点有多少个。
输入格式
Line 1: 2 integers, N and L (1 <= N <= 200,000, 1 <= L <= 10^18)
Lines 2..N: The ith line contains two integers p_i and l_i. p_i (1 <= p_i < i) is the first pasture on the shortest path between pasture i and the barn, and l_i (1 <= l_i <= 10^12) is the length of that path.
输出格式
Lines 1..N: One number per line, the number on line i is the number pastures that can be reached from pasture i by taking roads that lead strictly farther away from the barn (pasture 1) whose total length does not exceed L.
这道题有很多高级的做法,但是我都不会
我们分析题目可以得出这样一条结论——对于当前节点u,u的子树中与u的距离大于l的点与u的所有祖先的距离都大于l(u也是自己的祖先)。所以不难想到我们对于每个节点u,我们计算出u的第一个与它距离大于l的祖先anc,那么对于这个祖先,它的答案就要减去size(u)。size表示子树的节点数,初始化每个点的答案为子树的节点数。然后结合之前得到的性质,我们可以用树上前缀和的思想,把这个减去的size(u)累加到anc的祖先中去。
但是你会发现,直接算是有问题的。
首先对于u,它对anc的答案做了值为-size(u)的贡献,并且我们要将这个贡献累加到anc的祖先中去。然后我们发现,对于u的祖先,比如u的父亲fa(u),第一个与fa(u)距离大于l的祖先也必定是anc的祖先,但我们将-size(fa(u))加到了这个祖先中,也就是说这个祖先的答案累加了两次-size(u),答案显然是错的。如何避免呢?很简单,我们将size(fa(u))减去size(u)即可。那么问题就解决了。
对于求第一个距离大于l的祖先,我们可以用倍增来做,那么总的时间复杂度就是O(NlogN)。
*由于size(fa(u))减去的是size(u)原本的大小,而此时size(u)可能已经被u的子节点减去了一些,所以我们要再开一个size数组来记录原本的size。
*不开long long见祖宗
#include<iostream> #include<cstring> #include<cstdio> #include<cmath> #define maxn 200001 using namespace std; struct edge{ int to,next; long long dis; edge(){} edge(const int &_to,const long long &_dis,const int &_next){ to=_to,dis=_dis,next=_next; } }e[maxn<<1]; int head[maxn],k; int fa[maxn][20],size[maxn],size2[maxn],sum[maxn],maxdep; int n; long long m,dis[maxn][20]; inline long long read(){ register long long x(0),f(1); register char c(getchar()); while(c<'0'||'9'<c){ if(c=='-') f=-1; c=getchar(); } while('0'<=c&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar(); return x*f; } inline void add(const int &u,const int &v,const long long &w){ e[k]=edge(v,w,head[u]),head[u]=k++; } void dfs(int u){ size[u]=1; for(register int i=head[u];~i;i=e[i].next){ int v=e[i].to; if(v==fa[u][0]) continue; fa[v][0]=u,dis[v][0]=e[i].dis; for(register int j=1;j<=maxdep;j++) fa[v][j]=fa[fa[v][j-1]][j-1],dis[v][j]=dis[v][j-1]+dis[fa[v][j-1]][j-1]; dfs(v),size[u]+=size[v]; } } void dfs_getsum(int u){ for(register int i=head[u];~i;i=e[i].next){ int v=e[i].to; if(v==fa[u][0]) continue; dfs_getsum(v); long long len=0; int tmp=size[v],tmp2=size2[v]; for(register int j=maxdep;j>=0;j--) if(len+dis[v][j]<=m&&fa[v][j]) len+=dis[v][j],v=fa[v][j]; if(len+dis[v][0]>m&&fa[v][0]) sum[fa[v][0]]+=tmp,size[u]-=tmp2; } } void dfs_getans(int u){ for(register int i=head[u];~i;i=e[i].next){ int v=e[i].to; if(v==fa[u][0]) continue; dfs_getans(v),sum[u]+=sum[v]; } } int main(){ memset(head,-1,sizeof head); n=read(),m=read(); for(register int i=2;i<=n;i++){ int v=read(); long long w=read(); add(i,v,w),add(v,i,w); } maxdep=(int)log(n)/log(2),dfs(1); for(register int i=1;i<=n;i++) size2[i]=size[i]; dfs_getsum(1); dfs_getans(1); for(register int i=1;i<=n;i++) printf("%d\n",size2[i]-sum[i]); return 0; }