from walt.browser_use.dom.service import DomService, DOMState, DOMElementNode, SelectorMap, DOMBaseNode, DOMTextNode
from walt.browser_use.custom.dom_element_zoo import DomElementNodeBugFix
from typing import Any, Optional
import logging
from walt.browser_use.utils import time_execution_async
from importlib import resources
import json
import gc
from dataclasses import dataclass

logger = logging.getLogger(__name__)

@dataclass
class ViewportInfo:
	width: int
	height: int
class DomServiceBugFix(DomService):
  def __init__(
    self,
    *args: Any,
    **kwargs: Any,
  ):
    super().__init__(*args, **kwargs)
    self.js_code = resources.read_text('walt.browser_use.custom.js', 'buildDomTree.js')

  @time_execution_async('--get_clickable_elements')
  async def get_clickable_elements(
		self,
		highlight_elements: bool = True,
		focus_element: int = -1,
		viewport_expansion: int = 0,
	) -> DOMState:
    element_tree, selector_map = await self._build_dom_tree(highlight_elements, focus_element, viewport_expansion)
    return DOMState(element_tree=element_tree, selector_map=selector_map)
	

  @time_execution_async('--build_dom_tree')
  async def _build_dom_tree(
		self,
		highlight_elements: bool,
		focus_element: int,
		viewport_expansion: int,
	) -> tuple[DOMElementNode, SelectorMap]:
    if await self.page.evaluate('1+1') != 2:
      raise ValueError('The page cannot evaluate javascript code properly')

		# NOTE: We execute JS code in the browser to extract important DOM information.
		#       The returned hash map contains information about the DOM tree and the
		#       relationship between the DOM elements.
    debug_mode = logger.getEffectiveLevel() == logging.DEBUG
    args = {
			'doHighlightElements': highlight_elements,
			'focusHighlightIndex': focus_element,
			'viewportExpansion': viewport_expansion,
			'debugMode': debug_mode,
		}

    try:
      eval_page = await self.page.evaluate(self.js_code, args)
    except Exception as e:
      logger.error('Error evaluating JavaScript: %s', e)
      raise

		# Only log performance metrics in debug mode
    if debug_mode and 'perfMetrics' in eval_page:
      logger.debug('DOM Tree Building Performance Metrics:\n%s', json.dumps(eval_page['perfMetrics'], indent=2))

    return await self._construct_dom_tree(eval_page)
	
  @time_execution_async('--construct_dom_tree')
  async def _construct_dom_tree(self,eval_page: dict) -> tuple[DOMElementNode, SelectorMap]:
    js_node_map = eval_page['map']
    js_root_id = eval_page['rootId']

    selector_map = {}
    node_map = {}

    for id, node_data in js_node_map.items():
      node, children_ids = self._parse_node(node_data)
      if node is None:
        continue

      node_map[id] = node

      if isinstance(node, DOMElementNode) and node.highlight_index is not None:
        selector_map[node.highlight_index] = node

			# NOTE: We know that we are building the tree bottom up
			#       and all children are already processed.
      if isinstance(node, DOMElementNode):
        for child_id in children_ids:
          if child_id not in node_map:
            continue

          child_node = node_map[child_id]

          child_node.parent = node
          node.children.append(child_node)

    html_to_dict = node_map[str(js_root_id)]

    del node_map
    del js_node_map
    del js_root_id

    gc.collect()

    if html_to_dict is None or not isinstance(html_to_dict, DOMElementNode):
      raise ValueError('Failed to parse HTML to dictionary')

    return html_to_dict, selector_map

  def _parse_node(self,node_data: dict  ) -> tuple[Optional[DOMBaseNode], list[int]]:
    if not node_data:
      return None, []

		# Process text nodes immediately
    if node_data.get('type') == 'TEXT_NODE':
      text_node = DOMTextNode(
        text=node_data['text'],
        is_visible=node_data['isVisible'],
        parent=None,
      )
      return text_node, []

		# Process coordinates if they exist for element nodes

    viewport_info = None
    inner_text = None
    if 'viewport' in node_data:
      viewport_info = ViewportInfo(
        width=node_data['viewport']['width'],
        height=node_data['viewport']['height'],
      )
    if 'innerText' in node_data:
      inner_text = node_data['innerText']

    element_node = DOMElementNode(
      tag_name=node_data['tagName'],
      xpath=node_data['xpath'],
      attributes=node_data.get('attributes', {}),
      children=[],
      is_visible=node_data.get('isVisible', False),
      is_interactive=node_data.get('isInteractive', False),
      is_top_element=node_data.get('isTopElement', False),
      is_in_viewport=node_data.get('isInViewport', False),
      highlight_index=node_data.get('highlightIndex'),
      shadow_root=node_data.get('shadowRoot', False),
      parent=None,
      viewport_info=viewport_info,
      inner_text=inner_text,
      page_coordinates=node_data.get('pageCoordinates', {}),
      viewport_coordinates=node_data.get('viewportCoordinates', {})
    )

    children_ids = node_data.get('children', [])

    return element_node, children_ids
  
class DomServiceBlackTransparent(DomServiceBugFix):
  def __init__(
    self,
    *args: Any,
    **kwargs: Any,
  ):
    super().__init__(*args, **kwargs)
    self.js_code = resources.read_text('walt.browser_use.custom.js', 'buildDomTreeBlackTransparent.js')

# class DomServiceShopping(DomServiceBugFix):
#   def __init__(
#     self,
#     *args: Any,
#     **kwargs: Any,
#   ):
#     super().__init__(*args, **kwargs)
#     self.js_code = resources.read_text('walt.browser_use.custom.js', 'buildDomTree_shopping.js')

# class DomServiceShoppingBlackTransparent(DomServiceShopping):
#   def __init__(
#     self,
#     *args: Any,
#     **kwargs: Any,
#   ):
#     super().__init__(*args, **kwargs)
#     # Use shopping-specific JS with black transparent styling
#     self.js_code = resources.read_text('walt.browser_use.custom.js', 'buildDomTree_shopping_black_transparent.js')