from __future__ import annotations

import math
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple

import gymnasium as gym
import numpy as np
from gymnasium import spaces
from omegaconf import OmegaConf


@dataclass
class PendulumEnvConfig:
    """配置项：允许通过 YAML/OmegaConf 自定义底层 Gymnasium 环境。"""

    env_id: str = "Pendulum-v1"
    """Gymnasium 环境 ID"""

    env_kwargs: Dict[str, Any] = field(default_factory=dict)
    """传递给 gym.make 的额外关键字参数"""


class PendulumCompatEnv(gym.Env):
    """
    基于 Gymnasium `Pendulum-v1` 的环境包装，使其观测和动作接口
    与当前机械臂环境保持兼容：

    - 动作空间：1 维连续值，取值范围 [-1, 1]
    - 观测空间：4 维连续值 [q0, q1, dq0, dq1]
        * 单摆仅有一个关节，因此 q1 与 dq1 恒为 0
        * 通过 atan2(sin(theta), cos(theta)) 将单摆角度还原为 q0
    """

    def __init__(
        self,
        config: Optional[PendulumEnvConfig] = None,
        env_id: Optional[str] = None,
        env_kwargs: Optional[Dict[str, Any]] = None,
        **extra_kwargs: Any,
    ) -> None:
        super().__init__()

        base_cfg = OmegaConf.structured(PendulumEnvConfig)
        if config is None:
            merged_cfg = base_cfg
        else:
            merged_cfg = OmegaConf.merge(base_cfg, config)

        cfg_dict = OmegaConf.to_container(merged_cfg, resolve=True)

        if env_id is not None:
            cfg_dict["env_id"] = env_id
        if env_kwargs is not None:
            cfg_dict["env_kwargs"] = env_kwargs
        if extra_kwargs:
            # 若通过关键字参数直接传入 env 配置，合并到 env_kwargs 中
            cfg_dict.setdefault("env_kwargs", {})
            cfg_dict["env_kwargs"] = {**cfg_dict["env_kwargs"], **extra_kwargs}

        self.config = OmegaConf.create(cfg_dict)
        env_kwargs_dict = OmegaConf.to_container(self.config.env_kwargs, resolve=True)
        if env_kwargs_dict is None:
            env_kwargs_dict = {}

        self._base_env = gym.make(self.config.env_id, **env_kwargs_dict)

        if not isinstance(self._base_env.observation_space, spaces.Box):
            raise TypeError(
                f"{self.config.env_id} observation_space must be Box, "
                f"got {type(self._base_env.observation_space)}"
            )
        if not isinstance(self._base_env.action_space, spaces.Box):
            raise TypeError(
                f"{self.config.env_id} action_space must be Box, "
                f"got {type(self._base_env.action_space)}"
            )

        # 保存底层动作范围以用于缩放
        self._base_action_low = np.asarray(
            self._base_env.action_space.low, dtype=np.float32
        )
        self._base_action_high = np.asarray(
            self._base_env.action_space.high, dtype=np.float32
        )

        # 与机械臂环境保持一致的动作 / 观测空间
        self.action_space = spaces.Box(
            low=np.array([-1.0], dtype=np.float32),
            high=np.array([1.0], dtype=np.float32),
            dtype=np.float32,
        )
        high = np.array([np.inf, np.inf, np.inf, np.inf], dtype=np.float32)
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)

        # 传递底层元数据（若存在）
        self.metadata = getattr(self._base_env, "metadata", {}).copy()

    # ---------- Gymnasium API ----------
    def reset(
        self,
        *,
        seed: Optional[int] = None,
        options: Optional[Dict[str, Any]] = None,
    ) -> Tuple[np.ndarray, Dict[str, Any]]:
        obs, info = self._base_env.reset(seed=seed, options=options)
        return self._convert_obs(obs), info

    def step(
        self,
        action: np.ndarray,
    ) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
        scaled_action = self._scale_action(action)
        obs, reward, terminated, truncated, info = self._base_env.step(scaled_action)
        return self._convert_obs(obs), float(reward), terminated, truncated, info

    def render(self):
        return self._base_env.render()

    def close(self) -> None:
        self._base_env.close()

    # ---------- 辅助方法 ----------
    def _scale_action(self, action: np.ndarray) -> np.ndarray:
        arr = np.asarray(action, dtype=np.float32).reshape(-1)
        if arr.size != 1:
            raise ValueError(f"Expected action of shape (1,), got {arr.shape}")
        clipped = np.clip(arr[0], self.action_space.low[0], self.action_space.high[0])

        # 线性映射 [-1, 1] -> [low, high]
        span = self._base_action_high - self._base_action_low
        scaled = (clipped + 1.0) * 0.5 * span + self._base_action_low
        return np.array(scaled, dtype=np.float32)

    def _convert_obs(self, obs: np.ndarray) -> np.ndarray:
        arr = np.asarray(obs, dtype=np.float32).reshape(-1)
        if arr.size < 3:
            raise ValueError(
                f"Expected Pendulum observation with 3 elements, got {arr}"
            )
        cos_th, sin_th, theta_dot = float(arr[0]), float(arr[1]), float(arr[2])
        theta = math.atan2(sin_th, cos_th)
        return np.array([0.0, theta, 0.0, theta_dot], dtype=np.float32)
