NOI2007生成樹計數 狀壓DP+矩陣乘法
阿新 • • 發佈:2019-01-23
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
int mat[3300][3300],k,n,b[3300];
long long ans;
bool get[3300];
void work(){
//for (int i=1;i<=n;++i) printf("%d ",b[i]);printf("\n");
long long t=0;
for (int i=1;i<=n;++i)
for (int j=i+1;j<=n;++j) if (b[i]>b[j]) ++t;
long long s;
if ((t&1)==1) s=-1;else s=1;
long long tmp=1;
for (int i=1;i<=n;++i){
tmp*=mat[i][b[i]];
}
ans+=tmp*s;
}
void find(int x){
if (x==n+1){
work();return;
}
for (int i=1;i<=n;++i){
if (!get[i]&&mat[x][i]!=0){
b[x]=i;get[i]=true;find(x+1);
b[x]=0;get[i]=false;
}
}
}
int main(){
//freopen("count.in","r",stdin);
//freopen("count.out","w",stdout);
scanf("%d%d",&k,&n);
for (int i=1;i<=n;++i){
for (int j=1;j<=n;++j){
if (i==j) {mat[i][j]=i-max(1,i-k)+min(n,k+i)-i;continue;}//前面多少連過來+後面多少連過去
if (abs(i-j)<=k) {mat[i][j]=-1;continue;}
mat[i][j]=0;
}
}
n--;
memset(get,false,sizeof(get));
find(1);
/*for (int i=1;i<=n;++i){
for (int j=1;j<=n;++j){
printf("%d ",mat[i][j]);
}
printf("\n");
}*/
printf("%lld ",ans);
//printf("%d ",-5%4);
return 0;
}
矩陣加速:ORZ純手打程式碼沒有參考別人題解 所以大概很長…
#include <cstdio>
#include <cstring>
#define N 65521
int k,b[60][10],num,numm[10],a[10],s[60],sh[1000000],link[10],father[10],map[60][60];
long long n;
bool flag[10];
int NUM[10]={0,1,1,3,16,125};
struct matrix{
long long f[60][60],l,c;
}base1;
inline matrix multiply(matrix a,matrix b){
matrix c;memset(c.f,0,sizeof(c.f));
c.l=a.l;c.c=b.c;
for (int i=1;i<=c.l;++i)
for (int j=1;j<=c.c;++j)
for (int z=1;z<=a.c;++z){
c.f[i][j]+=a.f[i][z]*b.f[z][j]%N;
c.f[i][j]%=N;
}
return c;
}
inline matrix pow(matrix a,long long b){
matrix r=base1,base=a;
while (b!=0){
if (b&1) r=multiply(base,r);
base=multiply(base,base);
b>>=1;
}
return r;
}
void dfs(int x,int p){
if (x>k){
++num;
for (int i=1;i<=k;++i) b[num][i]=a[i];return;
}
for (int i=1;i<=k;++i){
if (i<=p){
a[x]=i;
if (i==p) dfs(x+1,i+1);else dfs(x+1,p);
}
}
}
void print(int se){
// for (int i=1;i<=k;++i) printf("%d ",link[i]);printf("dd ");
// for (int i=1;i<=k;++i) printf("%d ",b[se][i]);printf("dd ");
// for (int i=1;i<=k;++i) printf("%d ",flag[b[se][i]]);
memset(father,0,sizeof(father));
int num=0;int flagf[60];
memset(flagf,0,sizeof(flagf));
for (int i=1;i<=k;++i) father[i]=b[se][i];
bool flag1[10];memcpy(flag1,flag,sizeof(flag));
for (int i=1;i<=k;++i){
if (flag1[b[se][i]]==true){
father[k+1]=k+1;
for (int j=1;j<=k;++j){
if (b[se][j]==b[se][i]) father[j]=k+1;
}
flag1[b[se][i]]=false;
//if (link[i]==1) flagf[k+1]=num;
}
}
//for (int i=1;i<=k+1;++i) printf("%d ",father[i]);
int sheet[10];memset(sheet,-1,sizeof(sheet));
//printf("\n");
for (int i=2;i<=k+1;++i) {
if (sheet[father[i]]==-1){
++num;flagf[i]=num;sheet[father[i]]=num;continue;
}
flagf[i]=sheet[father[i]];
}
int tmp=0;
//for (int i=1;i<=k+1;++i) printf("%d ",flagf[i]);
for (int i=2;i<=k+1;++i) tmp=tmp*10+flagf[i];
// printf("%d\n",tmp);
map[se][sh[tmp]]+=1;
}
void dfs2(int se,int position){
if (position==k+1){
print(se);return;
}
if (position==1){
bool tmp=false;
for (int i=2;i<=k;++i) if (b[se][i]==1) {
tmp=true;break;
}
if (tmp==false){
flag[1]=true;link[1]=1;
dfs2(se,position+1);
return;
}
}
if (flag[b[se][position]]){
link[position]=0;numm[b[se][position]]++;//printf("%d ",numm[1]);
dfs2(se,position+1);
if (numm[b[se][position]]) numm[b[se][position]]--;
return;
}
flag[b[se][position]]=true;link[position]=1;numm[b[se][position]]++;//printf("%d ",numm[1]);
dfs2(se,position+1);
if (numm[b[se][position]]) numm[b[se][position]]--;
if (numm[b[se][position]]==0) flag[b[se][position]]=false;link[position]=0;
dfs2(se,position+1);
}
void build(){
memset(base1.f,0,sizeof(base1.f));
for (int i=1;i<=num;++i) {
for (int j=1;j<=num;++j) if (i==j) base1.f[i][j]=1;
}
base1.l=base1.c=num;
}
int main(){
freopen("count.in","r",stdin);
freopen("count.out","w",stdout);
scanf("%d %lld",&k,&n);num=0;
dfs(1,1);
//printf("%d ",num);
for (int i=1;i<=num;++i)
for (int j=1;j<=k;++j) s[i]=s[i]*10+b[i][j];
// sort(s+1,s+num+1);
// printf("%d\n",s[6]);
//for (int i=1;i<=num;++i) printf("%d ",s[i]);
for (int i=1;i<=num;++i) sh[s[i]]=i;
// dfs2(2,1);
int tmp1[60];memset(tmp1,0,sizeof(tmp1));
for (int i=1;i<=num;++i){
int t=s[i],p=0;
while (t>0) a[++p]=t%10,t/=10;
for (int j=1;j<=k;++j){
int p=1;
for (int i1=1;i1<=k;++i1){
if (i1!=j&&a[i1]==a[j]) ++p;
}
if (tmp1[i]<p) tmp1[i]=p;
}
memset(flag,false,sizeof(flag));memset(link,0,sizeof(link));
memset(numm,0,sizeof(numm));
dfs2(i,1); //printf("asdfasdf\n");
}
// for (int i=1;i<=num;++i) printf("%d ",NUM[tmp1[i]]);
//print(8);
/*for (int i=1;i<=num;++i){
for (int j=1;j<=num;++j) printf("%d ",map[i][j]);printf("\n");
}*/
build();
matrix p;p.l=num;p.c=num;
for (int i=1;i<=num;++i)
for (int j=1;j<=num;++j) p.f[i][j]=map[i][j];
/* for (int i=1;i<=num;++i){
for (int j=1;j<=num;++j) printf("%d ",base1.f[i][j]);printf("\n");
} */
p=pow(p,n-k);
/*for (int i=1;i<=num;++i){
for (int j=1;j<=num;++j) printf("%d ",p.f[i][j]);
printf("\n");
}*/
long long ans=0;;
for (int i=1;i<=num;++i) (ans+=p.f[i][1]*NUM[tmp1[i]])%=N;
printf("%lld",ans);
/*for (int i=1;i<=n;++i){
for (int j=1;j<=n;++j){
if (i==j) {mat[i][j]=i-max(1,i-k)+min(n,k+i)-i;continue;}//前面多少連過來+後面多少連過去
if (abs(i-j)<=k) {mat[i][j]=-1;continue;}
mat[i][j]=0;
}
}
n--;
memset(get,false,sizeof(get));*/
return 0;
}