added BST, but remove not working
This commit is contained in:
@ -1,45 +1,49 @@
|
||||
from typing import Iterator, Optional
|
||||
from trees import Node, traverse
|
||||
from dataclasses import InitVar, dataclass
|
||||
from typing import Optional
|
||||
from trees import Node
|
||||
|
||||
@dataclass
|
||||
class BinarySearchTree[T]:
|
||||
|
||||
def __post_init__(self):
|
||||
self.__root: Optional[Node[T]] = None
|
||||
|
||||
|
||||
def insert(self, value: T):
|
||||
prev = None
|
||||
curr = self.__root
|
||||
while curr:
|
||||
prev = curr
|
||||
if curr.value > value:
|
||||
curr = curr.left
|
||||
else:
|
||||
curr = curr.right
|
||||
if prev is None:
|
||||
self.__root = Node(value, None, None)
|
||||
elif value < prev.value:
|
||||
prev.left = Node(value, None, None)
|
||||
type BinarySearchTree[T] = Optional[Node[T]]
|
||||
|
||||
|
||||
def insert[T](node: BinarySearchTree[T], value: T) -> BinarySearchTree[T]:
|
||||
prev = None
|
||||
curr = node
|
||||
while curr:
|
||||
prev = curr
|
||||
if curr.value > value:
|
||||
curr = curr.left
|
||||
else:
|
||||
prev.right = Node(value, None, None)
|
||||
|
||||
def __delete(prev: Node[T], to_delete: Node[T]) -> Node[T]:
|
||||
# TODO: implement
|
||||
pass
|
||||
|
||||
def remove(self, value: T):
|
||||
prev = None
|
||||
curr = self.__root
|
||||
while curr:
|
||||
prev = curr
|
||||
if curr.value > value:
|
||||
curr = curr.left
|
||||
elif curr.value < value:
|
||||
curr = curr.right
|
||||
else:
|
||||
curr = self.__delete(prev, curr)
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
yield from traverse(self.__root)
|
||||
curr = curr.right
|
||||
if prev is None:
|
||||
return Node(value, None, None)
|
||||
elif value < prev.value:
|
||||
prev.left = Node(value, None, None)
|
||||
else:
|
||||
prev.right = Node(value, None, None)
|
||||
return node
|
||||
|
||||
|
||||
def exists[T](node: Optional[Node[T]], value: T) -> bool:
|
||||
return node and (exists(node.left, value)
|
||||
if value < node.value
|
||||
else (value == node.value or exists(node.right, value)))
|
||||
|
||||
|
||||
def remove[T](node: BinarySearchTree[T], value: T) -> BinarySearchTree[T]:
|
||||
if node is None:
|
||||
return node
|
||||
if value < node.value:
|
||||
node.left = remove(node.left, value)
|
||||
return node
|
||||
if value > node.value:
|
||||
node.right = remove(node.right, value)
|
||||
return node
|
||||
if node.right is None:
|
||||
return node.left
|
||||
if node.left is None:
|
||||
return node.right
|
||||
min_node = node.right
|
||||
while min_node.left:
|
||||
min_node = min_node.left
|
||||
node.value = min_node.value
|
||||
node.right = remove(node.right, min_node.value)
|
||||
return node
|
||||
|
@ -1,24 +1,51 @@
|
||||
from search_trees import BinarySearchTree
|
||||
from typing import Iterator, Optional
|
||||
from search_trees import BinarySearchTree, insert, remove, exists
|
||||
from trees import Node
|
||||
from random import randint as random
|
||||
|
||||
MAX= 1000
|
||||
|
||||
def traverse[T](tree: Optional[Node[T]]) -> Iterator[T]:
|
||||
match tree:
|
||||
case Node(value, left, right):
|
||||
yield from traverse(left)
|
||||
yield value
|
||||
yield from traverse(right)
|
||||
case _:
|
||||
return
|
||||
|
||||
|
||||
MAX = 100
|
||||
MIN = 0
|
||||
LENGTH = 100
|
||||
|
||||
NUMS = {random(MIN, MAX) for _ in range(0, LENGTH)}
|
||||
NUMS_FILTER = {random(MIN, MAX) for _ in range(0, LENGTH // 4)}
|
||||
|
||||
def test_insert():
|
||||
nums = [random(MIN, MAX) for _ in range(0, LENGTH)]
|
||||
bst = BinarySearchTree()
|
||||
for num in nums:
|
||||
bst.insert(num)
|
||||
assert list(iter(bst)) == sorted(nums)
|
||||
|
||||
bst: BinarySearchTree = None
|
||||
for num in NUMS:
|
||||
bst = insert(bst, num)
|
||||
assert list(traverse(bst)) == sorted(NUMS)
|
||||
|
||||
|
||||
def test_remove():
|
||||
nums = {random(MIN, MAX) for _ in range(0, LENGTH)}
|
||||
fil = {random(MIN, MAX) for _ in range(0, LENGTH // 4)}
|
||||
bst = BinarySearchTree()
|
||||
for num in nums:
|
||||
bst.insert(num)
|
||||
for num in fil:
|
||||
bst.remove(num)
|
||||
|
||||
assert list(iter(bst)) == sorted(filter(lambda x: x not in fil, nums))
|
||||
bst: BinarySearchTree = None
|
||||
for num in NUMS:
|
||||
bst = insert(bst, num)
|
||||
assert list(traverse(bst)) == sorted(NUMS)
|
||||
for num in NUMS_FILTER:
|
||||
remove(bst, num)
|
||||
|
||||
assert list(traverse(bst)) == sorted(filter(lambda x: x not in NUMS_FILTER, NUMS))
|
||||
|
||||
|
||||
def test_exists():
|
||||
bst: BinarySearchTree = None
|
||||
for num in NUMS:
|
||||
bst = insert(bst, num)
|
||||
for num in NUMS_FILTER:
|
||||
remove(bst, num)
|
||||
|
||||
assert all(map(lambda x: exists(bst, x),
|
||||
filter(lambda x: x not in NUMS_FILTER, NUMS)))
|
||||
assert all(map(lambda x: not exists(bst, x), NUMS_FILTER))
|
Reference in New Issue
Block a user