1. 程式人生 > 其它 >題解 P4427 [BJOI2018]求和

題解 P4427 [BJOI2018]求和

Solution

lca+字首和
快讀不要忘寫c=
複雜度 \(O(nk)\)

Code

#include<iostream>
#include<cstdio>
#include<cstdlib>
#define ll long long
using namespace std;
const int maxn=3e5+5;
int read()
{
    int ret=0;char c=getchar();
    while(c>'9'||c<'0')c=getchar();
    while(c>='0'&&c<='9')ret=(ret<<3)+(ret<<1)+(c^48),c=getchar();
    return ret;
}
const int mod=998244353;
class graph
{
    public:
    int n,m;
    int head[maxn],ver[maxn<<1],nxt[maxn<<1];
    int tot;
    void add(int x,int y)
    {
        ver[++tot]=y;nxt[tot]=head[x];
        head[x]=tot;
    }
    void link(int x,int y){add(x,y);add(y,x);}
    int dep[maxn],fa[maxn],son[maxn],siz[maxn],top[maxn],val[maxn][51];
    void getdep(int u,int f)
    {
        siz[u]=1;fa[u]=f;
        for(int i=head[u];i;i=nxt[i])
        {
            if(ver[i]==f)continue;
            dep[ver[i]]=dep[u]+1;
            getdep(ver[i],u);
            siz[u]+=siz[ver[i]];
            son[u]=(siz[son[u]]>siz[ver[i]])?son[u]:ver[i];
        }
    }
    void getroad(int u,int ance)
    {
        top[u]=ance;
        if(son[u])getroad(son[u],ance);
        for(int i=head[u];i;i=nxt[i])
        {
            if(ver[i]==fa[u]||ver[i]==son[u])continue;
            getroad(ver[i],ver[i]);
        }
    }
    void getval(int u)
    {
        for(int i=head[u];i;i=nxt[i])
        {
            if(ver[i]==fa[u])continue;
            for(int j=1;j<=50;j++)val[ver[i]][j]=(val[ver[i]][j]+val[u][j])%mod;
            getval(ver[i]);
        }
    }
    int lca(int x,int y)
    {
        while(top[x]!=top[y])
        {
            if(dep[top[x]]<dep[top[y]])swap(x,y);
            x=fa[top[x]];
        }
        return (dep[x]<dep[y])?x:y;
    }
    void init()
    {
        n=read();
        for(int i=1;i<=n-1;i++)link(read(),read());
        getdep(1,1);getroad(1,1);
        for(int i=1;i<=n;i++)
        {
            val[i][0]=1;
            for(int j=1;j<=50;j++)val[i][j]=(ll)val[i][j-1]*dep[i]%mod;
        }
        getval(1);
    }
    void solve()
    {
        m=read();
        while(m--)
        {
            int x,y,k;x=read();y=read();k=read();
            int f=lca(x,y);
            ll ret=(ll)val[x][k]+val[y][k]-val[f][k]-val[fa[f]][k];
            ret=(ret%mod+mod)%mod;
            printf("%lld\n",ret);
        }
    }
}o;
int main()
{
    o.init();
    o.solve();
    return 0;
}