LCA
让我们先回顾树上 的定义,对于有向树中任意结点对
即 的 是他们的公共祖先,且在树中 的深度最大,当然 可以是 或者 。之所以称为最近公共祖先,是因为 到 距离之和最小
我们关心 间的距离,往往是因为原输入 是无向树, 在 中是可达的,所以有向树 只是在 的基础上随便选了个根节点而已。因此如果 中有简单路径 ,这里并不是任取中间结点 ,而是选取出这样的 满足 ,所以 对应 中两条路径(为了叙述方便,认为结点到本身自然有路径)
则仅当 时 有最短距离
有时我们更关心 的最短带权距离,假设每条边都有一个非负权值,比如运输 网之类的。根据 易知此时最短路径仍是以最近公共祖先 为中转点,因此只要求 ,两条固定路径上的权值和
进一步地,如果要求这是一个在线算法呢,即边输入边查询,边权值动态变化。类似的在数组的 在线 (求定义域为下标区间的函数,如区间和)问题中我们手搓一个树状数组,线段树之类的递归树结构就可以解决,而且单次查询可以达到的时间复杂度。
那么对于一棵树,其实我们把结点都按某个顺序存到
(关于如何求解 ,最易实现的是二分查找,下面给出一个古早的树链剖分求的c
实现,记不清了,修改下文代码中qpath
函数应该就能求LCA
)
重链剖分
#include <bits/stdc++.h>
const int MAXN = 1e5 + 7;
int MOD = 0;
using namespace std;
long long addin;
inline int read();
inline long long add(long long a, long long b) {return (a + b) % MOD;}
inline long long mul(long long a, long long b) {return (a * b) % MOD;}
inline int lson(int i) {return i << 1;}
inline int rson(int i) {return (i << 1) | 1;}
typedef struct node_ {
int l, r, m, ls, rs;
long long laz, sum;
} node;
typedef struct e_ {
int to, next;
} edge;
node ns[MAXN << 2];
edge es[MAXN << 1];
int w[MAXN], cnt, heads[MAXN];
int pa[MAXN], dep[MAXN], hs[MAXN], sz[MAXN], top[MAXN], dfn[MAXN], rnk[MAXN], rear[MAXN];
int t;
long long build_tree(int l, int r, int p) {
node &d = ns[p];
d.l = l, d.r = r, d.m = (l + r) >> 1;
if (l == r) {
d.sum = w[rnk[l]];
return d.sum;
}
d.ls = p << 1;
d.rs = d.ls + 1;
return d.sum = add(build_tree(l, d.m, d.ls), build_tree(d.m + 1, r, d.rs));
}
inline void push_down(int p) {
node &ls = ns[ns[p].ls];
node &rs = ns[ns[p].rs];
long long &laz = ns[p].laz;
if (ns[p].laz) {
ls.laz = add(ls.laz, laz), ls.sum = add(ls.sum, mul(laz, ns[p].m - ns[p].l + 1));
rs.laz = add(rs.laz, laz), rs.sum = add(rs.sum, mul(laz, ns[p].r - ns[p].m));
laz = 0;
}
}
long long update(int x, int y, int p) {
int &l = ns[p].l;
int &r = ns[p].r;
if (x <= r && l <= y) {
if (r <= y && l >= x)
ns[p].laz = add(ns[p].laz, addin), ns[p].sum = add(ns[p].sum, mul(r - l + 1, addin));
else
push_down(p), ns[p].sum = add(update(x, y, ns[p].ls), update(x, y, ns[p].rs));
}
return ns[p].sum;
}
long long qsum(int x, int y, int p) {
int &l = ns[p].l;
int &r = ns[p].r;
long long res = 0;
if (r <= y && l >= x) return ns[p].sum;
push_down(p);
if (x <= ns[p].m) res = add(res, qsum(x, y, ns[p].ls));
if (y >= ns[p].m + 1) res = add(res, qsum(x, y, ns[p].rs));
return res;
}
inline void add_e(int x, int y) {es[++cnt] = {y, heads[x]}, heads[x] = cnt;}
void dfs0(int u, int f) {
sz[u] = 1;
pa[u] = f;
dep[u] = dep[f] + 1;
for (int i = heads[u]; i; i = es[i].next) {
int &v = es[i].to;
if (v != f) {
dfs0(v, u);
sz[u] += sz[v];
if (!hs[u] || sz[v] > sz[hs[u]]) hs[u] = v;
}
}
}
void dfs1(int u, int f, int smit) {
dfn[u] = ++t;
rnk[t] = u;
top[u] = smit;
if (hs[u]) dfs1(hs[u], u, smit);
for (int i = heads[u]; i; i = es[i].next) {
int &v = es[i].to;
if (v != f && v != hs[u])
dfs1(v, u, v);
}
rear[u] = t;
}
long long qpath(int x, int y) {
long long res = 0;
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]]) swap(x, y);
res = add(res, qsum(dfn[top[y]], dfn[y], 1));
y = pa[top[y]];
}
if (dep[x] > dep[y]) swap(x, y);
return add(res, qsum(dfn[x], dfn[y], 1));
}
void apath(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]]) swap(x, y);
update(dfn[top[y]], dfn[y], 1);
y = pa[top[y]];
}
if (dep[x] > dep[y]) swap(x, y);
update(dfn[x], dfn[y], 1);
}
int main() {
int n, m, root, x, y, op;
n = read(), m = read(), root = read(), MOD = read();
dep[0] = -1;
for (int i = 1; i <= n; i++) w[i] = read();
for (int i = 0; i < n - 1; i++)
x = read(), y = read(), add_e(x, y), add_e(y, x);
dfs0(root, 0);
dfs1(root, 0, root);
build_tree(1, n, 1);
for (int i = 0; i < m; i++) {
op = read(), x = read();
if (op == 4) printf("%lld\n", qsum(dfn[x], rear[x], 1));
else if (op == 3) addin = read(), update(dfn[x], rear[x], 1);
else {
y = read();
if (op == 2) printf("%lld\n", qpath(x, y));
else addin = read(), apath(x, y);
}
}
}
inline int read() {
int x = 0, f = 1; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0', c = getchar();
return x * f;
}