思路
显然我们需要求出原树的最大匹配。
定义 d p i , 0 / 1 dp_{i,0/1} d p i , 0/1 表示在 i i i 为根的子树中进行匹配,且 i i i 不选/选 的最大匹配。状态转移方程比较显然:
{ d p u , 0 = ∑ v a l v d p u , 1 = ( ∑ v a l v ) − v a l t + d p t , 0 + 1 \left\{\begin{matrix}
dp_{u,0} = \sum{val_v}\\
dp_{u,1} = (\sum{val_v}) - val_t + dp_{t,0} + 1
\end{matrix}\right.
{ d p u , 0 = ∑ v a l v d p u , 1 = ( ∑ v a l v ) − v a l t + d p t , 0 + 1
其中 v a l u = max ( d p u , 0 , d p u , 1 ) val_u = \max(dp_{u,0},dp_{u,1}) v a l u = max ( d p u , 0 , d p u , 1 ) ,然后 t t t 是使得 d p u , 1 dp_{u,1} d p u , 1 最大的一个 v v v 。显然这一步是可以 Θ ( n ) \Theta(n) Θ ( n ) 计算的。
注意到,我们需要计算删除 u u u 后的最大匹配,这等同于计算以 u u u 为根时的 d p u , 0 dp_{u,0} d p u , 0 。于是考虑换根 DP。
令当前的根由 u u u 转为 v v v 。接下来只需简单分讨,记转移 d p u , 1 dp_{u,1} d p u , 1 时的 t t t 为 i d u id_u i d u ,d p t , 0 − v a l t + 1 dp_{t,0} - val_t + 1 d p t , 0 − v a l t + 1 为 M a x u Max_u M a x u 。
由于 v v v 由 u u u 的儿子变为了父亲,d p u dp_u d p u 的值一定发生变化。显然 d p u , 0 ← d p u , 0 − v a l v dp_{u,0} \leftarrow dp_{u,0} - val_v d p u , 0 ← d p u , 0 − v a l v 。其次:
当 i d u = v id_u = v i d u = v 时:d p u , 1 ← d p u , 0 − d p v , 1 − 1 dp_{u,1} \leftarrow dp_{u,0} - dp_{v,1} - 1 d p u , 1 ← d p u , 0 − d p v , 1 − 1 ,然后和其次大值进行匹配。因此还需在第一次 DFS 时处理出次大的儿子。
否则:d p u , 1 dp_{u,1} d p u , 1 直接减去 v a l v val_v v a l v 即可。
然后可以更新 v a l u ← max ( d p u , 0 , d p u , 1 ) val_u \leftarrow \max(dp_{u,0},dp_{u,1}) v a l u ← max ( d p u , 0 , d p u , 1 ) 。接下来更新 d p v dp_v d p v 。
显然的是 d p v , 0 / 1 dp_{v,0/1} d p v , 0/1 都需要加上 v a l u val_u v a l u 。
其次,如果 M a x v Max_v M a x v 小于 u u u 带来的贡献,那么更新掉即可。需要注意的是,i d v , M a x v id_v,Max_v i d v , M a x v 以及维护的次大值都需要被更新。
这样就完成了所有元素的更新,继续向下换根即可。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 #include <bits/stdc++.h> #define re register using namespace std;const int N = 2e5 + 10 ,M = 4e5 + 10 ,inf = 1e9 + 10 ;int n,ans;int idx,h[N],ne[M],e[M];int res,dp[N][2 ],val[N],Max[2 ][N],id[2 ][N];inline int read () { int r = 0 ,w = 1 ; char c = getchar (); while (c < '0' || c > '9' ){ if (c == '-' ) w = -1 ; c = getchar (); } while (c >= '0' && c <= '9' ){ r = (r << 3 ) + (r << 1 ) + (c ^ 48 ); c = getchar (); } return r * w; }inline void add (int a,int b) { ne[idx] = h[a]; e[idx] = b; h[a] = idx++; }inline void dfs1 (int u,int fa) { int mMax = -inf,fMax = -inf; int mid = 0 ,fid = 0 ; val[u] = dp[u][0 ] = dp[u][1 ] = 0 ; for (re int i = h[u];~i;i = ne[i]){ int j = e[i]; if (j == fa) continue ; dfs1 (j,u); dp[u][0 ] += val[j]; dp[u][1 ] += val[j]; int t = dp[j][0 ] - val[j] + 1 ; if (mMax < t){ fMax = mMax; fid = mid; mMax = t; mid = j; } else if (fMax < t){ fMax = t; fid = j; } } Max[0 ][u] = mMax; Max[1 ][u] = fMax; id[0 ][u] = mid; id[1 ][u] = fid; dp[u][1 ] += max (0 ,mMax); val[u] = max (dp[u][0 ],dp[u][1 ]); }inline void dfs2 (int u,int fa) { ans += (dp[u][0 ] == res); for (re int i = h[u];~i;i = ne[i]){ int j = e[i]; if (j == fa) continue ; int dpu0 = dp[u][0 ],dpu1 = dp[u][1 ],valu = val[u]; int Max0u = Max[0 ][u],Max1u = Max[1 ][u],id0u = id[0 ][u],id1u = id[1 ][u]; int dpj0 = dp[j][0 ],dpj1 = dp[j][1 ],valj = val[j]; int Max0j = Max[0 ][j],Max1j = Max[1 ][j],id0j = id[0 ][j],id1j = id[1 ][j]; dp[u][0 ] -= val[j]; if (id[0 ][u] == j){ dp[u][1 ] -= (dp[j][0 ] + 1 ); if (id[1 ][u]) dp[u][1 ] += Max[1 ][u]; } else dp[u][1 ] -= val[j]; val[u] = max (dp[u][0 ],dp[u][1 ]); dp[j][0 ] += val[u]; dp[j][1 ] += val[u]; int t = dp[u][0 ] - val[u] + 1 ; if (Max[0 ][j] < t){ dp[j][1 ] -= Max[0 ][j]; Max[1 ][j] = Max[0 ][j]; id[1 ][j] = id[0 ][j]; Max[0 ][j] = t; id[0 ][j] = u; dp[j][1 ] += Max[0 ][j]; } else if (Max[1 ][j] < t){ Max[1 ][j] = t; id[1 ][j] = u; } dfs2 (j,u); dp[u][0 ] = dpu0; dp[u][1 ] = dpu1; val[u] = valu; Max[0 ][u] = Max0u; Max[1 ][u] = Max1u; id[0 ][u] = id0u; id[1 ][u] = id1u; dp[j][0 ] = dpj0; dp[j][1 ] = dpj1; val[j] = valj; Max[0 ][j] = Max0j; Max[1 ][j] = Max1j; id[0 ][j] = id0j; id[1 ][j] = id1j; } }int main () { memset (h,-1 ,sizeof (h)); n = read (); for (re int i = 1 ;i < n;i++){ int a,b; a = read (),b = read (); add (a,b); add (b,a); } dfs1 (1 ,0 ); res = val[1 ]; dfs2 (1 ,0 ); printf ("%d" ,ans); return 0 ; }