added BST, but remove not working

This commit is contained in:
2024-02-15 04:40:13 +01:00
parent 4c5d27e8d1
commit c7c1f3ea98
4 changed files with 113 additions and 61 deletions

View File

@ -1,45 +1,49 @@
from typing import Iterator, Optional
from trees import Node, traverse
from dataclasses import InitVar, dataclass
from typing import Optional
from trees import Node
@dataclass
class BinarySearchTree[T]:
def __post_init__(self):
self.__root: Optional[Node[T]] = None
def insert(self, value: T):
prev = None
curr = self.__root
while curr:
prev = curr
if curr.value > value:
curr = curr.left
else:
curr = curr.right
if prev is None:
self.__root = Node(value, None, None)
elif value < prev.value:
prev.left = Node(value, None, None)
type BinarySearchTree[T] = Optional[Node[T]]
def insert[T](node: BinarySearchTree[T], value: T) -> BinarySearchTree[T]:
prev = None
curr = node
while curr:
prev = curr
if curr.value > value:
curr = curr.left
else:
prev.right = Node(value, None, None)
def __delete(prev: Node[T], to_delete: Node[T]) -> Node[T]:
# TODO: implement
pass
def remove(self, value: T):
prev = None
curr = self.__root
while curr:
prev = curr
if curr.value > value:
curr = curr.left
elif curr.value < value:
curr = curr.right
else:
curr = self.__delete(prev, curr)
def __iter__(self) -> Iterator[T]:
yield from traverse(self.__root)
curr = curr.right
if prev is None:
return Node(value, None, None)
elif value < prev.value:
prev.left = Node(value, None, None)
else:
prev.right = Node(value, None, None)
return node
def exists[T](node: Optional[Node[T]], value: T) -> bool:
return node and (exists(node.left, value)
if value < node.value
else (value == node.value or exists(node.right, value)))
def remove[T](node: BinarySearchTree[T], value: T) -> BinarySearchTree[T]:
if node is None:
return node
if value < node.value:
node.left = remove(node.left, value)
return node
if value > node.value:
node.right = remove(node.right, value)
return node
if node.right is None:
return node.left
if node.left is None:
return node.right
min_node = node.right
while min_node.left:
min_node = min_node.left
node.value = min_node.value
node.right = remove(node.right, min_node.value)
return node

View File

@ -1,24 +1,51 @@
from search_trees import BinarySearchTree
from typing import Iterator, Optional
from search_trees import BinarySearchTree, insert, remove, exists
from trees import Node
from random import randint as random
MAX= 1000
def traverse[T](tree: Optional[Node[T]]) -> Iterator[T]:
match tree:
case Node(value, left, right):
yield from traverse(left)
yield value
yield from traverse(right)
case _:
return
MAX = 100
MIN = 0
LENGTH = 100
NUMS = {random(MIN, MAX) for _ in range(0, LENGTH)}
NUMS_FILTER = {random(MIN, MAX) for _ in range(0, LENGTH // 4)}
def test_insert():
nums = [random(MIN, MAX) for _ in range(0, LENGTH)]
bst = BinarySearchTree()
for num in nums:
bst.insert(num)
assert list(iter(bst)) == sorted(nums)
bst: BinarySearchTree = None
for num in NUMS:
bst = insert(bst, num)
assert list(traverse(bst)) == sorted(NUMS)
def test_remove():
nums = {random(MIN, MAX) for _ in range(0, LENGTH)}
fil = {random(MIN, MAX) for _ in range(0, LENGTH // 4)}
bst = BinarySearchTree()
for num in nums:
bst.insert(num)
for num in fil:
bst.remove(num)
assert list(iter(bst)) == sorted(filter(lambda x: x not in fil, nums))
bst: BinarySearchTree = None
for num in NUMS:
bst = insert(bst, num)
assert list(traverse(bst)) == sorted(NUMS)
for num in NUMS_FILTER:
remove(bst, num)
assert list(traverse(bst)) == sorted(filter(lambda x: x not in NUMS_FILTER, NUMS))
def test_exists():
bst: BinarySearchTree = None
for num in NUMS:
bst = insert(bst, num)
for num in NUMS_FILTER:
remove(bst, num)
assert all(map(lambda x: exists(bst, x),
filter(lambda x: x not in NUMS_FILTER, NUMS)))
assert all(map(lambda x: not exists(bst, x), NUMS_FILTER))