[BZOJ 5248][九省联考 2018] 一双木棋【轮廓线DP+对抗搜索】
Problem:
Description
Input
Output
Sample Input
2 3 2 7 3 9 1 2 3 7 2 2 3 1
Sample Output
2
HINT

Source
Solution:
这是一道神奇的轮廓线 DP 题。所谓轮廓线 DP,就是一类以当前已取部分和未取部分的轮廓线为状态的状压 DP。
在这道题中,我们假定左下角为起始点,轮廓线延伸的方向只能向右或向上。
记向右延伸为 0,向上延伸为 1,则轮廓线可以描述为一个长为 n + m 的二进制串。那么
- 初始状态可表示为 11...100...0 (n 个 1, m 个 0) = 2n + m - 2m
- 终止状态可表示为 00...011...1 (m 个 0, n 个 1) = 2n - 1。
考虑 DP 的状态转移方程,令 f[S][1] 为轮廓线状态为 S 时,选取轮廓线以下部分所能得到的最大收益,f[S][0] 则表示同样条件下所能得到的最小收益,nxtS 表示 S 的后继状态,x, y 为该后继状态与 S 相比少选取的点的坐标。则不难得到
- f[S][1] = max{f[nxtS][0]} + A[x][y]
- f[S][0] = min{f[nxtS][1]} - B[x][y]
通过记忆化搜索可以较方便地求解上式。这种一方最大化答案,另一方最小化答案的搜索被称为对抗搜索。
最后的问题是如何求解后继状态。
通过观察可知,只有连续两位为 "10" 处是合法的选取位置,而选择后轮廓线会转化为 "01"。
枚举每一位检查是否符合条件,并进行转移即可。
显然,该算法的时间复杂度为 O(状态数),而由于只能向右或向上走的限制,状态二进制串中的 0 和 1 个数保持不变,所以最多有 C(n + m, n) ≤ 184756 种不同状态,可以通过。
Code: O(状态数) [9480K, 420MS]
#include<cstdio> #include<cstdlib> #include<cstring> #include<cmath> #include<iostream> #include<algorithm> #define INF 0x3f3f3f3f using namespace std; inline void getint(int &num){ int ch; bool neg = 0; while(!isdigit(ch = getchar())) if(ch == '-') neg = 1; num = ch & 15; while(isdigit(ch = getchar())) num = num * 10 + (ch & 15); if(neg) num = -num; } inline void checkmax(int &a, const int &b) {if(b > a) a = b;} inline void checkmin(int &a, const int &b) {if(b < a) a = b;} int n, m, A[12][12], B[12][12], f[1 << 20][2]; // 轮廓线: 1 表示向上, 0 表示向右 int maxex, sumB = 0; int dfs(int S, bool flag){ // 返回轮廓线 S 以下部分的最大或最小值 (取决于 flag) if(f[S][flag] < INF) return f[S][flag]; int x = n + 1, y = 0, res = flag ? 0xc0c0c0c0 : 0x3f3f3f3f; for(register int ex = maxex - 1; ~ex; ex--) if(S & (1 << ex)) x--; else{ y++; if(S & (1 << ex + 1)){ int nxtS = S ^ (1 << ex) ^ (1 << ex + 1); if(flag) checkmax(res, dfs(nxtS, !flag) + A[x][y]); else checkmin(res, dfs(nxtS, !flag) - B[x][y]); } } return f[S][flag] = res; } int main(){ getint(n), getint(m); for(register int i = 1; i <= n; i++) for(register int j = 1; j <= m; j++) getint(A[i][j]); for(register int i = 1; i <= n; i++) for(register int j = 1; j <= m; j++) getint(B[i][j]); memset(f, INF, sizeof(f)), f[(1 << n) - 1][0] = f[(1 << n) - 1][1] = 0; maxex = n + m, printf("%d\n", dfs((1 << n + m) - (1 << m), 1)); return 0; }
发表评论