# interpreter.py
from .parser import (
    NumNode, StringNode, VarAccessNode, BinOpNode, 
    AssignNode, PrintNode, IfNode, WhileNode
)

class Interpreter:
    def __init__(self):
        self.symbol_table = {}

    def _is_truthy(self, val):
        if isinstance(val, bool): return val
        if isinstance(val, (int, float)): return val != 0
        if isinstance(val, str): return len(val) > 0
        return False 

    def visit(self, node):
        """
        Visits a node and executes it.
        Wraps execution in Try/Except to add Line Number context to errors.
        """
        method_name = f'visit_{type(node).__name__}'
        method = getattr(self, method_name, self.no_visit_method)
        
        try:
            return method(node)
        except Exception as e:
            # If the error already has "Runtime Error", just pass it up
            if "Runtime Error on Line" in str(e):
                raise e
            
            # Otherwise, add the line number context
            line = getattr(node, 'lineno', 'Unknown')
            raise Exception(f"Runtime Error on Line {line}: {e}")

    def no_visit_method(self, node):
        raise Exception(f'No visit_{type(node).__name__} method defined')

    def visit_NumNode(self, node):
        return node.value

    def visit_StringNode(self, node):
        return node.value
    
    def visit_VarAccessNode(self, node):
        name = node.name
        if name in self.symbol_table:
            return self.symbol_table[name]
        else:
            raise Exception(f"Variable '{name}' is not defined.")

    def visit_AssignNode(self, node):
        value = self.visit(node.value)
        self.symbol_table[node.name] = value

    def visit_PrintNode(self, node):
        value = self.visit(node.value)
        if isinstance(value, bool):
            print("True" if value else "False")
        else:
            print(str(value))

    def visit_BinOpNode(self, node):
        left = self.visit(node.left)
        right = self.visit(node.right)
        op = node.op_token.type

        # Arithmetic
        if op in ('PLUS', 'MINUS', 'MUL', 'DIV', 'MODULO'):
            if isinstance(left, (int, float)) and isinstance(right, (int, float)):
                if op == 'PLUS': return left + right
                elif op == 'MINUS': return left - right
                elif op == 'MUL': return left * right
                elif op == 'DIV':
                    if right == 0: raise Exception("Division by zero")
                    return left / right
                elif op == 'MODULO':
                    if right == 0: raise Exception("Modulo by zero")
                    return left % right
            elif isinstance(left, str) and isinstance(right, str):
                if op == 'PLUS': return left + right
                else: raise Exception(f"Unsupported operation '{op}' for strings")
            else:
                raise Exception(f"Type mismatch: {type(left)} {op} {type(right)}")

        # Comparison
        elif op in ('GREATER', 'LESS', 'GREATER_EQ', 'LESS_EQ', 'EQUALTO', 'NOT_EQUAL'):
            if isinstance(left, (int, float)) and isinstance(right, (int, float)):
                if op == 'GREATER': return left > right
                elif op == 'LESS': return left < right
                elif op == 'GREATER_EQ': return left >= right
                elif op == 'LESS_EQ': return left <= right
                elif op == 'EQUALTO': return left == right
                elif op == 'NOT_EQUAL': return left != right
            elif isinstance(left, str) and isinstance(right, str):
                if op == 'EQUALTO': return left == right
                elif op == 'NOT_EQUAL': return left != right
                else: raise Exception("Strings only support == and !=")
            else:
                # Allow mismatch types equality (always false)
                if op == 'EQUALTO': return False
                elif op == 'NOT_EQUAL': return True
                else: raise Exception(f"Type mismatch in comparison")

    def visit_IfNode(self, node):
        for condition_node, statements in node.cases:
            if self._is_truthy(self.visit(condition_node)):
                for stmt in statements:
                    self.visit(stmt)
                return
        if node.else_case:
            for stmt in node.else_case:
                self.visit(stmt)

    def visit_WhileNode(self, node):
        while self._is_truthy(self.visit(node.condition_node)):
            for stmt in node.statements:
                self.visit(stmt)