# 快速傅里叶变换

快速傅里叶变换(Fast Fourier Transform, FTT)在CP中最主要的应用是计算多项式乘法。

# 多项式的系数表示和点值表示

假设f(x)f(x)xxnn阶多项式,则其可以表示为:

f(x)=i=0naixif(x)=\sum_{i=0}^na_ix^i

这里的n+1n+1个系数{a0,a1,,an}\{a_0,a_1,\cdots,a_n\}就称为多项式f(x)f(x)的系数表示。

另一方面,我们也可以把f(x)f(x)看成是一个关于xx的函数,我们可以取n+1n+1个不同的xix_i,用{(x0,f(x0)),(x1,f(x1)),(xn,f(xn))}\{(x_0,f(x_0)),(x_1,f(x_1)),\cdots(x_n,f(x_n))\}n+1n+1个数值对来唯一确定f(x)f(x),这种表示形式就称为多项式f(x)f(x)的点值表示。

# 点值表示与多项式乘法的关系

假设我们现在要求的是F(x)=f(x)g(x)F(x)=f(x)\cdot g(x),如果我们已知f(x)f(x)g(x)g(x)的点值表示,那么我们可以非常容易地得到F(x)F(x)的点值表示为

{(x0,f(x0)g(x0)),(x1,f(x1)g(x1)),,(xn,f(xn)g(xn))}\{(x_0,f(x_0)g(x_0)),(x_1,f(x_1)g(x_1)),\cdots,(x_n,f(x_n)g(x_n))\}

注意这里的nn实际上要取到f(x)f(x)g(x)g(x)的阶数之和。

现在的关键问题是,如何快速将这一点值表示转换为系数表示。

# FFT的实现

为了解决这一问题,我们首先考虑其逆问题,也即:如何从系数表示快速计算点值表示。

# FFT

暴力计算nn对点值的总时间复杂度为O(n2)O(n^2)。如何优化呢?我们希望我们选择的nnxix_i之间存在一定的关系,使得我们可以复用xikx_i^k的计算结果。那么,应该如何选择呢?

前人的经验告诉我们,可以选择单位复根ωni\omega_n^i。它有三个重要的性质:

ωnn=1\omega_n^n=1

ωni=ω2n2i\omega_n^i=\omega_{2n}^{2i}

ω2nn+i=ω2ni\omega_{2n}^{n+i}=-\omega_{2n}^i

利用上述这三个性质,我们可以实现计算过程的简化。

不妨考虑一个最高阶为7阶的多项式

f(x)=a0+a1x1+a2x2+a3x3+a4x4+a5x5+a6x6+a7x7f(x)=a_0+a_1x^1+a_2x^2+a_3x^3+a_4x^4+a_5x^5+a_6x^6+a_7x^7

可以把奇偶项分别处理

f(x)=(a0+a2x2+a4x4+a6x6)+x(a1+a3x2+a5x4+a7x6)=G(x2)+xH(x2)\begin{aligned} f(x) &=(a_0+a_2x^2+a_4x^4+a_6x^6)+x(a_1+a_3x^2+a_5x^4+a_7x^6) \\ &=G(x^2)+xH(x^2) \end{aligned}

从而

DFT(f(x))=DFT(G(x2))+xDFT(H(x2))\text{DFT}(f(x))=\text{DFT}(G(x^2))+x\text{DFT}(H(x^2))

这时把单位复根ωnk\omega_n^kk<n/2k<n/2)代入,可以得到

DFT(f(ωnk))=DFT(G(ωn2k))+ωnkDFT(H(ωn2k))=DFT(G(ωn/2k))+ωnkDFT(H(ωn/2k))\begin{aligned} \text{DFT}(f(\omega_n^k))&=\text{DFT}(G(\omega_n^{2k}))+\omega_n^k\text{DFT}(H(\omega_n^{2k})) \\ &=\text{DFT}(G(\omega_{n/2}^k))+\omega_n^k\text{DFT}(H(\omega_{n/2}^k)) \end{aligned}

而另一方面,代入ωnk+n/2\omega_n^{k+n/2}可以得到

DFT(f(ωnk+n/2))=DFT(G(ωn2k+n))+ωnkDFT(H(ωn2k+n))=DFT(G(ωn/2k+n/2))+ωnk+n/2DFT(H(ωn/2k+n/2))=DFT(G(ωn/2k))ωnkDFT(H(ωn/2k))\begin{aligned} \text{DFT}(f(\omega_n^{k+n/2}))&=\text{DFT}(G(\omega_n^{2k+n}))+\omega_n^k\text{DFT}(H(\omega_n^{2k+n})) \\ &=\text{DFT}(G(\omega_{n/2}^{k+n/2}))+\omega_n^{k+n/2}\text{DFT}(H(\omega_{n/2}^{k+n/2})) \\ &=\text{DFT}(G(\omega_{n/2}^k))-\omega_n^k\text{DFT}(H(\omega_{n/2}^k)) \end{aligned}

因此,我们只要求得DFT(G(ωn/2k))\text{DFT}(G(\omega_{n/2}^k))DFT(H(ωn/2k))\text{DFT}(H(\omega_{n/2}^k)),就可以同时求得DFT(f(ωnk))\text{DFT}(f(\omega_n^k))DFT(f(ωnk+n/2))\text{DFT}(f(\omega_n^{k+n/2})),这样就把问题规模缩小了一半。

使用同样的方法对DFT(G(ωn/2k))\text{DFT}(G(\omega_{n/2}^k))DFT(H(ωn/2k))\text{DFT}(H(\omega_{n/2}^k))进行递归求解,我们有

T(n)=2T(n/2)T(n)=2T(n/2)

可知总的时间复杂度为O(nlogn)O(n\log n)

在这一过程中,我们默认n/2n/2总是整数,因此我们需要n=2kn=2^k。所以在计算之前,我们要先对系数补0,使得总的项数变为2的幂次。

# 逆FFT

将FFT的运算过程看做一个矩阵乘法,逆FFT,也即从点值表示求取系数表示的过程,可以视为左乘逆矩阵。在点值表示的点选取为ωnk\omega_n^k时,FFT矩阵A(ωnk)\mathbb{A}(\omega_n^k)的逆矩阵恰好为1nA(ωnk)\frac{1}{n}\mathbb{A}(\omega_n^{-k}),因此可以复用FFT的计算过程,只需要加上一个标志变量来表示当前是在进行FFT还是IFFT。

# 模板题:洛谷 P3803 - 多项式乘法(FFT) (opens new window)

下面给出了本题的递归实现。

参考代码(C++)
#include <cmath>
#include <complex>
#include <iostream>
#define MAXN (1 << 22)

using namespace std;
typedef complex<double> cd;
const cd I{0, 1};
cd tmp[MAXN], a[MAXN], b[MAXN];
void fft(cd *f, int n, int rev) {
  if (n == 1)
    return;
  for (int i = 0; i < n; ++i)
    tmp[i] = f[i];
  for (int i = 0; i < n; ++i) {
    if (i & 1)
      f[n / 2 + i / 2] = tmp[i];
    else
      f[i / 2] = tmp[i];
  }
  cd *g = f, *h = f + n / 2;
  fft(g, n / 2, rev), fft(h, n / 2, rev);
  cd omega = exp(I * (2 * M_PI / n * rev)), now = 1;
  for (int k = 0; k < n / 2; ++k) {
    tmp[k] = g[k] + now * h[k];
    tmp[k + n / 2] = g[k] - now * h[k];
    now *= omega;
  }
  for (int i = 0; i < n; ++i)
    f[i] = tmp[i];
}
int main() {
  int n, m;
  cin >> n >> m;
  int k = 1 << (32 - __builtin_clz(n + m + 1));
  for (int i = 0; i <= n; ++i)
    cin >> a[i];
  for (int j = 0; j <= m; ++j)
    cin >> b[j];
  fft(a, k, 1);
  fft(b, k, 1);
  for (int i = 0; i < k; ++i)
    a[i] *= b[i];
  fft(a, k, -1);
  for (int i = 0; i < k; ++i)
    a[i] /= k;
  for (int i = 0; i < n + m + 1; ++i)
    cout << (int)round(a[i].real()) << " ";
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

上述递归方法的常数较大,不能通过洛谷P3803的最后两个测试点。

为了改写非递归方法,我们引入蝴蝶变换的概念。

# 蝴蝶变换

继续使用前面的例子,经过第一步分治,将原来的系数分为两组:

{a0,a2,a4,a6},{a1,a3,a5,a7}\{a_0,a_2,a_4,a_6\},\{a_1,a_3,a_5,a_7\}

继续进行第二步分治,得到四组系数:

{a0,a4},{a2,a6},{a1,a5},{a3,a7}\{a_0,a_4\},\{a_2,a_6\},\{a_1,a_5\},\{a_3,a_7\}

最后一步分治,得到八组系数:

{a0},{a4},{a2},{a6},{a1},{a5},{a3},{a7}\{a_0\},\{a_4\},\{a_2\},\{a_6\},\{a_1\},\{a_5\},\{a_3\},\{a_7\}

所谓蝴蝶变换,指的就是从a0,a1,,an1{a_0,a_1,\cdots,a_{n-1}}这一原始系数序列,变换得到最后一步分治后的系数序列。

观察后可以发现,在蝴蝶变换的最终结果中,系数下标的二进制表示恰好是其所在位置二进制表示的逆序,因此,可以利用这一规律来求取蝴蝶变换的结果。

直接利用规律来计算的复杂度是O(nlogn)O(n\log n),如果从小到大递推实现,复杂度则为O(n)O(n)

# FFT的非递归实现

下面给出了洛谷P3803的非递归实现。

参考代码(C++)
#include <cmath>
#include <complex>
#include <iostream>
#define MAXN (1 << 22)

using namespace std;
typedef complex<double> cd;
const cd I{0, 1};
cd a[MAXN], b[MAXN];
void change(cd *f, int n) {
  int i, j, k;
  for (int i = 1, j = n / 2; i < n - 1; i++) {
    if (i < j)
      swap(f[i], f[j]);
    k = n / 2;
    while (j >= k) {
      j = j - k;
      k = k / 2;
    }
    if (j < k)
      j += k;
  }
}
void fft(cd *f, int n, int rev) {
  change(f, n);
  for (int len = 2; len <= n; len <<= 1) {
    cd omega = exp(I * (2 * M_PI / len * rev));
    for (int j = 0; j < n; j += len) {
      cd now = 1;
      for (int k = j; k < j + len / 2; ++k) {
        cd g = f[k], h = now * f[k + len / 2];
        f[k] = g + h, f[k + len / 2] = g - h;
        now *= omega;
      }
    }
  }
  if (rev == -1)
    for (int i = 0; i < n; ++i)
      f[i] /= n;
}
int main() {
  int n, m;
  cin >> n >> m;
  int k = 1 << (32 - __builtin_clz(n + m + 1));
  for (int i = 0; i <= n; ++i)
    cin >> a[i];
  for (int j = 0; j <= m; ++j)
    cin >> b[j];
  fft(a, k, 1);
  fft(b, k, 1);
  for (int i = 0; i < k; ++i)
    a[i] *= b[i];
  fft(a, k, -1);
  for (int i = 0; i < n + m + 1; ++i)
    cout << (int)round(a[i].real()) << " ";
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
参考代码(Kotlin)
import kotlin.math.*

fun readInts(): List<Int> {
    return readLine()!!.trim().split(" ").map(String::toInt)
}

fun nearest(n: Int): Int {
    return 1.shl(ceil(log2(n.toDouble())).toInt() + 1)
}

data class Complex(val real: Double, val imag: Double) {
    operator fun plus(other: Complex): Complex {
        return Complex(real + other.real, imag + other.imag)
    }

    operator fun minus(other: Complex): Complex {
        return Complex(real - other.real, imag - other.imag)
    }

    operator fun times(other: Double): Complex {
        return Complex(real * other, imag * other)
    }

    operator fun times(other: Complex): Complex {
        return Complex(real * other.real - imag * other.imag, real * other.imag + imag * other.real)
    }
}

val rev = IntArray(nearest(1000002))

fun change(f: Array<Complex>, n: Int) {
    if (rev[1] == 0)
        for (i in 0 until n) {
            rev[i] = rev[i shr 1] shr 1
            if (i % 2 == 1)
                rev[i] = rev[i] or (n shr 1)
        }
    for (i in 0 until n)
        if (i < rev[i]) {
            val tmp = f[i]
            f[i] = f[rev[i]]
            f[rev[i]] = tmp
        }
}

fun fft(f: Array<Complex>, n: Int, rev: Int) {
    change(f, n)
    var len = 2
    while (len <= n) {
        val omega = Complex(cos(2.0 * PI / len * rev), sin(2.0 * PI / len * rev))
        var l = 0
        while (l < n) {
            var now = Complex(1.0, 0.0)
            for (i in l until l + len / 2) {
                val g = f[i]
                val h = now * f[i + len / 2]
                f[i] = g + h
                f[i + len / 2] = g - h
                now *= omega
            }
            l += len
        }
        len *= 2
    }
    if (rev == -1)
        for (i in 0 until n)
            f[i] = f[i] * (1.0 / n)
}

fun main() {
    val (n, m) = readInts()
    val k = nearest(n + m + 1)
    val a = Array(k) { Complex(0.0, 0.0) }
    for ((i, v) in readInts().withIndex())
        a[i] = Complex(v.toDouble(), 0.0)
    val b = Array(k) { Complex(0.0, 0.0) }
    for ((i, v) in readInts().withIndex())
        b[i] = Complex(v.toDouble(), 0.0)
    fft(a, k, 1)
    fft(b, k, 1)
    for (i in 0 until k)
        a[i] = a[i] * b[i]
    fft(a, k, -1)
    val ans = a.slice(0 until n + m + 1).map { round(it.real).toInt() }
    println(ans.joinToString(" "))
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

由于语言差异,Kotlin版本并不能通过洛谷的测试。

# 学习资源

# Matters Computational (opens new window)

  • 第二十一章 快速傅里叶变换

# 练习题

裸FFT并不可怕,本身FFT的码量并不算大,背一背也不是多大的事,关键是如何看出一道题目是FFT。

# SPOJ - ADAMATCH (opens new window)

如果暴力枚举子串,时间复杂度为O(r2)O(|r|^2),显然不行。如何降低复杂度呢?

提示一

首先考虑字母'A'。不妨把字符串为'A'的位置设为11,其余位置设为00。看起来似乎可以进行多项式乘法,但乘法的结果似乎没有明显的意义。

提示二

如果把r串逆序呢?看看此时乘积的每一项有怎样的含义。

参考代码(C++)
#include <cmath>
#include <complex>
#include <cstring>
#include <iostream>
#include <vector>
#define MAXN (1 << 22)

using namespace std;
typedef complex<double> cd;
const cd I{0, 1};
cd a[MAXN], b[MAXN];
void change(cd *f, int n) {
  for (int i = 1, j = n / 2; i < n - 1; i++) {
    if (i < j)
      swap(f[i], f[j]);
    int k = n / 2;
    while (j >= k) {
      j = j - k;
      k = k / 2;
    }
    if (j < k)
      j += k;
  }
}
void fft(cd *f, int n, int rev) {
  change(f, n);
  for (int len = 2; len <= n; len <<= 1) {
    cd omega = exp(I * (2 * M_PI / len * rev));
    for (int j = 0; j < n; j += len) {
      cd now = 1;
      for (int k = j; k < j + len / 2; ++k) {
        cd g = f[k], h = now * f[k + len / 2];
        f[k] = g + h, f[k + len / 2] = g - h;
        now *= omega;
      }
    }
  }
  if (rev == -1)
    for (int i = 0; i < n; ++i)
      f[i] /= n;
}
int main() {
  string s, r;
  cin >> s >> r;
  int n = s.size(), m = r.size();
  int k = 1 << (32 - __builtin_clz(n + m + 1));
  vector<int> cnt(k);
  for (char c : "ACGT") {
    memset(a, 0, sizeof(a));
    memset(b, 0, sizeof(b));
    for (int i = 0; i < n; ++i)
      a[i] = s[i] == c;
    for (int i = 0; i < m; ++i)
      b[i] = r[m - i - 1] == c;
    fft(a, k, 1);
    fft(b, k, 1);
    for (int i = 0; i < k; ++i)
      a[i] *= b[i];
    fft(a, k, -1);
    for (int i = 0; i < k; ++i)
      cnt[i] += (int)round(a[i].real());
  }
  int ans = m;
  for (int i = m - 1; i < n; ++i)
    ans = min(ans, m - cnt[i]);
  cout << ans;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

# SPOJ - TSUM (opens new window)

如果暴力枚举,时间复杂度为O(n3)O(n^3),显然不行。如何降低复杂度呢?

提示一

加法可以变为多项式的乘法。

提示二

如何去除包含重复元素的项?

参考代码(C++)
#include <cmath>
#include <complex>
#include <iostream>
#include <vector>

#define MAXN 131072
#define OFFSET 20000

using namespace std;
typedef complex<double> cd;
const cd I{0, 1};

void change(vector<cd> &f, int n) {
  for (int i = 1, j = n / 2; i < n - 1; i++) {
    if (i < j)
      swap(f[i], f[j]);
    int k = n / 2;
    while (j >= k) {
      j = j - k;
      k = k / 2;
    }
    if (j < k)
      j += k;
  }
}

void fft(vector<cd> &f, int n, int rev) {
  change(f, n);
  for (int len = 2; len <= n; len <<= 1) {
    cd omega = exp(I * (2 * M_PI / len * rev));
    for (int j = 0; j < n; j += len) {
      cd now = 1;
      for (int k = j; k < j + len / 2; ++k) {
        cd g = f[k], h = now * f[k + len / 2];
        f[k] = g + h, f[k + len / 2] = g - h;
        now *= omega;
      }
    }
  }
  if (rev == -1)
    for (int i = 0; i < n; ++i)
      f[i] /= n;
}

int main() {
  int n;
  cin >> n;
  vector<cd> a(MAXN), a2(MAXN);
  vector<int> a3(MAXN);
  for (int i = 0; i < n; ++i) {
    int m;
    cin >> m;
    a[m + OFFSET] = cd{1, 0};
    a2[(m + OFFSET) << 1] = cd{1, 0};
    a3[(m + OFFSET) * 3] = 1;
  }
  vector<cd> tot(a), b(a);
  fft(tot, MAXN, 1);
  fft(b, MAXN, 1);
  fft(a2, MAXN, 1);
  for (int i = 0; i < MAXN; ++i)
    tot[i] *= b[i] * b[i], a2[i] *= b[i];
  fft(tot, MAXN, -1);
  fft(a2, MAXN, -1);
  for (int i = 0; i < MAXN; ++i) {
    int cnt1 = round(tot[i].real()); // ABC, with permutation
    int cnt2 = round(a2[i].real());  // AAB, no permutation
    int cnt3 = a3[i];                // AAA
    int cnt = (cnt1 - cnt2 * 3 + cnt3 * 2) / 6;
    if (cnt > 0)
      cout << i - OFFSET * 3 << " : " << cnt << endl;
  }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73