線段樹
線段樹是一種高效的數據結構,通常用於解決「單點更新、區間查詢」類型的問題。
線段樹的主要思想是:
- 將數組分割成多個區間,在樹中每個節點代表一個區間。
- 節點的值代表其對應區間內的「信息總和」(比如區間內數值的和)。
- 對於查詢某個區間,可以自上而下搜尋,找到表示該區間的節點。
- 更新某個元素,需要自下而上修改被影響的節點值。
線段樹的優點是:
- 構建線段樹的時間複雜度為 O(N)
- 查詢的時間為 O(logN)
- 更新的時間為 O(logN)
空間複雜度為 O(N)。
線段樹經常用於區間最大值/最小值查詢、區間總和查詢等算法題中。需要預先構建樹結構,然後就可以高效查詢和更新了。
給您舉一個使用線段樹的範例:
假設我們有一個数组 a = [1, 2, 3, 4, 5],要構建一個線段樹,可以這樣建:
[1, 5]
/ \
[1, 3] [4, 5]
/ \ / \
[1, 2] [3, 3] [4, 4] [5, 5]
/ \
[1, 1] [2, 2]
根節點 [1, 5] 表示整個原数组的範圍。 左子節點 [1, 3] 表示左半邊範圍 [1, 3]。 右子節點 [4, 5] 表示右半邊範圍 [4, 5]。 以此類推每個節點表示一個範圍。
現在如果要查詢範圍 [2, 4] 的總和:
從根開始,代表的範圍是 [1, 5],不完全包含[2, 4],繼續查詢。 左子節點 [1, 3] 也不包含[2, 4],排除。 右子節點 [4, 5] 包含了查询範圍,直接返回該節點的值 6(4 + 5) 即可。 這樣就可以快速定位到表示某個範圍的節點,避免全數組遍歷,實現了高效查詢。
範例程式碼
python
class Node:
def __init__(self, start, end):
self.start = start
self.end = end
self.total = 0
self.left = None
self.right = None
# 建樹
def build(arr, node, start, end):
if start == end:
node.total = arr[start]
return
mid = (start + end) // 2
node.left = Node(start, mid)
node.right = Node(mid + 1, end)
build(arr, node.left, start, mid)
build(arr, node.right, mid + 1, end)
node.total = node.left.total + node.right.total
# 修改值
def update(node, index, value):
if node.start == node.end:
node.total = value
return
mid = (node.start + node.end) // 2
if index <= mid:
update(node.left, index, value)
else:
update(node.right, index, value)
node.total = node.left.total + node.right.total
# 查詢
def query(node, start, end):
if node.start == start and node.end == end:
return node.total
mid = (node.start + node.end) // 2
if end <= mid:
return query(node.left, start, end)
elif start >= mid + 1:
return query(node.right, start, end)
else:
return query(node.left, start, mid) + query(node.right, mid + 1, end)
arr = [1, 2, 3, 4, 5]
root = Node(0, len(arr)-1)
build(arr, root, 0, len(arr)-1)
# 進行查詢操作
query_range = (1, 3) # 查詢範圍
result = query(root, query_range[0], query_range[1])
print('r1',result)
# 進行更新操作
update_index = 3
update_value = 6
update(root, update_index, update_value)
# 再次查詢以驗證
query_range = (1, 3)
result = query(root, query_range[0], query_range[1])
print('r2',result)