矩阵加速 学习笔记

发布于

概述

矩阵乘法一般可以用来优化 DP。

公式

$n \times m$ 的矩阵 $a$ 与 $p \times q$ 的矩阵 $b$ 相乘,要求 $m = p$。乘出来得到的结果是一个 $n \times q$ 的矩阵 $c$。

$$c_{i, j} = \sum\limits_{i = 0}^m a_{i, k} b_{k, j}$$

送大家一个封装好的矩阵模版(应该用不到矩阵除吧):

template <unsigned long long Mod>
struct Matrix : vector<vector<long long>> {
  Matrix() = default;
  Matrix(size_t n, size_t m) {
    resize(n);
    for (size_t i = 0; i < n; ++i) {
      for (size_t j = 0; j < m; ++j) operator[](i).push_back(0);
    }
  }

  Matrix operator*(const Matrix& b) {
    Matrix c(size(), b[0].size());
    for (size_t i = 0; i < size(); ++i) {
      for (size_t j = 0; j < b[0].size(); ++j) {
        for (size_t k = 0; k < b.size(); ++k)
          c[i][j] = (c[i][j] + operator[](i)[k] * b[k][j]) % Mod;
      }
    }
    return c;
  }

  Matrix operator+(const Matrix& b) {
    Matrix c(size(), b[0].size());
    for (size_t i = 0; i < size(); ++i) {
      for (size_t j = 0; j < b[0].size(); ++j) c[i][j] = operator[](i)[j] + b[i][j];
    }
    return c;
  }

  Matrix operator-(const Matrix& b) {
    Matrix c(size(), b[0].size());
    for (size_t i = 0; i < size(); ++i) {
      for (size_t j = 0; j < b[0].size(); ++j) c[i][j] = operator[](i)[j] - b[i][j];
    }
    return c;
  }

  Matrix qpow(unsigned long long p) {
    Matrix tar(size(), size()), self = *this;
    for (size_t i = 0; i < size(); ++i) tar[i][i] = 1;
    while (p) {
      if (p & 1) tar *= self;
      self *= self;
      p /= 2;
    }
    return tar;
  }
  
  Matrix operator*=(const Matrix& b) { return *this = *this * b; }
  Matrix operator+=(const Matrix& b) { return *this = *this + b; }
  Matrix operator-=(const Matrix& b) { return *this = *this - b; }
};

例题 1

题意

定义函数 $f\left(x\right)$:
$$f\left(x\right) = \begin{cases}1 & 1 \leq x \leq 3 \\f\left(x - 1\right) + f\left(x - 3\right) & x \geq 4 \end{cases}$$

$T$ 组数据。对于每组数据:
给定 $n$,求 $f\left(n\right)$ 对 $10^9 + 7$ 取模后的值。

$1 \leq T \leq 100, 1 \leq n \leq 2 \times 10^9$

题解

矩阵加速的特征就是乘法的次数特别多。我们考虑用矩阵加速来解决这题。

第一步:确定目标矩阵

非常显然:

$$ \begin{bmatrix} f\left(i\right) \\ f\left(i - 1\right) \\ f\left(i - 2\right) \end{bmatrix} $$

第二步:确定系数矩阵

$$ \begin{aligned} f\left(i\right) &= f\left(i - 1\right) \times 1 + f\left(i - 2\right) \times 0 + f\left(i - 3\right) \times 1 \\ f\left(i - 1\right) &= f\left(i - 1\right) \times 1 + f\left(i - 2\right) \times 0 + f\left(i - 3\right) \times 0 \\ f\left(i - 2\right) &= f\left(i - 1\right) \times 0 + f\left(i - 2\right) \times 1 + f\left(i - 3\right) \times 0 \end{aligned} $$

所以,我们的系数矩阵就确定出来了:

$$ \begin{bmatrix} 1 & 0 & 1 \\ 1 & 0 & 0 \\ 0 & 1 & 0 \end{bmatrix} $$

然后我们就可以通过矩阵快速幂进行求解。

代码

void STRUGGLING([[maybe_unused]] unsigned TEST_NUMBER) {
  tp n; bin >> n;
  Matrix<1000000007> mul(3, 3);
  if (n <= 3) { bin << "1\n"; return; }
  mul[0][0] = mul[0][2] = mul[1][0] = mul[2][1] = 1;
  bin << mul.qpow(n - 1)[0][0] << '\n';
}

例题 2

题意

$k$ 阶 Sam 数被定义为一个长度为 $k$ 的数,满足相邻两位的数字之差不超过 $2$。
求 $k$ 阶 Sam 数的个数,对 $10^9 + 7$ 取模。

$1 \leq k \leq 10^{18}$

题解

首先先写暴力:
$f\left(i, j\right)$ 表示第 $i$ 位为 $j$ 的 Sam 数个数。则:
$f\left(i, j\right) = \sum\limits_{k = \max\left\{j - 2, 0\right\}}^{\min\left\{j + 2, 9\right\}} f\left(i - 1, k\right)$

接下来考虑矩阵加速

第一步:确定目标矩阵:

$$ \begin{bmatrix} f\left(i, 0\right) \\ f\left(i, 1\right) \\ f\left(i, 2\right) \\ f\left(i, 3\right) \\ f\left(i, 4\right) \\ f\left(i, 5\right) \\ f\left(i, 6\right) \\ f\left(i, 7\right) \\ f\left(i, 8\right) \\ f\left(i, 9\right) \end{bmatrix} $$

第二步:确定系数矩阵

$$ \begin{aligned} f\left(i, 0\right) &= f\left(i - 1, 0\right) + f\left(i - 1, 1\right) + f\left(i - 1, 2\right) \\ f\left(i, 1\right) &= f\left(i - 1, 0\right) + f\left(i - 1, 1\right) + f\left(i - 1, 2\right) + f\left(i - 1, 3\right) \\ f\left(i, 2\right) &= f\left(i - 1, 0\right) + f\left(i - 1, 1\right) + f\left(i - 1, 2\right) + f\left(i - 1, 3\right) + f\left(i - 1, 4\right) \\ f\left(i, 3\right) &= f\left(i - 1, 1\right) + f\left(i - 1, 2\right) + f\left(i - 1, 3\right) + f\left(i - 1, 4\right) + f\left(i - 1, 5\right) \\ f\left(i, 4\right) &= f\left(i - 1, 2\right) + f\left(i - 1, 3\right) + f\left(i - 1, 4\right) + f\left(i - 1, 5\right) + f\left(i - 1, 6\right) \\ f\left(i, 5\right) &= f\left(i - 1, 3\right) + f\left(i - 1, 4\right) + f\left(i - 1, 5\right) + f\left(i - 1, 6\right) + f\left(i - 1, 7\right) \\ f\left(i, 6\right) &= f\left(i - 1, 4\right) + f\left(i - 1, 5\right) + f\left(i - 1, 6\right) + f\left(i - 1, 7\right) + f\left(i - 1, 8\right) \\ f\left(i, 7\right) &= f\left(i - 1, 5\right) + f\left(i - 1, 6\right) + f\left(i - 1, 7\right) + f\left(i - 1, 8\right) + f\left(i - 1, 9\right) \\ f\left(i, 8\right) &= f\left(i - 1, 6\right) + f\left(i - 1, 7\right) + f\left(i - 1, 8\right) + f\left(i - 1, 9\right) \\ f\left(i, 9\right) &= f\left(i - 1, 7\right) + f\left(i - 1, 8\right) + f\left(i - 1, 9\right) \end{aligned} $$

则得到系数矩阵为:

$$ \begin{bmatrix} 1 & 1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 \\ 1 & 1 & 1 & 1 & 1 & 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 1 & 1 & 1 & 1 & 0 & 0 & 0 & 0 \\ 0 & 0 & 1 & 1 & 1 & 1 & 1 & 0 & 0 & 0 \\ 0 & 0 & 0 & 1 & 1 & 1 & 1 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 1 & 1 & 1 & 1 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 & 1 & 1 & 1 & 1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 & 1 & 1 & 1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 1 & 1 \\ \end{bmatrix} $$

矩阵快速幂求解即可,注意特判 $n = 1$ 的情况。

代码

void STRUGGLING([[maybe_unused]] unsigned TEST_NUMBER) {
  tp n, tar = 0; bin >> n;
  Matrix<1000000007> g(10, 10), mul(10, 10);
  if (n == 1) { bin << "10\n"; return; }
  for (tp i = 0; i < 10; ++i) {
    for (tp j = max(ZERO, i - 2); j < min((tp)10, i + 3); ++j) mul[i][j] = 1;
  }
  for (tp i = 1; i < 10; ++i) g[0][i] = 1;
  mul = g * mul.qpow(n - 1);
  for (tp i = 0; i < 10; ++i) tar += mul[0][i];
  bin << tar % 1000000007 << '\n';
}

例题 3

题意

有一个 $n$ 个点 $m$ 条的无重边有向图。每条边有一个边权 $w$,表示从起点到终点需要的时间。问从 $1$ 出发 $s$ 秒刚好到达 $n$ 的方案数(每秒不可以不动)。答案对 $10^9 + 7$ 取模。

$1 \leq n \leq 8, 0 \leq m \leq 40, 0 \leq s \leq 2 \times 10^9, 1 \leq w \leq 9$

题解

我们直接考虑矩阵加速。
这题有边权,不好在矩阵中表示。我们可以设置一些虚拟的点,比如从 $u$ 到 $v$ 有一条边权为 $2$ 的边,那么我们就建 $u \to virtual \to v$。
然后直接矩阵快速幂求解即可。

代码

/*
 * Please submit with C++14! It's best to use C++20 or higher version.
 * By rbtree (https://rbtr.ee)
 * Apparition (n@rbtr.ee)
 * DO OR DIE
 */

#include <cstdio>
#include <utility>
#include <vector>
#define AS 3 +

using namespace std;
using tp = long long;

vector<vector<tp>> mul(vector<vector<tp>> a, vector<vector<tp>> b, tp mod) {
  vector<vector<tp>> c(a.size(), vector<tp>(b[0].size(), 0));
  for (tp i = 0; i < c.size(); ++i) {
    for (tp j = 0; j < c[i].size(); ++j) {
      for (tp k = 0; k < b.size(); ++k)
        c[i][j] = (c[i][j] + a[i][k] * b[k][j]) % mod;
    }
  }
  return c;
}

void qpow(vector<vector<tp>> &c, vector<vector<tp>> a, tp time, tp mod) {
  while (time) {
    if (time & 1) c = mul(c, a, mod);
    a = mul(a, a, mod);
    time >>= 1;
  }
}

signed main() {
  tp n, m, s;
  scanf("%lld%lld%lld", &n, &m, &s);
  vector<vector<tp>> a(1, vector<tp>(n * 11, 0)), c(n * 11, vector<tp>(n * 11, 0));
  a[0][0] = 1;
  while (m--) {
    tp u, v, w, lst;
    scanf("%lld%lld%lld", &u, &v, &w);
    for (tp i = 0, t = --w, _t = lst = u - 1; i < w; ++i)
      c[exchange(lst, u * 9 + i)][u * 9 + i] = 1;
    c[lst][v - 1] = 1;
  }
  qpow(a, c, s, 1e9 + 7);
  printf("%lld\n", a[0][n - 1]);
  return 0;
}

//*/

更多例题


暂无评论

发表评论