IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    实现一个友好的堆

    smallnest发表于 2024-03-07 14:50:56
    love 0

    在上一篇文章中,我吐槽了Go标准库的堆实现,基于“you can you up, no can no BB”的理论,这篇文章我来实现一个友好的堆。

    我们使用堆的时候,一般希望有Heap这样一个对象,并且能指定它是“小根堆”或者"大根堆"。我们希望这个类型有Push和Pop方法,可以加入一个元素或者弹出(最小的)元素。

    我们期望这个Heap支持泛型的,任何可以比较的类型都可以使用。

    处于简化的考虑,我们实现的Heap不考虑线程安全。如果要保证线程安全,可以使用sync.Mutex来保护Heap的操作。

    我们实现的Heap类型的操作基于标准库的操作,只不过我们封装了一下,让它更加友好。

    我们能够基于既有的一个slice创建Heap,也可以基于一个空的Heap创建一个新的Heap。

    最终我们实现了一个友好的堆,你可以在github上查看它的代码binheap。

    首先定义一个binHeap,这是一个泛型的slice,用来保存堆的元素,这样用户就不用定义这样一个类型了,简化了用户的使用。默认是小根堆。所有元素类型需要满足cmp.Ordered接口,可以进行大小比较。这个接口是标准库中的接口,如果你还不知道cmp包,那么需要刷新刷新Go新的变化了:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    package heap
    import (
    "cmp"
    )
    type binHeap[V cmp.Ordered] []V
    func (h binHeap[V]) Len() int { return len(h) }
    func (h binHeap[V]) Less(i, j int) bool { return h[i] < h[j] }
    func (h binHeap[V]) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
    func (h *binHeap[V]) Push(x V) {
    *h = append(*h, x)
    }
    func (h *binHeap[V]) Pop() V {
    old := *h
    n := len(old)
    x := old[n-1]
    *h = old[0 : n-1]
    return x
    }

    这样我们就可以定义一个BinHeap类型,它是binHeap的封装,它有maxHeap字段,用来表示是小根堆还是大根堆。BinHeap类型有Push和Pop方法,可以加入一个元素或者弹出(最小的)元素。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    type BinHeap[V cmp.Ordered] struct {
    maxHeap bool
    binHeap binHeap[V]
    }
    // NewBinHeap returns a new binary heap.
    func NewBinHeap[V cmp.Ordered](opts ...BinHeapOption[V]) *BinHeap[V] {
    h := &BinHeap[V]{}
    for _, opt := range opts {
    opt(h)
    }
    return h
    }
    // Len returns the number of elements in the heap.
    func (h *BinHeap[V]) Push(x V) {
    h.binHeap.Push(x)
    sift_up[V](&h.binHeap, h.binHeap.Len()-1, h.maxHeap)
    }
    // Push pushes the element x onto the heap.
    func (h *BinHeap[V]) Pop() V {
    n := h.binHeap.Len() - 1
    h.binHeap.Swap(0, n)
    sift_down[V](&h.binHeap, 0, n, h.maxHeap)
    return h.binHeap.Pop()
    }

    另外还附送了两个常用的方法Len和Peek,Len返回堆的大小,Peek返回堆顶元素但是并不会从堆中移除它,在和堆的最小值做比较的时候很有用。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    // Len returns the number of elements in the heap.
    func (h *BinHeap[V]) Len() int {
    return h.binHeap.Len()
    }
    // Peek returns the element at the top of the heap without removing it.
    func (h *BinHeap[V]) Peek() (V, bool) {
    var v V
    if h.Len() == 0 {
    return v, false
    }
    return h.binHeap[0], true
    }

    最后,我们还提供了一个BinHeapOption类型,用来设置BinHeap的属性,比如是小根堆还是大根堆;为了提高性能,如果预先已经知道堆的大小,可以在初始化的时候就进行设置。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    // WithMaxHeap returns a BinHeapOption that configures a binary heap to be a max heap.
    func WithMaxHeap[V cmp.Ordered](h *BinHeap[V]) *BinHeap[V] {
    h.maxHeap = true
    return h
    }
    // WithMinHeap returns a BinHeapOption that configures a binary heap to be a min heap.
    func WithCapacity[V cmp.Ordered](n int) BinHeapOption[V] {
    return func(h *BinHeap[V]) *BinHeap[V] {
    if h.binHeap == nil {
    h.binHeap = make(binHeap[V], 0, n)
    }
    return h
    }
    }

    这样我们就实现了一个友好的堆。

    当然,如果你已经有了一个slice: []V, 想把它转换成堆,并且在这个slice上进行堆操作,那么你可以使用下面的方法:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    // NewBinHeapWithInitial returns a new binary heap with the given initial slice.
    func NewBinHeapWithInitial[V cmp.Ordered](s []V, opts ...BinHeapOption[V]) *BinHeap[V] {
    h := &BinHeap[V]{}
    h.binHeap = binHeap[V](s)
    for _, opt := range opts {
    opt(h)
    }
    n := len(s)
    for i := n/2 - 1; i >= 0; i-- {
    sift_down[V](&h.binHeap, i, n, h.maxHeap)
    }
    return h
    }

    堆的操作sift_down和sift_up的堆和核心操作,也来源子标准库的代码,只不过我把它们改成成泛型的函数了:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    func sift_up[V cmp.Ordered](h *binHeap[V], j int, maxHeap bool) {
    less := h.Less
    if maxHeap {
    less = func(i, j int) bool { return !h.Less(i, j) }
    }
    for {
    i := (j - 1) / 2 // parent
    if i == j || !less(j, i) {
    break
    }
    h.Swap(i, j)
    j = i
    }
    }
    func sift_down[V cmp.Ordered](h *binHeap[V], i0, n int, maxHeap bool) bool {
    less := h.Less
    if maxHeap {
    less = func(i, j int) bool { return !h.Less(i, j) }
    }
    i := i0
    for {
    j1 := 2*i + 1
    if j1 >= n || j1 < 0 { // j1 < 0 after int overflow
    break
    }
    j := j1 // left child
    if j2 := j1 + 1; j2 < n && less(j2, j1) {
    j = j2 // = 2*i + 2 // right child
    }
    if !less(j, i) {
    break
    }
    h.Swap(i, j)
    i = j
    }
    return i > i0
    }


沪ICP备19023445号-2号
友情链接