51 lines
1.3 KiB
Python
51 lines
1.3 KiB
Python
from typing import Iterator, Optional
|
|
from search_trees import BinarySearchTree, insert, remove, exists
|
|
from trees import Node
|
|
from random import randint as random
|
|
|
|
|
|
def traverse[T](tree: Optional[Node[T]]) -> Iterator[T]:
|
|
match tree:
|
|
case Node(value, left, right):
|
|
yield from traverse(left)
|
|
yield value
|
|
yield from traverse(right)
|
|
case _:
|
|
return
|
|
|
|
|
|
MAX = 100
|
|
MIN = 0
|
|
LENGTH = 100
|
|
|
|
NUMS = {random(MIN, MAX) for _ in range(0, LENGTH)}
|
|
NUMS_FILTER = {random(MIN, MAX) for _ in range(0, LENGTH // 4)}
|
|
|
|
def test_insert():
|
|
bst: BinarySearchTree = None
|
|
for num in NUMS:
|
|
bst = insert(bst, num)
|
|
assert list(traverse(bst)) == sorted(NUMS)
|
|
|
|
|
|
def test_remove():
|
|
bst: BinarySearchTree = None
|
|
for num in NUMS:
|
|
bst = insert(bst, num)
|
|
assert list(traverse(bst)) == sorted(NUMS)
|
|
for num in NUMS_FILTER:
|
|
remove(bst, num)
|
|
|
|
assert list(traverse(bst)) == sorted(filter(lambda x: x not in NUMS_FILTER, NUMS))
|
|
|
|
|
|
def test_exists():
|
|
bst: BinarySearchTree = None
|
|
for num in NUMS:
|
|
bst = insert(bst, num)
|
|
for num in NUMS_FILTER:
|
|
remove(bst, num)
|
|
|
|
assert all(map(lambda x: exists(bst, x),
|
|
filter(lambda x: x not in NUMS_FILTER, NUMS)))
|
|
assert all(map(lambda x: not exists(bst, x), NUMS_FILTER)) |