[BZOJ 1036][ZJOI 2008]树的统计Count【树链剖分】
Problem:
Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作
I. CHANGE u t : 把结点u的权值改为t。
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值。
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和。
注意:从点u到点v的路径上的节点包括u和v本身。
Input
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Output
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
Sample Input
4 1 2 2 3 4 1 4 2 1 3 12 QMAX 3 4 QMAX 3 3 QMAX 3 2 QMAX 2 3 QSUM 3 4 QSUM 2 1 CHANGE 1 5 QMAX 3 4 CHANGE 3 6 QMAX 3 4 QMAX 2 4 QSUM 3 4
Sample Output
4 1 2 2 10 6 5 6 5 16
HINT
Source
Solution:
本题是树链剖分的模板题,相比于 树链剖分模板 中的例题,只需在线段树上多维护一个区间 max 信息即可。
需要注意的是,本题涉及线段树上直接单点修改,由于在求 DFS 序时给每个节点分配了一个在线段树上的位置 id[],修改时不能忘记将读入的节点 u 转化为线段树上位置 id[u] 后再处理。
Code: O(nlogn+qlog2n) [4276K, 1904MS]
#include<cstdio> #include<cstdlib> #include<cstring> #include<cmath> #include<iostream> #include<algorithm> #define MAXN 30005 using namespace std; inline void getint(int &num){ char ch; bool neg = 0; while(!isdigit(ch = getchar())) if(ch == '-') neg = 1; num = ch - '0'; while(isdigit(ch = getchar())) num = num * 10 + ch - '0'; if(neg) num = -num; } int n, w[MAXN], q, tope = 0; char opt[15]; int dep[MAXN], fa[MAXN], sz[MAXN], hson[MAXN]; int dfn[MAXN], dfstime = 0, top[MAXN], nw[MAXN]; int sum[MAXN << 2], mx[MAXN << 2]; #define lc (u << 1) #define rc (u << 1 | 1) inline void update(int u) {sum[u] = sum[lc] + sum[rc], mx[u] = max(mx[lc], mx[rc]);} inline void build(int u, int l, int r){ if(l == r) {sum[u] = mx[u] = nw[l]; return;} int mid = l + r >> 1; if(l <= mid) build(lc, l, mid); if(mid < r) build(rc, mid + 1, r); update(u); } inline void modify(int pos, int v){ int u = 1, l = 1, r = n, *stk = new int[20], tops = 0; pos = dfn[pos]; // Beware that we must use the NEW index on the segment tree !!! while(l < r){ stk[tops++] = u; int mid = l + r >> 1; if(pos <= mid) u = lc, r = mid; else u = rc, l = mid + 1; } sum[u] = mx[u] = v; while(tops) update(stk[--tops]); delete []stk; } inline int query_sum(int u, int l, int r, int ql, int qr){ if(l == ql && r == qr) return sum[u]; int mid = l + r >> 1; if(qr <= mid) return query_sum(lc, l, mid, ql, qr); else if(ql > mid) return query_sum(rc, mid + 1, r, ql, qr); else return query_sum(lc, l, mid, ql, mid) + query_sum(rc, mid + 1, r, mid + 1, qr); } inline int query_max(int u, int l, int r, int ql, int qr){ if(l == ql && r == qr) return mx[u]; int mid = l + r >> 1; if(qr <= mid) return query_max(lc, l, mid, ql, qr); else if(ql > mid) return query_max(rc, mid + 1, r, ql, qr); else return max(query_max(lc, l, mid, ql, mid), query_max(rc, mid + 1, r, mid + 1, qr)); } struct Edge{ int np; Edge *nxt; } E[MAXN << 1], *V[MAXN]; inline void addedge(int u, int v) {E[++tope].np = v, E[tope].nxt = V[u], V[u] = &E[tope];} inline void dfs1(int u, int f, int d){ dep[u] = d, fa[u] = f, sz[u] = 1, hson[u] = 0; for(register Edge *ne = V[u]; ne; ne = ne->nxt){ if(ne->np == f) continue; dfs1(ne->np, u, d + 1); sz[u] += sz[ne->np]; if(sz[ne->np] > sz[hson[u]]) hson[u] = ne->np; } } inline void dfs2(int u, int tp){ top[u] = tp, dfn[u] = ++dfstime, nw[dfstime] = w[u]; if(!hson[u]) return; dfs2(hson[u], tp); for(register Edge *ne = V[u]; ne; ne = ne->nxt) if(ne->np != fa[u] && ne->np != hson[u]) dfs2(ne->np, ne->np); } inline int path_sum(int u, int v){ int res = 0; while(top[u] != top[v]){ if(dep[top[u]] < dep[top[v]]) swap(u, v); res += query_sum(1, 1, n, dfn[top[u]], dfn[u]); u = fa[top[u]]; } if(dep[u] > dep[v]) swap(u, v); res += query_sum(1, 1, n, dfn[u], dfn[v]); return res; } inline int path_max(int u, int v){ int res = 0xc0c0c0c0; while(top[u] != top[v]){ if(dep[top[u]] < dep[top[v]]) swap(u, v); res = max(res, query_max(1, 1, n, dfn[top[u]], dfn[u])); u = fa[top[u]]; } if(dep[u] > dep[v]) swap(u, v); res = max(res, query_max(1, 1, n, dfn[u], dfn[v])); return res; } int main(){ getint(n); for(register int i = 1; i < n; i++){ int u, v; getint(u), getint(v); addedge(u, v), addedge(v, u); } for(register int i = 1; i <= n; i++) getint(w[i]); dfs1(1, 0, 1), dfs2(1, 1); build(1, 1, n), getint(q); while(q--){ int u, v; scanf("%s", opt), getint(u), getint(v); if(opt[1] == 'M') printf("%d\n", path_max(u, v)); else if(opt[1] == 'S') printf("%d\n", path_sum(u, v)); else if(opt[1] == 'H') modify(u, v); } return 0; }
发表评论