1. 程式人生 > >[NOI2011]兔農 矩陣乘法

[NOI2011]兔農 矩陣乘法

Descripition 你知道普通的Fibonacci數列嗎,他的遞推式為: f[i]=f[i-1]+f[i-2] 但這題有點不一樣。 這題給出一個k,如果當前f[i]%k==1,f[i]–。 其他還是照樣遞推,讓你輸出f[n]%mod。

Sample Input 6 7 100

Sample Output 7

這題很容易發現他是有迴圈節的,若無迴圈節那麼減1的f數量一定很有限,剩下的做矩陣乘法快速冪即可。 如何找出迴圈節呢,我使用的是exgcd求逆元的方式,因為假設你遇到一個i將他減1變成0,後面的第j項其實就是:fib[j]*f[i-1]。(其中f[i]表示在即成數列中的權值,fib[i]表示在原Fibonacci數列中的第i項的值,那麼其實你就是要找一個fib[j]*f[i-1]%k==1,你先預處理出每個值的位置,解一個exgcd即可。 然後中間是要用矩陣乘法加速的,矩陣有兩種: 一種是普通Fibonacci矩陣的遞推矩陣, 還有一種是減一的矩陣,自己推吧。。。 又臭又長的程式碼

#include <cstdio>
#include <cstring>

using namespace std;
typedef long long LL;

LL mod;
struct matrix {
    LL a[3][3];
    matrix() {memset(a, 0, sizeof(a));}
    friend matrix operator * (matrix a, matrix b) {
        matrix c;
        for(int i = 0; i < 3; i++) {
            for
(int j = 0; j < 3; j++) { for(int k = 0; k < 3; k++) { (c.a[i][j] += a.a[i][k] * b.a[k][j] % mod) %= mod; } } } return c; } } ans, A, h1, h2; int p[1100000]; LL pos[1100000], f[3100000]; int v[1100000]; LL exgcd(LL a, LL b, LL &x, LL &y) { if
(b == 0) {x = 1, y = 0; return a;} else { LL tx, ty, d = exgcd(b, a % b, tx, ty); x = ty, y = tx - ty * (a / b); return d; } } int main() { LL n; scanf("%lld", &n); LL k; scanf("%lld%lld", &k, &mod); LL now = 1, hh = 0; bool bk = 0; f[1] = f[2] = 1; for(int i = 3; ; i++) { f[i] = (f[i - 1] + f[i - 2]) % k; if(!pos[f[i]]) pos[f[i]] = i; if(f[i] == 1 && f[i - 1] == 1) break; } h1.a[0][1] = 1, h1.a[0][0] = h1.a[0][2] = 0; h1.a[1][0] = h1.a[1][1] = 1, h1.a[1][2] = 0; h1.a[2][0] = h1.a[2][1] = 0, h1.a[2][2] = 1; h2.a[0][0] = 1, h2.a[0][1] = h2.a[0][2] = 0; h2.a[1][0] = h2.a[1][2] = 0, h2.a[1][1] = 1; h2.a[2][0] = h2.a[2][1] = -1, h2.a[2][2] = 1; A.a[0][0] = A.a[1][1] = A.a[2][2] = 1; LL uy, ux = 1; for(int i = 1; ; i++) { LL x, y, d = exgcd(now, k, x, y); if(d > 1) {bk = 1; uy = i - 1; break;} x = (x % k + k) % k; if(v[x]) {ux = v[x], uy = i - 1; break;} hh += pos[x]; v[x] = i; p[i] = hh; now = now * f[pos[x] - 1] % k; } ans.a[0][0] = 0; ans.a[0][1] = ans.a[0][2] = 1; for(int i = ux; i <= uy; i++) { int u = p[i] - p[i - 1]; matrix o = h1; while(u) { if(u & 1) A = A * o; o = o * o; u /= 2; } A = A * h2; } if(bk) { if(n < hh) { LL f = n; LL u; for(int i = 1; i <= uy; i++) { if(p[i] > f) { u = f - p[i - 1]; matrix o = h1; while(u) { if(u & 1) ans = ans * o; o = o * o; u /= 2; } break; } u = p[i] - p[i - 1]; matrix o = h1; while(u) { if(u & 1) ans = ans * o; o = o * o; u /= 2; } ans = ans * h2; if(p[i] == f) break; } } else { n -= hh; ans = ans * A; while(n) { if(n & 1) ans = ans * h1; h1 = h1 * h1; n /= 2; } } printf("%lld\n", (ans.a[0][0] + mod) % mod); } else { if(n <= p[uy]) { LL f = n, u; for(int i = 1; i <= uy; i++) { if(p[i] > f) { u = f - p[i - 1]; matrix o = h1; while(u) { if(u & 1) ans = ans * o; o = o * o; u /= 2; } break; } u = p[i] - p[i - 1]; matrix o = h1; while(u) { if(u & 1) ans = ans * o; o = o * o; u /= 2; } ans = ans * h2; if(p[i] == f) break; } printf("%lld\n", (ans.a[0][0] + mod) % mod); } else { for(int i = 1; i < ux; i++) { int u = p[i] - p[i - 1]; matrix o = h1; while(u) { if(u & 1) ans = ans * o; o = o * o; u /= 2; } ans = ans * h2; } LL gg = p[ux - 1]; n -= gg; for(int i = ux; i <= uy; i++) p[i - ux + 1] = p[i] - gg; uy = uy - ux + 1; LL f = n % p[uy]; n /= p[uy]; while(n) { if(n & 1) ans = ans * A; A = A * A; n /= 2; } int u; for(int i = 1; i <= uy; i++) { if(p[i] > f) { u = f - p[i - 1]; matrix o = h1; while(u) { if(u & 1) ans = ans * o; o = o * o; u /= 2; } break; } u = p[i] - p[i - 1]; matrix o = h1; while(u) { if(u & 1) ans = ans * o; o = o * o; u /= 2; } ans = ans * h2; if(p[i] == f) break; } printf("%lld\n", (ans.a[0][0] + mod) % mod); } } return 0; }