P4689 [Ynoi2016] 这是我自己的发明 与 P5268 [SNOI2017] 一个简单的询问0

思路:

首先可以先考虑没有换根的情况。

先将树拍到 dfn 序上,那么一个子树 \(u\) 的所有点的 dfn 序区间为 \([dfn_u,dfn_u+siz_u-1]\)

那么询问变为:

  • 每次给定两个区间 \([l_1,r_1],[l_2,r_2]\),对于在第一个区间内的点 \(x\) 和在第二个区间的点 \(y\),若 \((x,y)\) 有贡献,当且仅当 \(w_x=w_y\)

  • 询问有贡献的点对数量。

P5268 [SNOI2017] 一个简单的询问

\(F(l_1,r_1,l_2,r_2)\) 表示 \([l_1,r_1]\)\([l_2,r_2]\) 的贡献,那么:

\[F(l_1,r_1,l_2,r_2) = F(1,r_1,1,r_2) - F(1,l_1-1,1,r_2) - F(1,r_1,1,l_2-1) - F(1,l_1-1,1,l_2-1) \]

那么一个询问就都转化为了四个 \(F(1,x,1,y)\) 的形式,考虑如何求 \(F(1,x,1,y)\),先钦定 \(x \le y\),那么考虑莫队:

  • 设当前 \(p_{1,x},p_{2,x}\) 分别表示两个区间 \(x\) 的出现次数。

  • \(x \gets x+1\) 时,贡献会增加 \(p_{2,a_{x+1}}\)

  • \(x \gets x-1\) 时,贡献会减少 \(p_{2,a_x}\)

  • \(y \gets y+1\) 时,贡献会增加 \(p_{1,a_{y+1}}\)

  • \(y \gets y-1\) 时,贡献会减少 \(p_{1,a_y}\)

现在再考虑换根操作,若当前以 \(rt\) 为根:

  • \(rt\) 不在初始以 \(1\) 为根时 \(x\) 的子树内,则不好造成影响。

  • 否则 \(x\) 子树内的点即为除了\((x \to rt)\) 路径上最接近 \(x\) 的点 \(y\) 子树内的点的全部点。

因为 \(x\) 在原始树上始终是 \(rt\) 的父亲,则 \(y\)\(rt\)\(dep_{rt}-dep_{x}-1\) 级祖先,直接倍增即可。

时间复杂度为 \(O(N\sqrt{M}+M \log N+M)\)

完整代码:

#include<bits/stdc++.h>
#define Add(x,y) (x+y>=mod)?(x+y-mod):(x+y)
#define lowbit(x) x&(-x)
#define pi pair<ll,ll>
#define pii pair<ll,pair<ll,ll>>
#define iip pair<pair<ll,ll>,ll>
#define ppii pair<pair<ll,ll>,pair<ll,ll>>
#define fi first
#define se second
#define full(l,r,x) for(auto it=l;it!=r;it++) (*it)=x
#define Full(a) memset(a,0,sizeof(a))
#define open(s1,s2) freopen(s1,"r",stdin),freopen(s2,"w",stdout);
#define For(i,l,r) for(int i=l;i<=r;i++)
#define _For(i,l,r) for(int i=r;i>=l;i--)
using namespace std;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
bool Begin;
const ll N=1e5+10,M=4e6+10,K=17;
inline ll read(){
    ll x=0,f=1;
    char c=getchar();
    while(c<'0'||c>'9'){
        if(c=='-')
          f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9'){
        x=(x<<1)+(x<<3)+(c^48);
        c=getchar();
    }
    return x*f;
}
inline void write(ll x){
	if(x<0){
		putchar('-');
		x=-x;
	}
	if(x>9)
	  write(x/10);
	putchar(x%10+'0');
}
ll op,n,m,t,q,u,v,rt,sum,l1,r1,l2,r2,l,r,cnt;
ll A[N],a[N],b[N],w[N],d[N],siz[N],dfn[N],p1[N],p2[N],ans[M];
ll F[N][K];
vector<pi> X,Y;
vector<ll> E[N];
struct Ques{
    ll x,y;
    ll id;
    ll v;
    inline bool operator<(const Ques &rhs)const{
        if(A[x]^A[rhs.x])
          return A[x]<A[rhs.x];
        return y>rhs.y;
    }
}Q[M];
inline void add(ll u,ll v){
    E[u].push_back(v);
    E[v].push_back(u);
}
inline void dfs(ll u,ll fa){
    For(i,1,K-1)
      F[u][i]=F[F[u][i-1]][i-1];
    dfn[u]=++cnt;
    w[cnt]=a[u];
    siz[u]=1;
    for(auto v:E[u]){
        if(v==fa)
          continue;
        F[v][0]=u;
        d[v]=d[u]+1;
        dfs(v,u);
        siz[u]+=siz[v];
    }
}
inline ll get_fa(ll u,ll k){
    _For(i,0,K-1){
        if((k>>i)&1ll){
            k-=(1ll<<i);
            u=F[u][i];
        }
    }
    return u;
}
inline vector<pi> get(ll x){
    vector<pi> ans;
    if(x==rt)
      ans.push_back({1,n});
    else if(dfn[x]<=dfn[rt]&&dfn[rt]<=dfn[x]+siz[x]-1){
        ll y=get_fa(rt,d[rt]-d[x]-1);
        if(dfn[y]!=1)
          ans.push_back({1,dfn[y]-1});
        if(dfn[y]+siz[y]<=n)
          ans.push_back({dfn[y]+siz[y],n});
    }
    else
      ans.push_back({dfn[x],dfn[x]+siz[x]-1});
    return ans;
}
inline void get(ll l1,ll r1,ll l2,ll r2){
    Q[++q]={r1,r2,cnt,1};
    if(l1-1)
      Q[++q]={l1-1,r2,cnt,-1};
    if(l2-1)
      Q[++q]={r1,l2-1,cnt,-1};
    if(l1-1&&l2-1)
      Q[++q]={l1-1,l2-1,cnt,1};
}
inline void insert1(ll x){
    sum+=p2[w[x]];
    p1[w[x]]++;
}
inline void insert2(ll x){
    sum+=p1[w[x]];
    p2[w[x]]++; 
}
inline void del1(ll x){
    sum-=p2[w[x]];
    p1[w[x]]--;
}
inline void del2(ll x){
    sum-=p1[w[x]];
    p2[w[x]]--;
}
bool End;
int main(){
    n=read(),m=read();
    For(i,1,n){
        a[i]=read();
        b[++cnt]=a[i];
    }
    sort(b+1,b+cnt+1);
    cnt=unique(b+1,b+cnt+1)-(b+1);
    For(i,1,n)
      a[i]=lower_bound(b+1,b+cnt+1,a[i])-b;
    cnt=0;
    For(i,1,n-1){
        u=read(),v=read();
        add(u,v);
    }
    dfs(1,1);
    cnt=0;
    For(i,1,m){
        op=read(),u=read();
        if(op==1){
            rt=u;
            continue;
        }
        ++cnt;
        v=read();
        X=get(u);
        Y=get(v);
        for(auto x:X)
          for(auto y:Y)
            get(x.fi,x.se,y.fi,y.se);
    }
    t=max(n/max((ll)sqrt(m),1ll),1ll);
    For(i,1,n)
      A[i]=(i-1)/t+1;
    For(i,1,q)
      if(Q[i].x>Q[i].y)
        swap(Q[i].x,Q[i].y);
    sort(Q+1,Q+q+1);
    For(i,1,q){
        while(l<Q[i].x)
          insert1(++l);
        while(l>Q[i].x)
          del1(l--);
        while(r<Q[i].y)
          insert2(++r);
        while(r>Q[i].y)
          del2(r--);
        ans[Q[i].id]+=sum*Q[i].v;
    }
    For(i,1,cnt){
        write(ans[i]);
        putchar('\n');
    }
	//cerr<<'\n'<<abs(&Begin-&End)/1048576<<"MB";
	return 0;
}