跳转至

4 多项式

4.1 NTT

init需要传入卷积的最长长度,N需要比大于这个长度的2的幂更大。

#include <bits/stdc++.h>
#define int long long
using namespace std;
constexpr int RN = 1e6 + 5;
constexpr int N = (1 << (__lg(2 * RN) + 1)) + 5;
constexpr int p = 998244353;
inline int add(const int &x, const int &y) { return x + y >= p ? x + y - p : x + y; }
inline int dec(const int &x, const int &y) { return x < y ? x - y + p : x - y; }

inline int power(int a, int t) {
    int res = 1;
    while (t) {
        if (t & 1) res = res * a % p;
        a = a * a % p;
        t >>= 1;
    }
    return res;
}

int siz;
int rev[N], rt[N], inv[N], fac[N], ifac[N];

void init(int n) {  // 传入最长卷积长度
    int lim = 1;
    while (lim < n) lim <<= 1, ++siz;
    for (int i = 0; i != lim; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (siz - 1));
    int w = power(3, (p - 1) >> siz);
    fac[0] = ifac[0] = rt[lim >> 1] = 1;
    for (int i = (lim >> 1) + 1; i != lim; ++i) rt[i] = rt[i - 1] * w % p;
    for (int i = (lim >> 1) - 1; i; --i) rt[i] = rt[i << 1];

    for (int i = 1; i <= n; ++i) fac[i] = fac[i - 1] * i % p;
    ifac[n] = power(fac[n], p - 2);
    for (int i = n - 1; i; --i) ifac[i] = ifac[i + 1] * (i + 1) % p;
    for (int i = 1; i <= n; ++i) inv[i] = fac[i - 1] * ifac[i] % p;
}

inline void dft(int *f, int n) {
    static unsigned long long a[N];
    int x, shift = siz - __builtin_ctz(n);
    for (int i = 0; i != n; ++i) a[rev[i] >> shift] = f[i];
    for (int mid = 1; mid != n; mid <<= 1)
        for (int j = 0; j != n; j += (mid << 1))
            for (int k = 0; k != mid; ++k) {
                x = a[j | k | mid] * rt[mid | k] % p;
                a[j | k | mid] = a[j | k] + p - x;
                a[j | k] += x;
            }
    for (int i = 0; i != n; ++i) f[i] = a[i] % p;
}

inline void idft(int *f, int n) {
    reverse(f + 1, f + n);
    dft(f, n);
    int x = p - ((p - 1) >> __builtin_ctz(n));
    for (int i = 0; i != n; ++i) f[i] = f[i] * x % p;
}

inline void conv(int *f, int fn, int *g, int gn, int *r) {
    int len = fn + gn - 1, lim = 1;
    while (lim < len) lim <<= 1;
    dft(f, lim);
    dft(g, lim);
    for (int i = 0; i < lim; ++i) r[i] = f[i] * g[i] % p;
    idft(r, lim);
}

inline int getlen(int n) { return 1 << (32 - __builtin_clz(n)); }

inline void inverse(const int *f, int n, int *r) {
    static int g[N], h[N], st[30];
    memset(g, 0, getlen(n << 1) * sizeof(int));
    int lim = 1, top = 0;
    while (n) {
        st[++top] = n;
        n >>= 1;
    }
    g[0] = 1;
    while (top--) {
        n = st[top + 1];
        while (lim <= (n << 1)) lim <<= 1;
        memcpy(h, f, (n + 1) * sizeof(int));
        memset(h + n + 1, 0, (lim - n) * sizeof(int));
        dft(g, lim), dft(h, lim);
        for (int i = 0; i != lim; ++i) g[i] = g[i] * (2 - g[i] * h[i] % p + p) % p;
        idft(g, lim);
        memset(g + n + 1, 0, (lim - n) * sizeof(int));
    }
    memcpy(r, g, (n + 1) * sizeof(int));
}

inline void log(const int *f, int n, int *r) {
    static int g[N], h[N];
    inverse(f, n, g);
    for (int i = 0; i != n; ++i) h[i] = f[i + 1] * (i + 1) % p;
    h[n] = 0;
    int lim = getlen(n << 1);
    memset(g + n + 1, 0, (lim - n) * sizeof(int));
    memset(h + n + 1, 0, (lim - n) * sizeof(int));
    dft(g, lim), dft(h, lim);
    for (int i = 0; i != lim; ++i) g[i] = g[i] * h[i] % p;
    idft(g, lim);
    for (int i = 1; i <= n; ++i) r[i] = g[i - 1] * inv[i] % p;
    r[0] = 0;
}

inline void exp(const int *f, int n, int *r) {
    static int g[N], h[N], st[30];
    memset(g, 0, getlen(n << 1) * sizeof(int));
    int lim = 1, top = 0;
    while (n) {
        st[++top] = n;
        n >>= 1;
    }
    g[0] = 1;
    while (top--) {
        n = st[top + 1];
        while (lim <= (n << 1)) lim <<= 1;
        memcpy(h, g, (n + 1) * sizeof(int));
        memset(h + n + 1, 0, (lim - n) * sizeof(int));
        log(g, n, g);
        for (int i = 0; i <= n; ++i) g[i] = dec(f[i], g[i]);
        g[0] = add(g[0], 1);
        dft(g, lim), dft(h, lim);
        for (int i = 0; i != lim; ++i) g[i] = g[i] * h[i] % p;
        idft(g, lim);
        memset(g + n + 1, 0, (lim - n) * sizeof(int));
    }
    memcpy(r, g, (n + 1) * sizeof(int));
}

4.2 FFT

三次变两次卷积。

#include <bits/stdc++.h
using namespace std;
using cp = complex<double>;
using vp = vector<cp>;
using vi = vector<int>;
vi rev;    
void init_rev(int limit) {
    rev.resize(limit);
    for (int i = 0; i < limit; ++i) rev[i] = (rev[i / 2] / 2 + (i % 2) * limit / 2);
}
void FFT(vp &x, int limit, bool inv = false) {
    for (int i = 0; i < limit; ++i)
        if (i < rev[i]) swap(x[i], x[rev[i]]);
    for (int len = 1; len < limit; len <<= 1) {
        cp wn(cos(PI / len), (-2 * inv + 1) * sin(PI / len));
        for (int i = 0; i < limit; i += 2 * len) {
            cp w(1);
            for (int j = i; j < i + len; j++, w *= wn) {
                cp u = x[j], v = w * x[j + len];
                x[j] = u + v, x[j + len] = u - v;
            }
        }
    }
    if (!inv) return;
    for (auto &i : x) i /= limit;
}
vi operator*(const vi &a, const vi &b) {
    int len = a.size() + b.size() - 1;
    int limit = 1LL << __lg(len);
    if (limit < len) limit <<= 1;
    init_rev(limit);
    vp c(limit);
    for (size_t i = 0; i < limit; i++) {
        c[i] = (double)(i<a.size() ? a[i] : 0LL) + I * (double)(i<b.size() ? b[i] : 0LL);
    }
    FFT(c, limit);
    for (int i = 0; i < limit; ++i) c[i] = c[i] * c[i];
    FFT(c, limit, true);
    vi res(len);
    for (size_t i = 0; i < len; i++) {
        res[i] = (int)(0.5 * c[i].imag() + 0.5);
    }
    return res;
}