IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    kth value in a subarray

    MaskRay发表于 2022-12-27 07:55:54
    love 0

    UNDER CONSTRUCTION.

    本文总结经典的区间第k小值数据结构题。 给定一个长为n的数组。有m个询问:求区间[l,r)中第k小的元素。

    一些方法支持扩展问题:有m个操作,或者修改某个位置上的元素,或者询问区间[l,r)中第k小的元素。

    归并树(merge sort tree)

    用O(n*log(n))时间构建线段树,每个节点存储对应区间的有序数组。 对于一个询问,二分搜索答案ans转化为计数问题:区间[l,r)内小于ans的元素个数是否大于等于k。 对于这个计数问题,把区间[l,r)解构为不超过log(n)个线段树节点。对于每个节点,二分查找这个节点存储的有序数组里小于ans的元素数。

    • static: O(n*log(n)+m*log(n)^3), space complexity: O(n*log(n)), not recommended

    若要支持修改元素,把每个节点存储的有序数组改成一棵binary search tree,这种嵌套树形解构俗称树套树。

    描述值域的线段树

    用O(n*log(n))时间构建一棵线段树,每个节点描述一个值域区间,存储出现的元素的位置序列。 对于静态问题,位置序列可以是一个有序数组。若要支持修改元素,位置序列得是线段树或binary search tree。

    • static: O(n*log(n)+m*log(n)^2), space complexity: O(n*log(n)), not recommended

    划分树(range tree with functional cascading)

    这是描述值域的线段树的一种优化。用O(n*log(n))时间构建一个描述值域的线段树,每个节点存储值域区间里按顺序出现的元素数组,和一个辅助数组表示分到左孩子的元素个数。 对于一个询问,可以O(1)知道[l,r)中落在左孩子值域的元素个数,判断要在左孩子或在右孩子找答案。

    • static: O((n+m)*log(n)), space complexity: O(n*log(n))

    整体二分(parallel binary search)

    有多组修改和询问。每个询问会受到时间序之前的修改的影响,询问目标可以二分搜索。 这类算法将二分答案应用到多组修改和询问上。

    在二分答案后,单点修改的影响为commutative monoid,区间询问的目标也是一个commutative monoid。

    • static: O((n+m)*log(n)^2), space complexity: O(n+m)
    • dynamic: O((n+m)*log(n+m)*log(n)), space complexity: O(n+m)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    #include <algorithm>
    #include <cstdio>
    #include <utility>
    using namespace std;

    #define FOR(i, a, b) for (int i = (a); i < (b); i++)
    #define REP(i, n) for (int i = 0; i < (n); i++)

    const int N = 200000, M = 200000;

    namespace {
    int ri() {
    int m = 0, s = 0; unsigned c;
    while ((c = getchar())-'0' >= 10u) m = c == '-';
    for (; c-'0' < 10u; c = getchar()) s = s*10+c-'0';
    return m ? -s : s;
    }

    pair<int, int> a[N];
    int ans[M], fenwick[N], n;
    struct Query { int id, l, r, k; } q[M], qq[M];

    void add(int i, int d) {
    for (; i < n; i |= i+1)
    fenwick[i] += d;
    }

    int get_sum(int i) {
    int sum = 0;
    for (; i; i &= i-1)
    sum += fenwick[i-1];
    return sum;
    }

    void conquer(int ml, int mh, int l, int h) {
    if (ml == mh-1) {
    FOR(i, l, h)
    ans[q[i].id] = a[ml].first;
    return;
    }
    int mm = ml+mh >> 1, nl = 0, nh = h-l;
    FOR(i, ml, mm)
    add(a[i].second, 1);
    FOR(i, l, h) {
    int t = get_sum(q[i].r)-get_sum(q[i].l);
    if (q[i].k <= t)
    qq[nl++] = q[i];
    else
    qq[--nh] = q[i], qq[nh].k -= t;
    }
    FOR(i, ml, mm)
    add(a[i].second, -1);
    copy_n(qq, nl, q+l);
    copy(qq+nh, qq+h-l, q+l+nl);
    if (nl) conquer(ml, mm, l, l+nl);
    if (l+nl < h) conquer(mm, mh, l+nl, h);
    }
    }

    int main() {
    n = ri();
    int m = ri();
    REP(i, n)
    a[i] = {ri(), i};
    REP(i, m)
    q[i] = {i, ri()-1, ri(), ri()};
    sort(a, a+n);
    conquer(0, n, 0, m);
    REP(i, m)
    printf("%d\n", ans[i]);
    }

    可持久化线段树(persistent segment tree)

    用O(n*log(n))时间构建n+1棵描述值域的线段树。每棵线段树表示一个原数组的一个前缀(共n+1个)。在每棵线段树中,每个节点存储一个值域区间里的元素数。 相邻两棵线段树描述的区间只相差一个元素,它们可以共用大部分节点,只有ceil(log(n))个节点有差异。

    • static: O((n+m)*log(n)), space complexity: O(n*log(n))
    • dynamic: O((n+m)*log(n)^2)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    // persistent segment tree
    #include <algorithm>
    #include <cstdio>
    using namespace std;

    #define REP(i, n) for (int i = 0; i < (n); i++)

    int ri() {
    int m = 0, s = 0; unsigned c;
    while ((c = getchar())-'0' >= 10u) m = c == '-';
    for (; c-'0' < 10u; c = getchar()) s = s*10+c-'0';
    return m ? -s : s;
    }

    const int N = 200000, M = 200000, LOG2N = 32-__builtin_clz(N-1);
    int a[N], b[N], roots[N+1], allo;
    struct Segment { int ch[2], cnt; } seg[N*2+M*LOG2N];

    void build(int &t, int l, int r) {
    t = ++allo;
    if (l < r-1) {
    int m = l+r >> 1;
    build(seg[t].ch[0], l, m);
    build(seg[t].ch[1], m, r);
    }
    }

    void add(int *t, int u, int l, int r, int v) {
    while (l < r-1) {
    *t = ++allo;
    seg[*t].cnt = seg[u].cnt+1;
    int m = l+r >> 1, d = v >= m;
    if (d) l = m;
    else r = m;
    seg[*t].ch[d^1] = seg[u].ch[d^1];
    t = &seg[*t].ch[d];
    u = seg[u].ch[d];
    }
    *t = ++allo;
    seg[*t].cnt = seg[u].cnt+1;
    }

    int kth(int t, int u, int l, int r, int k) {
    while (l < r-1) {
    int m = l+r >> 1, lcnt = seg[seg[t].ch[0]].cnt-seg[seg[u].ch[0]].cnt, d = k >= lcnt;
    if (d) l = m, k -= lcnt;
    else r = m;
    t = seg[t].ch[d];
    u = seg[u].ch[d];
    }
    return l;
    }

    int main() {
    int n = ri(), m = ri();
    REP(i, n) {
    a[i] = ri();
    b[i] = a[i];
    }
    sort(b, b+n);
    int nn = unique(b, b+n) - b;
    build(roots[0], 0, nn);
    REP(i, n) {
    int v = lower_bound(b, b+nn, a[i]) - b;
    add(&roots[i+1], roots[i], 0, nn, v);
    }
    while (m--) {
    int l = ri(), r = ri(), k = ri();
    printf("%d\n", b[kth(roots[r], roots[l-1], 0, nn, k-1)]);
    }
    }

    莫涛算法(Mo's algorithm)

    • static: O(n*log(n)+m*sqrt(n)+m*log(m))
    • dynamic (binary search on the value, 二分答案): O(n*log(n)+m*sqrt(n)*log(n)*log(n+m))
    • dynamic (区间 [l,r] 内所有的 x 变成 y, P4119):

    静态情形:维护两个频度数组,一个表示元素x的频度,另一个表示元素区间(如[i,i+block_size))的频度。区间长度加减一时,O(1)修改频度。

    1
    2
    c1[a[i]] += d;
    c2[block[a[i]]] += d;

    询问时O(sqrt(n))扫描频度数组得到答案。

    1
    2
    3
    4
    5
    6
    7
    int x = 0, k = qs[i].k;
    while (c2[x] < k) k -= c2[x++];
    for (int j = x*block_size; ; j++)
    if ((k -= c1[j]) <= 0) {
    qs[i].ans = j;
    break;
    }

    要点在于不要用有序数据结构维护区间内的元素,会不必要增大修改的时间复杂度。

    要支持修改元素,可在每个分块里里维护一个有序数组。 修改时重建有序数组。 询问时二分答案ans。在包含的分块里二分搜索小于ans的元素数。在分块外线性遍历至多2*block_size个元素

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    // Mo's algorithm
    #include <algorithm>
    #include <cmath>
    #include <cstdio>
    using namespace std;

    #define REP(i, n) for (int i = 0; i < (n); i++)

    int ri() {
    int m = 0, s = 0; unsigned c;
    while ((c = getchar())-'0' >= 10u) m = c == '-';
    for (; c-'0' < 10u; c = getchar()) s = s*10+c-'0';
    return m ? -s : s;
    }

    const int N = 200000, M = 200000;
    int a[N], b[N], block[N], c1[N], c2[N], block_size;
    struct Query {
    int l, r, k, id, ans;
    bool operator<(const Query &o) const {
    int i = block[l], j = block[o.l];
    if (i != j) return i < j;
    return i & 1 ? r < o.r : r > o.r;
    }
    } qs[M];

    static void add(int i, int d) {
    c1[a[i]] += d;
    c2[block[a[i]]] += d;
    }

    int main() {
    int n = ri(), m = ri();
    REP(i, n) {
    a[i] = ri();
    b[i] = a[i];
    }
    sort(b, b+n);
    int nn = unique(b, b+n) - b;
    REP(i, n)
    a[i] = lower_bound(b, b+nn, a[i]) - b;
    REP(i, m) {
    qs[i].l = ri()-1;
    qs[i].r = ri();
    qs[i].k = ri();
    qs[i].id = i;
    }
    block_size = sqrt(n);
    REP(i, n)
    block[i] = i/block_size;
    sort(qs, qs+m);

    int l = 0, r = 0;
    REP(i, m) {
    while (qs[i].l < l) add(--l, 1);
    while (r < qs[i].r) add(r++, 1);
    while (l < qs[i].l) add(l++, -1);
    while (qs[i].r < r) add(--r, -1);
    int x = 0, k = qs[i].k;
    while (c2[x] < k) k -= c2[x++];
    for (int j = x*block_size; ; j++)
    if ((k -= c1[j]) <= 0) {
    qs[i].ans = j;
    break;
    }
    }
    REP(i, m)
    a[qs[i].id] = b[qs[i].ans];
    REP(i, m)
    printf("%d\n", a[i]);
    }


沪ICP备19023445号-2号
友情链接