[CF1260F]树链剖分+线段树

题目大意

给定一棵树,每个节点颜色 $h_i\in[l_i,r_i]$ ,这样总共有 $\prod_{i=1}^n(r_i-l_i+1)$ 种不同的树,定义树的权值为
$$
\sum_{h_i=h_j,1\leq i<j\leq n}dis(i,j)
$$
求权值和

思路

首先可以把答案化一下 $\sum_{i<j}|[l_i,r_i]\and [l_j,r_j|dis(i,j)\prod_{k\not=i,j}(r_k-l_k+1)$

意思是枚举两个点,答案是这两个点的颜色的并的个数乘两个点的距离再乘上剩下的点随意选的方案数

首先想到直接求这个式子,但是发现颜色的并不方便处理,如果用树套树暴力维护也会多 $log$ ,后来看了 $tutorial$ 后才发现一个小 $Trick$

我们并不需要从左到右枚举 $j$ 再求所有 $i$ 的答案,可以考虑枚举颜色,因为每个点是有一个出现的区间的,所以可以在颜色变化的同时加入和删除点,在改变点集的时候维护答案即可

然后随便维护一下这个式子,用树剖+线段树处理就可以了,对于这个式子的卫华很套路,$dis_{lca}$ 通常就将其到根的路径+1然后再询问其到根的路径和就可以了

具体的可以参考这篇 : 戳我

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/***************************************************************
File name: f.cpp
Author: ljfcnyali
Create time: 2019年12月15日 星期日 14时38分45秒
***************************************************************/
#include<bits/stdc++.h>

using namespace std;

#define REP(i, a, b) for ( int i = (a), _end_ = (b); i <= _end_; ++ i )
#define mem(a) memset ( (a), 0, sizeof ( a ) )
#define str(a) strlen ( a )
#define pii pair<int, int>
#define x first
#define y second
#define lson root << 1
#define rson root << 1 | 1
#define int long long
#define inv(x) power(x, Mod - 2);
typedef long long LL;

const int maxn = 2e5 + 10;
const int Mod = 1e9 + 7;

int n, Begin[maxn], Next[maxn], To[maxn], e, size[maxn], dfn[maxn], dis[maxn], p[maxn], id[maxn];
int head[maxn], son[maxn], f[maxn], cnt, tot, ans, N, sum1, sum2, sum3, g[maxn];
vector<int> a[maxn];
struct node { int lazy, sum; } Tree[maxn << 2];

inline void add(int u, int v) { To[++ e] = v; Next[e] = Begin[u]; Begin[u] = e; }

inline void DFS1(int u, int fa)
{
size[u] = 1; int Max = 0;
for ( int i = Begin[u]; i; i = Next[i] )
{
int v = To[i]; if ( v == fa ) continue ;
f[v] = u; dis[v] = dis[u] + 1;
DFS1(v, u);
size[u] += size[v];
if ( size[v] > Max ) { Max = size[v]; son[u] = v; }
}
}

inline void DFS2(int u, int t)
{
head[u] = t; p[u] = ++ cnt; id[cnt] = u;
if ( son[u] ) DFS2(son[u], t);
for ( int i = Begin[u]; i; i = Next[i] )
{
int v = To[i]; if ( v == f[u] || v == son[u] ) continue ;
DFS2(v, v);
}
}

inline void PushUp(int root) { Tree[root].sum = (Tree[lson].sum + Tree[rson].sum) % Mod; }

inline void PushTag(int root, int l, int r, int val)
{
Tree[root].lazy = (Tree[root].lazy + val) % Mod;
Tree[root].sum = (Tree[root].sum + (r - l + 1) * val) % Mod;
}

inline void PushDown(int root, int l, int r)
{
if ( !Tree[root].lazy ) return ;
int Mid = l + r >> 1;
PushTag(lson, l, Mid, Tree[root].lazy);
PushTag(rson, Mid + 1, r, Tree[root].lazy);
Tree[root].lazy = 0;
}

inline void Modify2(int root, int l, int r, int L, int R, int val)
{
if ( L <= l && r <= R ) { PushTag(root, l, r, val); return ; }
PushDown(root, l, r);
int Mid = l + r >> 1;
if ( L <= Mid ) Modify2(lson, l, Mid, L, R, val);
if ( Mid < R ) Modify2(rson, Mid + 1, r, L, R, val);
PushUp(root);
}

inline int Query2(int root, int l, int r, int L, int R)
{
if ( L <= l && r <= R ) return Tree[root].sum;
PushDown(root, l, r);
int Mid = l + r >> 1, ret = 0;
if ( L <= Mid ) ret += Query2(lson, l, Mid, L, R);
if ( Mid < R ) ret += Query2(rson, Mid + 1, r, L, R);
return ret % Mod;
}

inline void Modify1(int x, int val)
{
while ( x ) { Modify2(1, 1, n, p[head[x]], p[x], val); x = f[head[x]]; }
}

inline int Query1(int x)
{
int ret = 0;
while ( x ) { ret += Query2(1, 1, n, p[head[x]], p[x]); x = f[head[x]]; }
return ret % Mod;
}

inline int power(int a, int b) { int r = 1; while ( b ) { if ( b & 1 ) r = r * a % Mod; a = a * a % Mod; b >>= 1; } return r; }

signed main()
{
#ifndef ONLINE_JUDGE
freopen("f.in", "r", stdin);
freopen("f.out", "w", stdout);
#endif
scanf("%lld", &n);
int sum = 1;
REP(i, 1, n)
{
int x, y; scanf("%lld%lld", &x, &y);
g[i] = y - x + 1; sum = sum * g[i] % Mod; g[i] = inv(g[i]);
a[x].push_back(i); a[y + 1].push_back(-i); N = max(N, y);
}
REP(i, 1, n - 1) { int u, v; scanf("%lld%lld", &u, &v); add(u, v); add(v, u); }
dis[1] = 1; DFS1(1, 0); DFS2(1, 1);
int ret = 0;
REP(i, 1, N)
{
for ( int j : a[i] )
{
int x = j;
if ( x > 0 )
{
sum1 = (sum1 + dis[x] * g[x] % Mod) % Mod;
sum2 = (sum2 + g[x]) % Mod;
sum3 = (sum3 + (dis[x] * g[x] % Mod) * g[x] % Mod) % Mod;
ret = (ret + g[x] * Query1(x)) % Mod;
Modify1(x, g[x]);
}
else
{
x = -x;
sum1 = (sum1 - dis[x] * g[x] % Mod) % Mod;
sum2 = (sum2 - g[x]) % Mod;
sum3 = (sum3 - (dis[x] * g[x] % Mod) * g[x] % Mod) % Mod;
Modify1(x, Mod - g[x]);
ret = (ret - g[x] * Query1(x)) % Mod;
}
}
ans = (ans + (sum1 * sum2 % Mod) - sum3 - 2 * ret) % Mod;
}
ans = (ans + Mod) % Mod;
printf("%lld\n", sum * ans % Mod);
return 0;
}