import errno
import logging
import random
import select
import socket
import struct
import threading
import time
from ctypes import pointer
from ipaddress import IPv4Address

from vnet.lib.rdup_struct import ACK, WINDOW_SIZE, RUDPConn, RUDPPacket, SenderWnd

from .host import Server

srv_logger = logging.getLogger("rudp_server")
srv_logger.setLevel(logging.INFO)
file_handler = logging.FileHandler("srv_cli.log")
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
srv_logger.addHandler(file_handler)


class RDTServer(Server):
    _socket_registry = {}

    def __init__(
        self, ipv4_addr: str, host_name="RDT_Server", port=8000, packet_loss_rate=0.0, delay_ms=0
    ):
        super().__init__(IPv4Address(ipv4_addr), host_name, port, packet_loss_rate, delay_ms)
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        # Allow address reuse
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.sock.bind((ipv4_addr, port))
        # Set non-blocking
        self.sock.setblocking(False)
        self.running = False

    def start(self):
        self.running = True
        self.thread = threading.Thread(target=self._run)
        self.thread.daemon = True
        self.thread.start()

    def stop(self):
        self.running = False
        if hasattr(self, "thread"):
            self.thread.join()
        self.sock.close()

    def set_network_conditions(self, packet_loss_rate=0, delay_ms=0):
        self.packet_loss_rate = packet_loss_rate
        self.delay_ms = delay_ms

    def _run(self):
        srv_logger.info("服务器线程启动")
        while self.running:
            try:
                readable, _, _ = select.select([self.sock], [], [], 0.1)
                if not readable:
                    continue

                data, addr = self.sock.recvfrom(2048)
                srv_logger.info(
                    f"[SRV] 接收到 {addr} 发送的数据：{data}，长度:{len(data) - 6} byte"
                )

                if random.random() < self.packet_loss_rate:
                    srv_logger.info("[SRV] 模拟丢包")
                    continue

                if self.delay_ms > 0:
                    time.sleep(self.delay_ms / 1000.0)

                ack_packet = RUDPPacket()
                ack_packet.header.flag = ACK
                ack_packet.header.ack_num = data[1]
                ack_packet.header.rwnd = WINDOW_SIZE

                ack_bytes = bytes(ack_packet)
                self.sock.sendto(ack_bytes, addr)
                srv_logger.info(f"[SRV] 发送 ACK {ack_packet.header.ack_num}")

            except Exception as e:
                if isinstance(e, socket.error) and e.errno == errno.EWOULDBLOCK:
                    continue
                srv_logger.error(f"服务器错误: {e}")

    def create_conn_with_srv(
        self,
        srv_port=8000,
        init_base_seq=0,
        next_seq=0,
        unacked_packet_count=0,
        peer_rwnd=WINDOW_SIZE,
    ):
        conn = RUDPConn()

        # Create UDP socket
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.bind((str(self.ipv4_addr), 0))
        srv_logger.info(f"[CLI] {sock.getsockname()}")

        fd = sock.fileno()
        self._socket_registry[fd] = sock
        conn.sockfd = fd
        conn.connected = 1

        # 设置server地址
        conn.peer.sin_family = socket.AF_INET
        conn.peer.sin_port = socket.htons(srv_port)
        conn.peer.sin_addr.s_addr = socket.htonl(socket.INADDR_LOOPBACK)  # Use network byte order
        srv_logger.info(
            f"[SRV] {socket.inet_ntoa(struct.pack('!L', socket.ntohl(conn.peer.sin_addr.s_addr)))}:{socket.ntohs(conn.peer.sin_port)}"
        )

        window = SenderWnd(init_base_seq, next_seq, unacked_packet_count, peer_rwnd)
        # window state
        srv_logger.info(
            f"[CLI] 初始化发送窗口: base_seq={window.base_seq}, "
            f"next_seq={window.next_seq}, "
            f"unacked={window.unacked_packet_count}, "
            f"peer_rwnd={window.peer_rwnd}"
        )

        # Allocate window memory and set pointer
        window_ptr = pointer(window)
        conn.send_window = window_ptr

        return pointer(conn)

    def __del__(self):
        self.stop()
        for sock in self._socket_registry.values():
            sock.close()
        srv_logger.debug("RDTServer 实例销毁")
