4 多项式
4.1 NTT
init需要传入卷积的最长长度,N需要比大于这个长度的2的幂更大。
#include <bits/stdc++.h>
#define int long long
using namespace std;
constexpr int RN = 1e6 + 5;
constexpr int N = (1 << (__lg(2 * RN) + 1)) + 5;
constexpr int p = 998244353;
inline int add(const int &x, const int &y) { return x + y >= p ? x + y - p : x + y; }
inline int dec(const int &x, const int &y) { return x < y ? x - y + p : x - y; }
inline int power(int a, int t) {
int res = 1;
while (t) {
if (t & 1) res = res * a % p;
a = a * a % p;
t >>= 1;
}
return res;
}
int siz;
int rev[N], rt[N], inv[N], fac[N], ifac[N];
void init(int n) { // 传入最长卷积长度
int lim = 1;
while (lim < n) lim <<= 1, ++siz;
for (int i = 0; i != lim; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (siz - 1));
int w = power(3, (p - 1) >> siz);
fac[0] = ifac[0] = rt[lim >> 1] = 1;
for (int i = (lim >> 1) + 1; i != lim; ++i) rt[i] = rt[i - 1] * w % p;
for (int i = (lim >> 1) - 1; i; --i) rt[i] = rt[i << 1];
for (int i = 1; i <= n; ++i) fac[i] = fac[i - 1] * i % p;
ifac[n] = power(fac[n], p - 2);
for (int i = n - 1; i; --i) ifac[i] = ifac[i + 1] * (i + 1) % p;
for (int i = 1; i <= n; ++i) inv[i] = fac[i - 1] * ifac[i] % p;
}
inline void dft(int *f, int n) {
static unsigned long long a[N];
int x, shift = siz - __builtin_ctz(n);
for (int i = 0; i != n; ++i) a[rev[i] >> shift] = f[i];
for (int mid = 1; mid != n; mid <<= 1)
for (int j = 0; j != n; j += (mid << 1))
for (int k = 0; k != mid; ++k) {
x = a[j | k | mid] * rt[mid | k] % p;
a[j | k | mid] = a[j | k] + p - x;
a[j | k] += x;
}
for (int i = 0; i != n; ++i) f[i] = a[i] % p;
}
inline void idft(int *f, int n) {
reverse(f + 1, f + n);
dft(f, n);
int x = p - ((p - 1) >> __builtin_ctz(n));
for (int i = 0; i != n; ++i) f[i] = f[i] * x % p;
}
inline void conv(int *f, int fn, int *g, int gn, int *r) {
int len = fn + gn - 1, lim = 1;
while (lim < len) lim <<= 1;
dft(f, lim);
dft(g, lim);
for (int i = 0; i < lim; ++i) r[i] = f[i] * g[i] % p;
idft(r, lim);
}
inline int getlen(int n) { return 1 << (32 - __builtin_clz(n)); }
inline void inverse(const int *f, int n, int *r) {
static int g[N], h[N], st[30];
memset(g, 0, getlen(n << 1) * sizeof(int));
int lim = 1, top = 0;
while (n) {
st[++top] = n;
n >>= 1;
}
g[0] = 1;
while (top--) {
n = st[top + 1];
while (lim <= (n << 1)) lim <<= 1;
memcpy(h, f, (n + 1) * sizeof(int));
memset(h + n + 1, 0, (lim - n) * sizeof(int));
dft(g, lim), dft(h, lim);
for (int i = 0; i != lim; ++i) g[i] = g[i] * (2 - g[i] * h[i] % p + p) % p;
idft(g, lim);
memset(g + n + 1, 0, (lim - n) * sizeof(int));
}
memcpy(r, g, (n + 1) * sizeof(int));
}
inline void log(const int *f, int n, int *r) {
static int g[N], h[N];
inverse(f, n, g);
for (int i = 0; i != n; ++i) h[i] = f[i + 1] * (i + 1) % p;
h[n] = 0;
int lim = getlen(n << 1);
memset(g + n + 1, 0, (lim - n) * sizeof(int));
memset(h + n + 1, 0, (lim - n) * sizeof(int));
dft(g, lim), dft(h, lim);
for (int i = 0; i != lim; ++i) g[i] = g[i] * h[i] % p;
idft(g, lim);
for (int i = 1; i <= n; ++i) r[i] = g[i - 1] * inv[i] % p;
r[0] = 0;
}
inline void exp(const int *f, int n, int *r) {
static int g[N], h[N], st[30];
memset(g, 0, getlen(n << 1) * sizeof(int));
int lim = 1, top = 0;
while (n) {
st[++top] = n;
n >>= 1;
}
g[0] = 1;
while (top--) {
n = st[top + 1];
while (lim <= (n << 1)) lim <<= 1;
memcpy(h, g, (n + 1) * sizeof(int));
memset(h + n + 1, 0, (lim - n) * sizeof(int));
log(g, n, g);
for (int i = 0; i <= n; ++i) g[i] = dec(f[i], g[i]);
g[0] = add(g[0], 1);
dft(g, lim), dft(h, lim);
for (int i = 0; i != lim; ++i) g[i] = g[i] * h[i] % p;
idft(g, lim);
memset(g + n + 1, 0, (lim - n) * sizeof(int));
}
memcpy(r, g, (n + 1) * sizeof(int));
}
4.2 FFT
三次变两次卷积。
#include <bits/stdc++.h
using namespace std;
using cp = complex<double>;
using vp = vector<cp>;
using vi = vector<int>;
vi rev;
void init_rev(int limit) {
rev.resize(limit);
for (int i = 0; i < limit; ++i) rev[i] = (rev[i / 2] / 2 + (i % 2) * limit / 2);
}
void FFT(vp &x, int limit, bool inv = false) {
for (int i = 0; i < limit; ++i)
if (i < rev[i]) swap(x[i], x[rev[i]]);
for (int len = 1; len < limit; len <<= 1) {
cp wn(cos(PI / len), (-2 * inv + 1) * sin(PI / len));
for (int i = 0; i < limit; i += 2 * len) {
cp w(1);
for (int j = i; j < i + len; j++, w *= wn) {
cp u = x[j], v = w * x[j + len];
x[j] = u + v, x[j + len] = u - v;
}
}
}
if (!inv) return;
for (auto &i : x) i /= limit;
}
vi operator*(const vi &a, const vi &b) {
int len = a.size() + b.size() - 1;
int limit = 1LL << __lg(len);
if (limit < len) limit <<= 1;
init_rev(limit);
vp c(limit);
for (size_t i = 0; i < limit; i++) {
c[i] = (double)(i<a.size() ? a[i] : 0LL) + I * (double)(i<b.size() ? b[i] : 0LL);
}
FFT(c, limit);
for (int i = 0; i < limit; ++i) c[i] = c[i] * c[i];
FFT(c, limit, true);
vi res(len);
for (size_t i = 0; i < len; i++) {
res[i] = (int)(0.5 * c[i].imag() + 0.5);
}
return res;
}