概述
矩阵乘法一般可以用来优化 DP。
公式
$n \times m$ 的矩阵 $a$ 与 $p \times q$ 的矩阵 $b$ 相乘,要求 $m = p$。乘出来得到的结果是一个 $n \times q$ 的矩阵 $c$。
$$c_{i, j} = \sum\limits_{i = 0}^m a_{i, k} b_{k, j}$$
送大家一个封装好的矩阵模版(应该用不到矩阵除吧):
template <unsigned long long Mod>
struct Matrix : vector<vector<long long>> {
Matrix() = default;
Matrix(size_t n, size_t m) {
resize(n);
for (size_t i = 0; i < n; ++i) {
for (size_t j = 0; j < m; ++j) operator[](i).push_back(0);
}
}
Matrix operator*(const Matrix& b) {
Matrix c(size(), b[0].size());
for (size_t i = 0; i < size(); ++i) {
for (size_t j = 0; j < b[0].size(); ++j) {
for (size_t k = 0; k < b.size(); ++k)
c[i][j] = (c[i][j] + operator[](i)[k] * b[k][j]) % Mod;
}
}
return c;
}
Matrix operator+(const Matrix& b) {
Matrix c(size(), b[0].size());
for (size_t i = 0; i < size(); ++i) {
for (size_t j = 0; j < b[0].size(); ++j) c[i][j] = operator[](i)[j] + b[i][j];
}
return c;
}
Matrix operator-(const Matrix& b) {
Matrix c(size(), b[0].size());
for (size_t i = 0; i < size(); ++i) {
for (size_t j = 0; j < b[0].size(); ++j) c[i][j] = operator[](i)[j] - b[i][j];
}
return c;
}
Matrix qpow(unsigned long long p) {
Matrix tar(size(), size()), self = *this;
for (size_t i = 0; i < size(); ++i) tar[i][i] = 1;
while (p) {
if (p & 1) tar *= self;
self *= self;
p /= 2;
}
return tar;
}
Matrix operator*=(const Matrix& b) { return *this = *this * b; }
Matrix operator+=(const Matrix& b) { return *this = *this + b; }
Matrix operator-=(const Matrix& b) { return *this = *this - b; }
};
例题 1
题意
定义函数 $f\left(x\right)$:
$$f\left(x\right) = \begin{cases}1 & 1 \leq x \leq 3 \\f\left(x - 1\right) + f\left(x - 3\right) & x \geq 4 \end{cases}$$$T$ 组数据。对于每组数据:
给定 $n$,求 $f\left(n\right)$ 对 $10^9 + 7$ 取模后的值。$1 \leq T \leq 100, 1 \leq n \leq 2 \times 10^9$
题解
矩阵加速的特征就是乘法的次数特别多。我们考虑用矩阵加速来解决这题。
第一步:确定目标矩阵
非常显然:
$$ \begin{bmatrix} f\left(i\right) \\ f\left(i - 1\right) \\ f\left(i - 2\right) \end{bmatrix} $$
第二步:确定系数矩阵
$$ \begin{aligned} f\left(i\right) &= f\left(i - 1\right) \times 1 + f\left(i - 2\right) \times 0 + f\left(i - 3\right) \times 1 \\ f\left(i - 1\right) &= f\left(i - 1\right) \times 1 + f\left(i - 2\right) \times 0 + f\left(i - 3\right) \times 0 \\ f\left(i - 2\right) &= f\left(i - 1\right) \times 0 + f\left(i - 2\right) \times 1 + f\left(i - 3\right) \times 0 \end{aligned} $$
所以,我们的系数矩阵就确定出来了:
$$ \begin{bmatrix} 1 & 0 & 1 \\ 1 & 0 & 0 \\ 0 & 1 & 0 \end{bmatrix} $$
然后我们就可以通过矩阵快速幂进行求解。
代码
void STRUGGLING([[maybe_unused]] unsigned TEST_NUMBER) {
tp n; bin >> n;
Matrix<1000000007> mul(3, 3);
if (n <= 3) { bin << "1\n"; return; }
mul[0][0] = mul[0][2] = mul[1][0] = mul[2][1] = 1;
bin << mul.qpow(n - 1)[0][0] << '\n';
}
例题 2
题意
$k$ 阶 Sam 数被定义为一个长度为 $k$ 的数,满足相邻两位的数字之差不超过 $2$。
求 $k$ 阶 Sam 数的个数,对 $10^9 + 7$ 取模。$1 \leq k \leq 10^{18}$
题解
首先先写暴力:
$f\left(i, j\right)$ 表示第 $i$ 位为 $j$ 的 Sam 数个数。则:
$f\left(i, j\right) = \sum\limits_{k = \max\left\{j - 2, 0\right\}}^{\min\left\{j + 2, 9\right\}} f\left(i - 1, k\right)$
接下来考虑矩阵加速
第一步:确定目标矩阵:
$$ \begin{bmatrix} f\left(i, 0\right) \\ f\left(i, 1\right) \\ f\left(i, 2\right) \\ f\left(i, 3\right) \\ f\left(i, 4\right) \\ f\left(i, 5\right) \\ f\left(i, 6\right) \\ f\left(i, 7\right) \\ f\left(i, 8\right) \\ f\left(i, 9\right) \end{bmatrix} $$
第二步:确定系数矩阵
$$ \begin{aligned} f\left(i, 0\right) &= f\left(i - 1, 0\right) + f\left(i - 1, 1\right) + f\left(i - 1, 2\right) \\ f\left(i, 1\right) &= f\left(i - 1, 0\right) + f\left(i - 1, 1\right) + f\left(i - 1, 2\right) + f\left(i - 1, 3\right) \\ f\left(i, 2\right) &= f\left(i - 1, 0\right) + f\left(i - 1, 1\right) + f\left(i - 1, 2\right) + f\left(i - 1, 3\right) + f\left(i - 1, 4\right) \\ f\left(i, 3\right) &= f\left(i - 1, 1\right) + f\left(i - 1, 2\right) + f\left(i - 1, 3\right) + f\left(i - 1, 4\right) + f\left(i - 1, 5\right) \\ f\left(i, 4\right) &= f\left(i - 1, 2\right) + f\left(i - 1, 3\right) + f\left(i - 1, 4\right) + f\left(i - 1, 5\right) + f\left(i - 1, 6\right) \\ f\left(i, 5\right) &= f\left(i - 1, 3\right) + f\left(i - 1, 4\right) + f\left(i - 1, 5\right) + f\left(i - 1, 6\right) + f\left(i - 1, 7\right) \\ f\left(i, 6\right) &= f\left(i - 1, 4\right) + f\left(i - 1, 5\right) + f\left(i - 1, 6\right) + f\left(i - 1, 7\right) + f\left(i - 1, 8\right) \\ f\left(i, 7\right) &= f\left(i - 1, 5\right) + f\left(i - 1, 6\right) + f\left(i - 1, 7\right) + f\left(i - 1, 8\right) + f\left(i - 1, 9\right) \\ f\left(i, 8\right) &= f\left(i - 1, 6\right) + f\left(i - 1, 7\right) + f\left(i - 1, 8\right) + f\left(i - 1, 9\right) \\ f\left(i, 9\right) &= f\left(i - 1, 7\right) + f\left(i - 1, 8\right) + f\left(i - 1, 9\right) \end{aligned} $$
则得到系数矩阵为:
$$ \begin{bmatrix} 1 & 1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 \\ 1 & 1 & 1 & 1 & 1 & 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 1 & 1 & 1 & 1 & 0 & 0 & 0 & 0 \\ 0 & 0 & 1 & 1 & 1 & 1 & 1 & 0 & 0 & 0 \\ 0 & 0 & 0 & 1 & 1 & 1 & 1 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 1 & 1 & 1 & 1 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 & 1 & 1 & 1 & 1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 & 1 & 1 & 1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 1 & 1 \\ \end{bmatrix} $$
矩阵快速幂求解即可,注意特判 $n = 1$ 的情况。
代码
void STRUGGLING([[maybe_unused]] unsigned TEST_NUMBER) {
tp n, tar = 0; bin >> n;
Matrix<1000000007> g(10, 10), mul(10, 10);
if (n == 1) { bin << "10\n"; return; }
for (tp i = 0; i < 10; ++i) {
for (tp j = max(ZERO, i - 2); j < min((tp)10, i + 3); ++j) mul[i][j] = 1;
}
for (tp i = 1; i < 10; ++i) g[0][i] = 1;
mul = g * mul.qpow(n - 1);
for (tp i = 0; i < 10; ++i) tar += mul[0][i];
bin << tar % 1000000007 << '\n';
}
例题 3
题意
有一个 $n$ 个点 $m$ 条的无重边有向图。每条边有一个边权 $w$,表示从起点到终点需要的时间。问从 $1$ 出发 $s$ 秒刚好到达 $n$ 的方案数(每秒不可以不动)。答案对 $10^9 + 7$ 取模。
$1 \leq n \leq 8, 0 \leq m \leq 40, 0 \leq s \leq 2 \times 10^9, 1 \leq w \leq 9$
题解
我们直接考虑矩阵加速。
这题有边权,不好在矩阵中表示。我们可以设置一些虚拟的点,比如从 $u$ 到 $v$ 有一条边权为 $2$ 的边,那么我们就建 $u \to virtual \to v$。
然后直接矩阵快速幂求解即可。
代码
/*
* Please submit with C++14! It's best to use C++20 or higher version.
* By rbtree (https://rbtr.ee)
* Apparition (n@rbtr.ee)
* DO OR DIE
*/
#include <cstdio>
#include <utility>
#include <vector>
#define AS 3 +
using namespace std;
using tp = long long;
vector<vector<tp>> mul(vector<vector<tp>> a, vector<vector<tp>> b, tp mod) {
vector<vector<tp>> c(a.size(), vector<tp>(b[0].size(), 0));
for (tp i = 0; i < c.size(); ++i) {
for (tp j = 0; j < c[i].size(); ++j) {
for (tp k = 0; k < b.size(); ++k)
c[i][j] = (c[i][j] + a[i][k] * b[k][j]) % mod;
}
}
return c;
}
void qpow(vector<vector<tp>> &c, vector<vector<tp>> a, tp time, tp mod) {
while (time) {
if (time & 1) c = mul(c, a, mod);
a = mul(a, a, mod);
time >>= 1;
}
}
signed main() {
tp n, m, s;
scanf("%lld%lld%lld", &n, &m, &s);
vector<vector<tp>> a(1, vector<tp>(n * 11, 0)), c(n * 11, vector<tp>(n * 11, 0));
a[0][0] = 1;
while (m--) {
tp u, v, w, lst;
scanf("%lld%lld%lld", &u, &v, &w);
for (tp i = 0, t = --w, _t = lst = u - 1; i < w; ++i)
c[exchange(lst, u * 9 + i)][u * 9 + i] = 1;
c[lst][v - 1] = 1;
}
qpow(a, c, s, 1e9 + 7);
printf("%lld\n", a[0][n - 1]);
return 0;
}
//*/