跳转至

8 数据结构

8.1 扩展树状数组

注意,被查询的值都应该小于等于 \(N\) ,否则会越界;如果离散化不可使用,则需要使用平衡树替代。

struct BIT {
    int n;
    vector<int> w;
    BIT(int n) : n(n), w(n + 1) {}
    void add(int x, int v) {
        for (; x <= n; x += x & -x) {
            w[x] += v;
        }
    }
    int kth(int x) { // 查找第 k 小的值
        int ans = 0;
        for (int i = __lg(n); i >= 0; i--) {
            int val = ans + (1 << i);
            if (val < n && w[val] < x) {
                x -= w[val];
                ans = val;
            }
        }
        return ans + 1;
    }
    int get(int x) { // 查找 x 的排名
        int ans = 1;
        for (x--; x; x -= x & -x) {
            ans += w[x];
        }
        return ans;
    }
    int pre(int x) { return kth(get(x) - 1); } // 查找 x 的前驱
    int suf(int x) { return kth(get(x + 1)); } // 查找 x 的后继
};
const int N = 10000000; // 可以用于在线处理平衡二叉树的全部要求
signed main() {
    BIT bit(N + 1); // 在线处理不能够离散化,一定要开到比最大值更大
    int n;
    cin >> n;
    for (int i = 1; i <= n; i++) {
        int op, x;
        cin >> op >> x;
        if (op == 1) bit.add(x, 1); // 插入 x
        else if (op == 2) bit.add(x, -1); // 删除任意一个 x
        else if (op == 3) cout << bit.get(x) << "\n"; // 查询 x 的排名
        else if (op == 4) cout << bit.kth(x) << "\n"; // 查询排名为 x 的数
        else if (op == 5) cout << bit.pre(x) << "\n"; // 求小于 x 的最大值(前驱)
        else if (op == 6) cout << bit.suf(x) << "\n"; // 求大于 x 的最小值(后继)
    }
}

8.2 可持久化线段树(主席树)

单点改,单点查。

#include <bits/stdc++.h>
#define int long long
#define mid ((l+r)>>1)
using namespace std;
constexpr int N = 1e6 + 5;
constexpr int M = 25 * N;
int a[N], n, m;
int ls[M], rs[M], val[M];
int root[N], tot;
void build(int &u, int l, int r) {
    u = ++tot;
    if (l == r) {
        val[u] = a[l];
        return;
    }
    build(ls[u], l, mid);
    build(rs[u], mid+1, r);
}
void change(int &u, int v, int l, int r, int p, int c) {
    u = ++tot;
    ls[u] = ls[v], rs[u] = rs[v], val[u] = val[v];
    if (l == r) {
        val[u] = c;
        return;
    }
    if (p <= mid) change(ls[u], ls[v], l, mid, p, c);
    else change(rs[u], rs[v], mid+1, r, p, c);
}
int query(int u, int l, int r, int p) {
    if (l == r) return val[u];
    if (p <= mid) return query(ls[u], l, mid, p);
    else return query(rs[u], mid+1, r, p);
}
signed main() {
    ios::sync_with_stdio(false), cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    build(root[0], 1, n);
    for (int i = 1; i <= m; i++) {
        int v, op, p, c;
        cin >> v >> op >> p;
        if (op == 1) {
            cin >> c;
            change(root[i], root[v], 1, n, p, c);
        } else if (op == 2) {
            root[i] = root[v];
            cout << query(root[i], 1, n, p) << "\n";
        }
    }
}

单点改,二分查(静态区间第k小)

#include <bits/stdc++.h>
#define int long long
#define ls(u) w[u].l
#define rs(u) w[u].r
#define mid ((l+r)>>1)
using namespace std;
constexpr int N = 1e6 + 5;
int n, m, a[N], b[N];
struct Node {
    int l, r, sum;
} w[N*25];
int root[N], tot;
void pushup(int u) {
    w[u].sum = w[ls(u)].sum + w[rs(u)].sum;
}
void change(int &u, int v, int l, int r, int p) {
    u = ++tot;
    w[u] = w[v];
    if (l == r) {
        w[u].sum++;
        return;
    }
    if (p <= mid) change(ls(u), ls(v), l, mid, p);
    else change(rs(u), rs(v), mid+1, r, p);
    pushup(u);
}
int query(int L, int R, int l, int r, int k) {
    int sum = w[ls(R)].sum - w[ls(L)].sum;
    if (l == r) return l;
    if (sum >= k) return query(ls(L), ls(R), l, mid, k);
    else return query(rs(L), rs(R), mid+1, r, k - sum);
}
signed main() {
    ios::sync_with_stdio(false), cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        b[i] = a[i];
    }
    sort(b+1,b+1+n);
    auto ed = unique(b+1,b+1+n);
    for (int i = 1; i <= n; i++) {
        int t = lower_bound(b+1,ed,a[i]) - b;
        // cerr << t << " ";
        change(root[i], root[i-1], 1, n, t);
    }
    // cerr << "\n";
    for (int i = 1; i <= m; i++) {
        int L, R, k;
        cin >> L >> R >> k;
        int idx = query(root[L-1], root[R], 1, n, k);
        // cerr << idx << "\n";
        cout << b[idx] << "\n";
    }
}

8.3 重链剖分 (by Jiangly)

struct HLD {
    int n;
    vector<int> siz, top, dep, parent, in, out, seq;
    vector<vector<int>> adj;
    int cur;

    HLD() {}
    HLD(int n) {
        init(n);
    }
    void init(int n) {
        this->n = n;
        siz.resize(n);
        top.resize(n);
        dep.resize(n);
        parent.resize(n);
        in.resize(n);
        out.resize(n);
        seq.resize(n);
        cur = 0;
        adj.assign(n, {});
    }
    void addEdge(int u, int v) {
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    void work(int root = 0) {
        top[root] = root;
        dep[root] = 0;
        parent[root] = -1;
        dfs1(root);
        dfs2(root);
    }
    void dfs1(int u) {
        if (parent[u] != -1) {
            adj[u].erase(find(adj[u].begin(), adj[u].end(), parent[u]));
        }

        siz[u] = 1;
        for (auto &v : adj[u]) {
            parent[v] = u;
            dep[v] = dep[u] + 1;
            dfs1(v);
            siz[u] += siz[v];
            if (siz[v] > siz[adj[u][0]]) {
                swap(v, adj[u][0]);
            }
        }
    }
    void dfs2(int u) {
        in[u] = cur++;
        seq[in[u]] = u;
        for (auto v : adj[u]) {
            top[v] = v == adj[u][0] ? top[u] : v;
            dfs2(v);
        }
        out[u] = cur;
    }
    int lca(int u, int v) {
        while (top[u] != top[v]) {
            if (dep[top[u]] > dep[top[v]]) {
                u = parent[top[u]];
            } else {
                v = parent[top[v]];
            }
        }
        return dep[u] < dep[v] ? u : v;
    }

    int dist(int u, int v) {
        return dep[u] + dep[v] - 2 * dep[lca(u, v)];
    }

    int jump(int u, int k) {
        if (dep[u] < k) {
            return -1;
        }

        int d = dep[u] - k;

        while (dep[top[u]] > d) {
            u = parent[top[u]];
        }

        return seq[in[u] - dep[u] + d];
    }

    bool isAncester(int u, int v) {
        return in[u] <= in[v] && in[v] < out[u];
    }

    int rootedParent(int u, int v) {
        swap(u, v);
        if (u == v) {
            return u;
        }
        if (!isAncester(u, v)) {
            return parent[u];
        }
        auto it = upper_bound(adj[u].begin(), adj[u].end(), v, [&](int x, int y) {
            return in[x] < in[y];
        }) - 1;
        return *it;
    }

    int rootedSize(int u, int v) {
        if (u == v) {
            return n;
        }
        if (!isAncester(v, u)) {
            return siz[v];
        }
        return n - siz[rootedParent(u, v)];
    }

    int rootedLca(int a, int b, int c) {
        return lca(a, b) ^ lca(b, c) ^ lca(c, a);
    }
};

8.4 ST表

struct ST {
    vector<vector<int>> st;
    ST (const vector<int>& arr) {
        int n = arr.size();
        st.assign(__lg(n) + 1, vector<int>(n)); // 维度小的放前面
        for (int i = 0; i < n; ++i) st[0][i] = arr[i];
        for (int j = 1; 1 << j <= n; ++j) {
            for (int i = 0; i + (1 << j) <= n; ++i) {
                st[j][i] = max(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
            }
        }
    }
    int query(int l, int r) { // [l, r]
        int k = __lg(r - l + 1);
        return max(st[k][l], st[k][r - (1 << k) + 1]);
    }
};

8.5 字典树 Trie

字符串的插入和查询,01字典树维护异或极值

class Trie
{
private:
    int amount;
    vector<vector<int>> son;
    vector<int> cnt, prefix_cnt;
    // flag为true表示 查询该字符串数量
    // flag为false表示 查询以该字符串为前缀的字符串的数量
    int query(const string &str, bool flag)
    {
        int p = 0, image;
        for (auto ch : str)
        {
            image = f(ch);
            if (!son[p][image])
                return 0;
            p = son[p][image];
        }
        return flag ? cnt[p] : prefix_cnt[p];
    }
    function<int(char)> f;

public:
    // f的返回值应该在[0, n)范围内
    Trie(function<int(char)> function, int n) : son(1, vector<int>(n)), cnt(1), prefix_cnt(1) { f = function, amount = n; }
    void insert(const string &str)
    {
        int p = 0, image;
        for (auto ch : str)
        {
            image = f(ch);
            if (!son[p][image])
            {
                son[p][image] = son.size();
                vector<int> temp(amount);
                son.push_back(temp), cnt.push_back(0), prefix_cnt.push_back(0);
            }
            p = son[p][image];
            prefix_cnt[p]++;
        }
        cnt[p]++;
    }
    int queryAmount(const string &str) { return query(str, true); }
    int queryPrefixAmount(const string &str) { return query(str, false); }
};

int f(char ch)
{
    if (ch >= 'a' && ch <= 'z')
        return ch - 'a';
    else if (ch >= 'A' && ch <= 'Z')
        return ch - 'A' + 26;
    assert(ch >= '0' && ch <= '9');
    return ch - '0' + 52;
}

8.6 根号分块

struct Sqrt {
    int block_size;
    vector<int> nums;
    vector<long long> blocks;
    Sqrt(int sqrtn, vector<int> &arr) : block_size(sqrtn), blocks(sqrtn, 0) {
        nums = arr;
        for (int i = 0; i < nums.size(); i++) { blocks[i / block_size] += nums[i]; }
    }

    /** O(1) update to set nums[x] to v */
    void update(int x, int v) {
        blocks[x / block_size] -= nums[x];
        nums[x] = v;
        blocks[x / block_size] += nums[x];
    }

    /** O(sqrt(n)) query for sum of [0, r) */
    long long query(int r) {
        long long res = 0;
        for (int i = 0; i < r / block_size; i++) { res += blocks[i]; }
        for (int i = (r / block_size) * block_size; i < r; i++) { res += nums[i]; }
        return res;
    }

    /** O(sqrt(n)) query for sum of [l, r) */
    long long query(int l, int r) { return query(r) - query(l - 1); }
};

8.7 莫队

如果修改可以做到 \(O(1)\) 并且可以离线查询的话可以考虑,复杂度 \(O(n \sqrt n)\)

// 询问区间内有多少个不同的数
int cmp(query a, query b) {
    return (belong[a.l] ^ belong[b.l]) ? belong[a.l] < belong[b.l] : 
    ((belong[a.l] & 1) ? a.r < b.r : a.r > b.r);
}
void add(int pos) {
    if(!cnt[aa[pos]]) ++now;
    ++cnt[aa[pos]];
}
void del(int pos) {
    --cnt[aa[pos]];
    if(!cnt[aa[pos]]) --now;
}
sort(q + 1, q + m + 1, cmp);
int l = 1, r = 0;
for(int i = 1; i <= q; ++i) {//对于每次询问
        int ql, qr;
        scanf("%d%d", &ql, &qr);//输入询问的区间
        while(l < ql) del(l++);//如左指针在查询区间左方,左指针向右移直到与查询区间左端点重合
        while(l > ql) add(--l);//如左指针在查询区间左端点右方,左指针左移
        while(r < qr) add(++r);//右指针在查询区间右端点左方,右指针右移
        while(r > qr) del(r--);//否则左移
        printf("%d\n", now);//输出统计结果
    }
}

8.8 笛卡尔树

这是一种键满足平衡二叉树性质,值满足堆性质的二叉树,用于某些序列上的计数问题。

int stk[N], tp;
int p[N], ls[N], rs[N], n;
void build_tree() { // 小根堆
    for (int i = 1; i <= n; i++) {
        int j = 0;
        while (tp and p[stk[tp]] > p[i]) j = stk[tp--];
        ls[i] = j;
        if (tp) rs[stk[tp]] = i;
        stk[++tp] = i;
    }
}

8.9 可撤销并查集 (by Jiangly)

放弃路径压缩,保留按秩合并。find复杂度为 \(O(\log n)\),回溯revert复杂度为 \(O(k)\)。用于线段树分治。

struct DSU {O(k)
    vector<int> siz;
    vector<int> f;
    vector<array<int, 2>> his;

    DSU(int n) : siz(n + 1, 1), f(n + 1) {
        iota(f.begin(), f.end(), 0);
    }

    int find(int x) {
        while (f[x] != x) {
            x = f[x];
        }
        return x;
    }

    bool merge(int x, int y) {
        x = find(x);
        y = find(y);
        if (x == y) {
            return false;
        }
        if (siz[x] < siz[y]) {
            swap(x, y);
        }
        his.push_back({x, y});
        siz[x] += siz[y];
        f[y] = x;
        return true;
    }

    int time() {
        return his.size();
    }

    void revert(int tm) {
        while (his.size() > tm) {
            auto [x, y] = his.back();
            his.pop_back();
            f[y] = y;
            siz[x] -= siz[y];
        }
    }
};

8.10 李超线段树

插入线段,查询某一点处的最小值。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const ll INF_LL = (ll)4e18;     // 用作“正无穷”,最小值模板使用正无穷作为占位
const ll X_L = -1000000000LL;   // 查询 x 的左端点(根据题意调整)
const ll X_R =  1000000000LL;   // 查询 x 的右端点(根据题意调整)

// -----------------
// 线: y = m*x + b
// -----------------
struct Line {
    ll m, b;
    // 默认构造使用 b = +INF(最小值模板)
    Line(ll _m = 0, ll _b = INF_LL) : m(_m), b(_b) {}

    // 评估函数,用 __int128 防止溢出
    ll eval(ll x) const {
        __int128 t = (__int128)m * x + (__int128)b;
        if (t > (__int128)INF_LL) return INF_LL;
        if (t < -(__int128)INF_LL) return -INF_LL;
        return (ll)t;
    }
};

/* =============================
   节点结构与根指针(动态节点)
   ============================= */
struct Node {
    Line ln;
    Node *l = nullptr, *r = nullptr;
    Node(const Line& _ln) : ln(_ln), l(nullptr), r(nullptr) {}
};

Node* root = nullptr;

// -----------------
// 在区间 [l,r] 的子树 node 中插入一条线(全域/段插入会调用它)
// 最小值版:比较使用 '<'
// -----------------
void add_line(Line nw, Node*& node, ll l = X_L, ll r = X_R){
    if (!node){
        node = new Node(nw);
        return;
    }
    ll mid = (l + r) >> 1;
    // ---------- 如果要改成最大值版:把下面两处 '<' 改为 '>' ----------
    bool lef = nw.eval(l) < node->ln.eval(l);    // <-- change '<' -> '>' for max
    bool m   = nw.eval(mid) < node->ln.eval(mid); // <-- change '<' -> '>' for max
    // --------------------------------------------------------------------
    if (m) swap(nw, node->ln);
    if (l == r) return;
    if (lef != m) add_line(nw, node->l, l, mid);
    else add_line(nw, node->r, mid + 1, r);
}

// wrapper:全局插线
inline void add_line(Line ln){
    add_line(ln, root, X_L, X_R);
}

// 在区间 [Lq,Rq] 上插入线段(线只在该区间生效)
void add_segment(Line nw, Node*& node, ll Lq, ll Rq, ll l = X_L, ll r = X_R){
    if (Rq < l || r < Lq) return;
    if (Lq <= l && r <= Rq){
        add_line(nw, node, l, r);
        return;
    }
    ll mid = (l + r) >> 1;
    // 占位节点:最小值版使用默认 Line(0, +INF)
    if (!node) node = new Node(Line(0, INF_LL)); // <-- 如果改为最大值版,这里要把 INF_LL 换成 NEG_INF(见下)
    add_segment(nw, node->l, Lq, Rq, l, mid);
    add_segment(nw, node->r, Lq, Rq, mid + 1, r);
}
inline void add_segment(Line nw, ll Lq, ll Rq){
    add_segment(nw, root, Lq, Rq, X_L, X_R);
}

// 查询点 x 的最小值(若节点为空返回 INF)
ll query(ll x, Node* node, ll l = X_L, ll r = X_R){
    if (!node) return INF_LL; // <-- 如果改为最大值版,这里要返回 NEG_INF
    ll res = node->ln.eval(x);
    if (l == r) return res;
    ll mid = (l + r) >> 1;
    if (x <= mid) {
        // ---------- 改为最大值时把 min -> max ----------
        return min(res, query(x, node->l, l, mid)); // <-- change to max(...) for max-version
    } else {
        return min(res, query(x, node->r, mid + 1, r)); // <-- change to max(...) for max-version
    }
}
inline ll query(ll x){
    return query(x, root, X_L, X_R);
}

// 释放树(若需要)
void clear_tree(Node* node){
    if (!node) return;
    clear_tree(node->l);
    clear_tree(node->r);
    delete node;
}

// -----------------
// 示例 main(演示用法)
// -----------------
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    // 示例:插入三条直线(最小值查询)
    add_line(Line(2, 3));    // y = 2x + 3
    add_line(Line(-1, 10));  // y = -x + 10
    add_line(Line(0, 5));    // y = 5

    cout << "query(0) = " << query(0) << "\n";   // 期望 min(3,10,5) = 3
    cout << "query(2) = " << query(2) << "\n";   // 期望 min(7,8,5) = 5
    cout << "query(10) = " << query(10) << "\n"; // 期望 min(23,0,5) = 0

    // 在区间 [1, 5] 插入一条只在此区间生效的线
    add_segment(Line(-3, 50), 1, 5); // y = -3x + 50 在 [1,5] 有效
    cout << "query(3) = " << query(3) << "\n";

    // 程序结束前释放内存(比赛中一般不需要)
    clear_tree(root);
    root = nullptr;
    return 0;
}

8.11 平衡树 fhq-Treap

8.11.1 按值分裂

普通平衡树,维护集合。

#include <bits/stdc++.h>
using namespace std;
mt19937 rnd(chrono::high_resolution_clock::now().time_since_epoch().count());
struct Treap {
    static const int N = 2000000 + 5;
    int ls[N], rs[N], key[N], val[N], siz[N], root = 0, tot = 0, T1, T2, T3;
    int newNode(int v) {
        int u = ++tot;
        ls[u] = rs[u] = 0;
        key[u] = rnd();
        val[u] = v;
        siz[u] = 1;
        return u;
    }
    void push_up(int u) { siz[u] = siz[ls[u]] + siz[rs[u]] + 1; }
    int merge(int x, int y) {
        if (!x or !y) return x + y;
        if (key[x] > key[y]) {
            rs[x] = merge(rs[x], y);
            push_up(x);
            return x;
        } else {
            ls[y] = merge(x, ls[y]);
            push_up(y);
            return y;
        }
    }
    void split(int u, int v, int &x, int &y) {
        if (!u) {
            x = y = 0;
            return;
        }
        if (val[u] <= v) {
            x = u;
            split(rs[u], v, rs[u], y);
        } else {
            y = u;
            split(ls[u], v, x, ls[u]);
        }
        push_up(u);
    }
    void insert(int v) {
        split(root, v, T1, T2);
        root = merge(merge(T1, newNode(v)), T2);
    }
    void erase(int v) {
        split(root, v - 1, T1, T2);
        split(T2, v, T2, T3);
        root = merge(merge(T1, merge(ls[T2], rs[T2])), T3);
    }
    int rank(int v) {
        split(root, v - 1, T1, T2);
        int r = siz[T1] + 1;
        root = merge(T1, T2);
        return r;
    }
    int kth(int k) {
        int u = root;
        while (u) {
            int s = siz[ls[u]] + 1;
            if (k == s) break;
            if (k < s)
                u = ls[u];
            else
                k -= s, u = rs[u];
        }
        return val[u];
    }
    int pre(int v) {
        int u = root, best = INT_MIN;
        while (u) {
            if (val[u] < v)
                best = max(best, val[u]), u = rs[u];
            else
                u = ls[u];
        }
        return best;
    }
    int nex(int v) {
        int u = root, best = INT_MAX;
        while (u) {
            if (val[u] > v)
                best = min(best, val[u]), u = ls[u];
            else
                u = rs[u];
        }
        return best;
    }
};

8.11.2 按排名分裂

文艺平衡树,维护序列,这里给出的是实现序列的区间翻转操作。

#include <bits/stdc++.h>
using namespace std;
constexpr int N = 1e5 + 5;
mt19937 rnd(time(nullptr));
struct Treap {
    int ls[N], rs[N], key[N], val[N], siz[N];
    int root = 0, tot = 0, T1, T2, T3;
    bool tag[N];
    int node(int v) {
        int u = ++tot;
        ls[u] = rs[u] = 0;
        key[u] = rnd();
        val[u] = v;
        siz[u] = 1;
        return u;
    }
    void pushup(int u) {
        siz[u] = siz[ls[u]] + siz[rs[u]] + 1;
    }
    void pushdown(int u) {
        if (tag[u]) {
            swap(ls[u], rs[u]);
            if (ls[u]) tag[ls[u]] ^= 1;
            if (rs[u]) tag[rs[u]] ^= 1;
            tag[u] = false;
        }
    }
    void split(int u, int k, int &x, int &y) {
        if (!u) { x = y = 0; return;}
        pushdown(u);
        if (siz[ls[u]] >= k) {
            y = u;
            split(ls[u], k, x, ls[u]);
        } else {
            x = u;
            split(rs[u], k - siz[ls[u]] - 1, rs[u], y);
        }
        pushup(u);
    }
    int merge(int x, int y) {
        if (!x or !y) return x + y;
        if (key[x] <= key[y]) {
            pushdown(y);
            ls[y] = merge(x, ls[y]);
            pushup(y);
            return y;
        } else {
            pushdown(x);
            rs[x] = merge(rs[x], y);
            pushup(x);
            return x;
        }
    }
    void reverse(int l, int r) {
        split(root, l - 1, T1, T2);
        split(T2, r - l + 1, T2, T3);
        tag[T2] ^= true;
        root = merge(T1, merge(T2, T3));
    }
    void output(int u) {
        if (u == 0) return;
        pushdown(u);
        output(ls[u]);
        cout << val[u] << " ";
        output(rs[u]);
    }
} s;
int main() {
    ios::sync_with_stdio(false), cin.tie(nullptr);
    int n, m;
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        s.root = s.merge(s.root, s.node(i));
    }
    while (m--) {
        int l, r; cin >> l >> r;
        s.reverse(l, r);
    }
    s.output(s.root);
}