【算法】最近公共祖先

LCA

让我们先回顾树上LCA(Lowest Common Ancestor)LCA(Lowest\ Common\ Ancestor) 的定义,对于有向树TT中任意结点对<x,y><x, y>

(1)A(x)={a  a=x  P(P: ax)}A(x) = \{a\ |\ a = x \ \lor\ \exists P(P:\ a \rightarrow \cdots \rightarrow x)\} \tag{1}

(2)z=LCA(x,y)=arg maxfA(x)A(y) DEPTH(f)z = LCA(x, y) \\ = \mathop{arg\ max}\limits_{f\in A(x) \cap A(y)}\ \mathrm{DEPTH}(f ) \tag{2}

x,yx, yLCA(z)LCA(z) 是他们的公共祖先,且在树中 的深度最大,当然 zz 可以是 xx 或者 yy 。之所以称为最近公共祖先,是因为 zzx,yx, y 距离之和最小

我们关心 x,yx, y 间的距离,往往是因为原输入 T0T_0 是无向树, x,yx, yT0T_0 中是可达的,所以有向树 TT 只是在 T0T_0 的基础上随便选了个根节点而已。因此如果 T0T_0 中有简单路径 Pxy: xtyP_{xy}:\ x \leftrightarrow \cdots t \cdots \leftrightarrow y ,这里并不是任取中间结点 tt ,而是选取出这样的 tPxyt \in P_{xy} 满足 t(tPxyDEPTH(t)DEPTH(t))\forall t^{\prime}(t^{\prime} \in P_{xy} \rightarrow \mathrm{DEPTH}(t) \le \mathrm{DEPTH}(t^{\prime})) ,所以 PxyP_{xy} 对应 TT 中两条路径(为了叙述方便,认为结点到本身自然有路径)

(3)Ptx: txPty: tytA(x)A(y)P_{tx}:\ t \rightarrow \cdots \rightarrow x\quad P_{ty}:\ t \rightarrow \cdots \rightarrow y\\ \Rightarrow t \in A(x) \cap A(y) \tag{3}

则仅当 z=tz = tx,yx, y 有最短距离

(4)DIS(x,y)=DIS(z,x)+DIS(z,y)=(DEPTH(x)DEPTH(z))+(DEPTH(y)DEPTH(z))=DEPTH(x)+DEPTH(y)2×DEPTH(z)\mathrm{DIS}(x, y) = \mathrm{DIS}(z, x) + \mathrm{DIS}(z, y)\\ =(\mathrm{DEPTH}(x) - \mathrm{DEPTH}(z)) + (\mathrm{DEPTH}(y) - \mathrm{DEPTH}(z))\\ =\mathrm{DEPTH}(x)+ \mathrm{DEPTH}(y) - 2\times \mathrm{DEPTH}(z) \tag{4}

有时我们更关心 x,yx, y 的最短带权距离,假设每条边都有一个非负权值,比如运输 网之类的。根据 (3)(3) 易知此时最短路径仍是以最近公共祖先 zz 为中转点,因此只要求 DIS(z,x),DIS(z,y)\mathrm{DIS}(z, x), \mathrm{DIS}(z, y),两条固定路径上的权值和

进一步地,如果要求这是一个在线算法呢,即边输入边查询,边权值动态变化。类似的在数组的 在线 RMQRMQ (求定义域为下标区间的函数,如区间和)问题中我们手搓一个树状数组,线段树之类的递归树结构就可以解决,而且单次查询可以达到的时间复杂度。

那么对于一棵树,其实我们把结点都按某个顺序存到

(关于如何求解 LCALCA ,最易实现的是二分查找,下面给出一个古早的树链剖分求LCALCAc实现,记不清了,修改下文代码中qpath函数应该就能求LCA)

重链剖分

#include <bits/stdc++.h>
const int MAXN = 1e5 + 7;
int MOD = 0;
using namespace std;

long long addin;
inline int read();
inline long long add(long long a, long long b) {return (a + b) % MOD;}
inline long long mul(long long a, long long b) {return (a * b) % MOD;}
inline int lson(int i) {return i << 1;}
inline int rson(int i) {return (i << 1) | 1;}

typedef struct node_ {
    int l, r, m, ls, rs;
    long long laz, sum;
} node;

typedef struct e_ {
    int to, next;
} edge;

node ns[MAXN << 2];
edge es[MAXN << 1];
int w[MAXN], cnt, heads[MAXN];
int pa[MAXN], dep[MAXN], hs[MAXN], sz[MAXN], top[MAXN], dfn[MAXN], rnk[MAXN], rear[MAXN];
int t;

long long build_tree(int l, int r, int p) {
    node &d = ns[p];
    d.l = l, d.r = r, d.m = (l + r) >> 1;
    if (l == r) {
        d.sum = w[rnk[l]];
        return d.sum;
    }
    d.ls = p << 1;
    d.rs = d.ls + 1;
    return d.sum = add(build_tree(l, d.m, d.ls), build_tree(d.m + 1, r, d.rs));
}

inline void push_down(int p) {
    node &ls = ns[ns[p].ls];
    node &rs = ns[ns[p].rs];
    long long &laz = ns[p].laz;
    if (ns[p].laz) {
        ls.laz = add(ls.laz, laz), ls.sum = add(ls.sum, mul(laz, ns[p].m - ns[p].l + 1));
        rs.laz = add(rs.laz, laz), rs.sum = add(rs.sum, mul(laz, ns[p].r - ns[p].m));
        laz = 0;
    }
}

long long update(int x, int y, int p) {
    int &l = ns[p].l;
    int &r = ns[p].r;
    if (x <= r && l <= y) {
        if (r <= y && l >= x) 
            ns[p].laz = add(ns[p].laz, addin), ns[p].sum = add(ns[p].sum, mul(r - l + 1, addin));
        else 
            push_down(p), ns[p].sum = add(update(x, y, ns[p].ls), update(x, y, ns[p].rs));
    }
    return ns[p].sum;
}

long long qsum(int x, int y, int p) {
    int &l = ns[p].l;
    int &r = ns[p].r;
    long long res = 0;
    if (r <= y && l >= x) return ns[p].sum;
    push_down(p);
    if (x <= ns[p].m) res = add(res, qsum(x, y, ns[p].ls));
    if (y >= ns[p].m + 1) res = add(res, qsum(x, y, ns[p].rs));
    return res;
}

inline void add_e(int x, int y) {es[++cnt] = {y, heads[x]}, heads[x] = cnt;}

void dfs0(int u, int f) {
    sz[u] = 1;
    pa[u] = f;
    dep[u] = dep[f] + 1;
    for (int i = heads[u]; i; i = es[i].next) {
        int &v = es[i].to;
        if (v != f) {
            dfs0(v, u);
            sz[u] += sz[v];
            if (!hs[u] || sz[v] > sz[hs[u]]) hs[u] = v;
        }
    }
}

void dfs1(int u, int f, int smit) {
    dfn[u] = ++t;
    rnk[t] = u;
    top[u] = smit;
    if (hs[u]) dfs1(hs[u], u, smit);
    for (int i = heads[u]; i; i = es[i].next) {
        int &v = es[i].to;
        if (v != f && v != hs[u])
            dfs1(v, u, v);
    }
    rear[u] = t;
}

long long qpath(int x, int y) {
    long long res = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] > dep[top[y]]) swap(x, y);
        res = add(res, qsum(dfn[top[y]], dfn[y], 1));
        y = pa[top[y]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    return add(res, qsum(dfn[x], dfn[y], 1));
}

void apath(int x, int y) {
    while (top[x] != top[y]) {
        if (dep[top[x]] > dep[top[y]]) swap(x, y);
        update(dfn[top[y]], dfn[y], 1);
        y = pa[top[y]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    update(dfn[x], dfn[y], 1);
}

int main() {
    int n, m, root, x, y, op;
    n = read(), m = read(), root = read(), MOD = read();
    dep[0] = -1;
    for (int i = 1; i <= n; i++) w[i] = read();
    for (int i = 0; i < n - 1; i++) 
        x = read(), y = read(), add_e(x, y), add_e(y, x);
    dfs0(root, 0);
    dfs1(root, 0, root);
    build_tree(1, n, 1);  
    for (int i = 0; i < m; i++) {
        op = read(), x = read();
        if (op == 4) printf("%lld\n", qsum(dfn[x], rear[x], 1));
        else if (op == 3) addin = read(), update(dfn[x], rear[x], 1);
        else {
            y = read();
            if (op == 2) printf("%lld\n", qpath(x, y));
            else addin = read(), apath(x, y);
        }
    }
}

inline int read() {
    int x = 0, f = 1; char c = getchar();
    while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
    while (isdigit(c)) x = (x << 3) + (x << 1) + c - '0', c = getchar();
    return x * f;
}