CSES 2176 Counting Bishops

Published on

Counting Bishops

Your task is to count the number of ways $k$ bishops can be placed on an $n \times n$ chessboard so that no two bishops attack each other.

Two bishops attack each other if they are on the same diagonal.

$1 \leq n \leq 500$
$1 \leq k \leq n^2$

Tutorial

If you often play chess, you will soon realize that the squares on the chessboard are divided into two colors. A bishops placed on a light square can never reach a dark square, and vice versa.

Therefore, we can compute light-square bishops and dark-square bishops separately. In the end, we simply use a convolution to get the answer.

#@#@#@#
@#@#@#@
#@#@#@#
@#@#@#@
#@#@#@#
@#@#@#@
#@#@#@#

By rotating the board 45 degrees, the diagonal structure becomes clearer:

      #
     @ @
    # # #
   @ @ @ @
  # # # # #
 @ @ @ @ @ @
# # # # # # #
 @ @ @ @ @ @
  # # # # #
   @ @ @ @
    # # #
     @ @
      #

We can separate the light and dark square patterns as follows:

   #            #    1
  ###           #    2
 #####         ###   3
#######  ==>   ###   4
 #####        #####  5
  ###         #####  6
   #         ####### 7
  
  @@            @@   1
 @@@@           @@   2
@@@@@@   ==>   @@@@  3
@@@@@@         @@@@  4
 @@@@         @@@@@@ 5
  @@          @@@@@@ 6

Let $f_{i, j}$ denote the number of ways to place $j$ bishops within the first $i$ rows.

$$f_{i, j} = f_{i - 1, j} + \left(cnt_i - j + 1\right) \cdot f_{i - 1, j - 1}$$

Where $cnt_i$ represents the number of squares in the $i$-th row.

Code

struct SNOWFLAKE {
  SNOWFLAKE(int argc, char* argv[], int TEST_NUMBER) {
    using lib::mint;
    lib::mtg.init(1e9 + 7);
    int n, k; std::cin >> n >> k;
    if (n == 1) {
      std::cout << (k == 1) << "\n";
      return;
    }
    std::vector<mint> f(n + 1);
    std::vector<mint> g(n + 1);
    f[0] = f[1] = 1;
    g[0] = 1;
    g[1] = 2;
    for (int i = 1; i < n; ++i) {
      std::vector<mint> h(n + 1);
      h[0] = 1;
      for (int j = 1; j <= n; ++j) h[j] = f[j] + f[j - 1] * (i / 2 * 2 + 2 - j);
      f = std::move(h);
    }
    for (int i = 1; i + 1 < n; ++i) {
      std::vector<mint> h(n + 1);
      h[0] = 1;
      for (int j = 1; j <= n; ++j) h[j] = g[j] + g[j - 1] * (i / 2 * 2 + 3 - j);
      g = std::move(h);
    }
    mint tar = 0;
    for (int i = 0; i <= k; ++i) {
      if (i <= n && k - i <= n) tar += f[i] * g[k - i];
    }
    std::cout << tar << "\n";
  }
};

No comments

Post a comment