import asyncio
from asyncio import FIRST_COMPLETED, CancelledError, Task, wait_for
from dataclasses import dataclass

from bluesky.protocols import Movable
from ophyd_async.core import (
    AsyncStatus,
    Device,
    StandardReadable,
    StandardReadableFormat,
    StrictEnum,
    set_and_wait_for_value,
    wait_for_value,
)
from ophyd_async.epics.core import (
    epics_signal_r,
    epics_signal_rw,
    epics_signal_rw_rbv,
    epics_signal_x,
)

from dodal.log import LOGGER

WAIT_FOR_OLD_PIN_MSG = "Waiting on old pin unloaded"
WAIT_FOR_NEW_PIN_MSG = "Waiting on new pin loaded"


class RobotLoadError(Exception):
    error_code: int
    error_string: str

    def __init__(self, error_code: int, error_string: str) -> None:
        self.error_code, self.error_string = error_code, error_string
        super().__init__(error_string)

    def __str__(self) -> str:
        return self.error_string


@dataclass
class SampleLocation:
    puck: int
    pin: int


class PinMounted(StrictEnum):
    NO_PIN_MOUNTED = "No Pin Mounted"
    PIN_MOUNTED = "Pin Mounted"


class ErrorStatus(Device):
    def __init__(self, prefix: str) -> None:
        self.str = epics_signal_r(str, prefix + "_ERR_MSG")
        self.code = epics_signal_r(int, prefix + "_ERR_CODE")
        super().__init__()

    async def raise_if_error(self, raise_from: Exception):
        error_code = await self.code.get_value()
        if error_code:
            error_string = await self.str.get_value()
            raise RobotLoadError(int(error_code), error_string) from raise_from


class BartRobot(StandardReadable, Movable[SampleLocation]):
    """The sample changing robot."""

    # How long to wait for the robot if it is busy soaking/drying
    NOT_BUSY_TIMEOUT = 5 * 60

    # How long to wait for the actual load to happen
    LOAD_TIMEOUT = 60

    # Error codes that we do special things on
    NO_PIN_ERROR_CODE = 25
    LIGHT_CURTAIN_TRIPPED = 40

    # How far the gonio position can be out before loading will fail
    LOAD_TOLERANCE_MM = 0.02

    def __init__(self, prefix: str, name: str = "") -> None:
        with self.add_children_as_readables(StandardReadableFormat.HINTED_SIGNAL):
            self.barcode = epics_signal_r(str, prefix + "BARCODE")
            self.gonio_pin_sensor = epics_signal_r(PinMounted, prefix + "PIN_MOUNTED")

            self.current_puck = epics_signal_r(float, prefix + "CURRENT_PUCK_RBV")
            self.current_pin = epics_signal_r(float, prefix + "CURRENT_PIN_RBV")

        self.next_pin = epics_signal_rw_rbv(float, prefix + "NEXT_PIN")
        self.next_puck = epics_signal_rw_rbv(float, prefix + "NEXT_PUCK")

        self.sample_id = epics_signal_r(int, prefix + "CURRENT_ID_RBV")
        self.next_sample_id = epics_signal_rw_rbv(int, prefix + "NEXT_ID")

        self.load = epics_signal_x(prefix + "LOAD.PROC")
        self.unload = epics_signal_x(prefix + "UNLD.PROC")
        self.program_running = epics_signal_r(bool, prefix + "PROGRAM_RUNNING")
        self.program_name = epics_signal_r(str, prefix + "PROGRAM_NAME")

        self.prog_error = ErrorStatus(prefix + "PRG")
        self.controller_error = ErrorStatus(prefix + "CNTL")

        self.reset = epics_signal_x(prefix + "RESET.PROC")
        self.abort = epics_signal_x(prefix + "ABORT.PROC")
        self.init = epics_signal_x(prefix + "INIT.PROC")
        self.soak = epics_signal_x(prefix + "SOAK.PROC")
        self.home = epics_signal_x(prefix + "GOHM.PROC")
        self.unload = epics_signal_x(prefix + "UNLD.PROC")
        self.dry = epics_signal_x(prefix + "DRY.PROC")
        self.open = epics_signal_x(prefix + "COLO.PROC")
        self.close = epics_signal_x(prefix + "COLC.PROC")
        self.cryomode_rbv = epics_signal_r(float, prefix + "CRYO_MODE_RBV")
        self.cryomode = epics_signal_rw(str, prefix + "CRYO_MODE_CTRL")
        self.gripper_temp = epics_signal_r(float, prefix + "GRIPPER_TEMP")
        self.dewar_lid_temperature = epics_signal_rw(
            float, prefix + "DW_1_TEMP", prefix + "DW_1_SET_POINT"
        )
        super().__init__(name=name)

    async def pin_mounted_or_no_pin_found(self):
        """This co-routine will finish when either a pin is detected or the robot gives
        an error saying no pin was found (whichever happens first). In the case where no
        pin was found a RobotLoadError error is raised.
        """

        async def raise_if_no_pin():
            await wait_for_value(self.prog_error.code, self.NO_PIN_ERROR_CODE, None)
            raise RobotLoadError(self.NO_PIN_ERROR_CODE, "Pin was not detected")

        async def wfv():
            await wait_for_value(self.gonio_pin_sensor, PinMounted.PIN_MOUNTED, None)

        tasks = [
            (Task(raise_if_no_pin())),
            (Task(wfv())),
        ]
        try:
            finished, unfinished = await asyncio.wait(
                tasks,
                return_when=FIRST_COMPLETED,
            )
            for task in unfinished:
                task.cancel()
            for task in finished:
                await task
        except CancelledError:
            # If the outer enclosing task cancels after a timeout, this causes CancelledError to be raised
            # in the current task, when it propagates to here we should cancel all pending tasks before bubbling up
            for task in tasks:
                task.cancel()

            raise

    async def _load_pin_and_puck(self, sample_location: SampleLocation):
        if await self.controller_error.code.get_value() == self.LIGHT_CURTAIN_TRIPPED:
            LOGGER.info("Light curtain tripped, trying again")
            await self.reset.trigger()
        LOGGER.info(f"Loading pin {sample_location}")
        if await self.program_running.get_value():
            LOGGER.info(
                f"Waiting on robot to finish {await self.program_name.get_value()}"
            )
            await wait_for_value(
                self.program_running, False, timeout=self.NOT_BUSY_TIMEOUT
            )
        await asyncio.gather(
            set_and_wait_for_value(self.next_puck, sample_location.puck),
            set_and_wait_for_value(self.next_pin, sample_location.pin),
        )
        await self.load.trigger()
        if await self.gonio_pin_sensor.get_value() == PinMounted.PIN_MOUNTED:
            LOGGER.info(WAIT_FOR_OLD_PIN_MSG)
            await wait_for_value(self.gonio_pin_sensor, PinMounted.NO_PIN_MOUNTED, None)
        LOGGER.info(WAIT_FOR_NEW_PIN_MSG)

        await self.pin_mounted_or_no_pin_found()

    @AsyncStatus.wrap
    async def set(self, value: SampleLocation):
        try:
            await wait_for(
                self._load_pin_and_puck(value),
                timeout=self.LOAD_TIMEOUT + self.NOT_BUSY_TIMEOUT,
            )
        except TimeoutError as e:
            await self.prog_error.raise_if_error(e)
            await self.controller_error.raise_if_error(e)
            raise RobotLoadError(0, "Robot timed out") from e
