Source code for xuance.common.segtree_tool

import operator


[docs] class SegmentTree(object): """ A data structure for efficient range queries and point updates using a binary tree representation. Attributes: _capacity (int): The number of elements in the tree, must be a power of 2. _value (list): Internal array to store the tree nodes. _operation (Callable): A binary operation (e.g., addition, min, max) for range queries. _neutral_element (Any): The neutral element for the operation (e.g., 0 for addition, infinity for min). Methods: __init__(capacity, operation, neutral_element): Initializes the segment tree with a specified capacity, operation, and neutral element. reduce(start=0, end=None): Computes the result of the operation over a range [start, end). __setitem__(idx, val): Updates the value at a specific index and propagates changes. __getitem__(idx): Retrieves the value at a specific index. """
[docs] def __init__(self, capacity, operation, neutral_element): """ Initialize a SegmentTree. Args: capacity (int): Number of elements in the tree, must be a power of 2. operation (Callable): Binary operation (e.g., lambda x, y: x + y) for combining elements. neutral_element (Any): Neutral element for the operation (e.g., 0 for addition, float('inf') for min). Raises: AssertionError: If capacity is not positive or not a power of 2. """ assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." self._capacity = capacity self._value = [neutral_element for _ in range(2 * capacity)] self._operation = operation
def _reduce_helper(self, start, end, node, node_start, node_end): """ Recursively computes the result of the operation over a range. Args: start (int): Start of the query range (inclusive). end (int): End of the query range (inclusive). node (int): Current node index in the tree. node_start (int): Start of the range represented by the current node. node_end (int): End of the range represented by the current node. Returns: Any: The result of the operation over the specified range. """ if start == node_start and end == node_end: return self._value[node] mid = (node_start + node_end) // 2 if end <= mid: return self._reduce_helper(start, end, 2 * node, node_start, mid) else: if mid + 1 <= start: return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) else: return self._operation( self._reduce_helper(start, mid, 2 * node, node_start, mid), self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) )
[docs] def reduce(self, start=0, end=None): """ Computes the result of the operation over a range [start, end). Args: start (int, optional): Start of the range (default is 0). end (int, optional): End of the range (default is the tree's capacity). Returns: Any: The result of the operation over the specified range. """ if end is None: end = self._capacity if end < 0: end += self._capacity end -= 1 return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
[docs] def __setitem__(self, idx, val): """ Updates the value at a specific index and propagates the changes. Args: idx (int): Index to update. val (Any): New value to set. """ # index of the leaf idx += self._capacity self._value[idx] = val idx //= 2 while idx >= 1: self._value[idx] = self._operation( self._value[2 * idx], self._value[2 * idx + 1] ) idx //= 2
[docs] def __getitem__(self, idx): """ Retrieves the value at a specific index. Args: idx (int): Index to query. Returns: Any: The value at the specified index. Raises: AssertionError: If the index is out of range. """ assert 0 <= idx < self._capacity return self._value[self._capacity + idx]
[docs] class SumSegmentTree(SegmentTree): """ A specialized implementation of a Segment Tree for summation queries and prefix-sum searches. Attributes: _capacity (int): The size of the underlying array, must be a power of 2. _value (list): The tree representation of the segment tree, storing intermediate sums. _operation (callable): The operation to be performed (addition in this case). _neutral_element (float): The neutral element for the operation (0.0 for addition). """ def __init__(self, capacity): """ Initialize a SumSegmentTree with the given capacity. Parameters: capacity (int): The capacity of the segment tree, must be a power of 2. """ super(SumSegmentTree, self).__init__( capacity=capacity, operation=operator.add, neutral_element=0.0 )
[docs] def sum(self, start=0, end=None): """ Compute the sum of elements in the range [start, end). Returns arr[start] + ... + arr[end] Parameters: start (int): The starting index of the range (inclusive). end (int, optional): The ending index of the range (exclusive). Defaults to the full range. Returns: float: The sum of elements in the specified range. """ return super(SumSegmentTree, self).reduce(start, end)
[docs] def find_prefixsum_idx(self, prefixsum): """ Find the index of the smallest prefix sum greater than or equal to the given value. Parameters: prefixsum (float): The target prefix sum. Returns: int: The index corresponding to the target prefix sum. Raises: AssertionError: If prefixsum is not within the valid range [0, total sum]. """ assert 0 <= prefixsum <= self.sum() + 1e-5 idx = 1 while idx < self._capacity: # while non-leaf if self._value[2 * idx] > prefixsum: idx = 2 * idx else: prefixsum -= self._value[2 * idx] idx = 2 * idx + 1 return idx - self._capacity
[docs] class MinSegmentTree(SegmentTree): """ A specialized implementation of a Segment Tree for range minimum queries. Attributes: _capacity (int): The size of the underlying array, must be a power of 2. _value (list): The tree representation of the segment tree, storing intermediate minimums. _operation (callable): The operation to be performed (minimum in this case). _neutral_element (float): The neutral element for the operation (infinity for minimum). """ def __init__(self, capacity): """ Initialize a MinSegmentTree with the given capacity. Parameters: capacity (int): The capacity of the segment tree, must be a power of 2. """ super(MinSegmentTree, self).__init__( capacity=capacity, operation=min, neutral_element=float('inf') )
[docs] def min(self, start=0, end=None): """ Compute the minimum value in the range [start, end). Returns min(arr[start], ..., arr[end]) Parameters: start (int): The starting index of the range (inclusive). end (int, optional): The ending index of the range (exclusive). Defaults to the full range. Returns: float: The minimum value in the specified range. """ return super(MinSegmentTree, self).reduce(start, end)