FFT 快速傅里叶变换 学习笔记

发布于

概述

快速傅里叶变换(Fast Fourier Transform, FFT)可以在 $\Theta\left(n \log n\right)$ 的时间内计算多项式乘法。

但 FFT 的用处远不止于此。

多项式

一个多项式 $F\left(x\right) = a_0 + a_1x + a_2x^2 + \cdots$。

其中,$a_i$ 被称为 $F\left(x\right)$ 的 $i$ 次项系数。

多项式的点值表示法

$n + 1$ 个点可以唯一确定一个 $n$ 次多项式。可以使用 待定系数法 证明。

在某些情况下,使用点值表示法的多项式可以在线性复杂度内计算乘积。

给定两个 $n$ 次 $F\left(x\right)$ 和 $G\left(x\right)$:

他们的乘积 $\left(F * G\right)\left(x\right) = F\left(x\right) G\left(x\right)$。

因此,如果我们有 $k$ 个点 $x_1, x_2, \ldots, x_k$,并且知道 $F\left(x_1\right), F\left(x_2\right), \ldots, F\left(x_k\right)$ 和 $G\left(x_1\right), G\left(x_2\right), \ldots, G\left(x_k\right)$,那么我们就可以得到:

$$\left(F * G\right)\left(x_i\right) = F\left(x_i\right)G\left(x_i\right)$$

注意:因为 $\left(F * G\right)\left(x\right)$ 是一个 $2n$ 次的多项式,所以需要 $k = 2n + 1$ 个点才能确定。

把 系数表达 转换为 点值表达 称为 DFT(离散傅里叶变换)。

离散傅里叶变换

把 点值表达 转换为 系数表达 称为 IDFT(离散傅里叶逆变换)。

快速傅里叶变换 把 离散傅里叶变换 和 离散傅里叶逆变换 加速到了 $\mathcal O\left(n \log n\right)$。

虚数

虚数单位 $i$ 的定义为 $i = \sqrt{-1}$,即 $i^2 = -1$。

复数

复数定义为 $a + bi$。前项称为该复数的实部,后项称为该复数的虚部

几何意义

实数的几何意义可以看作是数轴上的一个点。

而复数则可以看作平面上的一个点。我们把这个平面称为复平面。它的 $x$ 轴表示复数的实部,而 $y$ 轴表示复数的虚部。

具体来说:一个复数 $x = a + bi$ 可以看作 $\left(a, b\right)$ 这个点。

这个点也可以使用极坐标系来定义,我们把这个点到原点的距离(极径)叫做模长,写作 $\lvert x \rvert$。把原点到该点的射线与实轴正方向射线组成的角(逆时针旋转)的角度(极角)称为这个复数的幅角,写作 $\arg\left(x\right)$。

图炸了,刷新试试吧

加法

实部 和 虚部 分别相加即可。

$\left(a + bi\right) + \left(c + di\right) = \left(a + c\right) + \left(b + d\right)i$

减法

实部 和 虚部 分别相加即可。

$\left(a + bi\right) - \left(c + di\right) = \left(a - c\right) + \left(b - d\right)i$

乘法

直接乘,然后暴力化简。

$\left(a + bi\right)\left(c + di\right) = ac + adi + bci + bdi^2 = ac + adi + bci - bd = \left(ac - bd\right) + \left(ad + bc\right)i$

几何意义:复数相乘时,模长相乘,幅角相加!

共轭

$\left(a + bi\right)\left(a - bi\right) = a^2 + b^2$

除法

直接乘,然后乱搞化简。

$\dfrac{a + bi}{c + di} = \dfrac{\left(a + bi\right)\left(c - di\right)}{\left(c + di\right)\left(c - di\right)} = \dfrac{\left(ac + bd\right) + \left(bc - ad\right)i}{c^2 + d^2} = \dfrac{ac + bd}{c^2 + d^2} + \dfrac{bc - ad}{c^2 + d^2}i$

单位根

定义

$n$ 次单位根,即方程 $x^n = 1$ 的复数解。

首先,单位圆定义为一个以原点为圆心,半径为 $1$ 的圆:

在这个单位圆上的点表示的复数的模长都是 $1$。

如果一个复数 $x$ 的模长大于 $1$,那么 $\lvert x^n \rvert = \lvert x \rvert^n > 1$。

如果一个复数 $x$ 的模长小于 $1$,那么 $\lvert x^n \rvert = \lvert x \rvert^n < 1$。

所以只有模长等于 $1$ 的复数才有可能成为 $n$ 次单位根。

在圆上的复数的模长都是 $1$,接下来考虑幅角。

容易找到幅角为 $0, \dfrac{2\pi}n, \dfrac{4\pi}n, \ldots \dfrac{2\pi\left(n - 1\right)}n$ 的复数,就是单位根。而根据代数基本定理:$n$ 次方程在复数域内有且只有 $n$ 个根。

$n$ 次单位根也可以定义为 $e^\dfrac{2\pi i}{n}$。

容易知道这玩意的 $n$ 次方幅角是圆周的若干倍,那么就和 $1$ 重合。

比如,$3$ 次单位根就长这样:

当然,单位根也有符号。第 $i$ 个 $n$ 次单位根:$\omega_n^i$

性质

  1. $\omega_n^0 = 1$
  2. $\omega_n^k \cdot \omega_n^j = \omega_n^{k + j}$:感性理解就是在单位圆上转。也可以套公式获得严格证明。
  3. $\omega_{2n}^{2k} = \omega_n^k$:把圆分成 $2n$ 份取 $2k$ 份跟把圆分成 $n$ 份取 $k$ 份是一样的。
  4. 如果 $n$ 是偶数,$\omega_n^{k + n / 2}=-\omega_n^k$:相当于转了半个圆,那就走到了跟当前位置完全相反的那个点上。

快速傅里叶变换

现在有多项式 $F\left(x\right) = a_0 + a_1x + a_2x^2 + a_3x^3 + \cdots$,我们想加速离散傅里叶变换。

首先,我们要做的,是把 $F\left(x\right)$ 的奇数次项和偶数次项分开考虑。

设 $G\left(x\right) = a_0 + a_2x + a_4x^2 + \cdots$,$H\left(x\right) = a_1 + a_3x + a_5x^2 + \cdots$。这里保证 $F\left(x\right)$ 的项数是 $2$ 的整数次幂。不会出现分不匀的情况。

现在,$F\left(x\right)$ 可以写作 $G\left(x^2\right) + xH\left(x^2\right)$。

把 $\omega_n^k$ 代入 $F\left(x\right)$:

$$ \begin{aligned} F\left(\omega_n^k\right) &= G\left(\left(\omega_n^k\right)^2\right) + \omega_n^k H\left(\left(\omega_n^k\right)^2\right)\\ &= G\left(\omega_{n / 2}^k\right) + \omega_n^k H\left(\omega_{n / 2}^k\right) \end{aligned} $$

而把 $\omega_n^{k + n / 2}$ 代入 $F\left(x\right)$:

$$ \begin{aligned} F\left(\omega_n^{k + n / 2}\right) &= G\left(\left(\omega_n^{k + n / 2}\right)^2\right) + \omega_n^{k + n / 2} H\left(\left(\omega_n^{k + n / 2}\right)^2\right)\\ &= G\left(\omega_n^{2k + n}\right) + \omega_n^{k + n / 2} H\left(\omega_n^{2k + n}\right)\\ &= G\left(\omega_n^{2k}\right) + \omega_n^{k + n / 2} H\left(\omega_n^{2k}\right)\\ &= G\left(\omega_{n / 2}^k\right) + \omega_n^{k + n / 2} H\left(\omega_{n / 2}^k\right)\\ &= G\left(\omega_{n / 2}^k\right) - \omega_n^k H\left(\omega_{n / 2}^k\right) \end{aligned} $$

把两个式子放在一起看看:

$$ \begin{aligned} F\left(\omega_n^k\right) &= G\left(\omega_{n / 2}^k\right) + \omega_n^k H\left(\omega_{n / 2}^k\right) \\ F\left(\omega_n^{k + n / 2}\right) &= G\left(\omega_{n / 2}^k\right) - \omega_n^k H\left(\omega_{n / 2}^k\right) \end{aligned} $$

就只有一个正负号的区别!

如果我们知道两个多项式 $G\left(x\right)$ 和 $H\left(x\right)$ 分别在 $\omega^0_{n/2}, \omega^1_{n / 2}, \omega^2_{n / 2}, \ldots, \omega^{n / 2 - 1}_{n / 2}$ 的点值表示,

根据 $F\left(\omega_n^k\right) = G\left(\omega_{n / 2}^k\right) + \omega_n^k H\left(\omega_{n / 2}^k\right)$,我们可以在线性复杂度内求出 $F\left(x\right)$ 在 $\omega^0_n, \omega^1_n, \omega^2_n, \ldots, \omega^{n / 2 - 1}_n$ 处的点值表示。

根据 $F\left(\omega_n^{k + n / 2}\right) = G\left(\omega_{n / 2}^k\right) - \omega_n^k H\left(\omega_{n / 2}^k\right)$ 我们可以在线性复杂度内求出 $F\left(x\right)$ 在 $\omega^{n / 2}_n, \omega^{n / 2 + 1}_n, \omega^{n / 2 + 2}_n, \ldots, \omega^{n - 1}_n$ 处的点值表示。

所以如果我们知道两个多项式 $G\left(x\right)$ 和 $H\left(x\right)$ 分别在 $\omega^0_{n/2}, \omega^1_{n / 2}, \omega^2_{n / 2}, \ldots, \omega^{n / 2 - 1}_{n / 2}$ 的点值表示,我们就可以在线性复杂度内求出 $F\left(x\right)$ 在 $\omega^0_n, \omega^1_n, \omega^2_n, \ldots, \omega^{n - 1}_n$ 处的点值表示。

由于计算 $G\left(x\right)$ 和 $H\left(x\right)$ 跟求 $F\left(x\right)$ 的流程是同样的,所以我们可以递归去求。

得到复杂度的递归式为 $T\left(n\right) = 2T\left(n / 2\right) + \Theta\left(n\right)$,根据主定理,$T\left(n\right) = \Theta\left(n \log n\right)$。

我们就在 $\Theta\left(n \log n\right)$ 的复杂度内成功获得了 $F\left(x\right)$ 的点值表示。

实现

首先,我们要保证 $F\left(x\right)$ 的项数是 $2$ 的整数次幂。因此要在后面添加系数为 $0$ 的项。

然后就是求单位根的事了。我们知道 $\omega_n^k = \left(\omega_n^1\right)^k$,所以只需要知道 $\omega_n^1$ 即可。

通过一些三角函数知识可以知道:

$$\omega_n^1 = \cos\left(\frac{2\pi}n\right) + \sin\left(\frac{2\pi}n\right)i$$

struct Complex {
  double r, i;
  
  Complex() = default;
  Complex(double x, double y) : r(x), i(y) {}
  
  Complex operator+(Complex const& b) const { return Complex(r + b.r, i + b.i); }
  Complex operator-(Complex const& b) const { return Complex(r - b.r, i - b.i); }
  Complex operator*(Complex const& b) const { return Complex(r * b.r - i * b.i, r * b.i + i * b.r); }
  Complex operator/(Complex const& b) const {
    double t = b.r * b.r + b.i * b.i;
    return Complex((r * b.r + i * b.i) / t, (i * b.r - r * b.i) / t);
  }
  
  Complex operator+=(Complex const& b) { return *this = *this + b; }
  Complex operator-=(Complex const& b) { return *this = *this - b; }
  Complex operator*=(Complex const& b) { return *this = *this * b; }
  Complex operator/=(Complex const& b) { return *this = *this / b; }
};
struct Complex{double r,i;Complex()=default;Complex(double x,double y):r(x),i(y){}Complex operator+(Complex const&b)const{return Complex(r+b.r,i+b.i);}Complex operator-(Complex const&b)const{return Complex(r-b.r,i-b.i);}Complex operator*(Complex const&b)const{return Complex(r*b.r-i*b.i,r*b.i+i*b.r);}Complex operator/(Complex const&b)const{double t=b.r*b.r+b.i*b.i;return Complex((r*b.r+i*b.i)/t,(i*b.r-r*b.i)/t);}Complex operator+=(Complex const&b){return*this=*this+b;}Complex operator-=(Complex const&b){return*this=*this-b;}Complex operator*=(Complex const&b){return*this=*this*b;}Complex operator/=(Complex const&b){return*this=*this/b;}};
vector<Complex> w(n); // omega
w[0] = Complex(1, 0);
w[1] = Complex(cos(2 * acos(-1) / n), sin(2 * acos(-1) / n));
for (tp i = 2; i < n; ++i) w[i] = w[i - 1] * w[1];
for (tp i = 0; i < n; ++i) bin << i << ' ' << w[i].r << ' ' << w[i].i << '\n';
vector<Complex> f, tmp;

void FFT(vector<Complex>::iterator f, tp len) {
  if (len == 1) return;
  vector<Complex>::iterator g = f, h = f + len / 2;
  for (tp k = 0; k < len; ++k) tmp[k] = f[k];
  for (tp k = 0; k < len / 2; ++k) { g[k] = tmp[k * 2]; h[k] = tmp[k * 2 + 1];}
  FFT(g, len / 2);
  FFT(h, len / 2);
  Complex w(cos(2 * acos(-1) / len), sin(2 * acos(-1) / len)), p(1, 0);
  for (tp k = 0; k < len / 2; ++k) {
    tmp[k] = g[k] + p * h[k];
    tmp[k + len / 2] = g[k] - p * h[k];
    p *= w;
  }
  for (tp k = 0; k < len; ++k) f[k] = tmp[k];
}

void STRUGGLING([[maybe_unused]] unsigned TEST_NUMBER) {
  tp n, m = 1; bin >> n;
  while (m < n) m *= 2;
  f = tmp = vector<Complex>(m, Complex());
  for (tp i = 0; i < n; ++i) bin >> f[i].r;
  FFT(f.begin(), m);
  for (tp i = 0; i < m; ++i) bin << f[i].r << ' ' << f[i].i << '\n';
}

这样,我们就写完了 FFT。

快速傅里叶逆变换

这里直接给出结论:

把 FFT 中的 $\omega_n^1$ 换成 $\omega_n^{-1}$,做完之后除以 $n$ 即可。

证明平凡。

void IFFT(vector<Complex>::iterator f, tp len) {
  if (len == 1) return;
  vector<Complex>::iterator g = f, h = f + len / 2;
  for (tp k = 0; k < len; ++k) tmp[k] = f[k];
  for (tp k = 0; k < len / 2; ++k) { g[k] = tmp[k * 2]; h[k] = tmp[k * 2 + 1];}
  IFFT(g, len / 2);
  IFFT(h, len / 2);
  Complex w(cos(2 * acos(-1) / len), -sin(2 * acos(-1) / len)), p(1, 0);
  for (tp k = 0; k < len / 2; ++k) {
    tmp[k] = g[k] + p * h[k];
    tmp[k + len / 2] = g[k] - p * h[k];
    p *= w;
  }
  for (tp k = 0; k < len; ++k) f[k] = tmp[k];
}

多项式乘法

vector<Complex> f, tmp;

void FFT(vector<Complex>::iterator f, tp len, tp type) {
  if (len == 1) return;
  vector<Complex>::iterator g = f, h = f + len / 2;
  for (tp k = 0; k < len; ++k) tmp[k] = f[k];
  for (tp k = 0; k < len / 2; ++k) { g[k] = tmp[k * 2]; h[k] = tmp[k * 2 + 1];}
  FFT(g, len / 2, type);
  FFT(h, len / 2, type);
  Complex w(cos(2 * acos(-1) / len), type * sin(2 * acos(-1) / len)), p(1, 0);
  for (tp k = 0; k < len / 2; ++k) {
    tmp[k] = g[k] + p * h[k];
    tmp[k + len / 2] = g[k] - p * h[k];
    p *= w;
  }
  for (tp k = 0; k < len; ++k) f[k] = tmp[k];
}

void STRUGGLING([[maybe_unused]] unsigned TEST_NUMBER) {
  tp n, m, k = 1; bin >> n >> m;
  while (k <= n + m) k *= 2;
  f = g = tmp = vector<Complex>(k, Complex());
  for (tp i = 0; i <= n; ++i) bin >> f[i].r;
  for (tp i = 0; i <= m; ++i) bin >> g[i].r;
  FFT(f.begin(), k, 1); FFT(g.begin(), k, 1);
  for (tp i = 0; i < k; ++i) f[i] *= g[i];
  FFT(f.begin(), k, -1);
  for (tp i = 0; i <= n + m; ++i) bin << tp(f[i].r / k + 0.5) << ' ';
  bin << '\n';
}

根据这个方法,我们可以得到一个 $\frac23$ 常数的 FFT:

vector<Complex> f, tmp;

void FFT(vector<Complex>::iterator f, tp len, tp type) {
  if (len == 1) return;
  vector<Complex>::iterator g = f, h = f + len / 2;
  for (tp k = 0; k < len; ++k) tmp[k] = f[k];
  for (tp k = 0; k < len / 2; ++k) { g[k] = tmp[k * 2]; h[k] = tmp[k * 2 + 1];}
  FFT(g, len / 2, type);
  FFT(h, len / 2, type);
  Complex w(cos(2 * acos(-1) / len), type * sin(2 * acos(-1) / len)), p(1, 0);
  for (tp k = 0; k < len / 2; ++k) {
    tmp[k] = g[k] + p * h[k];
    tmp[k + len / 2] = g[k] - p * h[k];
    p *= w;
  }
  for (tp k = 0; k < len; ++k) f[k] = tmp[k];
}

void STRUGGLING([[maybe_unused]] unsigned TEST_NUMBER) {
  tp n, m, k = 1; bin >> n >> m;
  while (k <= n + m) k *= 2;
  f = tmp = vector<Complex>(k, Complex());
  for (tp i = 0; i <= n; ++i) bin >> f[i].r;
  for (tp i = 0; i <= m; ++i) bin >> f[i].i;
  FFT(f.begin(), k, 1);
  for (tp i = 0; i < k; ++i) f[i] *= f[i];
  FFT(f.begin(), k, -1);
  for (tp i = 0; i <= n + m; ++i) bin << tp(f[i].i / 2 / k + 0.5) << ' ';
  bin << '\n';
}

再使用蝴蝶变换:

void FFT(vector<Complex>::iterator f, tp len, tp type) {
  if (len == 1) return;
  FFT(f, len / 2, type);
  FFT(f + len / 2, len / 2, type);
  Complex w(cos(2 * acos(-1) / len), type * sin(2 * acos(-1) / len)), p(1, 0);
  for (tp k = 0; k < len / 2; ++k) {
    Complex t = p * f[k + len / 2];
    f[k + len / 2] = f[k] - t;
    f[k] += t;
    p *= w;
  }
}

void STRUGGLING([[maybe_unused]] unsigned TEST_NUMBER) {
  tp n, m, k = 1; bin >> n >> m;
  while (k <= n + m) k *= 2;
  vector<Complex> f(k, Complex());
  vector<tp> bf(k, 0);
  for (tp i = 0; i <= n; ++i) bin >> f[i].r;
  for (tp i = 0; i <= m; ++i) bin >> f[i].i;
  for (tp i = 0; i < k; ++i) bf[i] = (bf[i / 2] / 2) | (i % 2 * k / 2);
  for (tp i = 0; i < k; ++i) {
    if (i < bf[i]) swap(f[i], f[bf[i]]);
  }
  FFT(f.begin(), k, 1);
  for (tp i = 0; i < k; ++i) f[i] *= f[i];
  for (tp i = 0; i < k; ++i) {
    if (i < bf[i]) swap(f[i], f[bf[i]]);
  }
  FFT(f.begin(), k, -1);
  for (tp i = 0; i <= n + m; ++i) bin << tp(f[i].i / 2 / k + 0.5) << ' ';
  bin << '\n';
}

写成非递归形式:

void FFT(vector<Complex>::iterator f, tp n, tp type) {
  double pi2 = acos(-1) * 2;
  vector<tp> bf(n, 0);
  for (tp i = 0; i < n; ++i) bf[i] = (bf[i / 2] / 2) | (i % 2 * n / 2);
  for (tp i = 0; i < n; ++i) {
    if (i < bf[i]) swap(f[i], f[bf[i]]);
  }
  for (tp p = 2; p <= n; p *= 2) {
    tp m = p / 2;
    Complex w(cos(pi2 / p), type * sin(pi2 / p));
    for (tp k = 0; k < n; k += p) {
      Complex g(1, 0);
      for (tp l = k; l < k + m; ++l) {
        Complex t = g * f[m + l];
        f[m + l] = f[l] - t;
        f[l] += t;
        g *= w;
      }
    }
  }
  if (type == 1) return;
  for (tp i = 0; i < n; ++i) f[i].r /= n;
  for (tp i = 0; i < n; ++i) f[i].i /= n;
}

void STRUGGLING([[maybe_unused]] unsigned TEST_NUMBER) {
  tp n, m, k = 1; bin >> n >> m;
  while (k <= n + m) k *= 2;
  vector<Complex> f(k, Complex());
  for (tp i = 0; i <= n; ++i) bin >> f[i].r;
  for (tp i = 0; i <= m; ++i) bin >> f[i].i;
  FFT(f.begin(), k, 1);
  for (tp i = 0; i < k; ++i) f[i] *= f[i];
  FFT(f.begin(), k, -1);
  for (tp i = 0; i <= n + m; ++i) bin << tp(f[i].i / 2 + 0.5) << ' ';
  bin << '\n';
}

施工中。。。


暂无评论

发表评论