Codeforces Round #499 (Div. 1) F. Tree
题目链接
\(\rm CodeForces\):https://codeforces.com/contest/1010/problem/F
Solution
设\(v_i\)表示第\(i\)个点的果子数,设\(b_i=v_i-\sum_{x\in son}v_x\),显然依题意要满足\(b_i\geqslant 0\)。
根据差分的性质我们可以得到\(\sum b_i=x\)。
假设我们硬点树上剩下了\(m\)个点,则根据插板法\(b_i\)的方案数为\(\displaystyle\binom{x+m-1}{m-1}\)。
由于\(b\)唯一确定\(v\),所以\(v\)的方案数也是这么多。
那么我们就考虑剩下\(m\)个点的方案数。
考虑一种暴力的\(dp\),设\(f[u][i]\)表示\(x\)子树保留了\(i\)个点(包括\(i\))的方案数,转移显然:
\[ f[u][i]=\sum_{k=1}^{i-1}f[l][k]\cdot f[r][i-1-k] \]
写成生成函数就是:
\[ F_u(x)=xF_l(x)F_r(x)+1 \]
若\(u\)只有一个儿子也同理:
\[ F_u(x)=xF_{son}(x)+1 \]
若\(u\)为叶子则:
\[ F_u(x)=x+1 \]
证明显然。
考虑这个玩意怎么优化,显然如果\(\rm NTT\)优化复杂度\(O(n^2\log n)\),这是我们无法接受的。
我们考虑对这棵树进行轻重链剖分,那么对于一条重链上的每个点我们假设求出了非重儿子的\(F(x)\)。
那么我们对这条重链进行编号,从顶端到叶子为\(1\cdots c\),设\(a_i=xF_{son_i}(x)\)。
那么链顶的答案就是:
\[ F_1(x)=xF_{son_1}(x)F_2(x)+1=a_1F_2(x)+1 \]
我们递归的写完所有\(c\)个:
\[ F_1(x)=a_1(a_2((\cdots(a_c+1))+1)+1)+1 \]
暴力展开就是:
\[ F_1(x)=a_1a_2\cdots a_c+a_1a_2\cdots a_{c-1}+\cdots+a_1+1 \]
这个可以分治\(\rm FFT\)解决,代码大概长这样:
void solve(int lt,int rt,vec &a,vec &b) { if(lt==rt) return a=b=r[lt],void(); vec al,ar,bl,br;int mid=(lt+rt)>>1;solve(lt,mid,al,bl),solve(mid+1,rt,br); b=poly::pmul(bl,br);a=poly::padd(poly::pmul(ar,al); }
其中\(al,ar\)表示答案,\(bl,br\)表示区间乘积,其他的变量可以参考下下面的代码。
那么我们就解决了这个问题,因为总轻边的子树大小之和为\(O(n\log n)\),所以总复杂度为\(O(n\log ^3n)\)。
Code
#include<bits/stdc++.h> using namespace std; void read(int &x) { x=0;int f=1;char ch=getchar(); for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f; for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f; } void print(int x) { if(x<0) putchar('-'),x=-x; if(!x) return ;print(x/10),putchar(x%10+48); } void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');} #define lf double #define ll long long #define pii pair<int,int > #define vec vector<int > #define pb push_back #define mp make_pair #define fr first #define sc second #define _sz(x) ((int)x.size()) #define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++) const int maxn = 1<<19|10; const int inf = 1e9; const lf eps = 1e-8; const int mod = 998244353; int add(int x,int y) {return x+y>=mod?x+y-mod:x+y;} int del(int x,int y) {return x-y<0?x-y+mod:x-y;} int mul(int x,int y) {return 1ll*x*y-1ll*x*y/mod*mod;} int qpow(int a,int x) { int res=1; for(;x;x>>=1,a=mul(a,a)) if(x&1) res=mul(res,a); return res; } namespace poly { int N,w[maxn],pos[maxn],bit,mxn,t[2][maxn]; void init(int l) { for(mxn=1;mxn<=l;mxn<<=1) ; w[0]=1,w[1]=qpow(3,(mod-1)/mxn); for(int i=2;i<=mxn;i++) w[i]=mul(w[i-1],w[1]); } void ntt(int *r,int op) { FOR(i,1,N-1) if(pos[i]>i) swap(r[pos[i]],r[i]); for(int i=1,d=mxn>>1;i<N;i<<=1,d>>=1) for(int j=0;j<N;j+=i<<1) for(int k=0;k<i;k++) { int x=r[j+k],y=mul(r[i+j+k],w[k*d]); r[j+k]=add(x,y),r[i+j+k]=del(x,y); } if(op==-1) { reverse(r+1,r+N);int d=qpow(N,mod-2); for(int i=0;i<N;i++) r[i]=mul(r[i],d); } } vec pmul(vec a,vec b) { if(1ll*_sz(a)*_sz(b)<=5000) { vec c;c.resize(_sz(a)+_sz(b)-1); FOR(i,_sz(a)-1) FOR(j,_sz(b)-1) c[i+j]=add(c[i+j],mul(a[i],b[j])); return c; } for(N=1,bit=0;N<_sz(a)+_sz(b);N<<=1,bit++); FOR(i,N-1) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1)); FOR(i,_sz(a)-1) t[0][i]=a[i];FOR(i,_sz(a),N) t[0][i]=0; FOR(i,_sz(b)-1) t[1][i]=b[i];FOR(i,_sz(b),N) t[1][i]=0; ntt(t[0],1),ntt(t[1],1); FOR(i,N-1) t[0][i]=mul(t[0][i],t[1][i]); ntt(t[0],-1);vec c; FOR(i,_sz(a)+_sz(b)-1) c.pb(t[0][i]); return c; } vec padd(vec a,vec b) { if(_sz(a)>_sz(b)) {FOR(i,_sz(b)-1) a[i]=add(a[i],b[i]);return a;} FOR(i,_sz(a)-1) b[i]=add(a[i],b[i]);return b; } } ll k; int n,ch[maxn],head[maxn],tot,sz[maxn],F[maxn],cnt; struct edge{int to,nxt;}e[maxn<<1]; vec f[maxn],r[maxn]; void ins(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;} void dfs(int x,int fa) { sz[x]=1;F[x]=fa; for(int i=head[x],v;i;i=e[i].nxt) if((v=e[i].to)!=fa) {dfs(v,x);sz[x]+=sz[v];if(sz[ch[x]]<sz[v]) ch[x]=v;} } void solve(int lt,al); } vec dfs2(int x) { for(int t=x;t;t=ch[t]) { for(int i=head[t];i;i=e[i].nxt) if(e[i].to!=F[t]&&e[i].to!=ch[t]) f[t]=dfs2(e[i].to); if(_sz(f[t])<1) f[t].resize(1);f[t][0]++; f[t].insert(f[t].begin(),0); } cnt=0;for(int t=x;t;t=ch[t]) r[++cnt]=f[t]; vec a,b;solve(1,cnt,a,b);return a; } int main() { read(n),scanf("%lld",&k);poly::init(n<<1);k%=mod; for(int i=1,x,y;i<n;i++) read(x),read(y),ins(x,ins(y,x); dfs(1,0);vec res=dfs2(1);int t=1,ans=0; for(int i=1;i<_sz(res);i++) { ans=add(ans,mul(res[i],t)); t=mul(t,mul((k+i)%mod,qpow(i,mod-2))); }write(ans); return 0; }
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。