import logging
import signal
import sys
from concurrent import futures
from pathlib import Path
from enum import Enum

import grpc
import typer

from parnassus.envs.arm.config import GymEnvConfig
from parnassus.envs.arm.gym_env import GymEnv
from parnassus.envs import PendulumCompatEnv, PendulumEnvConfig
from parnassus.protos.arm_stream import arm_stream_pb2_grpc
from parnassus.servers.arm_stream.servicer import ArmEnvServicer

app = typer.Typer(help="gRPC server for environment streaming service")


class EnvType(str, Enum):
    ARM = "arm"
    PENDULUM = "pendulum"


@app.command()
def serve(
    port: int = typer.Option(50051, help="Port to listen on"),
    max_workers: int = typer.Option(10, help="Maximum number of worker threads"),
    log_level: str = typer.Option("INFO", help="Logging level"),
    env_type: EnvType = typer.Option(
        EnvType.ARM,
        "--env",
        "-e",
        case_sensitive=False,
        help="选择运行的环境类型：arm 或 pendulum",
    ),
    config_path: Path = typer.Argument(..., help="Path to the configuration file (YAML)"),
):
    """启动 gRPC 服务器"""
    logging.basicConfig(
        level=getattr(logging, log_level.upper()),
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers))

    env_registry = {
        EnvType.ARM: {
            "env_factory": GymEnv,
            "config_cls": GymEnvConfig,
            "name": "ArmEnv",
        },
        EnvType.PENDULUM: {
            "env_factory": PendulumCompatEnv,
            "config_cls": PendulumEnvConfig,
            "name": "PendulumEnv",
        },
    }

    selected_env = env_registry[env_type]

    # 添加服务
    servicer = ArmEnvServicer(
        config_path=config_path,
        env_factory=selected_env["env_factory"],
        config_cls=selected_env["config_cls"],
    )
    arm_stream_pb2_grpc.add_ArmEnvServicer_to_server(servicer, server)

    # 监听端口
    listen_addr = f"[::]:{port}"
    server.add_insecure_port(listen_addr)

    # 启动服务器
    server.start()
    logging.info(f"{selected_env['name']} gRPC server started on {listen_addr}")

    # 优雅关闭处理
    def signal_handler(signum, frame):
        logging.info("Received shutdown signal")
        server.stop(grace=5)
        sys.exit(0)

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    try:
        server.wait_for_termination()
    except KeyboardInterrupt:
        logging.info("Server interrupted")
    finally:
        server.stop(grace=5)

    logging.info(f"Starting {selected_env['name']} gRPC server with config:")
    logging.info(f"  Port: {port}")
    logging.info(f"  Max workers: {max_workers}")
    logging.info(f"  Log level: {log_level}")
    logging.info(f"  Config file: {config_path}")


def main():
    app()


if __name__ == "__main__":
    main()
