1. 程式人生 > 實用技巧 >拉格朗日插值法學習筆記

拉格朗日插值法學習筆記

適用範圍:給出n個點$(x_i,y_i)$,過這n個點能夠確定一個最高n-1次的多項式$f(x)$,求$f(k)$

做法:如圖所示,我們將每一個點$(x_i,y_i)$在x軸上的投影$(x_i,0)$記為$H_i$。對每一個i,我們選擇一個點集$\lbrace P_i\rbrace \cup \lbrace H_j \vert 1 \le i\le n, j \neq i\rbrace$,求過這n個點的最高n-1次的多項式$g_i(x)$。這樣我們就得到了n個$g_i(x)$,它們都在各自對應的$x_i$處的值為$y_i$,並且在其它$x_j(i \neq j)$處值為0

很容易就能夠構造出$g_i(x)$的表示式:

$$g_i(x)=y_i*\prod_{j\neq i}\frac{x-x_j}{x_i-x_j}$$

顯然最後有:

$$f(x)=\sum_{i=1}^{n}g_i(x)=\sum _{i=1}^{n}y_i*\prod_{j\neq i}\frac{x-x_j}{x_i-x_j}$$

由於只用求$f(k)$的值,代入得$f(k)=\sum _{i=1}^{n}y_i*\prod_{j\neq i}\frac{k-x_j}{x_i-x_j}$

例題:

luogu 4781 [模板]拉格朗日插值

#include <iostream>
#include <algorithm>
#include 
<cstring> #include <cstdio> using namespace std; typedef long long ll; const ll mod = 998244353; const int N = 2010; int n; ll k, x[N], y[N]; ll power(ll a, ll n) { ll res = 1; while (n) { if (n & 1) res = res * a % mod; a = a * a % mod; n >>= 1
; } return res; } int main() { scanf("%d%lld", &n, &k); for (int i = 1; i <= n; i++) scanf("%lld%lld", &x[i], &y[i]); ll res = 0; for (int i = 1; i <= n; i++) { ll ta = y[i] % mod, tb = 1; for (int h = 1; h <= n; h++) { if (h == i) continue; ta = (ta * ((k - x[h]) % mod + mod) % mod) % mod; tb = (tb * ((x[i] - x[h]) % mod + mod) % mod) % mod; } res = (res + ta * power(tb, mod - 2) % mod) % mod; } printf("%lld\n", res); return 0; }
[模板]拉格朗日插值

Codeforces - 622FThe Sum of the k-th Powers

題意:求$\sum_{i=1}^{n}i^k,1\leq n\leq 10^9,0\leq k \leq 10^6$

思路:首先有一個結論,$\sum_{i=1}^{n}i^k$為k+1階多項式,所以我們只需要暴力算出$f(n)=\sum_{i=1}^{n}i^k$的前k+2項,然後用拉格朗日插值法求第n項即可

下面簡單證明一下$\sum_{i=1}^{n}i^k$為k+1階多項式

對於一個數列${a_n}$來說,把數列${a_n}$的元素兩兩做差得到數列${b_n}$,我們稱數列${b_n}$為數列${a_n}$的一階階差數列,如果再將數列${b_n}$的元素兩兩做差得到數列${c_n}$,我們稱數列${c_n}$為數列${a_n}$的一階階差數列,依次類推定義p階階差數列


定理:數列${a_n}$是一個p階等差數列的充要條件是數列的通項$a_n$為n的一個p次多項式

證明:設$f(x)=\sum_{i=0}^{n}u_i x^i$,令${b_n}$為${a_n}$的一階階差數列

那麼$\Deltaf(x)=f(x+1)-f(x)=\sum_{i=0}^{n}u_i (x+1)^i-\sum_{i=0}^{n}u_i x^i$

我們只考慮$x^n$有$\Delta f(x)=u_n(x+1)^n-u_n x^n$

將$u_n(x+1)^n$二項式展開,僅考慮$x^n$有$\Delta f(x)=u_n x^n-u_n x^n=0$

所以每次差分後多項式的最高次數會減1,p次差分後,變為常數項,此時多項式的最高次數為0,定理成立


我們將$\sum_{i=1}^{n}i^k$差分後得$1^k,2^k,3^k\dots n^k$

顯然$1^k,2^k,3^k\dots n^k$為k階多項式,所以$\sum_{i=1}^{n}i^k$為k+1階多項式

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
 
using namespace std;
 
typedef long long ll;
 
const int N = 2000010;
const ll mod = 1000000007;
 
int n, k;
ll a[N], f[N], nf[N], now;
 
ll power(ll a, ll n)
{
    ll res = 1;
    while (n) {
        if (n & 1) res = res * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return res;
}
 
ll solve()
{
    now = f[0] = nf[0] = 1;
    for (int i = 1; i <= k + 2; i++) {
        now = now * (n - i) % mod;
        f[i] = f[i - 1] * i % mod;
        nf[i] = -nf[i - 1] * i % mod;
    }
    ll res = 0;
    for (int i = 1; i <= k + 2; i++) {
        ll t = power(f[i - 1] * nf[k + 2 - i] % mod, mod - 2);
        res = (res + a[i] * now % mod * power(n - i, mod - 2) % mod * t % mod) % mod;
        res = (res + mod) % mod;
    }
    return res;
}
 
int main()
{
    // freopen("in.txt", "r", stdin);
    // freopen("out.txt", "w", stdout);
    scanf("%d%d", &n, &k);
    for (int i = 1; i <= min(k + 2, n); i++)
        a[i] = (a[i - 1] + power(i, k)) % mod;
    if (n <= k + 2) printf("%lld\n", a[n]);
    else printf("%lld\n", solve());
    return 0;
}
The Sum of the k-th Powers

luogu 4593[TJOI2018]教科書般的褻瀆

題意:求$\sum_{i=0}^m f(n-a_i)-\sum_{i=0}^{m-1} \sum_{j=i+1}^{m}(a_j-a_i)^m+1$,其中$f(n)=\sum_{i=1}^{n} i^{m+1}$

思路:後半部分$\sum_{i=0}^{m-1} \sum_{j=i+1}^{m}(a_j-a_i)^m+1$可以直接暴力求解

$f(n)=\sum_{i=1}^{n} i^{m+1}$為m+2階多項式,考慮到m不是很大,所以我們可以插入m+3個值後,暴力對每一個$a_i$求一次f(n-a_i),最後相加即可

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>

using namespace std;

typedef long long ll;

const int N = 60;
const ll mod = 1000000007;

int T;
ll n, m, now;
ll a[N], c[N], f[N], nf[N];

ll power(ll a, ll n)
{
    ll res = 1;
    while (n) {
        if (n & 1) res = res * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return res;
}

ll solve(ll n)
{
    now = f[0] = nf[0] = 1;
    for (int i = 1; i <= m + 3; i++) {
        now = now * (n - i) % mod;
        f[i] = f[i - 1] * i % mod;
        nf[i] = -nf[i - 1] * i % mod;
    }
    ll res = 0;
    for (int i = 1; i <= m + 3; i++) {
        ll t = power(f[i - 1] * nf[m + 3 - i] % mod, mod - 2);
        res = (res + c[i] * now % mod * power(n - i, mod - 2) % mod * t % mod) % mod;
        res = (res + mod) % mod;
    }
    return res;
}

ll calc(ll x)
{
    if (x <= m + 3) return c[x];
    return solve(x);
}

int main()
{
    // freopen("in.txt", "r", stdin);
    // freopen("out.txt", "w", stdout);
    scanf("%d", &T);
    while (T--) {
        scanf("%lld%lld", &n, &m);
        for (int i = 1; i <= m; i++) scanf("%lld", &a[i]);
        sort(a + 1, a + m + 1);
        for (int i = 1; i <= min(m + 3, n); i++)
            c[i] = (c[i - 1] + power(i, m + 1)) % mod;
        ll res = 0;
        for (int i = 0; i <= m; i++) res = (res + calc(n - a[i])) % mod;
        for (int i = 0; i <= m - 1; i++) {
            for (int h = i + 1; h <= m; h++) {
                ll t = power(a[h] - a[i], m + 1);
                res = ((res - t) % mod + mod) % mod;
            }
        }
        printf("%lld\n", res);
    }
    return 0;
}
[TJOI2018]教科書般的褻瀆

bzoj 3453 tyvj 1858 XLkxc

題意:求$\sum_{i=0}^n \sum_{j=1}^{a+i*d} \sum_{x=1}^{j}x^k$

思路:$f(n)=\sum_{x=1}^{n}x^k$,為k+1階多項式

$g(n)=\sum_{i=1}^{n} \sum_{j=1}^{i} j^k$,g(n)差分後的結果為f(n),所以g(n)為k+2階多項式

$res(n)=\sum_{i=0}^{n}g(a+i*d)$,res(n)進行k+5次差分後為0,所以res(n)為k+4階多項式

然後插值求解即可

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>

using namespace std;

typedef long long ll;

const int N = 200;
const ll mod = 1234567891;

int T, k;
ll a, n, d, now, fc[N], nfc[N];
ll f[N], g[N];

ll power(ll a, ll n)
{
    ll res = 1;
    while (n) {
        if (n & 1) res = res * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return res;
}

ll solve(ll n, int m, ll *c)
{
    now = fc[0] = nfc[0] = 1;
    for (int i = 1; i <= m; i++) {
        now = now * (n - i) % mod;
        fc[i] = fc[i - 1] * i % mod;
        nfc[i] = -nfc[i - 1] * i % mod;
    }
    ll res = 0;
    for (int i = 1; i <= m; i++) {
        ll t = power(fc[i - 1] * nfc[m - i] % mod, mod - 2);
        res = (res + c[i] * now % mod * power(n - i, mod - 2) % mod * t % mod) % mod;
        res = (res + mod) % mod;
    }
    return res;
}

ll calc(ll n, int m, ll *c)
{
    if (n <= m) return c[n];
    return solve(n, m, c);
}

int main()
{
    // freopen("in.txt", "r", stdin);
    // freopen("out.txt", "w", stdout);
    scanf("%d", &T);
    while (T--) {
        scanf("%d%lld%lld%lld", &k, &a, &n, &d);
        for (int i = 0; i <= k + 3; i++) f[i] = (f[i - 1] + power(i, k)) % mod;
        for (int i = 1; i <= k + 3; i++) f[i] = (f[i] + f[i - 1]) % mod;
        for (int i = 0; i <= k + 5; i++) g[i] = calc((a + i * d) % mod, k + 3, f);
        for (int i = 1; i <= k + 5; i++) g[i] = (g[i] + g[i - 1]) % mod;
        printf("%lld\n", calc(n, k + 5, g));
    }
    return 0;
}
tyvj 1858 XLkxc