from typing import Iterator, Optional
from trees import Node, traverse
from dataclasses import InitVar, dataclass

@dataclass
class BinarySearchTree[T]:
    
    def __post_init__(self):
        self.__root: Optional[Node[T]] = None
        
        
    def insert(self, value: T):
        prev = None
        curr = self.__root
        while curr:
            prev = curr
            if curr.value > value:
                curr = curr.left
            else:
                curr = curr.right
        if prev is None:
            self.__root = Node(value, None, None)
        elif value < prev.value:
            prev.left = Node(value, None, None)
        else:
            prev.right = Node(value, None, None)
            
    def __delete(prev: Node[T], to_delete: Node[T]) -> Node[T]:
        # TODO: implement
        pass
    
    def remove(self, value: T):
        prev = None
        curr = self.__root
        while curr:
            prev = curr
            if curr.value > value:
                curr = curr.left 
            elif curr.value < value:
                curr = curr.right
            else:
                curr = self.__delete(prev, curr)
            
    def __iter__(self) -> Iterator[T]:
        yield from traverse(self.__root)