[POJ 3378] Crazy Thairs【离散化+DP+树状数组】
Problem:
Time Limit: 3000MS | Memory Limit: 65536K |
Description
These days, Sempr is crazed on one problem named Crazy Thair. Given N (1 ≤ N ≤ 50000) numbers, which are no more than 109, Crazy Thair is a group of 5 numbers {i, j, k, l,m} satisfying:
1. 1 ≤ i < j < k < l < m ≤ N
2. Ai < Aj < Ak < Al < Am
For example, in the sequence {2, 1, 3, 4, 5, 7, 6},there are four Crazy Thair groups: {1, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, {1, 3, 4, 5, 7} and {2, 3, 4, 5, 7}.
Could you help Sempr to count how many Crazy Thairs in the sequence?
Input
Input contains several test cases. Each test case begins with a line containing a number N, followed by a line containing N numbers.
Output
Output the amount of Crazy Thairs in each sequence.
Sample Input
5 1 2 3 4 5 7 2 1 3 4 5 7 6 7 1 2 3 4 5 6 7
Sample Output
1 4 21
Source
Solution:
这道题描述很简洁,思路也很显而易见。
令 dp[i][j] 为前 i 个数中选取 j 个严格升序的数的方案数,则
我们发现 ∑ 的条件 k < i 只要从小到大枚举即可,而 a[k] < a[i] 用树状数组维护求和即可。
由于 N ≤ 50000,答案的最大值出现在 1 ~ 50000 顺序排列时,可达 C(50000, 5) = 2603645869790625010000,超出 unsigned long long (ull) 范围。由于 C(50000, 4) = 260385417812487500 不会溢出,我们可以对 j = 1 ~ 4 的 dp 数组用 ull 存储,而 j = 5 时开两个 ull,分别存答案的后 11 位和前 10 位,模拟高精度加法即可。这样就避免了繁琐的高精度运算。
Code: O(TNlogN) [3404K, 250MS]
#include<cstdio> #include<cstdlib> #include<cstring> #include<cmath> #include<iostream> #include<algorithm> using namespace std; typedef unsigned long long ull; int N, a[50005]; int disc[50005]; ull dp[50005][5], ans[2]; struct BIT{ ull node[50005]; inline void clear() {memset(node, 0, sizeof(node));} inline void add(int u, const ull &v) {while(u < 50005) node[u] += v, u += u & -u;} inline ull query(int u) {ull res = 0; while(u) res += node[u], u -= u & -u; return res;} } b; int main(){ while(scanf("%d", &N) != EOF){ for(register int i = 1; i <= N; i++) scanf("%d", a + i), disc[i] = a[i]; sort(disc + 1, disc + N + 1); int dsize = unique(disc + 1, disc + N + 1) - disc - 1; for(register int i = 1; i <= N; i++) a[i] = lower_bound(disc + 1, disc + dsize + 1, a[i]) - disc; // Discretization for(register int i = 1; i <= N; i++) dp[i][1] = 1; for(register int j = 2; j <= 4; j++){ b.clear(); for(register int i = j; i <= N; i++){ b.add(a[i - 1], dp[i - 1][j - 1]); dp[i][j] = b.query(a[i] - 1); } } b.clear(), ans[0] = ans[1] = 0; for(register int i = 5; i <= N; i++){ b.add(a[i - 1], dp[i - 1][4]); ans[0] += b.query(a[i] - 1); ans[1] += ans[0] / 100000000000ULL; ans[0] %= 100000000000ULL; } if(ans[1]) printf("%llu", ans[1]); printf("%llu\n", ans[0]); } return 0; }
发表评论