from .packet import Packet
from .firewall import Firewall
from .router import Router


class _SimulationResult:
    def __init__(self):
        self.action = None


class Simulator:
    def __init__(self, firewall: Firewall = None, router: Router = None):
        self.firewall = firewall if firewall is not None else Firewall()
        self.router = router if router is not None else Router()

    def simulate(self, packet: Packet) -> _SimulationResult:
        # http://linux-ip.net/html/routing-saddr-selection.html
        # TODO: guess source ip address based solely on input interface name

        # if packet.iiface is None and packet.s:
        # packet was generated by host

        # https://wiki.nftables.org/wiki-nftables/index.php/Netfilter_hooks
        # 1. check if packet originates from host. if the packet originates from host then send the packet to output -> postrouting
        # 2. check if the packet is destined for the host. prerouting -> input

        if packet.source is None:
            return self.resolve_outgoing(packet)
            # this function also has to ensure it has oiface AND does not have iiface

        for interface in self.router.interfaces:
            for address in interface["addresses"]:
                if packet.source == address["address"]:
                    return self.resolve_outgoing(packet)  # remember about loopback

                # TODO: we actually handle this in forward
                # if packet.destination == address["address"]:
                #     return self.resolve_incoming(packet)

        self.resolve_incoming(packet)

    def resolve_outgoing(self, packet: Packet):
        self.router.route(packet)

        output_result = self.firewall.resolve_hook("output", packet)
        print(f"[simulator] output_result -> {output_result}")

        if output_result in ("drop", "reject"):
            return output_result

        postrouting_result = self.firewall.resolve_hook("postrouting", packet)
        print(f"[simulator] postrouting_result -> {postrouting_result}")

        return postrouting_result

    def resolve_incoming(self, packet: Packet):
        if packet.iiface is None:
            for interface in self.router.interfaces:
                for address in interface["addresses"]:
                    if packet.source in address["network"]:
                        packet.iiface = interface["iface"]

        prerouting_result = self.firewall.resolve_hook("prerouting", packet)
        print(f"[simulator] prerouting_result -> {prerouting_result}")

        if prerouting_result in ("drop", "reject"):
            return prerouting_result

        self.router.route(packet)

        for interface in self.router.interfaces:
            for address in interface["addresses"]:
                if packet.destination == address["address"]:
                    input_result = self.firewall.resolve_hook("input", packet)
                    print(f"[simulator] input_result -> {input_result}")
                    return input_result

        forwarding_result = self.firewall.resolve_hook("forward", packet)
        print(f"[simulator] forward -> {forwarding_result}")

        if forwarding_result in ("drop", "reject"):
            return forwarding_result

        # TODO: simulate loopback/lo traffic
        # TODO: support multiple loopback devices
        postrouting_result = self.firewall.resolve_hook("postrouting", packet)
        print(f"[simulator] postrouting_result -> {postrouting_result}")

        return postrouting_result
