from __future__ import annotations

import logging
import math
from pathlib import Path
from typing import Optional, Tuple

import gymnasium as gym
import numpy as np
import pybullet as p
import pybullet_data
from gymnasium import spaces
from omegaconf import OmegaConf

from .config import GymEnvConfig

logger = logging.getLogger(__name__)


class GymEnv(gym.Env):
    """
    一个将 PyBullet 机械臂封装为 Gymnasium 环境的最小可用版本。

    环境特点：
    - 关节0：力矩控制（动作 ∈ [-1,1] 缩放为 [-max_torque, +max_torque]）
    - 关节1：被动（仅受 URDF 阻尼/摩擦/重力作用）
    - 观测空间：[q0, q1, dq0, dq1]（两个关节的位置和速度）
    """

    def __init__(
        self,
        config: Optional[GymEnvConfig] = None,
    ):
        """
        初始化 Gymnasium 机械臂环境。

        参数：
            config: 环境配置，GymEnvConfig 实例。
                如果为 None，则使用默认配置。
        """
        super().__init__()

        # 解析配置
        if config is None:
            self.config = OmegaConf.structured(GymEnvConfig)
        else:
            self.config = OmegaConf.structured(config)

        # 转换为容器以便访问
        self.config = OmegaConf.to_container(self.config, resolve=True)
        self.config = OmegaConf.create(self.config)

        # 设置元数据
        self.metadata = {
            "render_modes": self.config.render.render_modes,
            "render_fps": self.config.render.render_fps,
        }

        self.current_step = 0  # 当前步数计数器

        # ---- 定义动作和观测空间 ----
        # 动作空间：1维连续值，范围由配置决定
        self.action_space = spaces.Box(
            low=np.array([self.config.control.action_low], dtype=np.float32),
            high=np.array([self.config.control.action_high], dtype=np.float32),
            dtype=np.float32,
        )
        # 观测空间：4维连续值 [q0, q1, dq0, dq1]，每个维度无限制
        high = np.array([np.inf, np.inf, np.inf, np.inf], dtype=np.float32)
        self.observation_space = spaces.Box(-high, +high, dtype=np.float32)

        # ---- PyBullet 物理引擎连接和初始化 ----
        # 连接到 DIRECT 模式（无图形渲染，仅进行物理计算）
        self._cid: int = -1
        self._connect_physics()

        # 以 -1 作为"未创建"的无效 id，避免 int|None 的类型告警
        self.plane_id: int = -1  # 平面的 id
        self.robot_id: int = -1  # 机械臂模型的 id

        self._build_world()

    def _connect_physics(self) -> None:
        """Ensure the environment has an active PyBullet connection."""
        if self._cid != -1 and p.isConnected(self._cid):
            return

        self._cid = p.connect(p.DIRECT)
        p.setTimeStep(self.config.physics.time_step, physicsClientId=self._cid)
        p.setGravity(*self.config.physics.gravity, physicsClientId=self._cid)
        p.setAdditionalSearchPath(pybullet_data.getDataPath(), physicsClientId=self._cid)

    # ---------- 物理世界构建 ----------
    def _build_world(self) -> None:
        """
        初始化/重置 PyBullet 物理世界。

        该方法会：
        1. 移除已存在的平面和机械臂模型
        2. 创建新的平面（ground）
        3. 从 URDF 文件加载机械臂模型
        4. 释放所有关节的电机控制，以便进行力矩控制
        5. 创建约束来固定机械臂的基座位置
        """
        self._connect_physics()

        # 清理旧的平面和机械臂
        if self.plane_id != -1:
            p.removeBody(self.plane_id, physicsClientId=self._cid)
            self.plane_id = -1
        if self.robot_id != -1:
            p.removeBody(self.robot_id, physicsClientId=self._cid)
            self.robot_id = -1

        # 创建平面（地面）
        plane_shape = p.createCollisionShape(
            p.GEOM_PLANE, physicsClientId=self._cid
        )
        self.plane_id = p.createMultiBody(
            baseMass=0,
            baseCollisionShapeIndex=plane_shape,
            physicsClientId=self._cid,
        )

        # 验证 URDF 文件存在，然后加载机械臂模型
        urdf_path = self.config.robot.urdf_path
        if not Path(urdf_path).exists():
            raise FileNotFoundError(f"URDF not found: {urdf_path}")

        self.robot_id = p.loadURDF(
            urdf_path,
            [0, 0, 0.001],
            useFixedBase=self.config.robot.use_fixed_base,
            physicsClientId=self._cid,
        )

        # 释放关节（关闭默认速度电机），之后才能使用 TORQUE_CONTROL 进行力矩控制
        for j in self.config.robot.joint_indices:
            p.setJointMotorControl2(
                self.robot_id,
                j,
                p.VELOCITY_CONTROL,
                force=0.0,
                physicsClientId=self._cid,
            )

        # 轻微固定基座，避免数值漂移（基座会在锚点周围轻微振动，但不会剧烈运动）
        p.createConstraint(
            parentBodyUniqueId=self.robot_id,
            parentLinkIndex=-1,  # -1 表示基座
            childBodyUniqueId=-1,  # -1 表示世界坐标系
            childLinkIndex=-1,
            jointType=p.JOINT_FIXED,
            jointAxis=[0, 0, 0],
            parentFramePosition=[0, 0, -0.0005],
            childFramePosition=list(self.config.robot.anchor_position),
            physicsClientId=self._cid,
        )

    # ---------- 辅助函数 ----------
    def _get_joint_state(self) -> Tuple[float, float, float, float]:
        """
        获取两个关节的位置和速度。

        返回值：
            (q0, q1, dq0, dq1) - 关节0和1的位置及速度
        """
        assert self.robot_id != -1, "robot not loaded"
        # 从 PyBullet 获取关节0的状态：[位置, 速度, ...]
        s0 = p.getJointState(
            self.robot_id,
            self.config.robot.joint_indices[0],
            physicsClientId=self._cid,
        )
        # 从 PyBullet 获取关节1的状态：[位置, 速度, ...]
        s1 = p.getJointState(
            self.robot_id,
            self.config.robot.joint_indices[1],
            physicsClientId=self._cid,
        )
        q0, dq0 = float(s0[0]), float(s0[1])
        q1, dq1 = float(s1[0]), float(s1[1])
        return q0, q1, dq0, dq1

    def _get_ee_pos(self) -> Tuple[float, float, float]:
        """
        获取末端执行器（end effector）的位置（笛卡尔坐标系下的 x, y, z）。

        优先策略：
        1. 尝试查找名为 'ball' 的 link（末端球形执行器）
        2. 如果没有找到，则使用最后一个关节对应的 link

        返回值：
            (x, y, z) - 末端执行器在世界坐标系下的位置
        """
        assert self.robot_id != -1, "robot not loaded"
        ee_link_index = None
        # 遍历所有关节，查找末端 link 名为 'ball' 的关节
        n = p.getNumJoints(self.robot_id, physicsClientId=self._cid)
        for i in range(n):
            # getJointInfo 返回的第 [12] 个元素是 child link 名称（字节字符串）
            child_link_name = p.getJointInfo(
                self.robot_id, i, physicsClientId=self._cid
            )[12].decode("utf-8")
            if child_link_name == "ball":
                ee_link_index = i
                break
        # 如果找不到名为 'ball' 的 link，则使用最后一个关节对应的 link
        if ee_link_index is None:
            ee_link_index = self.config.robot.joint_indices[-1]

        # 获取指定 link 的位置（world frame）
        pos = p.getLinkState(
            self.robot_id,
            ee_link_index,
            computeForwardKinematics=True,
            physicsClientId=self._cid,
        )[4]
        return float(pos[0]), float(pos[1]), float(pos[2])

    def _get_obs(self) -> np.ndarray:
        """
        获取当前观测（observation）。

        返回值：
            np.ndarray - 形状为 (4,) 的观测向量 [q0, q1, dq0, dq1]
        """
        q0, q1, dq0, dq1 = self._get_joint_state()
        # 归一化 q0 和 q1 到 [-pi, pi] 以处理 2pi 周期
        q0_normalized = ((q0 + np.pi) % (2 * np.pi)) - np.pi
        q1_normalized = ((q1 + np.pi) % (2 * np.pi)) - np.pi
        return np.array([q0_normalized, q1_normalized, dq0, dq1], dtype=np.float32)

    # ---------- Gymnasium API 接口 ----------
    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        """
        重置环境到初始状态。符合 Gymnasium 规范。

        参数：
            seed: 随机数生成器种子
            options: 环境特定选项（未使用）

        返回值：
            obs: np.ndarray - 初始观测
            info: dict - 环境信息（当前为空字典）
        """
        # Gymnasium 规范：用 super().reset(seed=seed) 初始化内部 RNG
        super().reset(seed=seed)

        self._connect_physics()

        self.current_step = 0  # 重置步数计数器
        # 重置 PyBullet 模拟的全局状态
        p.resetSimulation(physicsClientId=self._cid)
        p.setTimeStep(self.config.physics.time_step, physicsClientId=self._cid)
        p.setGravity(*self.config.physics.gravity, physicsClientId=self._cid)
        p.setAdditionalSearchPath(
            pybullet_data.getDataPath(), physicsClientId=self._cid
        )

        self.plane_id = -1
        self.robot_id = -1
        self._build_world()

        # 随机初始角度（范围由配置决定）
        rng = np.random.default_rng(seed)
        noise = self.config.episode.initial_joint_noise
        for j in self.config.robot.joint_indices:
            # 在 [pi - noise, pi + noise] 范围内随机初始化
            angle = math.pi + float(rng.uniform(-noise, noise))
            # 随机初始速度，增加多样性
            velocity = float(rng.uniform(-1.0, 1.0))
            p.resetJointState(
                self.robot_id,
                j,
                targetValue=angle,
                targetVelocity=velocity,
                physicsClientId=self._cid,
            )

        obs = self._get_obs()
        return obs, {}

    def step(self, action: np.ndarray):
        """
        执行一个环境步骤。

        参数：
            action: np.ndarray - 形状为 (1,) 的动作，范围 [-1, 1]

        返回值：
            obs: np.ndarray - 执行动作后的观测
            reward: float - 标量奖励
            terminated: bool - episode 是否终止（达到目标）
            truncated: bool - episode 是否被截断（超时或出错）
            info: dict - 额外信息
        """
        assert self.robot_id != -1, "robot not loaded"

        self.current_step += 1

        # 将动作从配置范围缩放为力矩 [-max_torque, +max_torque]
        # 首先裁剪动作到有效范围
        action_clipped = np.clip(
            float(action[0]),
            self.config.control.action_low,
            self.config.control.action_high
        )
        # 线性映射：[action_low, action_high] -> [-max_torque, +max_torque]
        action_range = self.config.control.action_high - self.config.control.action_low
        action_normalized = (action_clipped - self.config.control.action_low) / action_range
        tau = (action_normalized * 2.0 - 1.0) * self.config.control.max_torque

        # 对关节0施加力矩控制（关节1保持被动）
        p.setJointMotorControl2(
            self.robot_id,
            self.config.robot.joint_indices[0],
            p.TORQUE_CONTROL,
            force=tau,
            physicsClientId=self._cid,
        )

        # 执行 frame_skip 步的物理模拟
        for _ in range(self.config.physics.frame_skip):
            p.stepSimulation(physicsClientId=self._cid)

        # 获取新状态
        obs = self._get_obs()
        q0, q1, dq0, dq1 = map(float, obs[:4])
        
        # 使用归一化角度计算奖励（与官方 Pendulum 一致）
        # q1 已经在 _get_obs 中归一化到 [-pi, pi]
        reward = -(q1**2 + 0.1 * dq1**2 + 0.001 * tau**2)

        # 判断 episode 终止条件（基于新状态）
        terminated = False
        truncated = False
        if not np.isfinite(obs).all():  # 观测值包含 NaN 或 inf
            truncated = True
            logger.warning("Observation contains NaN or inf; truncating episode.")
            info = {}
            return obs, float(reward), terminated, truncated, info

        # 检查关节0角度（如果配置了阈值）
        q0_threshold = self.config.episode.q0_termination_threshold
        if q0_threshold is not None and abs(obs[0]) > q0_threshold:
            terminated = True
            logger.info(f"Episode terminated: |q0|={abs(obs[0]):.3f} > {q0_threshold}")
            info = {}
            return obs, float(reward), terminated, truncated, info

        # 检查关节1角度（如果配置了阈值）
        q1_threshold = self.config.episode.q1_termination_threshold
        if q1_threshold is not None and abs(obs[1]) > q1_threshold:
            terminated = True
            logger.info(f"Episode terminated: |q1|={abs(obs[1]):.3f} > {q1_threshold}")
            info = {}
            return obs, float(reward), terminated, truncated, info

        # 检查是否超过最大步数
        if self.current_step >= self.config.episode.max_episode_steps:
            truncated = True
            logger.info(
                f"Episode truncated: current_step={self.current_step} >= max_episode_steps={self.config.episode.max_episode_steps}"
            )
            info = {}
            return obs, float(reward), terminated, truncated, info

        info = {}
        return obs, float(reward), terminated, truncated, info

    def render(self):
        """
        渲染环境（不支持）。

        由于环境运行在 DIRECT 模式（无图形渲染），该方法会抛出异常。
        """
        raise RuntimeError("GymEnv is headless-only; rendering is not supported.")

    def close(self):
        """
        关闭环境，释放 PyBullet 连接和资源。
        """
        if self._cid != -1 and p.isConnected(self._cid):
            p.disconnect(self._cid)
        self._cid = -1
        self.robot_id = -1
        self.plane_id = -1
