Skip to content
本站總訪問量
本站訪客數 人次

線段樹

線段樹是一種高效的數據結構,通常用於解決「單點更新、區間查詢」類型的問題。

線段樹的主要思想是:

  1. 將數組分割成多個區間,在樹中每個節點代表一個區間。
  2. 節點的值代表其對應區間內的「信息總和」(比如區間內數值的和)。
  3. 對於查詢某個區間,可以自上而下搜尋,找到表示該區間的節點。
  4. 更新某個元素,需要自下而上修改被影響的節點值。

線段樹的優點是:

  • 構建線段樹的時間複雜度為 O(N)
  • 查詢的時間為 O(logN)
  • 更新的時間為 O(logN)

空間複雜度為 O(N)。

線段樹經常用於區間最大值/最小值查詢、區間總和查詢等算法題中。需要預先構建樹結構,然後就可以高效查詢和更新了。

給您舉一個使用線段樹的範例:

假設我們有一個数组 a = [1, 2, 3, 4, 5],要構建一個線段樹,可以這樣建:

                      [1, 5]  
                   /        \
               [1, 3]        [4, 5]  
            /     \         /     \
        [1, 2]   [3, 3]  [4, 4]   [5, 5]
      /    \
    [1, 1] [2, 2]

根節點 [1, 5] 表示整個原数组的範圍。 左子節點 [1, 3] 表示左半邊範圍 [1, 3]。 右子節點 [4, 5] 表示右半邊範圍 [4, 5]。 以此類推每個節點表示一個範圍。

現在如果要查詢範圍 [2, 4] 的總和:

從根開始,代表的範圍是 [1, 5],不完全包含[2, 4],繼續查詢。 左子節點 [1, 3] 也不包含[2, 4],排除。 右子節點 [4, 5] 包含了查询範圍,直接返回該節點的值 6(4 + 5) 即可。 這樣就可以快速定位到表示某個範圍的節點,避免全數組遍歷,實現了高效查詢。

範例程式碼

python
class Node:
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.total = 0
        self.left = None
        self.right = None


# 建樹
def build(arr, node, start, end):
    if start == end:
        node.total = arr[start]
        return

    mid = (start + end)  // 2
    node.left = Node(start, mid)
    node.right = Node(mid + 1, end)

    build(arr, node.left, start, mid)
    build(arr, node.right, mid + 1, end)

    node.total = node.left.total + node.right.total


# 修改值
def update(node, index, value):
    if node.start == node.end:
        node.total = value
        return

    mid = (node.start + node.end) // 2
    if index <= mid:
        update(node.left, index, value)
    else:
        update(node.right, index, value)

    node.total = node.left.total + node.right.total


# 查詢
def query(node, start, end):
    if node.start == start and node.end == end:
        return node.total

    mid = (node.start + node.end) // 2
    if end <= mid:
        return query(node.left, start, end)
    elif start >= mid + 1:
        return query(node.right, start, end)
    else:
        return query(node.left, start, mid) + query(node.right, mid + 1, end)
    
arr = [1, 2, 3, 4, 5]
root = Node(0, len(arr)-1)
build(arr, root, 0, len(arr)-1)

# 進行查詢操作
query_range = (1, 3) # 查詢範圍
result = query(root, query_range[0], query_range[1])
print('r1',result)

# 進行更新操作
update_index = 3
update_value = 6
update(root, update_index, update_value)

# 再次查詢以驗證
query_range = (1, 3)
result = query(root, query_range[0], query_range[1])
print('r2',result)

Contributors

The avatar of contributor named as lucashsu95 lucashsu95

Changelog