[POJ 1845] Sumdiv【逆元】
Problem:
Time Limit: 1000MS | Memory Limit: 30000K |
Description
Input
Output
Sample Input
2 3
Sample Output
15
Hint
The natural divisors of 8 are: 1,2,4,8. Their sum is 15.
15 modulo 9901 is 15 (that should be output).
Source
Solution:
这是一道经典的数论题,题意是求 AB 的约数和 mod 9901 的值。
直接套用约数和定理即可:
- 设 A = ∏1≤i≤n pici,则
- 约数和 s(AB) = ∏1≤i≤n ∑0≤j≤Bci pij
但是等比数列求和公式 ∑0≤j≤Bci pij = (piBci+1-1) / (pi-1) 需要在模意义下计算除法。
而除法在模意义下并不封闭,所以需要引入乘法逆元来将其转化为乘法。
- 记 a 关于模 p (0 < a < p) 的乘法逆元为 a-1,则有 a · a-1 ≡ 1 (mod p)。
- 那么 (a / b) % p == a · b-1 % p,即可避免除法运算。
求逆元的方法:
- 费马小定理 + 快速幂:
- 若 p 为素数,由费马小定理得 ap-1 ≡ 1 (mod p),即 a · ap-2 ≡ 1 (mod p),则
- a-1 = ap-2
- 用快速幂求解即可,时间复杂度 O(logp)
- 扩展欧几里得:
- 将 a · a-1 ≡ 1 (mod p) 化为 a-1 · a + k · p = 1。
- 若 gcd(a, p) % p == 1,则 a 存在逆元,否则不存在
- 用 exgcd 求出 a 的逆元,时间复杂度 O(log max{a,p})
- 线性递推逆元表:
- 若 p 为素数,则设 p = k * q + r,可得 k * q + r ≡ 0 (mod p)
- 两边同乘 q-1 · r-1 ,得 k * r-1 + q-1 ≡ 0 (mod p)
- q-1 = - k * r-1 (mod p)
- 将 k = ⌊p / q⌋,r = p % q 代入上式得
- q-1 = - ⌊p / q⌋ * (p % q)-1 (mod p)
- 由于 p % q < q,从小到大递推即可,边界为 1-1 = 1,时间复杂度 O(n)
- 线性递推阶乘逆元表:
- 从小到大递推求出 a! (1 ≤ a < p)
- 用方法 1 求出 (p-1)!-1
- 从大到小递推求出 a!-1
- 阶乘逆元表可以和逆元表相互转化,此方法时间复杂度 O(n)
由于当 p | a 时 a 的逆元不存在,有一种方法可以避开求逆元:
- (a / b) % p == a % (b * p) / b % p
此时 b * p 可能会溢出。
注意本题套用等比数列求和公式时可能出现负数,需要特判将其转化为正数!!!
Code: O(A0.5logp), 其中p=9901 [660K, 16MS]
#include<cstdio> #include<cstdlib> #include<cstring> #include<cmath> #include<iostream> #include<algorithm> #define MOD 9901LL using namespace std; typedef long long ll; ll A, B, sqrtA, _A, cnt, ans; inline ll mul(ll a, ll b, ll mod){ a %= mod, b %= mod; ll res = 0; while(b){ if(b & 1) res = (res + a < mod) ? res + a : res + a - mod; a = ((a << 1) < mod) ? (a << 1) : (a << 1) - mod; b >>= 1; } return res; } inline ll fastpow(ll bas, ll ex, ll mod){ ll res = 1; while(ex){ if(ex & 1) res = mul(res, bas, mod); bas = mul(bas, bas, mod); ex >>= 1; } return res; } inline ll inv(const ll &x) {return fastpow(x, MOD - 2, MOD);} int main(){ scanf("%lld%lld", &A, &B); sqrtA = (int)sqrt(A), ans = 1; for(register int i = 2; i <= sqrtA; i++) if(A % i == 0){ cnt = 0; while(A % i == 0) A /= i, cnt++; if((i - 1) % MOD) ans = ans * (fastpow(i, cnt * B + 1, MOD) - 1) * inv(i - 1) % MOD; else ans = ans * ((fastpow(i, cnt * B + 1, MOD * (i - 1)) - 1) / (i - 1)) % MOD; } if(A > 1){ if((A - 1) % MOD) ans = ans * (fastpow(A, B + 1, MOD) - 1) * inv(A - 1) % MOD; else ans = ans * ((fastpow(A, B + 1, MOD * (A - 1)) - 1) / (A - 1)) % MOD; } if(ans < 0) ans += MOD; // If MOD | i, then i^x % MOD - 1 == -1 !!! printf("%lld\n", ans); return 0; }
发表评论