from typing import Callable, Any, TypeVar, Iterable, Sized, Generic, Optional

from typing_extensions import Self

__all__ = [
	'LineWriter',
	'TreePrinter',
	'print_tree',
]


_T = TypeVar('_T')
LineWriter = Callable[[str], Any]
_ChildrenGetter = Callable[[_T], Iterable[_T]]
_NameGetter = Callable[[_T], str]


class TreePrinter(Generic[_T]):
	def __init__(self):
		self.__line_writer: LineWriter = print
		self.__children_getter: Optional[_ChildrenGetter] = None
		self.__name_getter: Optional[_NameGetter] = None
		self.__use_tab = False
		self.__use_ascii = False

	def getters(self, children_getter: _ChildrenGetter, name_getter: _NameGetter) -> Self:
		self.__children_getter = children_getter
		self.__name_getter = name_getter
		return self

	def writer(self, line_writer: LineWriter) -> Self:
		self.__line_writer = line_writer
		return self

	def tab(self) -> Self:
		self.__use_tab = True
		return self

	def ascii(self) -> Self:
		self.__use_ascii = True
		return self

	def print(self, root: _T) -> Self:
		def is_root(node: _T) -> bool:
			return node == root

		def get_item_line(node: _T, is_last: bool) -> str:
			if is_root(node):
				return ''
			if self.__use_ascii:
				return '`-- ' if is_last else '+-- '
			else:
				return '└── ' if is_last else '├── '

		def get_parent_line(node: _T, is_last: bool) -> str:
			if is_root(node):
				return ''
			if is_last:
				base = '' if self.__use_tab else ' '
			else:
				base = '|' if self.__use_ascii else '│'
			padding = '\t' if self.__use_tab else '   '
			return base + padding

		def do_print(node: _T, prefix: str, is_last: bool):
			if self.__children_getter is None:
				raise AssertionError('children_getter not assigned')
			if self.__name_getter is None:
				raise AssertionError('name_getter not assigned')

			line = self.__name_getter(node)
			line = get_item_line(node, is_last) + line
			self.__line_writer(prefix + line)

			children = self.__children_getter(node)
			if not isinstance(children, Sized):
				children = tuple(children)
			for i, child in enumerate(children):
				do_print(child, prefix + get_parent_line(node, is_last), i == len(children) - 1)

		do_print(root, '', False)
		return self


def print_tree(root: _T, children_getter: _ChildrenGetter, name_getter: _NameGetter, line_writer: LineWriter):
	tp: TreePrinter[_T] = TreePrinter()
	tp.writer(line_writer)
	tp.getters(children_getter, name_getter)
	tp.print(root)
