多项式2·NTT以及任意模数NTT

$0x01\quad \rm{Preface}$

无特别说明,本文不区分$n$和$N$两种符号,均表示形式为$2^j(j \in \mathbb{N+})$的多项式长度(或者,次数)。

我们知道,对于$FFT$而言,其得以优化成$\log$的根本原因是找到了单位复根这个东西,可以方便处理+计算。而另一种方法则是在模意义下,利用原根的美妙性质,进行多项式卷积。

$\boldsymbol{NTT~\text{(Fast Number-Theoreti Transform)}}$,快速数论变换。在分析$NTT$是如何利用原根之前,需要先分析$FFT$是如何利用的单位复根$^{[1]}$:

  • $\omega_n^n = 1$。

  • $\omega_n^1,\omega_n^2, \omega_n^3\cdots\omega_n^{n-1}$是互不相同的,这样带入计算才能保证插出一个完整的$n$次多项式。

  • {$\omega_n^2$} = {$\omega_{\frac{n}{2}}^{2}$},这使得问题规模可以在计算的时候减半。

  • $$
    \sum \limits _{k=0}^{n-1} (\omega_n^{j-i})^k =
    \begin{cases} 0 \quad i \neq j \newline n \quad i = j \newline \end{cases}
    $$

这样可以保证我们能够使用相同的方法进行逆变换。


首先,原根的基本定义:设$g$为$p$的一个原根,则满足:
$$
𝑔^{𝑝−1} \equiv 1(\mod p) \
∀1≤𝑘<𝑝−1, 𝑔^𝑘 \not \equiv 1(\mod p)
$$
换句话说$g^0,g^1,g^2\cdots,g^{p-2} \quad (\bmod p)$ 是互不相同的数,满足性质二。

同时如果我们令$p-2$作为这个群的阶,那么$g^{p-1}$和$\omega_n^n$,其实就是等价的,只不过
$$
g^{p-1} \equiv1(\bmod~p) \ \omega_n^n=1
$$
而已。于是就满足性质一。

而对于性质三,我们先考虑一个转化。我们如果要将$g$作为单位根的替代的话,就需要用到$g^{\frac{p-1}{N}}$。换句话说,$N | (p-1)$。那么我们便可以令$g_n^k \equiv g^{\frac{k(p-1)}{N}} (\bmod~p)$,得到一个和单位根相似的形式。

那么接下来,因为$p$是素数,所以在$g_n^n\equiv g^{\frac{N(p-1)}{N}}\equiv g^{p-1} \equiv 1 (\bmod ~p)$的基础上,我们可以得到$g_{2n}^{n} \equiv1\quad\text{or}\quad \text{-1} (\bmod~ p)$,那么平方之后性质三便显而易见;或者考虑另一种思路,我们根据刚才得出的、跟二次剩余有些相似的式子,可以得到以下结论:
$$
g_n^{\frac{n}{2}+k}=-g_n^k (\bmod~p)
$$
再结合显而易证的消去引理$g_n^k \equiv g_{jn}^{jk}$,我们可以很自然像$FFT$证明单位复根的折半性一样,证出这个结论。

至于性质四,证明的大体相似于单位单位复根。即:
$$
\sum\limits_{j =0}^{n-1}{(g_n^k)^j} \equiv \frac{(g_n^k)^n -1}{g_n^k -1} \Longrightarrow \frac{(g_n^n)^k -1}{g_n^k -1} \equiv \frac{(1)^k -1}{g_n^k -1} = 0
$$
而对于$n=k$的情况,不适用于普通的几何级数求和,所以直接就是$\sum 1 =n$ 。

$0x02\quad \rm{Codes}$

呃,于是NTT就完了。注意因为要保证$N | (p-1)$,且$N$是$2$的幂次,所以素数$p$一定要是$2^j+1$的形式。

至于求原根,不是本界探讨的内容。普通的NTT模数,原根可以背过;其余情况暴力求+验证即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
int N, M, K, qaq ;
const int MAXN = 3000010 ;
LL A[MAXN], B[MAXN], Inv ;
const double Pi = acos(-1.0) ;
int i, j, k, l, Lim = 1, L, R[MAXN] ;
const int P = 998244353, G = 3, Gi = 332748118 ;

namespace IO{
const int ch_top=4e7+3;
char ch[ch_top],*now_r=ch-1,*now_w=ch-1;
inline int read(){
while(*++now_r<'0');
register int x=*now_r-'0';
while(*++now_r>='0')x=x*10+*now_r-'0';
return x;
}
inline void write(int x){
static char st[20];static int top;
while(st[++top]='0'+x%10,x/=10);
while(*++now_w=st[top],--top);
*++now_w=' ';
}
}
inline LL expow(LL a, LL b){
register LL res = 1 ;
while (b){
if (b & 1) (res *= a) %= P ;
(a *= a) %= P, b >>= 1 ;
}
return res ;
}
void NTT(LL *J, int flag){
for(i = 0; i < Lim; i ++)
if(i < R[i]) swap(J[i], J[R[i]]) ;
for(j = 1; j < Lim; j <<= 1){
LL Gn = expow(flag == 1 ? G : Gi, (P - 1) / (j << 1)) ;
for(k = 0; k < Lim; k += (j << 1) ){
LL g = 1 ;
for(l = 0 ; l < j ; l ++, g = (g * Gn) % P){
LL Nx = J[k + l], Ny = g * J[k + j + l] % P ;
J[k + l] = (Nx + Ny) % P, J[k + j + l] = (Nx - Ny + P) % P ;
}
}
}
}
using namespace IO ;
int main(){
fread(ch,1,ch_top,stdin);
N = read(), M = read() ;
while(Lim <= N + M) Lim <<= 1, L ++ ;
for(i = 0; i <= N; i ++) A[i] = (read() + P) % P ;
for(i = 0; i <= M; i ++) B[i] = (read() + P) % P ;
for(i = 0; i < Lim; i ++ ) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1)) ;
NTT(A, 1), NTT(B, 1);
for(i = 0; i <= Lim; i ++) A[i] = (A[i] * B[i]) % P ; Inv = expow(Lim, P - 2) ;
NTT(A, -1) ;
for(i = 0; i <= N + M; i ++) write((long long) (A[i] * Inv + P) % P) ; fwrite(ch, 1, now_w - ch, stdout) ; return 0 ;
}

其中Gi表示$998244353$的原根的逆元。

$0x03\quad \rm{Extending}$

接上节内容,$NTT$本质上是只能处理“$NTT$模数($p=2^k+1$)”。但是当我们需要对其进行任意模数取模时,就需要我们用$CRT$合并。

然后…我也不知道从哪找到了一张比较浅显的图:

但是这个地方仍然会存在不少问题,巨佬KsCla如此解释:

具体做法是这样:先选取三个乘积在$10^{23 }$以上的便于使用NTT的模数。在这里我选的是$m_1=998244353=223∗119+1$,$m_2=1004535809=221∗479+1$,$m_3=469762049=226∗7+1$。选这三个模数的好处在于它们的原根都是3。
然后用这三个模数做NTT,可以得到以下三条式子:
$$
~ans≡c_1~(\bmod m_1)\\
ans≡c_2~(\bmod m_2)\\
ans≡c_3~(\bmod m_3)
$$
虽然这三条式子可以在$10^{23}$以内唯一固定$ans$的值,但问题也随之而来:$m_1∗m_2∗m_3$很大,无法直接用long long存下,而用long double之类的则会丢失精度,所以无法用普通的$CRT$。难道要写高精度?

不,有一种很妙的方法可以解决这个问题。
首先注意到这里只有三个模数,而且两个模数乘起来是不会爆long long的,所以可以先合并前两条式子。根据CRT,有:
$$
ans≡(c_1m_2Inv(m_2,m_1)+c_2m_1Inv(m_1,m_2))(\bmod m_1m_2)
$$
其中$Inv(x,y)$表示x关于y的逆元。
这条式子涉及到两个很大的数相乘然后再取模,而直接相乘会爆long long。可以用$O(\log(m_1m_2))$的快速乘,或者$O(1)$转double后相乘。
为了方便,把上式化成这样的形式:
$$
ans≡C(\bmod M)
$$
然后设:
$$
ans=xM+C=ym_3+c_3
$$
接下来的部分才是精髓。我们求出$x$在$\bmod m_3$意义下的值:
$$
xM≡c_3−C(\bmod m_3)
$$
在$\bmod m_3$意义下,$ym_3$被消掉了。

然后有:
$$
x≡(c_3−C)M−1(\bmod m_3)
$$
算出右半部分的值为$q$,则可令$x=km_3+qx=km_3+q$。将其代入$ans=xM+C$:
$$
ans=(km_3+q)M+C=km_3M+qM+C
$$
也就是说:
$$ans=km_1m_2m_3+qM+C$$
而由于$ans∈[0,m_1m_2m_3)$,所以$k$必为0。也就是说$ans$就是$qM+C$!直接把这条式子对题面要求的模数取模即可.

嗝……其实我就是加了一遍mkd,但是也算是复习了一遍吧XD。

有一点是需要注意的:

  • 为什么要选三个乘积大于$10^{23}$的质数作为模数?
    • 为了是最后的结果可以不取模

嗯,然后就是板子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <cmath>
#include <cstdio>
#include <iostream>
#include <algorithm>

#define rr register
#define LL long long

const int G = 3 ;
int N, M, K, qaq ;
using namespace std ;
const int MAXN = 600010 ;
int i, j, k, l, Lim = 1, L, R[MAXN] ;
int Mod, P, pr[] = {469762049, 998244353, 1004535809} ;
LL Ans[4][MAXN], A[MAXN], B[MAXN], Inv, A1[MAXN], B1[MAXN] ;

inline LL qr(){
LL k = 0, f = 1 ; char c = getchar() ;
while(!isdigit(c)) {if(c == '-') f = -1 ;c = getchar() ;}
while(isdigit(c)) k = (k << 1) + (k << 3) + c - 48 ,c = getchar() ;
return k * f ;
}
inline LL expow(LL a, LL b, LL p){
register LL res = 1 ;
while (b){
if (b & 1) (res *= a) %= p ;
(a *= a) %= p, b >>= 1 ;
}
return res % p ;
}
void NTT(LL *J, int flag){
rr int i, j, k, l ;
for(i = 0; i < Lim; i ++)
if(i < R[i]) swap(J[i], J[R[i]]) ;
for(j = 1; j < Lim; j <<= 1){
LL Gn = expow(G, (P - 1) / (j << 1), P) ;
for(k = 0; k < Lim; k += (j << 1) ){
LL g = 1 ;
for(l = 0 ; l < j ; l ++, g = (g * Gn) % P){
LL Nx = J[k + l], Ny = g * J[k + j + l] % P ;
J[k + l] = (Nx + Ny) % P, J[k + j + l] = (Nx - Ny + P) % P ;
}
}
}
if (flag > 0) return ;
int Inv = expow(Lim, P - 2, P) ; reverse(J + 1, J + Lim) ;
for (i = 0 ; i <= Lim ; ++ i) J[i] = 1ll * J[i] * Inv % P ;
}
void clear(LL *J, LL *L){
for (rr int i = 0 ; i < Lim ; ++ i) J[i] = L[i] = 0 ;
}
void egg(){ return ; }
LL mul(LL a,LL b,LL p){
LL re = 0;
for (; b; b >>= 1,a = (a + a) % p)
if (b & 1) re = (re + a) % p;
return re;
}
void T_NTT(){
rr int i, j, k ;
for (i = 0 ; i <= 2 ; ++ i){
P = pr[i] ; i ? clear(A1, B1) : egg() ;
for (j = 0 ; j <= N ; ++ j) A1[j] = A[j] ;
for (j = 0 ; j <= M ; ++ j) B1[j] = B[j] ;
NTT(A1, 1), NTT(B1, 1) ;
for (j = 0 ; j < Lim ; ++ j) A1[j] = A1[j] * B1[j] % P ;
NTT(A1, -1) ; for (j = 0 ; j <= Lim ; ++ j) Ans[i + 1][j] = A1[j] ;
}
rr LL Mo = 1ll * pr[1] * pr[0], k1, k2, a, b, c, mod = Mod ;
LL inv1 = expow(pr[1] % pr[0], pr[0] - 2, pr[0]), inv0 = expow(pr[0] % pr[1], pr[1] - 2, pr[1]), inv3 = expow(Mo % pr[2], pr[2] - 2, pr[2]) ;
for (i = 0 ; i <= N + M ; ++ i){
a = Ans[1][i], b = Ans[2][i], c = Ans[3][i] ;
k1 = (mul(a * pr[1] % Mo, inv1, Mo) + mul(b * pr[0] % Mo, inv0, Mo)) % Mo ;
k2 = ((c - k1 % pr[2]) % pr[2] + pr[2]) % pr[2] * inv3 % pr[2], Ans[0][i] = ((k2 % mod) * (Mo % mod) % mod + k1 % mod) % mod;
}/*
LL a,b,c,t,k,M = 1ll * pr[0] * pr[1];
LL inv1 = inv(pr[1],pr[0]),inv0 = inv(pr[0],pr[1]),inv3 = inv(M % pr[2],pr[2]);
for (int i = 0; i <= deg; i++){
a = fft[0].A[i],b = fft[1].A[i],c = fft[2].A[i];
t = (mul(a * pr[1] % M,inv1,M) + mul(b * pr[0] % M,inv0,M)) % M;
k = ((c - t % pr[2]) % pr[2] + pr[2]) % pr[2] * inv3 % pr[2];
ans[i] = ((k % md) * (M % md) % md + t % md) % md;
}*/
}
int main(){
rr int i ;
N = qr(), M = qr(), Mod = qr() ;
while(Lim <= N + M) Lim <<= 1, L ++ ;
for(i = 0; i <= N; i ++) A[i] = qr() % Mod ;
for(i = 0; i <= M; i ++) B[i] = qr() % Mod ;
for(i = 0; i < Lim; i ++ ) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1)) ;
T_NTT() ; for(i = 0; i <= N + M; i ++) printf("%d ", Ans[0][i]) ; return 0 ;
}

于是这玩意儿进行了9遍NTT,那是真的慢。。。

$\rm{Reference}$