import logging
import uuid
from collections.abc import Callable
from typing import Iterable, Iterator, Optional

import grpc
import numpy as np
from google.protobuf import struct_pb2

from parnassus.protos.arm_stream import arm_stream_pb2 as pb2
from parnassus.protos.arm_stream import arm_stream_pb2_grpc as pb2_grpc
from parnassus.utils.tensor import numpy_to_tensor, tensor_to_numpy

logger = logging.getLogger(__name__)


class ArmEnvServicer(pb2_grpc.ArmEnvServicer):
    """ArmEnv gRPC服务实现"""

    def __init__(self, env_factory: Callable):
        """初始化服务

        Args:
            config_path: 配置文件路径，必须传入
            env_factory: 用于创建环境实例的方法
            config_cls: 环境配置类
        """
        self.env_factory = env_factory

    def StreamEnv(
        self,
        request_iterator: Iterable[pb2.EnvRequest],
        context: grpc.ServicerContext,
    ) -> Iterator[pb2.EnvReply]:
        """处理双向流请求"""
        session_id = uuid.uuid4().hex
        logger.info(f"[{session_id}] StreamEnv session started")
        env = None
        try:
            for request in request_iterator:
                if request.HasField("reset"):
                    env, reply = self._handle_reset(request.reset, env, session_id)
                    yield reply
                elif request.HasField("step"):
                    yield self._handle_step(request.step, env, session_id)
                elif request.HasField("close"):
                    env, reply = self._handle_close(env, session_id)
                    yield reply
                    break
                else:
                    logger.warning(f"[{session_id}] Unknown request type received")
        except Exception as exc:  # pragma: no cover - gRPC will surface the error
            logger.exception(f"[{session_id}] Error in StreamEnv")
            context.set_code(grpc.StatusCode.INTERNAL)
            context.set_details(f"Internal server error: {exc}")
        finally:
            if env is not None:
                try:
                    env.close()
                    logger.info(f"[{session_id}] Environment closed (cleanup)")
                except Exception:  # pragma: no cover - best effort cleanup
                    logger.exception(
                        f"[{session_id}] Error closing environment during cleanup"
                    )
            logger.info(f"[{session_id}] StreamEnv session ended")

    def _handle_reset(
        self,
        reset_request: pb2.Reset,
        env,
        session_id: str,
    ):
        """处理reset请求"""
        try:
            seed_list = list(reset_request.seed)
            if len(seed_list) > 1:
                raise ValueError(
                    "Vectorized resets are no longer supported. "
                    "Please open separate sessions for each environment."
                )

            if env is None:
                env = self.env_factory()
                logger.info(f"[{session_id}] Environment created with loaded config")

            if not seed_list:
                seed_arg: Optional[int] = None
            else:
                seed_arg = seed_list[0]

            obs, info = env.reset(seed=seed_arg)
            logger.info(
                f"[{session_id}] Environment reset with observation shape: {obs.shape}"
            )

            return (
                env,
                pb2.EnvReply(reset=pb2.ResetReply(observation=numpy_to_tensor(obs))),
            )
        except Exception as exc:
            logger.exception(f"[{session_id}] Error in reset")
            raise

    def _handle_step(
        self, step_request: pb2.Step, env, session_id: str
    ) -> pb2.EnvReply:
        """处理step请求"""
        try:
            if env is None:
                raise RuntimeError("Environment not initialized. Call reset first.")

            action = tensor_to_numpy(step_request.action)
            obs, reward, terminated, truncated, info = env.step(action)

            logger.debug(
                f"[{session_id}] Step executed: obs_shape={obs.shape}, reward={reward}, "
                f"terminated={terminated}, truncated={truncated}"
            )

            struct_info = struct_pb2.Struct()
            try:
                struct_info.update(
                    {
                        key: (value.tolist() if hasattr(value, "tolist") else value)
                        for key, value in dict(info).items()
                    }
                )
            except Exception:
                logger.debug(
                    f"[{session_id}] Failed to serialise info dict; sending empty Struct"
                )

            obs_tensor = numpy_to_tensor(np.asarray(obs, dtype=np.float32))
            reward_arr = np.asarray(reward, dtype=np.float32)
            if reward_arr.ndim == 0:
                reward_arr = reward_arr.reshape(1)

            terminated_arr = np.asarray(terminated, dtype=np.bool_)
            if terminated_arr.ndim == 0:
                terminated_arr = terminated_arr.reshape(1)

            truncated_arr = np.asarray(truncated, dtype=np.bool_)
            if truncated_arr.ndim == 0:
                truncated_arr = truncated_arr.reshape(1)

            return pb2.EnvReply(
                step=pb2.StepReply(
                    observation=obs_tensor,
                    reward=numpy_to_tensor(reward_arr),
                    terminated=numpy_to_tensor(terminated_arr),
                    truncated=numpy_to_tensor(truncated_arr),
                    info=struct_info,
                )
            )
        except Exception as exc:
            logger.exception(f"[{session_id}] Error in step")
            raise

    def _handle_close(self, env, session_id: str):
        """处理close请求"""
        try:
            if env is not None:
                env.close()
                logger.info(f"[{session_id}] Environment closed")

            return None, pb2.EnvReply(close=pb2.CloseReply())
        except Exception as exc:
            logger.exception(f"[{session_id}] Error in close")
            raise
