HDU5628 Clarke and math 题解 狄利克雷卷积+快速幂
题目链接:https://acm.hdu.edu.cn/showproblem.php?pid=5628
题目大意:
输入 \(n\) 和 \(k\),对于每个 \(i(1 \le i \le n)\),求
\[f(i) = \sum_{i_1 \mid i} \sum_{i_2 \mid i_1} \sum_{i_3 \mid i_2} \cdots \sum_{i_k \mid i_{k-1}} f_{i_k}
\]
答案对 \(10^9 + 7\) 取模。
解题思路:
数论函数 \(f(n)\) 和 \(g(n)\) 的 狄利克雷卷积 \(f \ast g\) 定义为
\[(f \ast g)(n) = \sum_{d \mid n} f(d) g(\frac{n}{d}) = \sum_{de = n} f(d) g(e)
\]
定义一个单位函数 \(I(n) = 1\)。有
\[f(i) = \sum_{i_1 \mid i} \sum_{i_2 \mid i_1} \sum_{i_3 \mid i_2} \cdots \sum_{i_{k-1} \mid i_{k-2}} \left( \sum_{i_k \mid i_{k-1}} f_{i_k} I(\frac{i_{k-1}}{i_k}) \right)
\]
\[= \sum_{i_1 \mid i} \sum_{i_2 \mid i_1} \sum_{i_3 \mid i_2} \cdots \sum_{i_{k-1} \mid i_{k-2}} (f \ast I)(i_{k-1})
\]
\[= \sum_{i_1 \mid i} \sum_{i_2 \mid i_1} \sum_{i_3 \mid i_2} \cdots \sum_{i_{k-3} \mid i_{k-2}} \sum_{i_{k-1} \mid i_{k-2}} (f \ast I)(i_{k-1})
\]
\[= \sum_{i_1 \mid i} \sum_{i_2 \mid i_1} \sum_{i_3 \mid i_2} \cdots \sum_{i_{k-3} \mid i_{k-2}} \left( \sum_{i_{k-1} \mid i_{k-2}} (f \ast I)(i_{k-1}) I(\frac{i_{k-2}}{i_{k-1}}) \right)
\]
\[= \sum_{i_1 \mid i} \sum_{i_2 \mid i_1} \sum_{i_3 \mid i_2} \cdots \sum_{i_{k-3} \mid i_{k-2}} \left( \sum_{i_{k-1} \mid i_{k-2}} (f \ast I \ast I)(i_{k-2}) \right)
\]
\[= ...
\]
\[= (f \ast I \ast \cdot \ast I)(i)
\]
\[= (f \ast I^k)(i)
\]
其中,\(I^k\) 表示 \(k\) 个 \(I\) 做狄利克雷卷积的结果。\(\rightarrow\) 我们可以用 快速幂 的方式优化这个卷积。
然后再拿 \(f\) 和 \(I^k\) 做一次卷积就可以了。
示例程序:
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const long long mod = 1e9 + 7;
int T, n, k;
vector<ll> dc(vector<ll> a, vector<ll> b, int n) {
vector<ll> c(n+1, 0);
for (int i = 1; i <= n; i++) {
for (int j = 1; i * j <= n; j++) {
(c[i * j] += a[i] * b[j]) %= mod;
}
}
return c;
}
vector<ll> dpow(vector<ll> a, int b, int n) {
vector<ll> res(n+1, 0);
res[1] = 1;
for (; b; b >>= 1, a = dc(a, a, n)) {
if (b & 1)
res = dc(res, a, n);
}
return res;
}
int main() {
scanf("%d", &T);
while (T--) {
scanf("%d%d", &n, &k);
vector<ll> f(n+1, 0);
for (int i = 1; i <= n; i++)
scanf("%lld", &f[i]);
vector<ll> I(n+1, 1);
I[0] = 0;
f = dc(f, dpow(I, k, n), n);
for (int i = 1; i <= n; i++) {
if (i > 1) putchar(' ');
printf("%lld", f[i]);
}
putchar('\n');
}
return 0;
}
浙公网安备 33010602011771号