[ZJOI2019] 语言

给定一棵 \(n\) 个节点的树,并给定树上的 \(m\) 条链 \((s_i,t_i)\)

我们称点对 \((u,v)\) 合法,当且仅当:

  • \(1 \leq u < v \leq n\)

  • 存在一条链 \((s_i,t_i)\) 同时经过 \(u\)\(v\)

求合法点对的数量,\(1 \leq n,m \leq 10^5\)


以下认为 \(n,m\) 同阶。

\(O(n \log^3 n)\)

容易想到将所有点对 \((u,v)\) 看成平面直角坐标系上的点。

我们尝试用链覆盖点对,容易使用树剖将 \((s_i,t_i)\) 拆分成 \(O(\log n)\) 个 dfn 连续段。

接下来我们可以将一条链拆分成 \(O(\log^2 n)\) 个矩形,做扫描线即可。

\(O(n \log^2 n)\)

考虑对于每个点 \(u\),计算合法的点 \(v\) 数量,这里我们不考虑 \(u < v\) 的限制。

此时我们发现,\(v\) 一定在所有经过 \(u\) 的链的并集上,我们只需要统计这个并集的大小即可。

因为这些链存在公共点 \(u\),所以其并集就是所有端点构成的虚树大小。

虚树大小可以将所有点按 dfn 排序后,相邻两个点距离和除以 \(2\) 得到,这很典。

然而上面的是边的条数,我们为了统计点数需要将其 \(+1\)

接下来,只要考虑如何快速维护每个 \(u\) 对应的端点集合,即可解决问题。

不难利用树上差分技巧,将一条链拆成 \(4\) 次单点加删点对。

此时我们需要将子树的虚树端点合并到父亲节点,启发式合并即可。

超大常数,\(O(n \log^2 n)\),洛谷跑了 \(6.00\) 秒。

//Ad astra per aspera
#include<iostream>
#include<cstdio>
#include<vector>
#include<map> 
#include<algorithm>
using namespace std;
vector<int> G[100010];
int n,dfn[100010],rev[100010],tot;
int pa[100010][20],depth[100010];
void dfs(int u,int fa,int cur){
	tot++;
	dfn[u]=tot;
	rev[tot]=u;
	depth[u]=cur;
	for(int i=0;i<G[u].size();i++){
		int v=G[u][i];
		if(v!=fa){
			pa[v][0]=u;
			for(int i=1;i<=19;i++){
				pa[v][i]=pa[pa[v][i-1]][i-1];
			}
			dfs(v,u,cur+1);
		}
	}
}
int LCA(int u,int v){
	if(depth[u]<depth[v]){
		swap(u,v);
	}
	for(int i=19;i>=0;i--){
		if(depth[pa[u][i]]>=depth[v]){
			u=pa[u][i];
		}
	}
	if(u==v){
		return u;
	}
	for(int i=19;i>=0;i--){
		if(pa[u][i]!=pa[v][i]){
			u=pa[u][i];
			v=pa[v][i];
		}
	}
	return pa[u][0];
}
int dist(int dfn_u,int dfn_v){
	if(min(dfn_u,dfn_v)==0  ||  max(dfn_u,dfn_v)==n+1){
		return 0;
	}
	else{
		return depth[rev[dfn_u]]+depth[rev[dfn_v]]-2*depth[LCA(rev[dfn_u],rev[dfn_v])];
	}
}
vector<int> vec_add[100010],vec_sub[100010];
void add(int u,int v){
	vec_add[u].push_back(dfn[u]);
	vec_add[u].push_back(dfn[v]);
	vec_add[v].push_back(dfn[u]);
	vec_add[v].push_back(dfn[v]);
	int lca=LCA(u,v);
	vec_sub[lca].push_back(dfn[u]);
	vec_sub[lca].push_back(dfn[v]);
	if(pa[lca][0]){
		vec_sub[pa[lca][0]].push_back(dfn[u]);
		vec_sub[pa[lca][0]].push_back(dfn[v]);
	}
}
long long ans,pre_ans[100010],child[100010];
map<int,int> dict[100010];
void add_dot(int id,int dfn_id){
	map<int,int> :: iterator iter=dict[id].lower_bound(dfn_id);
	if(iter->first!=dfn_id){
		map<int,int> :: iterator iter2=iter;
		iter2--;
		pre_ans[id]-=dist(iter2->first,iter->first);
		pre_ans[id]+=dist(iter2->first,dfn_id);
		pre_ans[id]+=dist(dfn_id,iter->first);
		dict[id][dfn_id]=1;
	}
	else{
		iter->second++;
	}
	child[id]++;
}
void sub_dot(int id,int dfn_id){
	map<int,int> :: iterator iter=dict[id].lower_bound(dfn_id);
	if(iter->second==1){
		map<int,int> :: iterator iter2=iter,iter3=iter;
		iter2--;
		iter3++;
		pre_ans[id]-=dist(iter2->first,iter->first);
		pre_ans[id]-=dist(iter->first,iter3->first);
		pre_ans[id]+=dist(iter2->first,iter3->first);
		dict[id].erase(iter);
	}
	else{
		iter->second--;
	}
	child[id]--;
}
void dfs2(int u,int fa){
	int son=0;
	for(int i=0;i<G[u].size();i++){
		int v=G[u][i];
		if(v!=fa){
			dfs2(v,u);
			if(child[v]>=child[son]){
				son=v;
			}
		}
	}
	if(son){
		swap(dict[u],dict[son]);
		swap(pre_ans[u],pre_ans[son]);
		swap(child[u],child[son]);
	}
	for(int i=0;i<G[u].size();i++){
		int v=G[u][i];
		if(v!=fa  &&  v!=son){
			for(map<int,int> :: iterator iter=dict[v].begin();iter!=dict[v].end();iter++){
				int id=iter->first,cnt=iter->second;
				while(cnt--){
					if(id!=0  &&  id!=n+1){
						add_dot(u,id);
					}
				}
			}
			dict[v].clear();
		}
	}
	for(int i=0;i<vec_add[u].size();i++){
		int id=vec_add[u][i];
		add_dot(u,id);
	}
	for(int i=0;i<vec_sub[u].size();i++){
		int id=vec_sub[u][i];
		sub_dot(u,id);
	}
	map<int,int> :: iterator iter1=dict[u].begin(),iter2=dict[u].end();
	iter1++;
	iter2--;
	iter2--;
	ans+=(pre_ans[u]+dist(iter1->first,iter2->first))/2;
}
int main(){
	int m;
	scanf("%d %d",&n,&m);
	for(int i=1;i<n;i++){
		int u,v;
		scanf("%d %d",&u,&v);
		G[u].push_back(v);
		G[v].push_back(u);
	}
	dfs(1,0,1);
	while(m--){
		int u,v;
		scanf("%d %d",&u,&v);
		add(u,v);
	}
	for(int i=1;i<=n;i++){
		add(i,i);
	}
	for(int i=1;i<=n;i++){
		dict[i][0]=1;
		dict[i][n+1]=1;
		child[i]=2;
	}
	dfs2(1,0);
	printf("%lld",ans/2);
	return 0;
}
posted @ 2026-02-04 08:51  Oken喵~  阅读(2)  评论(0)    收藏  举报