题目大意: 给你一棵$n(n\leqslant10^5)$个点的树,一个简单路径的集合$S_k$被称为$k$合法当且仅当:树的每个节点至多属于其中一条路径,且每条路径恰好包含$k$个点。对于$k\in[1,n]$,$|S
题目大意:给你一棵$n(n\leqslant10^5)$个点的树,一个简单路径的集合$S_k$被称为$k$合法当且仅当:树的每个节点至多属于其中一条路径,且每条路径恰好包含$k$个点。对于$k\in[1,n]$,$|S_k|$的最大值
题解:树形$DP$,$O(n^2)$的很好想,记录一个节点向下最长的没有被选的链长度和以这个点为根的子树的答案,转移显然。
发现对于答案$ans_k$,$ans_k\times k\leqslant n$,当$k\leqslant \sqrt n$时,$ans_k$最多只有$\sqrt n$个;当$k\geqslant\sqrt n$时,$ans_k\leqslant\sqrt n$,取值范围只有$\sqrt n$个,于是$ans_k$个数最多$2\sqrt n$个,可以二分相同的区间。复杂度$O(n\sqrt n\log_2n)$,但是这样会$TLE$。
猜想是$dfs$常数过大,于是按$dfs$序逆序排列后递推,就过了
卡点:$TLE$
C++ Code:
#include <cstdio> #include <cstring> #include <algorithm> #define maxn 100010 int head[maxn], cnt; struct Edge { int to, nxt; } e[maxn << 1]; inline void add(int a, int b) { e[++cnt] = (Edge) {b, head[a]}; head[a] = cnt; } int n; int ans[maxn]; int f[maxn], rem[maxn], fa[maxn]; int dfn[maxn], idx, rnk[maxn]; void dfs(int u) { dfn[u] = ++idx; rnk[u] = u; for (int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if (v != fa[u]) { fa[v] = u; dfs(v); } } } inline bool cmp(int a, int b) {return dfn[a] > dfn[b];} int max_1[maxn], max_2[maxn]; inline void work(const int k) { memset(f, 0, sizeof f); memset(rem, 0, sizeof rem); memset(max_1, 0, sizeof max_1); memset(max_2, 0, sizeof max_2); for (int I = 1, u, fa; I <= n; I++) { u = rnk[I]; if (max_1[u] + max_2[u] + 1 >= k) { f[u]++; rem[u] = 0; } else rem[u] = max_1[u] + 1; fa = ::fa[u]; f[fa] += f[u]; if (max_1[fa] < rem[u]) { max_2[fa] = max_1[fa]; max_1[fa] = rem[u]; } else if (max_2[fa] < rem[u]) max_2[fa] = rem[u]; } ans[k] = f[1]; } int main() { scanf("%d", &n); for (int i = 1, a, b; i < n; i++) { scanf("%d%d", &a, &b); add(a, b); add(b, a); } memset(ans, -1, sizeof ans); dfs(1); std::sort(rnk + 1, rnk + n + 1, cmp); ans[1] = n; int l, r; for (l = 2; l <= n; l = r + 1) { if (ans[l] == -1) work(l); int last = ans[l]; int res = l, L = l, R = n; while (L <= R) { int mid = L + R >> 1; if (ans[mid] == -1) work(mid); if (ans[mid] == last) { res = mid; L = mid + 1; } else R = mid - 1; } r = res; for (int i = l; i <= r; i++) ans[i] = last; } for (int i = 1; i <= n; i++) printf("%d\n", ans[i]); return 0; }