import asyncio
from functools import partial
from typing import AsyncIterable, cast

import wandelbots_api_client as wb

from nova.actions import Action, CombinedActions, MovementController, MovementControllerContext
from nova.actions.mock import WaitAction
from nova.actions.motions import CollisionFreeMotion, Motion
from nova.api import models
from nova.cell.robot_cell import AbstractRobot
from nova.config import ENABLE_TRAJECTORY_TUNING
from nova.core import logger
from nova.core.exceptions import InconsistentCollisionScenes
from nova.core.gateway import ApiGateway
from nova.core.movement_controller import move_forward
from nova.core.tuner import TrajectoryTuner
from nova.types import InitialMovementStream, LoadPlanResponse, MovementResponse, Pose, RobotState
from nova.utils import StreamExtractor

MAX_JOINT_VELOCITY_PREPARE_MOVE = 0.2
START_LOCATION_OF_MOTION = 0.0


def compare_collision_scenes(scene1: wb.models.CollisionScene, scene2: wb.models.CollisionScene):
    if scene1.colliders != scene2.colliders:
        return False

    # Compare motion groups
    if scene1.motion_groups != scene2.motion_groups:
        return False

    return True


# TODO: when collision scene is different in different motions
#  , we should plan them separately
def split_actions_into_batches(actions: list[Action]) -> list[list[Action]]:
    """
    Splits the list of actions into batches of actions, collision free motions and waits.
    Actions are sent to plan_trajectory API and collision free motions are sent to plan_collision_free_ptp API.
    Waits generate a trajectory with the same start and end position.
    """
    batches: list[list[Action]] = []
    for action in actions:
        if (
            # Start a new batch if:
            not batches  # first action no batches yet
            or isinstance(action, CollisionFreeMotion)
            or isinstance(batches[-1][-1], CollisionFreeMotion)
            or isinstance(action, WaitAction)
            or isinstance(batches[-1][-1], WaitAction)
        ):
            batches.append([action])
        else:
            batches[-1].append(action)
    return batches


def combine_trajectories(
    trajectories: list[wb.models.JointTrajectory],
) -> wb.models.JointTrajectory:
    """
    Combines multiple trajectories into one trajectory.
    """
    final_trajectory = trajectories[0]
    current_end_time = final_trajectory.times[-1]
    current_end_location = final_trajectory.locations[-1]

    for trajectory in trajectories[1:]:
        # Shift times and locations to continue from last endpoint
        shifted_times = [t + current_end_time for t in trajectory.times[1:]]  # Skip first point
        shifted_locations = [
            location + current_end_location for location in trajectory.locations[1:]
        ]  # Skip first point

        final_trajectory.times.extend(shifted_times)
        final_trajectory.joint_positions.extend(trajectory.joint_positions[1:])
        final_trajectory.locations.extend(shifted_locations)

        current_end_time = final_trajectory.times[-1]
        current_end_location = final_trajectory.locations[-1]

    return final_trajectory


def validate_collision_scenes(actions: list[Action]) -> list[models.CollisionScene]:
    """
    RAE V1 APIs provide two ways of planning actions.
    Collition free planning and collision checked planning.
    As the names suggest, collision free planning produces a joint trajectory with which the collision is avoided.
    But collision check planning checks for collision and if there is one, it will return an error.


    The action list python sdk takes from the user has a variety of types.
    There can be a variety of actions in the list.
    1. Collision free motions
    2. Normal motions
    3. Waits
    4. Write actions -> this is a special write on the path supported by the API


    This function checks that a given set of actions contains valid collision scenes data.
    It is assumed that the action list provided here is a sub-batch generated by the split_actions_into_batches function.
    """
    motion_count = len([action for action in actions if isinstance(action, Motion)])
    collision_scenes = [
        action.collision_scene
        for action in actions
        if isinstance(action, Motion) and action.collision_scene is not None
    ]

    if len(collision_scenes) != 0 and len(collision_scenes) != motion_count:
        raise InconsistentCollisionScenes(
            "Only some of the actions have collision scene. Either specify it for all or none."
        )

    # If a collision scene is provided, the same should be provided for all the collision scene
    if len(collision_scenes) > 1:
        first_scene = collision_scenes[0]
        if not all(compare_collision_scenes(first_scene, scene) for scene in collision_scenes[1:]):
            raise InconsistentCollisionScenes(
                "All actions must use the same collision scene but some are different"
            )

    return collision_scenes


class MotionGroup(AbstractRobot):
    """Manages motion planning and execution within a specified motion group."""

    def __init__(self, api_gateway: ApiGateway, cell: str, motion_group_id: str):
        """
        Initializes a new MotionGroup instance.

        Args:
            api_gateway (ApiGateway): The API gateway through which motion commands are sent.
            cell (str): The name or identifier of the robotic cell.
            motion_group_id (str): The identifier of the motion group.
        """
        self._api_gateway = api_gateway
        self._cell = cell
        self._motion_group_id = motion_group_id
        self._current_motion: str | None = None
        self._optimizer_setup: wb.models.OptimizerSetup | None = None
        super().__init__(id=motion_group_id)

    async def open(self):
        await self._api_gateway.activate_motion_group(
            cell=self._cell, motion_group_id=self._motion_group_id
        )
        return self

    async def close(self):
        # RPS-1174: when a motion group is deactivated, RAE closes all open connections
        #           this behaviour is not desired in some cases,
        #           so for now we will not deactivate for the user
        pass

    @property
    def motion_group_id(self) -> str:
        """
        Returns:
            str: The unique identifier for this motion group.
        """
        return self._motion_group_id

    @property
    def current_motion(self) -> str | None:
        # if not self._current_motion:
        #    raise ValueError("No MotionId attached. There is no planned motion available.")
        return self._current_motion

    async def _plan_with_collision_check(
        self,
        actions: list[Action],
        tcp: str,
        start_joint_position: tuple[float, ...] | None = None,
        optimizer_setup: wb.models.OptimizerSetup | None = None,
    ) -> wb.models.JointTrajectory:
        """
        This method plans a trajectory and checks for collisions.
        The collision check only happens if the actions have collision scene data.

        You must provide the exact same collision data into all the actions.
        Because the underlying API supports collision checks for the whole trajectory only.

        Raises:
            InconsistentCollisionScenes: If the collision scene data is not consistent across all actions

            Your actions should follow below rules to be considered consistent:
            1- They all should have the same collision scene data
            2- They all should have no collision data

            PlanTrajectoryFailed: If the trajectory planning failed including the collision check

        For more information about this API, please refer to the plan_trajectory in the API documentation.

        Args:
            actions: list of actions to plan, current supported actions are Motion and WriteActions
                     WriteAction you specify on your path is handled in a performant way.
                     Please check execute_trajectory.motion_command.set_io for more information.
            tcp:     The tool to use
            start_joint_position: The starting joint position, if none provided, current position of the robot is used
            optimizer_setup: The optimizer setup

        Returns: planned joint trajectory

        """
        # PREPARE THE REQUEST
        collision_scenes = validate_collision_scenes(actions)
        start_joint_position = start_joint_position or await self.joints()
        robot_setup = optimizer_setup or await self._get_optimizer_setup(tcp=tcp)

        motion_commands = CombinedActions(items=tuple(actions)).to_motion_command()  # type: ignore

        static_colliders = None
        collision_motion_group = None
        if collision_scenes and len(collision_scenes) > 0:
            static_colliders = collision_scenes[0].colliders

            motion_group_type = robot_setup.motion_group_type
            if (
                collision_scenes[0].motion_groups
                and motion_group_type in collision_scenes[0].motion_groups
            ):
                collision_motion_group = collision_scenes[0].motion_groups[motion_group_type]

        request = wb.models.PlanTrajectoryRequest(
            robot_setup=robot_setup,
            start_joint_position=list(start_joint_position),
            motion_commands=motion_commands,
            static_colliders=static_colliders,
            collision_motion_group=collision_motion_group,
        )

        return await self._api_gateway.plan_trajectory(
            cell=self._cell, motion_group_id=self.motion_group_id, request=request
        )

    # TODO: we get the optimizer setup from as an input because
    #  it has a velocity setting which is used in collision free movement, I need to double check this
    async def _plan_collision_free(
        self,
        action: CollisionFreeMotion,
        tcp: str,
        start_joint_position: list[float],
        optimizer_setup: wb.models.OptimizerSetup | None = None,
    ) -> wb.models.JointTrajectory:
        """
        This method plans a trajectory and avoids collisions.
        This means if there is a collision along the way to the target pose or joint positions,
        It will adjust the trajectory to avoid the collision.

        The collision check only happens if the action have collision scene data.

        For more information about this API, please refer to the plan_collision_free_ptp in the API documentation.

        Args:
            action: The target pose or joint positions to reach
            tcp:     The tool to use
            start_joint_position: The starting joint position, if none provided, current position of the robot is used
            optimizer_setup: The optimizer setup

        Returns: planned joint trajectory


        """
        target = wb.models.PlanCollisionFreePTPRequestTarget(**action.to_api_model().model_dump())
        robot_setup = optimizer_setup or await self._get_optimizer_setup(tcp=tcp)

        static_colliders = None
        collision_motion_group = None
        collision_scene = action.collision_scene
        if collision_scene and collision_scene.colliders:
            static_colliders = collision_scene.colliders

            if (
                collision_scene.motion_groups
                and robot_setup.motion_group_type in collision_scene.motion_groups
            ):
                collision_motion_group = collision_scene.motion_groups[
                    robot_setup.motion_group_type
                ]

        request: wb.models.PlanCollisionFreePTPRequest = wb.models.PlanCollisionFreePTPRequest(
            robot_setup=robot_setup,
            start_joint_position=start_joint_position,
            target=target,
            static_colliders=static_colliders,
            collision_motion_group=collision_motion_group,
        )

        return await self._api_gateway.plan_collision_free_ptp(
            cell=self._cell, motion_group_id=self.motion_group_id, request=request
        )

    async def _plan(
        self,
        actions: list[Action],
        tcp: str,
        start_joint_position: tuple[float, ...] | None = None,
        optimizer_setup: wb.models.OptimizerSetup | None = None,
    ) -> wb.models.JointTrajectory:
        if not actions:
            raise ValueError("No actions provided")

        current_joints = start_joint_position or await self.joints()
        robot_setup = optimizer_setup or await self._get_optimizer_setup(tcp=tcp)

        all_trajectories = []
        for batch in split_actions_into_batches(actions):
            if len(batch) == 0:
                raise ValueError("Empty batch of actions")

            if isinstance(batch[0], CollisionFreeMotion):
                motion: CollisionFreeMotion = cast(CollisionFreeMotion, batch[0])
                trajectory = await self._plan_collision_free(
                    action=motion,
                    tcp=tcp,
                    start_joint_position=list(current_joints),
                    optimizer_setup=robot_setup,
                )
                all_trajectories.append(trajectory)
                # the last joint position of this trajectory is the starting point for the next one
                current_joints = tuple(trajectory.joint_positions[-1].joints)
            elif isinstance(batch[0], WaitAction):
                # Waits generate a trajectory with the same joint position at each timestep
                # Use 50ms timesteps from 0 to wait_for_in_seconds
                wait_time = batch[0].wait_for_in_seconds
                timestep = 0.050  # 50ms timestep
                num_steps = max(2, int(wait_time / timestep) + 1)  # Ensure at least 2 points

                # Create equal-length arrays for positions, times, and locations
                joint_positions = [
                    wb.models.Joints(joints=list(current_joints)) for _ in range(num_steps)
                ]
                times = [i * timestep for i in range(num_steps)]
                # Ensure the last timestep is exactly the wait duration
                times[-1] = wait_time
                # Use the same location value for all points
                locations = [0] * num_steps

                trajectory = wb.models.JointTrajectory(
                    joint_positions=joint_positions,
                    times=times,
                    locations=[float(loc) for loc in locations],
                )
                all_trajectories.append(trajectory)
                # the last joint position of this trajectory is the starting point for the next one
                current_joints = tuple(trajectory.joint_positions[-1].joints)
            else:
                trajectory = await self._plan_with_collision_check(
                    actions=batch,
                    tcp=tcp,
                    start_joint_position=current_joints,
                    optimizer_setup=robot_setup,
                )
                all_trajectories.append(trajectory)
                # the last joint position of this trajectory is the starting point for the next one
                current_joints = tuple(trajectory.joint_positions[-1].joints)

        return combine_trajectories(all_trajectories)

    # TODO: refactor and simplify code, tests are already there
    # TODO: split into batches when the collision scene changes in a batch of collision free motions

    async def _execute(
        self,
        joint_trajectory: wb.models.JointTrajectory,
        tcp: str,
        actions: list[Action],
        movement_controller: MovementController | None,
        start_on_io: wb.models.StartOnIO | None = None,
    ) -> AsyncIterable[MovementResponse]:
        # This is the entrypoint for the trajectory tuning mode
        if ENABLE_TRAJECTORY_TUNING:
            logger.info("Entering trajectory tuning mode...")
            async for execute_response in self._tune_trajectory(joint_trajectory, tcp, actions):
                yield execute_response
            return

        if movement_controller is None:
            movement_controller = move_forward

        # Load planned trajectory
        load_plan_response = await self._load_planned_motion(joint_trajectory, tcp)

        # Move to start position
        number_of_joints = await self._api_gateway.get_joint_number(
            cell=self._cell, motion_group_id=self.motion_group_id
        )
        joints_velocities = [MAX_JOINT_VELOCITY_PREPARE_MOVE] * number_of_joints
        movement_stream = await self.move_to_start_position(joints_velocities, load_plan_response)

        # If there's an initial consumer, feed it the data
        async for move_to_response in movement_stream:
            # TODO: refactor
            if (
                move_to_response.state is None
                or move_to_response.state.motion_groups is None
                or len(move_to_response.state.motion_groups) == 0
                or move_to_response.move_response is None
                or move_to_response.move_response.current_location_on_trajectory is None
            ):
                continue

            yield move_to_response

        controller = movement_controller(
            MovementControllerContext(
                combined_actions=CombinedActions(items=tuple(actions)),  # type: ignore
                motion_id=load_plan_response.motion,
                start_on_io=start_on_io,
            )
        )

        def stop_condition(response: wb.models.ExecuteTrajectoryResponse) -> bool:
            instance = response.actual_instance
            # Stop when standstill indicates motion ended
            return (
                isinstance(instance, wb.models.Standstill)
                and instance.standstill.reason == wb.models.StandstillReason.REASON_MOTION_ENDED
            )

        execute_response_streaming_controller = StreamExtractor(controller, stop_condition)
        execution_task = asyncio.create_task(
            self._api_gateway.motion_api.execute_trajectory(
                cell=self._cell, client_request_generator=execute_response_streaming_controller
            )
        )

        async for execute_response in execute_response_streaming_controller:
            yield execute_response
        await execution_task

    async def _tune_trajectory(
        self, joint_trajectory: wb.models.JointTrajectory, tcp: str, actions: list[Action]
    ) -> AsyncIterable[MovementResponse]:
        start_joints = await self.joints()

        async def plan_fn(actions: list[Action]) -> tuple[str, wb.models.JointTrajectory]:
            # we fix the start joints here because the tuner might call plan multiple times whilst tuning
            # and the start joints would change to the respective joint positions at the time of planning
            # which is not what we want
            joint_trajectory = await self._plan(actions, tcp, start_joints)
            load_planned_motion_response = await self._load_planned_motion(joint_trajectory, tcp)
            return load_planned_motion_response.motion, joint_trajectory

        execute_fn = partial(self._api_gateway.motion_api.execute_trajectory, cell=self._cell)
        tuner = TrajectoryTuner(actions, plan_fn, execute_fn)
        async for response in tuner.tune():
            yield response

    async def _get_optimizer_setup(self, tcp: str) -> wb.models.OptimizerSetup:
        # TODO: mypy failed on main branch, need to check
        if self._optimizer_setup is None or self._optimizer_setup.tcp != tcp:  # type: ignore
            self._optimizer_setup = await self._api_gateway.get_optimizer_config(
                cell=self._cell, motion_group_id=self.motion_group_id, tcp=tcp
            )

        return self._optimizer_setup

    async def _load_planned_motion(
        self, joint_trajectory: wb.models.JointTrajectory, tcp: str
    ) -> wb.models.PlanSuccessfulResponse:
        return await self._api_gateway.load_planned_motion(
            cell=self._cell,
            motion_group_id=self.motion_group_id,
            joint_trajectory=joint_trajectory,
            tcp=tcp,
        )

    async def move_to_start_position(
        self, joint_velocities, load_plan_response: LoadPlanResponse
    ) -> InitialMovementStream:
        limit_override = wb.models.LimitsOverride()
        if joint_velocities is not None:
            limit_override.joint_velocity_limits = wb.models.Joints(joints=joint_velocities)

        return self._api_gateway.stream_move_to_trajectory_via_join_ptp(
            cell=self._cell,
            motion_id=load_plan_response.motion,
            location_on_trajectory=0,
            joint_velocity_limits=limit_override.joint_velocity_limits,
        )

    async def stop(self):
        logger.debug(f"Stopping motion of {self}...")
        try:
            if self._current_motion is None:
                raise ValueError("No motion to stop")
            await self._api_gateway.stop_motion(cell=self._cell, motion_id=self._current_motion)
            logger.debug(f"Motion {self.current_motion} stopped.")
        except ValueError as e:
            logger.debug(f"No motion to stop for {self}: {e}")

    async def get_state(self, tcp: str | None = None) -> RobotState:
        """
        Returns the motion group state.
        Args:
            tcp (str | None): The reference TCP for the cartesian pose part of the robot state. Defaults to None.
                                        If None, the current active/selected TCP of the motion group is used.
        """
        response = await self._api_gateway.get_motion_group_state(
            cell=self._cell, motion_group_id=self.motion_group_id, tcp=tcp
        )
        pose = Pose(response.tcp_pose or response.state.tcp_pose)
        return RobotState(pose=pose, joints=tuple(response.state.joint_position.joints))

    async def joints(self) -> tuple:
        """Returns the current joint positions of the motion group."""
        state = await self.get_state()
        if state.joints is None:
            raise ValueError(
                f"No joint positions available for motion group {self._motion_group_id}"
            )
        return state.joints

    async def tcp_pose(self, tcp: str | None = None) -> Pose:
        """
        Returns the current TCP pose of the motion group.
        Args:
            tcp (str | None): The reference TCP for the returned pose. Defaults to None.
                                        If None, the current active/selected TCP of the motion group is used.
        """
        state = await self.get_state(tcp=tcp)
        return state.pose

    async def tcps(self) -> list[wb.models.RobotTcp]:
        response = await self._api_gateway.list_tcps(
            cell=self._cell, motion_group_id=self.motion_group_id
        )
        return response.tcps if response.tcps else []

    async def tcp_names(self) -> list[str]:
        return [tcp.id for tcp in await self.tcps()]

    async def active_tcp(self) -> wb.models.RobotTcp:
        active_tcp = await self._api_gateway.get_active_tcp(
            cell=self._cell, motion_group_id=self.motion_group_id
        )
        return active_tcp

    async def active_tcp_name(self) -> str:
        active_tcp = await self.active_tcp()
        return active_tcp.id

    async def ensure_virtual_tcp(self, tcp: models.RobotTcp, timeout: int = 12) -> models.RobotTcp:
        """
        Ensure that a virtual TCP with the expected configuration exists on this motion group.
        If it doesn't exist, it will be created. If it exists but has different configuration,
        it will be updated by recreating it.

        Args:
            tcp (models.RobotTcp): The expected TCP configuration

        Returns:
            models.RobotTcp: The TCP configuration
        """
        existing_tcps = await self.tcps()

        existing_tcp = next((tcp_ for tcp_ in existing_tcps if tcp_.id == tcp.id), None)
        if existing_tcp and existing_tcp == tcp:
            return existing_tcp

        controller_name = self._motion_group_id.split("@")[1]
        motion_group_index = int(self._motion_group_id.split("@")[0])

        await self._api_gateway.virtual_robot_setup_api.add_virtual_robot_tcp(
            cell=self._cell, controller=controller_name, id=motion_group_index, robot_tcp=tcp
        )

        # TODO: this is a workaround to wait for the TCP to be created
        t = timeout
        while t > 0:
            existing_tcps = await self.tcps()
            tcp_names = [tcp_.id for tcp_ in existing_tcps]
            if tcp.id in tcp_names:
                return tcp
            await asyncio.sleep(1)
            t -= 1

        raise TimeoutError(f"Failed to create TCP '{tcp.id}' within {timeout} seconds")
