AT_abc298_h [ABC298Ex] Sum of Min of Length

模拟赛怒码 7KB 错解,赛后 10min AC。

思路

首先观察 l,rl,r 不同的关系对于结果的构成有什么影响,记 gf=LCA(l,r)gf = LCA(l,r)

  1. l=rl = r。很显然,答案就是以 ll 为根的节点的深度和。

  2. gflgfrgf \neq l \wedge gf \neq r。在 ll 子树中的节点在式子中一定会取 dist(i,l)dist(i,l),在 rr 子树中的节点在式子中一定会取 dist(i,r)dist(i,r)。如果 dldrd_l \leq d_r,则在 gfgf 之上的节点(包括 gfgf 除开 l,rl,r 所在的子树);否则是同理的。其次对于 lrl \rightsquigarrow r 这段路径,长度为 lenlen,则前 len2\frac{len}{2} 个节点及其不在链上的子树归 ll,后 len2\frac{len}{2} 个节点及其不在链上的子树归 rr

  3. gf=lgf=rgf = l \vee gf = r。与情况 2 同理,除了在处理 gfgf 子树的时候需要除去包含链的子树。

发现这些贡献可以离线下来换根。发现需要动态求子树深度和,所以用线段树换根即可。注意判断中点与需要处理的点的位置关系。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#include <bits/stdc++.h>
#define re register
#define int long long

using namespace std;

const int N = 2e5 + 10,M = 4e5 + 10,K = 24,inf = 1e9 + 10;
int n,q,ans[N];
int num,id[N],d[N],sz[N],val[N];
int lg[N],f[N][K];
int idx,h[N],ne[M],e[M];

struct Query{
int u,id;
int gf,midx,midy;
};
vector<Query> Q[N];

inline int read(){
int r = 0,w = 1;
char c = getchar();
while (c < '0' || c > '9'){
if (c == '-') w = -1;
c = getchar();
}
while (c >= '0' && c <= '9'){
r = (r << 3) + (r << 1) + (c ^ 48);
c = getchar();
}
return r * w;
}

inline void add(int a,int b){
ne[idx] = h[a];
e[idx] = b;
h[a] = idx++;
}

struct seg{
#define ls(u) (u << 1)
#define rs(u) (u << 1 | 1)

struct node{
int l,r;
int sum,tag;
}tr[N << 2];

inline void calc(int u,int k){
tr[u].sum += k * (tr[u].r - tr[u].l + 1); tr[u].tag += k;
}

inline void pushup(int u){
tr[u].sum = tr[ls(u)].sum + tr[rs(u)].sum;
}

inline void pushdown(int u){
if (tr[u].tag){
calc(ls(u),tr[u].tag); calc(rs(u),tr[u].tag);
tr[u].tag = 0;
}
}

inline void build(int u,int l,int r){
tr[u] = {l,r};
if (l == r) return tr[u].sum = val[l] - 1,void();
int mid = l + r >> 1;
build(ls(u),l,mid); build(rs(u),mid + 1,r);
pushup(u);
}

inline void modify(int u,int l,int r,int k){
if (l <= tr[u].l && tr[u].r <= r) return calc(u,k);
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(ls(u),l,r,k);
if (r > mid) modify(rs(u),l,r,k);
pushup(u);
}

inline int query(int u,int l,int r){
if (l <= tr[u].l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int res = 0;
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) res += query(ls(u),l,r);
if (r > mid) res += query(rs(u),l,r);
return res;
}

#undef ls
#undef rs
}T;

inline void get(int u,int fa){
sz[u] = 1; f[u][0] = fa;
val[id[u] = ++num] = d[u] = d[fa] + 1;
for (re int i = 1;i <= lg[d[u]];i++) f[u][i] = f[f[u][i - 1]][i - 1];
for (re int i = h[u];~i;i = ne[i]){
int j = e[i];
if (j == fa) continue;
get(j,u); sz[u] += sz[j];
}
}

inline int lca(int x,int y){
while (d[x] != d[y]){
if (d[x] < d[y]) swap(x,y);
x = f[x][lg[d[x] - d[y]]];
}
if (x == y) return x;
for (re int i = lg[d[x]];~i;i--){
if (f[x][i] != f[y][i]) x = f[x][i],y = f[y][i];
}
return f[x][0];
}

inline bool check(int u,int v){
if (d[v] > d[u]) return false;
while (d[u] != d[v]) u = f[u][lg[d[u] - d[v]]];
return (u == v);
}

inline void dfs(int u,int fa){
for (auto p:Q[u]){
int res = 0;
int v = p.u,gf = p.gf;
int midx = p.midx,midy = p.midy;
if (u == v) ans[p.id] = T.tr[1].sum;
else if (check(u,midx) && !check(v,midx)) res = T.query(1,id[midx],id[midx] + sz[midx] - 1);
else res = T.tr[1].sum - T.query(1,id[midy],id[midy] + sz[midy] - 1);
ans[p.id] += res;
}
for (re int i = h[u];~i;i = ne[i]){
int j = e[i];
if (j == fa) continue;
T.modify(1,1,n,1); T.modify(1,id[j],id[j] + sz[j] - 1,-2);
dfs(j,u);
T.modify(1,1,n,-1); T.modify(1,id[j],id[j] + sz[j] - 1,2);
}
}

signed main(){
memset(h,-1,sizeof(h));
n = read();
for (re int i = 2;i <= n;i++) lg[i] = lg[i >> 1] + 1;
for (re int i = 1;i < n;i++){
int a,b; a = read(),b = read();
add(a,b); add(b,a);
}
get(1,0); T.build(1,1,n);
q = read();
for (re int i = 1;i <= q;i++){
int x,y;
int gf = lca(x = read(),y = read());
int len = d[x] + d[y] - 2 * d[gf] - 1;
int mid = len / 2,midx,midy;
int u = x,v = y;
if (d[u] <= d[v]){
for (re int j = 20;~j;j--){
if ((1ll << j) <= mid){
v = f[v][j]; mid -= (1ll << j);
}
}
midx = f[v][0],midy = v;
}
else{
for (re int j = 20;~j;j--){
if ((1ll << j) <= mid){
u = f[u][j]; mid -= (1ll << j);
}
}
midx = u,midy = f[u][0];
}
Q[x].push_back({y,i,gf,midx,midy});
Q[y].push_back({x,i,gf,midy,midx});
}
dfs(1,0);
for (re int i = 1;i <= q;i++) printf("%lld\n",ans[i]);
return 0;
}

AT_abc298_h [ABC298Ex] Sum of Min of Length
http://watersun.top/[题解]AT_abc298_h [ABC298Ex] Sum of Min of Length/
作者
WaterSun
发布于
2024年5月1日
许可协议