题意
给定 $n$ 和 $k$,生成一个长度为 $n$ 的数组 $a_i = i - 1$,如果一段区间 $\left[l, r\right]$ 是好的,仅当对于任意 $l \leq i \leq r$,都满足 $a_i$ 二进制表示中为 $1$ 的位的个数不超过 $k$。
找到 $a$ 中好的区间的个数,对 $10^9 + 7$ 取模。
$1 \leq n \leq 10^{18}, 1 \leq k \leq 60$
题解
我们对于一个区间维护以下变量。
tar
表示此区间中有多少个好区间l
表示从区间左边数起,有多少个数的二进制表示中为 $1$ 的位的个数不超过 $k$。r
表示从区间右边数起,有多少个数的二进制表示中为 $1$ 的位的个数不超过 $k$。len
表示区间长度
这样,我们可以合并两个区间。l
,r
和 len
的维护是简单的,tar
需要把两边相加,再加上跨两个区间的好区间个数 l.r * r.l
。
struct foo {
tp l, r, tar, len;
foo() = default;
foo(tp l, tp r, tp t, tp len) : l(l), r(r), tar(t), len(len) {}
foo operator+(foo o) {
if (len == 0) return o;
if (o.len == 0) return *this;
tp bar = l, baz = o.r;
if (l == len) bar += o.l;
if (o.r == o.len) baz += r;
return foo(bar, baz, (tar + o.tar + r % mod * (o.l % mod)) % mod, len + o.len);
}
};
这样我们可以计算出 $n$ 为二的幂时的答案。
如果 $n$ 不是二的幂,我们断言可以用 $\mathcal O\left(\log n\right)$ 个长度为二的幂的区间拼出来。
考虑当前 $a$ 数组的二进制中最高有效位,一定是
$$0, 0, \ldots, 0, 1, 1, \ldots 1$$
,其中 $0$ 的个数为二的幂,且长度至少有一半。我们再考虑右边的 $1$。忽略最高位的 $1$,然后将 $k$ 减一,就是一个完全相同的子问题。
这样我们就用一个区间,将 $n$ 除以了 $2$。
最终时间复杂度为 $\mathcal O\left(k\log n\right)$。
代码
/* Please submit with C++17! It's best to use C++20 or higher version.
* No header file and no RBLIB (https://git.rbtr.ee/root/Template).
* By Koicy (https://koicy.ly)
* Email n@rbtr.ee
* I've reached the end of my fantasy.
__ __ __ _
_____/ /_ / /_________ ___ / /______ (_)______ __
/ ___/ __ \/ __/ ___/ _ \/ _ \ ______ / //_/ __ \/ / ___/ / / /
/ / / /_/ / /_/ / / __/ __/ /_____/ / ,< / /_/ / / /__/ /_/ /
/_/ /_.___/\__/_/ \___/\___/ /_/|_|\____/_/\___/\__, /
SIGN /___*/
#ifndef XCODE
constexpr bool _CONSOLE = false;
#else
constexpr bool _CONSOLE = true;
#endif
#define __NO_MAIN__ false
#define __ENABLE_RBLIB__ true
constexpr bool _MTS = true, SPC_MTS = false;
constexpr char EFILE[] = "";
#define FULL(arg) arg.begin(), arg.end()
#define dor(i, s, e) for (tp i = s, $##i =(s)<(e)?1:-1,$e##i=e;i!=$e##i;i+=$##i)
#define gor(i, s,e)for(tp i=s,$##i=(s)<(e)?1:-1,$e##i=(e)+$##i;i!=$e##i;i+=$##i)
// :/
signed STRUGGLING([[maybe_unused]] unsigned long TEST_NUMBER) {
constexpr tp mod = 1e9 + 7;
struct foo {
tp l, r, tar, len;
foo() = default;
foo(tp l, tp r, tp t, tp len) : l(l), r(r), tar(t), len(len) {}
foo operator+(foo o) {
if (len == 0) return o;
if (o.len == 0) return *this;
tp bar = l, baz = o.r;
if (l == len) bar += o.l;
if (o.r == o.len) baz += r;
return foo(bar, baz, (tar + o.tar + r % mod * (o.l % mod)) % mod, len + o.len);
}
};
tp n, k; bin >> n >> k;
tp lg = 1;
while ((ONE << lg) < n) ++lg;
vector f(lg + 1, vector<foo>(k + 1));
gor(i, 0, k) f[0][i] = foo(1, 1, 1, 1);
gor(i, 1, lg) f[i][0] = foo(1, 0, 1, ONE << i);
gor(i, 1, lg) {
gor(j, 1, k) f[i][j] = f[i - 1][j] + f[i - 1][j - 1];
}
foo tar = foo(0, 0, 0, 0);
gor(i, lg, 0) {
if (n >> i & 1 && k >= 0) tar = tar + f[i][k--];
}
bin << tar.tar << '\n';
return 0;
}
void MIST() {
}
// :\ */