Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lazy segment tree #539

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions pydatastructs/miscellaneous_data_structures/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def __new__(cls, array, func, data_structure='segment_tree', **kwargs):

@classmethod
def methods(cls):
return ['query', 'update']
return ['query', 'update', 'update_range']

def query(start, end):
"""
Expand Down Expand Up @@ -289,6 +289,21 @@ def update(self, index, value):
raise NotImplementedError(
"This is an abstract method.")

def update_range(self, start, end, value):
"""
Method to update [start, end] with a new value.

Parameters
==========

index: int
The index to be update.
value: int
The new value.
"""
raise NotImplementedError(
"This is an abstract method.")

class RangeQueryDynamicArray(RangeQueryDynamic):

__slots__ = ["range_query_static"]
Expand All @@ -312,24 +327,28 @@ def update(self, index, value):

class RangeQueryDynamicSegmentTree(RangeQueryDynamic):

__slots__ = ["segment_tree", "bounds"]
__slots__ = ["segment_tree", "bounds", "is_lazy"]

def __new__(cls, array, func, **kwargs):
raise_if_backend_is_not_python(
cls, kwargs.pop('backend', Backend.PYTHON))
obj = object.__new__(cls)
obj.segment_tree = ArraySegmentTree(array, func, dimensions=1)
is_lazy = kwargs.pop('lazy', False)
obj.segment_tree = ArraySegmentTree(array, func, dimensions=1, is_lazy=is_lazy, **kwargs)
obj.segment_tree.build()
obj.bounds = (0, len(array))
return obj

@classmethod
def methods(cls):
return ['query', 'update']
return ['query', 'update', 'update_range']

def query(self, start, end):
_check_range_query_inputs((start, end + 1), self.bounds)
return self.segment_tree.query(start, end)

def update(self, index, value):
self.segment_tree.update(index, value)
self.segment_tree[index] = value

def update_range(self, start, end, value):
self.segment_tree.update_range(start, end, value)
182 changes: 177 additions & 5 deletions pydatastructs/miscellaneous_data_structures/segment_tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .stack import Stack
from pydatastructs import OneDimensionalArray
from pydatastructs.utils.misc_util import (TreeNode,
Backend, raise_if_backend_is_not_python)

Expand Down Expand Up @@ -51,7 +52,7 @@ class ArraySegmentTree(object):
1
>>> s_t.query(1, 3)
2
>>> s_t.update(2, -1)
>>> s_t[2] = -1
>>> s_t.query(1, 3)
-1
>>> arr = OneDimensionalArray(int, [1, 2])
Expand All @@ -65,10 +66,18 @@ class ArraySegmentTree(object):

.. [1] https://cp-algorithms.com/data_structures/segment_tree.html
"""
def __new__(cls, array, func, **kwargs):
def __new__(cls, array, func, is_lazy=False, **kwargs):

dimensions = kwargs.pop("dimensions", 1)
if dimensions == 1:
if is_lazy:
if kwargs.get('neutral_element') is not None:
neutral_element = kwargs.pop('neutral_element')
return OneDimensionalArraySegmentTreeLazy(array, neutral_element,
func, **kwargs)
else:
raise ValueError("ArraySegmentTree with lazy implementation should have a "
"neutral_element argument")
return OneDimensionalArraySegmentTree(array, func, **kwargs)
else:
raise NotImplementedError("ArraySegmentTree do not support "
Expand All @@ -82,7 +91,7 @@ def build(self):
raise NotImplementedError(
"This is an abstract method.")

def update(self, index, value):
def __setitem__(self, index, value):
"""
Updates the value at given index.
"""
Expand All @@ -98,6 +107,15 @@ def query(self, start, end):
raise NotImplementedError(
"This is an abstract method.")

def update_range(self, start, end, value):
"""
Updates [start, end] range according
to the function provided while constructing
`ArraySegmentTree` object.
"""
raise NotImplementedError(
"This is an abstract method.")

def __str__(self):
recursion_stack = Stack(implementation='linked_list')
recursion_stack.push(self._root)
Expand Down Expand Up @@ -131,7 +149,7 @@ def __new__(cls, array, func, **kwargs):

@classmethod
def methods(self):
return ['__new__', 'build', 'update',
return ['__new__', 'build',
'query']

@property
Expand Down Expand Up @@ -171,7 +189,7 @@ def build(self):
node.right = right_node
recursion_stack.push(right_node)

def update(self, index, value):
def __setitem__(self, index, value):
if not self.is_ready:
raise ValueError("{} tree is not built yet. ".format(self) +
"Call .build method to prepare the segment tree.")
Expand Down Expand Up @@ -223,3 +241,157 @@ def query(self, start, end):

return self._query(self._root, 0, len(self._array) - 1,
start, end)


class OneDimensionalArraySegmentTreeLazy(OneDimensionalArraySegmentTree):

__slots__ = ["_func", "_array", "_lazy_node",
"_neutral_element", "_root", "_backend"]

def __new__(cls, array, neutral_element, func, **kwargs):
backend = kwargs.get('backend', Backend.PYTHON)
raise_if_backend_is_not_python(cls, backend)

obj = object.__new__(cls)
obj._func = func
obj._array = array
obj._lazy_node = None
obj._neutral_element = neutral_element
obj._root = None
obj._backend = backend
return obj

@classmethod
def methods(self):
return ['__new__', 'build', 'update_range',
'query']

@property
def is_ready(self):
return self._root is not None

def build(self):
if self.is_ready:
return

recursion_stack = Stack(implementation='linked_list')
node = TreeNode((0, len(self._array) - 1), None, backend=self._backend)
lazy_node = TreeNode((0, len(self._array) - 1), None, backend=self._backend)
node.is_root = True
lazy_node.is_root = True
self._root = node
self._lazy_node = lazy_node
recursion_stack.push((node, lazy_node))

while not recursion_stack.is_empty:
node, lazy_node = recursion_stack.peek.key
start, end = node.key
if start == end:
node.data = self._array[start]
lazy_node.data = self._neutral_element
recursion_stack.pop()
continue

if (node.left is not None and
node.right is not None):
recursion_stack.pop()
node.data = self._func((node.left.data, node.right.data))
lazy_node.data = self._neutral_element
else:
mid = (start + end) // 2
if node.left is None:
left_node = TreeNode((start, mid), None)
node.left = left_node
lazy_left_node = TreeNode((start, mid), None)
lazy_node.left = lazy_left_node
recursion_stack.push((left_node, lazy_left_node))
if node.right is None:
right_node = TreeNode((mid + 1, end), None)
node.right = right_node
lazy_rig_node = TreeNode((mid + 1, end), None)
lazy_node.right = lazy_rig_node
recursion_stack.push((right_node, lazy_rig_node))

def _update_range(self, node, lazy_node, start, end, l, r, value):
if not self.is_ready:
raise ValueError("{} tree is not built yet. ".format(self) +
"Call .build method to prepare the segment tree.")
if lazy_node.data != self._neutral_element:
node.data = self._func((node.data, lazy_node.data))
if start != end:
lazy_node.left.data = self._func((lazy_node.left.data, lazy_node.data))
lazy_node.right.data = self._func((lazy_node.right.data, lazy_node.data))
lazy_node.data = self._neutral_element
if r < start or end < l:
return

if l <= start and end <= r:
node.data = self._func((node.data, value))
if start != end:
lazy_node.left.data = self._func((lazy_node.left.data, value))
lazy_node.right.data = self._func((lazy_node.right.data, value))
return
mid = (start + end) // 2
self._update_range(node.left, lazy_node.left, start, mid, l, r, value)
self._update_range(node.right, lazy_node.right, mid + 1, end, l, r, value)
node.data = self._func((node.left.data, node.right.data))

def update_range(self, start, end, value):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def update_range(self, start, end, value):
def update(self, start, end, value):

if not self.is_ready:
raise ValueError("{} tree is not built yet. ".format(self) +
"Call .build method to prepare the segment tree.")
self._update_range(self._root, self._lazy_node, 0, len(self._array) - 1,
start, end, value)


def _set_single_element(self, node, lazy_node, start, end, index, value):
if lazy_node.data != self._neutral_element:
node.data = self._func((node.data, lazy_node.data))
if start != end:
lazy_node.left.data = self._func((lazy_node.left.data, lazy_node.data))
lazy_node.right.data = self._func((lazy_node.right.data, lazy_node.data))
lazy_node.data = self._neutral_element
if index == start and end == index:
node.data = value
return
mid = (start + end) // 2
if start <= index and index <= mid:
self._set_single_element(node.left, lazy_node.left, start, mid, index, value)
else:
self._set_single_element(node.right, lazy_node.right, mid + 1, end, index, value)
node.data = self._func((node.left.data, node.right.data))


def __setitem__(self, index, value):
if not self.is_ready:
raise ValueError("{} tree is not built yet. ".format(self) +
"Call .build method to prepare the segment tree.")
self._set_single_element(self._root, self._lazy_node, 0, len(self._array) - 1,
index, value)

def _query(self, node, lazy_node, start, end, l, r):
if r < start or end < l:
return None

if lazy_node.data != self._neutral_element:
node.data = self._func((node.data, lazy_node.data))
if start != end:
lazy_node.left.data = self._func((lazy_node.left.data, lazy_node.data))
lazy_node.right.data = self._func((lazy_node.right.data, lazy_node.data))
lazy_node.data = self._neutral_element

if l <= start and end <= r:
return node.data

mid = (start + end) // 2
left_result = self._query(node.left, lazy_node.left, start, mid, l, r)
right_result = self._query(node.right, lazy_node.right, mid + 1, end, l, r)
return self._func((left_result, right_result))

def query(self, start, end):
if not self.is_ready:
raise ValueError("{} tree is not built yet. ".format(self) +
"Call .build method to prepare the segment tree.")

return self._query(self._root, self._lazy_node, 0, len(self._array) - 1,
start, end)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import random, math
from copy import deepcopy

def _test_RangeQueryDynamic_common(func, gen_expected):
def _test_RangeQueryDynamic_common(func, neutral_element, gen_expected,
range_update_possible=True):

array = OneDimensionalArray(int, [])
raises(ValueError, lambda: RangeQueryDynamic(array, func))
Expand All @@ -25,11 +26,13 @@ def _test_RangeQueryDynamic_common(func, gen_expected):
for j in range(i + 1, array_size):
inputs.append((i, j))

data_structures = ["array", "segment_tree"]
for ds in data_structures:
data_structures = {"array":{}, "segment_tree": {},
"segment_tree": {'lazy': True, 'neutral_element': neutral_element}}
for ds, kw in data_structures.items():
range_update_possible = range_update_possible and kw.get('lazy', False)
data = random.sample(range(-2*array_size, 2*array_size), array_size)
array = OneDimensionalArray(int, data)
rmq = RangeQueryDynamic(array, func, data_structure=ds)
rmq = RangeQueryDynamic(array, func, data_structure=ds, **kw)
for input in inputs:
assert rmq.query(input[0], input[1]) == gen_expected(data, input[0], input[1])

Expand All @@ -43,12 +46,26 @@ def _test_RangeQueryDynamic_common(func, gen_expected):
for input in inputs:
assert rmq.query(input[0], input[1]) == gen_expected(data_copy, input[0], input[1])

if range_update_possible:
for _ in range(min(array_size//2, 20)):
start = random.randint(0, array_size - 1)
end = random.randint(0, array_size - 1)
value = random.randint(0, 4 * array_size)
start, end = min(start, end), max(start, end)
for j in range(start, end+1):
data_copy[j] = func((data_copy[j], value))
rmq.update_range(start, end, value)

for input in inputs:
assert rmq.query(input[0], input[1]) == gen_expected(data_copy, input[0], input[1])


def test_RangeQueryDynamic_minimum():

def _gen_minimum_expected(data, i, j):
return min(data[i:j + 1])

_test_RangeQueryDynamic_common(minimum, _gen_minimum_expected)
_test_RangeQueryDynamic_common(minimum, int(1e10), _gen_minimum_expected)

def test_RangeQueryDynamic_greatest_common_divisor():

Expand All @@ -61,11 +78,11 @@ def _gen_gcd_expected(data, i, j):
expected_gcd = math.gcd(expected_gcd, data[idx])
return expected_gcd

_test_RangeQueryDynamic_common(greatest_common_divisor, _gen_gcd_expected)
_test_RangeQueryDynamic_common(greatest_common_divisor, 0, _gen_gcd_expected)

def test_RangeQueryDynamic_summation():

def _gen_summation_expected(data, i, j):
return sum(data[i:j + 1])

return _test_RangeQueryDynamic_common(summation, _gen_summation_expected)
return _test_RangeQueryDynamic_common(summation, 0, _gen_summation_expected, False)
Loading